"""LLM 工具执行器 根据工具名调用现有数据层,返回 JSON 字符串供 LLM 使用。 """ import json import logging import math from app.db.error_logger import log_error logger = logging.getLogger(__name__) _chat_user_context: dict | None = None def set_chat_user_context(user: dict | None) -> None: global _chat_user_context _chat_user_context = user async def execute_tool(name: str, arguments: dict) -> str: """执行工具调用,返回 JSON 字符串""" try: if name == "get_strategy_board": return await _get_strategy_board() elif name == "get_market_temperature": return await _get_market_temperature() elif name == "get_hot_sectors": return await _get_hot_sectors(arguments.get("limit", 10)) elif name == "get_latest_recommendations": return await _get_latest_recommendations() elif name == "get_user_watchlist_snapshot": return await _get_user_watchlist_snapshot() elif name == "get_stock_kline": return await _get_stock_kline( arguments["ts_code"], arguments.get("days", 60) ) elif name == "get_stock_capital_flow": return await _get_stock_capital_flow( arguments["ts_code"], arguments.get("days", 10) ) elif name == "search_stock": return await _search_stock(arguments["keyword"]) elif name == "get_stock_technical_signal": return await _get_stock_technical_signal(arguments["ts_code"]) elif name == "get_sector_performance": return await _get_sector_performance(arguments["sector_name"]) elif name == "get_realtime_indices": return await _get_realtime_indices() else: return json.dumps({"error": f"未知工具: {name}"}, ensure_ascii=False) except Exception as e: logger.error(f"工具执行失败 {name}: {e}") await log_error( "llm_tool", f"工具执行失败 {name}: {e}", detail=f"arguments={json.dumps(arguments, ensure_ascii=False, default=str)}", ) return json.dumps({"error": str(e)}, ensure_ascii=False) def _clean_for_json(obj): """清理 float NaN/Inf 为 None""" if isinstance(obj, dict): return {k: _clean_for_json(v) for k, v in obj.items()} if isinstance(obj, list): return [_clean_for_json(v) for v in obj] if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)): return None return obj async def _get_strategy_board() -> str: from app.llm.strategy_board import build_strategy_board board = await build_strategy_board(include_llm=False) payload = { "trade_date": board.get("trade_date", ""), "market_regime": board.get("market_regime", ""), "risk_level": board.get("risk_level", ""), "action_bias": board.get("action_bias", ""), "position_suggestion": board.get("position_suggestion", ""), "summary": board.get("summary", ""), "recommended_mode": board.get("recommended_mode", ""), "watch_sectors": board.get("watch_sectors", [])[:5], "strategy_focus": board.get("strategy_focus", [])[:4], "avoid_rules": board.get("avoid_rules", [])[:4], "iteration_notes": board.get("iteration_notes", [])[:3], "metrics": board.get("metrics", {}), "generated_by": board.get("generated_by", "rules"), } return json.dumps(_clean_for_json(payload), ensure_ascii=False, default=str) async def _get_market_temperature() -> str: from app.engine.recommender import get_latest_recommendations result = await get_latest_recommendations() mt = result.get("market_temp") if not mt: return json.dumps({"error": "暂无市场温度数据"}, ensure_ascii=False) return json.dumps(mt.model_dump(), ensure_ascii=False, default=str) async def _get_hot_sectors(limit: int) -> str: from app.engine.recommender import get_latest_sectors sectors = await get_latest_sectors() data = [s.model_dump() for s in sectors[:limit]] return json.dumps(data, ensure_ascii=False, default=str) async def _get_latest_recommendations() -> str: from app.engine.recommender import get_latest_recommendations result = await get_latest_recommendations() recs = result.get("recommendations", []) data = [] for rec in recs: item = rec.model_dump(exclude={"created_at"}) item["llm_analysis"] = "" data.append(item) return json.dumps(data, ensure_ascii=False, default=str) async def _get_user_watchlist_snapshot() -> str: from sqlalchemy import text from app.db.database import get_db user_id = (_chat_user_context or {}).get("id") if not user_id: return json.dumps({"error": "当前会话缺少用户上下文"}, ensure_ascii=False) async with get_db() as db: rows = (await db.execute( text( "SELECT w.id, w.ts_code, w.name, w.note, w.watch_group, w.cost_price, w.updated_at, " "a.conclusion, a.advice, a.trigger_condition, a.risk_note, a.summary, " "a.analysis_mode, a.created_at AS analysis_created_at " "FROM user_watchlists w " "LEFT JOIN watchlist_analyses a ON a.id = (" " SELECT id FROM watchlist_analyses " " WHERE watchlist_id = w.id ORDER BY created_at DESC, id DESC LIMIT 1" ") " "WHERE w.user_id = :uid AND COALESCE(w.is_active, 1) = 1 " "ORDER BY CASE w.watch_group " " WHEN 'focus' THEN 1 " " WHEN 'candidate' THEN 2 " " WHEN 'holding' THEN 3 " " ELSE 4 END, w.updated_at DESC, w.id DESC" ), {"uid": user_id}, )).fetchall() items = [dict(row._mapping) for row in rows] grouped = {"focus": 0, "candidate": 0, "holding": 0, "observe": 0} for item in items: key = item.get("watch_group") or "observe" if key in grouped: grouped[key] += 1 actionable = [ item for item in items if item.get("conclusion") in {"可操作", "重点关注"} ][:8] payload = { "count": len(items), "group_counts": grouped, "high_priority": actionable, "items": items[:20], } return json.dumps(_clean_for_json(payload), ensure_ascii=False, default=str) async def _get_stock_kline(ts_code: str, days: int) -> str: from app.data.tushare_client import tushare_client from app.analysis.technical import add_all_indicators df = tushare_client.get_stock_daily(ts_code, days=days) if df.empty: return json.dumps({"error": f"未找到 {ts_code} 的K线数据"}, ensure_ascii=False) df = df.sort_values("trade_date").reset_index(drop=True) df = add_all_indicators(df) # 只返回最近 20 条以节省 token records = df.tail(20).to_dict(orient="records") records = _clean_for_json(records) return json.dumps(records, ensure_ascii=False, default=str) async def _get_stock_capital_flow(ts_code: str, days: int) -> str: from app.data.tushare_client import tushare_client df = tushare_client.get_stock_moneyflow(ts_code, days=days) if df.empty: return json.dumps({"error": f"未找到 {ts_code} 的资金流向数据"}, ensure_ascii=False) df = df.sort_values("trade_date") records = [] for _, row in df.iterrows(): main_net = ( (row.get("buy_elg_amount", 0) or 0) - (row.get("sell_elg_amount", 0) or 0) + (row.get("buy_lg_amount", 0) or 0) - (row.get("sell_lg_amount", 0) or 0) ) records.append({ "trade_date": row["trade_date"], "main_net_inflow": round(main_net, 2), "net_mf_amount": round(float(row.get("net_mf_amount", 0) or 0), 2), }) return json.dumps(records, ensure_ascii=False, default=str) async def _search_stock(keyword: str) -> str: from app.data.tushare_client import tushare_client basic = tushare_client.get_stock_basic() if basic.empty: return json.dumps([], ensure_ascii=False) matches = basic[ basic["name"].str.contains(keyword, na=False) | basic["ts_code"].str.contains(keyword, na=False) | basic["symbol"].str.contains(keyword, na=False) ].head(10) data = matches[["ts_code", "name", "industry"]].to_dict(orient="records") return json.dumps(data, ensure_ascii=False, default=str) async def _get_stock_technical_signal(ts_code: str) -> str: """获取个股技术面信号详情""" from app.analysis.signals import generate_signals signal = generate_signals(ts_code) data = _clean_for_json(signal.model_dump()) return json.dumps(data, ensure_ascii=False, default=str) async def _get_sector_performance(sector_name: str) -> str: """获取板块表现数据""" from app.engine.recommender import get_latest_sectors sectors = await get_latest_sectors() # 模糊匹配板块名称 matched = [s for s in sectors if sector_name in s.sector_name or s.sector_name in sector_name] if not matched: # 返回所有热门板块概览 data = [{"sector_name": s.sector_name, "pct_change": s.pct_change, "capital_inflow": s.capital_inflow, "limit_up_count": s.limit_up_count, "heat_score": s.heat_score, "stage": s.stage} for s in sectors[:10]] return json.dumps({"matched": False, "available_sectors": data}, ensure_ascii=False, default=str) data = _clean_for_json([s.model_dump() for s in matched]) return json.dumps({"matched": True, "sectors": data}, ensure_ascii=False, default=str) async def _get_realtime_indices() -> str: """获取指数实时行情数据(盘中用腾讯实时数据)""" from app.data.tencent_client import get_index_realtime from app.config import is_market_session is_trading = is_market_session() try: index_data = await get_index_realtime() if not index_data: return json.dumps({"error": "获取指数数据失败"}, ensure_ascii=False) # 格式化数据 results = [] for ts_code, data in index_data.items(): results.append({ "ts_code": ts_code, "name": data.get("name", ts_code), "price": round(data.get("price", 0), 2), "pct_chg": round(data.get("pct_chg", 0), 2), "volume": data.get("volume", 0), "amount": data.get("amount", 0), "high": round(data.get("high", 0), 2), "low": round(data.get("low", 0), 2), "open": round(data.get("open", 0), 2), "pre_close": round(data.get("pre_close", 0), 2), }) return json.dumps({ "is_realtime": is_trading, "mode": "盘中实时" if is_trading else "盘后收盘", "indices": results }, ensure_ascii=False, default=str) except Exception as e: logger.error(f"获取实时指数失败: {e}") await log_error("llm_tool", f"获取实时指数失败: {e}") return json.dumps({"error": f"获取指数数据失败: {e}"}, ensure_ascii=False)