284 lines
11 KiB
Python
284 lines
11 KiB
Python
"""LLM 工具执行器
|
|
|
|
根据工具名调用现有数据层,返回 JSON 字符串供 LLM 使用。
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import math
|
|
|
|
from app.db.error_logger import log_error
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_chat_user_context: dict | None = None
|
|
|
|
|
|
def set_chat_user_context(user: dict | None) -> None:
|
|
global _chat_user_context
|
|
_chat_user_context = user
|
|
|
|
|
|
async def execute_tool(name: str, arguments: dict) -> str:
|
|
"""执行工具调用,返回 JSON 字符串"""
|
|
try:
|
|
if name == "get_strategy_board":
|
|
return await _get_strategy_board()
|
|
elif name == "get_market_temperature":
|
|
return await _get_market_temperature()
|
|
elif name == "get_hot_sectors":
|
|
return await _get_hot_sectors(arguments.get("limit", 10))
|
|
elif name == "get_latest_recommendations":
|
|
return await _get_latest_recommendations()
|
|
elif name == "get_user_watchlist_snapshot":
|
|
return await _get_user_watchlist_snapshot()
|
|
elif name == "get_stock_kline":
|
|
return await _get_stock_kline(
|
|
arguments["ts_code"], arguments.get("days", 60)
|
|
)
|
|
elif name == "get_stock_capital_flow":
|
|
return await _get_stock_capital_flow(
|
|
arguments["ts_code"], arguments.get("days", 10)
|
|
)
|
|
elif name == "search_stock":
|
|
return await _search_stock(arguments["keyword"])
|
|
elif name == "get_stock_technical_signal":
|
|
return await _get_stock_technical_signal(arguments["ts_code"])
|
|
elif name == "get_sector_performance":
|
|
return await _get_sector_performance(arguments["sector_name"])
|
|
elif name == "get_realtime_indices":
|
|
return await _get_realtime_indices()
|
|
else:
|
|
return json.dumps({"error": f"未知工具: {name}"}, ensure_ascii=False)
|
|
except Exception as e:
|
|
logger.error(f"工具执行失败 {name}: {e}")
|
|
await log_error(
|
|
"llm_tool",
|
|
f"工具执行失败 {name}: {e}",
|
|
detail=f"arguments={json.dumps(arguments, ensure_ascii=False, default=str)}",
|
|
)
|
|
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
|
|
|
|
|
def _clean_for_json(obj):
|
|
"""清理 float NaN/Inf 为 None"""
|
|
if isinstance(obj, dict):
|
|
return {k: _clean_for_json(v) for k, v in obj.items()}
|
|
if isinstance(obj, list):
|
|
return [_clean_for_json(v) for v in obj]
|
|
if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
|
|
return None
|
|
return obj
|
|
|
|
|
|
async def _get_strategy_board() -> str:
|
|
from app.llm.strategy_board import build_strategy_board
|
|
|
|
board = await build_strategy_board(include_llm=False)
|
|
payload = {
|
|
"trade_date": board.get("trade_date", ""),
|
|
"market_regime": board.get("market_regime", ""),
|
|
"risk_level": board.get("risk_level", ""),
|
|
"action_bias": board.get("action_bias", ""),
|
|
"position_suggestion": board.get("position_suggestion", ""),
|
|
"summary": board.get("summary", ""),
|
|
"recommended_mode": board.get("recommended_mode", ""),
|
|
"watch_sectors": board.get("watch_sectors", [])[:5],
|
|
"strategy_focus": board.get("strategy_focus", [])[:4],
|
|
"avoid_rules": board.get("avoid_rules", [])[:4],
|
|
"iteration_notes": board.get("iteration_notes", [])[:3],
|
|
"metrics": board.get("metrics", {}),
|
|
"generated_by": board.get("generated_by", "rules"),
|
|
}
|
|
return json.dumps(_clean_for_json(payload), ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _get_market_temperature() -> str:
|
|
from app.engine.recommender import get_latest_recommendations
|
|
result = await get_latest_recommendations()
|
|
mt = result.get("market_temp")
|
|
if not mt:
|
|
return json.dumps({"error": "暂无市场温度数据"}, ensure_ascii=False)
|
|
return json.dumps(mt.model_dump(), ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _get_hot_sectors(limit: int) -> str:
|
|
from app.engine.recommender import get_latest_sectors
|
|
sectors = await get_latest_sectors()
|
|
data = [s.model_dump() for s in sectors[:limit]]
|
|
return json.dumps(data, ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _get_latest_recommendations() -> str:
|
|
from app.engine.recommender import get_latest_recommendations
|
|
result = await get_latest_recommendations()
|
|
recs = result.get("recommendations", [])
|
|
data = []
|
|
for rec in recs:
|
|
item = rec.model_dump(exclude={"created_at"})
|
|
item["llm_analysis"] = ""
|
|
data.append(item)
|
|
return json.dumps(data, ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _get_user_watchlist_snapshot() -> str:
|
|
from sqlalchemy import text
|
|
from app.db.database import get_db
|
|
|
|
user_id = (_chat_user_context or {}).get("id")
|
|
if not user_id:
|
|
return json.dumps({"error": "当前会话缺少用户上下文"}, ensure_ascii=False)
|
|
|
|
async with get_db() as db:
|
|
rows = (await db.execute(
|
|
text(
|
|
"SELECT w.id, w.ts_code, w.name, w.note, w.watch_group, w.cost_price, w.updated_at, "
|
|
"a.conclusion, a.advice, a.trigger_condition, a.risk_note, a.summary, "
|
|
"a.analysis_mode, a.created_at AS analysis_created_at "
|
|
"FROM user_watchlists w "
|
|
"LEFT JOIN watchlist_analyses a ON a.id = ("
|
|
" SELECT id FROM watchlist_analyses "
|
|
" WHERE watchlist_id = w.id ORDER BY created_at DESC, id DESC LIMIT 1"
|
|
") "
|
|
"WHERE w.user_id = :uid AND COALESCE(w.is_active, 1) = 1 "
|
|
"ORDER BY CASE w.watch_group "
|
|
" WHEN 'focus' THEN 1 "
|
|
" WHEN 'candidate' THEN 2 "
|
|
" WHEN 'holding' THEN 3 "
|
|
" ELSE 4 END, w.updated_at DESC, w.id DESC"
|
|
),
|
|
{"uid": user_id},
|
|
)).fetchall()
|
|
|
|
items = [dict(row._mapping) for row in rows]
|
|
grouped = {"focus": 0, "candidate": 0, "holding": 0, "observe": 0}
|
|
for item in items:
|
|
key = item.get("watch_group") or "observe"
|
|
if key in grouped:
|
|
grouped[key] += 1
|
|
|
|
actionable = [
|
|
item for item in items
|
|
if item.get("conclusion") in {"可操作", "重点关注"}
|
|
][:8]
|
|
|
|
payload = {
|
|
"count": len(items),
|
|
"group_counts": grouped,
|
|
"high_priority": actionable,
|
|
"items": items[:20],
|
|
}
|
|
return json.dumps(_clean_for_json(payload), ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _get_stock_kline(ts_code: str, days: int) -> str:
|
|
from app.data.tushare_client import tushare_client
|
|
from app.analysis.technical import add_all_indicators
|
|
df = tushare_client.get_stock_daily(ts_code, days=days)
|
|
if df.empty:
|
|
return json.dumps({"error": f"未找到 {ts_code} 的K线数据"}, ensure_ascii=False)
|
|
df = df.sort_values("trade_date").reset_index(drop=True)
|
|
df = add_all_indicators(df)
|
|
# 只返回最近 20 条以节省 token
|
|
records = df.tail(20).to_dict(orient="records")
|
|
records = _clean_for_json(records)
|
|
return json.dumps(records, ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _get_stock_capital_flow(ts_code: str, days: int) -> str:
|
|
from app.data.tushare_client import tushare_client
|
|
df = tushare_client.get_stock_moneyflow(ts_code, days=days)
|
|
if df.empty:
|
|
return json.dumps({"error": f"未找到 {ts_code} 的资金流向数据"}, ensure_ascii=False)
|
|
df = df.sort_values("trade_date")
|
|
records = []
|
|
for _, row in df.iterrows():
|
|
main_net = (
|
|
(row.get("buy_elg_amount", 0) or 0) - (row.get("sell_elg_amount", 0) or 0) +
|
|
(row.get("buy_lg_amount", 0) or 0) - (row.get("sell_lg_amount", 0) or 0)
|
|
)
|
|
records.append({
|
|
"trade_date": row["trade_date"],
|
|
"main_net_inflow": round(main_net, 2),
|
|
"net_mf_amount": round(float(row.get("net_mf_amount", 0) or 0), 2),
|
|
})
|
|
return json.dumps(records, ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _search_stock(keyword: str) -> str:
|
|
from app.data.tushare_client import tushare_client
|
|
basic = tushare_client.get_stock_basic()
|
|
if basic.empty:
|
|
return json.dumps([], ensure_ascii=False)
|
|
matches = basic[
|
|
basic["name"].str.contains(keyword, na=False) |
|
|
basic["ts_code"].str.contains(keyword, na=False) |
|
|
basic["symbol"].str.contains(keyword, na=False)
|
|
].head(10)
|
|
data = matches[["ts_code", "name", "industry"]].to_dict(orient="records")
|
|
return json.dumps(data, ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _get_stock_technical_signal(ts_code: str) -> str:
|
|
"""获取个股技术面信号详情"""
|
|
from app.analysis.signals import generate_signals
|
|
signal = generate_signals(ts_code)
|
|
data = _clean_for_json(signal.model_dump())
|
|
return json.dumps(data, ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _get_sector_performance(sector_name: str) -> str:
|
|
"""获取板块表现数据"""
|
|
from app.engine.recommender import get_latest_sectors
|
|
sectors = await get_latest_sectors()
|
|
# 模糊匹配板块名称
|
|
matched = [s for s in sectors if sector_name in s.sector_name or s.sector_name in sector_name]
|
|
if not matched:
|
|
# 返回所有热门板块概览
|
|
data = [{"sector_name": s.sector_name, "pct_change": s.pct_change,
|
|
"capital_inflow": s.capital_inflow, "limit_up_count": s.limit_up_count,
|
|
"heat_score": s.heat_score, "stage": s.stage}
|
|
for s in sectors[:10]]
|
|
return json.dumps({"matched": False, "available_sectors": data}, ensure_ascii=False, default=str)
|
|
data = _clean_for_json([s.model_dump() for s in matched])
|
|
return json.dumps({"matched": True, "sectors": data}, ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _get_realtime_indices() -> str:
|
|
"""获取指数实时行情数据(盘中用腾讯实时数据)"""
|
|
from app.data.tencent_client import get_index_realtime
|
|
from app.config import is_market_session
|
|
|
|
is_trading = is_market_session()
|
|
|
|
try:
|
|
index_data = await get_index_realtime()
|
|
if not index_data:
|
|
return json.dumps({"error": "获取指数数据失败"}, ensure_ascii=False)
|
|
|
|
# 格式化数据
|
|
results = []
|
|
for ts_code, data in index_data.items():
|
|
results.append({
|
|
"ts_code": ts_code,
|
|
"name": data.get("name", ts_code),
|
|
"price": round(data.get("price", 0), 2),
|
|
"pct_chg": round(data.get("pct_chg", 0), 2),
|
|
"volume": data.get("volume", 0),
|
|
"amount": data.get("amount", 0),
|
|
"high": round(data.get("high", 0), 2),
|
|
"low": round(data.get("low", 0), 2),
|
|
"open": round(data.get("open", 0), 2),
|
|
"pre_close": round(data.get("pre_close", 0), 2),
|
|
})
|
|
|
|
return json.dumps({
|
|
"is_realtime": is_trading,
|
|
"mode": "盘中实时" if is_trading else "盘后收盘",
|
|
"indices": results
|
|
}, ensure_ascii=False, default=str)
|
|
except Exception as e:
|
|
logger.error(f"获取实时指数失败: {e}")
|
|
await log_error("llm_tool", f"获取实时指数失败: {e}")
|
|
return json.dumps({"error": f"获取指数数据失败: {e}"}, ensure_ascii=False)
|