"""LLM 工具执行器 根据工具名调用现有数据层,返回 JSON 字符串供 LLM 使用。 """ import json import logging import math from app.db.error_logger import log_error logger = logging.getLogger(__name__) _chat_user_context: dict | None = None def set_chat_user_context(user: dict | None) -> None: global _chat_user_context _chat_user_context = user async def execute_tool(name: str, arguments: dict) -> str: """执行工具调用,返回 JSON 字符串""" try: if name == "get_strategy_board": return await _get_strategy_board() elif name == "get_market_temperature": return await _get_market_temperature() elif name == "get_hot_sectors": return await _get_hot_sectors(arguments.get("limit", 10)) elif name == "get_latest_recommendations": return await _get_latest_recommendations() elif name == "get_user_watchlist_snapshot": return await _get_user_watchlist_snapshot() elif name == "get_stock_kline": return await _get_stock_kline( arguments["ts_code"], arguments.get("days", 60) ) elif name == "get_stock_capital_flow": return await _get_stock_capital_flow( arguments["ts_code"], arguments.get("days", 10) ) elif name == "search_stock": return await _search_stock(arguments["keyword"]) elif name == "get_stock_technical_signal": return await _get_stock_technical_signal(arguments["ts_code"]) elif name == "diagnose_stock": return await _diagnose_stock(arguments["ts_code"], arguments.get("mode", "entry")) elif name == "get_sector_performance": return await _get_sector_performance(arguments["sector_name"]) elif name == "get_realtime_indices": return await _get_realtime_indices() else: return json.dumps({"error": f"未知工具: {name}"}, ensure_ascii=False) except Exception as e: logger.error(f"工具执行失败 {name}: {e}") await log_error( "llm_tool", f"工具执行失败 {name}: {e}", detail=f"arguments={json.dumps(arguments, ensure_ascii=False, default=str)}", ) return json.dumps({"error": str(e)}, ensure_ascii=False) def _clean_for_json(obj): """清理 float NaN/Inf 为 None""" if isinstance(obj, dict): return {k: _clean_for_json(v) for k, v in obj.items()} if isinstance(obj, list): return [_clean_for_json(v) for v in obj] if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)): return None return obj async def _get_strategy_board() -> str: """返回当日市场概况 + 策略状态,替代已删除的 strategy_board 模块。""" from app.engine.recommender import get_latest_recommendations result = await get_latest_recommendations() mt = result.get("market_temp") recs = result.get("recommendations", []) strategy = result.get("strategy_profile") or {} actionable = [r for r in recs if r.action_plan == "可操作"] watch = [r for r in recs if r.action_plan == "重点关注"] payload = { "market_temperature": mt.temperature if mt else 0, "up_count": mt.up_count if mt else 0, "down_count": mt.down_count if mt else 0, "limit_up_count": mt.limit_up_count if mt else 0, "strategy_name": strategy.get("name", "未知"), "market_stance": strategy.get("market_stance", ""), "decision_note": strategy.get("decision_note", ""), "actionable_count": len(actionable), "watch_count": len(watch), "total_recommendations": len(recs), } return json.dumps(_clean_for_json(payload), ensure_ascii=False, default=str) async def _get_market_temperature() -> str: from app.engine.recommender import get_latest_recommendations result = await get_latest_recommendations() mt = result.get("market_temp") if not mt: return json.dumps({"error": "暂无市场温度数据"}, ensure_ascii=False) return json.dumps(mt.model_dump(), ensure_ascii=False, default=str) async def _get_hot_sectors(limit: int) -> str: from app.engine.recommender import get_latest_sectors sectors = await get_latest_sectors() data = [s.model_dump() for s in sectors[:limit]] return json.dumps(data, ensure_ascii=False, default=str) async def _get_latest_recommendations() -> str: from app.engine.recommender import get_latest_recommendations result = await get_latest_recommendations() recs = result.get("recommendations", []) data = [] for rec in recs: item = rec.model_dump(exclude={"created_at"}) item["llm_analysis"] = "" data.append(item) return json.dumps(data, ensure_ascii=False, default=str) async def _get_user_watchlist_snapshot() -> str: from sqlalchemy import text from app.db.database import get_db user_id = (_chat_user_context or {}).get("id") if not user_id: return json.dumps({"error": "当前会话缺少用户上下文"}, ensure_ascii=False) async with get_db() as db: rows = (await db.execute( text( "SELECT w.id, w.ts_code, w.name, w.note, w.watch_group, w.cost_price, w.updated_at, " "a.conclusion, a.advice, a.trigger_condition, a.risk_note, a.summary, " "a.analysis_mode, a.created_at AS analysis_created_at " "FROM user_watchlists w " "LEFT JOIN watchlist_analyses a ON a.id = (" " SELECT id FROM watchlist_analyses " " WHERE watchlist_id = w.id ORDER BY created_at DESC, id DESC LIMIT 1" ") " "WHERE w.user_id = :uid AND COALESCE(w.is_active, 1) = 1 " "ORDER BY CASE w.watch_group " " WHEN 'focus' THEN 1 " " WHEN 'candidate' THEN 2 " " WHEN 'holding' THEN 3 " " ELSE 4 END, w.updated_at DESC, w.id DESC" ), {"uid": user_id}, )).fetchall() items = [dict(row._mapping) for row in rows] grouped = {"focus": 0, "candidate": 0, "holding": 0, "observe": 0} for item in items: key = item.get("watch_group") or "observe" if key in grouped: grouped[key] += 1 actionable = [ item for item in items if item.get("conclusion") in {"可操作", "重点关注"} ][:8] payload = { "count": len(items), "group_counts": grouped, "high_priority": actionable, "items": items[:20], } return json.dumps(_clean_for_json(payload), ensure_ascii=False, default=str) async def _get_stock_kline(ts_code: str, days: int) -> str: from app.data.tushare_client import tushare_client from app.analysis.technical import add_all_indicators df = tushare_client.get_stock_daily(ts_code, days=days) if df.empty: return json.dumps({"error": f"未找到 {ts_code} 的K线数据"}, ensure_ascii=False) df = df.sort_values("trade_date").reset_index(drop=True) df = add_all_indicators(df) # 只返回最近 20 条以节省 token records = df.tail(20).to_dict(orient="records") records = _clean_for_json(records) return json.dumps(records, ensure_ascii=False, default=str) async def _get_stock_capital_flow(ts_code: str, days: int) -> str: from app.data.tushare_client import tushare_client df = tushare_client.get_stock_moneyflow(ts_code, days=days) if df.empty: return json.dumps({"error": f"未找到 {ts_code} 的资金流向数据"}, ensure_ascii=False) df = df.sort_values("trade_date") records = [] for _, row in df.iterrows(): main_net = ( (row.get("buy_elg_amount", 0) or 0) - (row.get("sell_elg_amount", 0) or 0) + (row.get("buy_lg_amount", 0) or 0) - (row.get("sell_lg_amount", 0) or 0) ) records.append({ "trade_date": row["trade_date"], "main_net_inflow": round(main_net, 2), "net_mf_amount": round(float(row.get("net_mf_amount", 0) or 0), 2), }) return json.dumps(records, ensure_ascii=False, default=str) async def _search_stock(keyword: str) -> str: from app.data.tushare_client import tushare_client basic = tushare_client.get_stock_basic() if basic.empty: return json.dumps([], ensure_ascii=False) matches = basic[ basic["name"].str.contains(keyword, na=False) | basic["ts_code"].str.contains(keyword, na=False) | basic["symbol"].str.contains(keyword, na=False) ].head(10) data = matches[["ts_code", "name", "industry"]].to_dict(orient="records") return json.dumps(data, ensure_ascii=False, default=str) async def _get_stock_technical_signal(ts_code: str) -> str: """获取个股技术面信号详情""" from app.analysis.signals import generate_signals signal = generate_signals(ts_code) data = _clean_for_json(signal.model_dump()) return json.dumps(data, ensure_ascii=False, default=str) async def _diagnose_stock(ts_code: str, mode: str = "entry") -> str: """生成系统化个股会诊,供作战问答智能体调用。""" from sqlalchemy import text from app.db.database import get_db from app.db import tables from app.llm.client import chat_completion mode_map = { "entry": "建仓前诊断", "holding": "持仓复核", "review": "回撤复盘", "tracking": "继续跟踪", } mode = mode if mode in mode_map else "entry" strategy_board = await _get_strategy_board() latest_recommendations = await _get_latest_recommendations() hot_sectors = await _get_hot_sectors(8) kline = await _get_stock_kline(ts_code, 80) capital_flow = await _get_stock_capital_flow(ts_code, 15) technical_signal = await _get_stock_technical_signal(ts_code) latest_rec = None stock_name = ts_code try: recs = json.loads(latest_recommendations) latest_rec = next((item for item in recs if item.get("ts_code") == ts_code), None) except Exception: latest_rec = None if latest_rec and latest_rec.get("name"): stock_name = latest_rec["name"] recent_diagnoses = [] try: async with get_db() as db: rows = (await db.execute( text( "SELECT name, diagnosis_mode, diagnosis, created_at " "FROM stock_diagnoses WHERE ts_code = :ts_code " "ORDER BY created_at DESC, id DESC LIMIT 3" ), {"ts_code": ts_code}, )).fetchall() recent_diagnoses = [dict(row._mapping) for row in rows] if recent_diagnoses and recent_diagnoses[0].get("name"): stock_name = recent_diagnoses[0]["name"] except Exception: recent_diagnoses = [] prompt = f"""请在 A 股作战台语境下,对 {ts_code} 做一次系统化个股会诊。 诊断模式: {mode_map[mode]} 今日作战结论: {strategy_board} 最新推荐池中该股记录: {json.dumps(latest_rec, ensure_ascii=False, default=str) if latest_rec else "不在最新推荐池"} 热门板块: {hot_sectors} 资金流: {capital_flow} K线与价格行为: {kline} 技术信号: {technical_signal} 最近诊断: {json.dumps(recent_diagnoses, ensure_ascii=False, default=str)} 输出要求: - 先给明确结论,只能是「可操作 / 重点关注 / 观察 / 回避」 - 明确当前动作、触发条件、失效条件、仓位边界、下一步观察点 - 分析顺序必须是:资金与主线 > 量价与价格行为 > 位置与边界 > 技术指标备注 - A 股优先看资金顺势、主线板块、量价承接、价格行为和位置;技术指标只做节奏与风控确认 - RSI、MACD、KDJ 的超买超卖不能单独决定买卖,也不能放在核心依据第一位 - 输出必须包含以下小节:当前结论、资金与主线、量价与价格行为、位置与边界、执行动作、风险清单、技术指标备注 - 不写传统研报,不堆原始数据,不承诺收益 - 用 Markdown 输出,保持简洁""" resp = await chat_completion([ { "role": "system", "content": ( "你是 A 股投研作战台的个股会诊智能体。" "你必须优先分析资金流向、主线板块、量价关系、价格行为和位置边界," "技术指标只能作为最后的节奏与风控备注。" "输出可执行但带风险边界的会诊结论。" ), }, {"role": "user", "content": prompt}, ]) if not resp or not resp.content: return json.dumps({"error": "个股会诊生成失败,LLM 未返回内容"}, ensure_ascii=False) diagnosis = resp.content.strip() try: async with get_db() as db: await db.execute( tables.stock_diagnoses_table.insert().values( ts_code=ts_code, name=stock_name, diagnosis_mode=mode, diagnosis=diagnosis, ) ) await db.commit() except Exception as e: logger.warning(f"保存聊天会诊结果失败 {ts_code}: {e}") return json.dumps({ "ts_code": ts_code, "name": stock_name, "mode": mode, "diagnosis": diagnosis, "saved": True, }, ensure_ascii=False, default=str) async def _get_sector_performance(sector_name: str) -> str: """获取板块表现数据""" from app.engine.recommender import get_latest_sectors sectors = await get_latest_sectors() # 模糊匹配板块名称 matched = [s for s in sectors if sector_name in s.sector_name or s.sector_name in sector_name] if not matched: # 返回所有热门板块概览 data = [{"sector_name": s.sector_name, "pct_change": s.pct_change, "capital_inflow": s.capital_inflow, "limit_up_count": s.limit_up_count, "heat_score": s.heat_score, "stage": s.stage} for s in sectors[:10]] return json.dumps({"matched": False, "available_sectors": data}, ensure_ascii=False, default=str) data = _clean_for_json([s.model_dump() for s in matched]) return json.dumps({"matched": True, "sectors": data}, ensure_ascii=False, default=str) async def _get_realtime_indices() -> str: """获取指数实时行情数据(盘中用腾讯实时数据)""" from app.data.tencent_client import get_index_realtime from app.config import is_market_session is_trading = is_market_session() try: index_data = await get_index_realtime() if not index_data: return json.dumps({"error": "获取指数数据失败"}, ensure_ascii=False) # 格式化数据 results = [] for ts_code, data in index_data.items(): results.append({ "ts_code": ts_code, "name": data.get("name", ts_code), "price": round(data.get("price", 0), 2), "pct_chg": round(data.get("pct_chg", 0), 2), "volume": data.get("volume", 0), "amount": data.get("amount", 0), "high": round(data.get("high", 0), 2), "low": round(data.get("low", 0), 2), "open": round(data.get("open", 0), 2), "pre_close": round(data.get("pre_close", 0), 2), }) return json.dumps({ "is_realtime": is_trading, "mode": "盘中实时" if is_trading else "盘后收盘", "indices": results }, ensure_ascii=False, default=str) except Exception as e: logger.error(f"获取实时指数失败: {e}") await log_error("llm_tool", f"获取实时指数失败: {e}") return json.dumps({"error": f"获取指数数据失败: {e}"}, ensure_ascii=False)