409 lines
16 KiB
Python
409 lines
16 KiB
Python
"""LLM 工具执行器
|
||
|
||
根据工具名调用现有数据层,返回 JSON 字符串供 LLM 使用。
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
import math
|
||
|
||
from app.db.error_logger import log_error
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_chat_user_context: dict | None = None
|
||
|
||
|
||
def set_chat_user_context(user: dict | None) -> None:
|
||
global _chat_user_context
|
||
_chat_user_context = user
|
||
|
||
|
||
async def execute_tool(name: str, arguments: dict) -> str:
|
||
"""执行工具调用,返回 JSON 字符串"""
|
||
try:
|
||
if name == "get_strategy_board":
|
||
return await _get_strategy_board()
|
||
elif name == "get_market_temperature":
|
||
return await _get_market_temperature()
|
||
elif name == "get_hot_sectors":
|
||
return await _get_hot_sectors(arguments.get("limit", 10))
|
||
elif name == "get_latest_recommendations":
|
||
return await _get_latest_recommendations()
|
||
elif name == "get_user_watchlist_snapshot":
|
||
return await _get_user_watchlist_snapshot()
|
||
elif name == "get_stock_kline":
|
||
return await _get_stock_kline(
|
||
arguments["ts_code"], arguments.get("days", 60)
|
||
)
|
||
elif name == "get_stock_capital_flow":
|
||
return await _get_stock_capital_flow(
|
||
arguments["ts_code"], arguments.get("days", 10)
|
||
)
|
||
elif name == "search_stock":
|
||
return await _search_stock(arguments["keyword"])
|
||
elif name == "get_stock_technical_signal":
|
||
return await _get_stock_technical_signal(arguments["ts_code"])
|
||
elif name == "diagnose_stock":
|
||
return await _diagnose_stock(arguments["ts_code"], arguments.get("mode", "entry"))
|
||
elif name == "get_sector_performance":
|
||
return await _get_sector_performance(arguments["sector_name"])
|
||
elif name == "get_realtime_indices":
|
||
return await _get_realtime_indices()
|
||
else:
|
||
return json.dumps({"error": f"未知工具: {name}"}, ensure_ascii=False)
|
||
except Exception as e:
|
||
logger.error(f"工具执行失败 {name}: {e}")
|
||
await log_error(
|
||
"llm_tool",
|
||
f"工具执行失败 {name}: {e}",
|
||
detail=f"arguments={json.dumps(arguments, ensure_ascii=False, default=str)}",
|
||
)
|
||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||
|
||
|
||
def _clean_for_json(obj):
|
||
"""清理 float NaN/Inf 为 None"""
|
||
if isinstance(obj, dict):
|
||
return {k: _clean_for_json(v) for k, v in obj.items()}
|
||
if isinstance(obj, list):
|
||
return [_clean_for_json(v) for v in obj]
|
||
if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
|
||
return None
|
||
return obj
|
||
|
||
|
||
async def _get_strategy_board() -> str:
|
||
from app.llm.strategy_board import build_strategy_board
|
||
|
||
board = await build_strategy_board(include_llm=False)
|
||
payload = {
|
||
"trade_date": board.get("trade_date", ""),
|
||
"market_regime": board.get("market_regime", ""),
|
||
"risk_level": board.get("risk_level", ""),
|
||
"action_bias": board.get("action_bias", ""),
|
||
"position_suggestion": board.get("position_suggestion", ""),
|
||
"summary": board.get("summary", ""),
|
||
"recommended_mode": board.get("recommended_mode", ""),
|
||
"watch_sectors": board.get("watch_sectors", [])[:5],
|
||
"strategy_focus": board.get("strategy_focus", [])[:4],
|
||
"avoid_rules": board.get("avoid_rules", [])[:4],
|
||
"iteration_notes": board.get("iteration_notes", [])[:3],
|
||
"metrics": board.get("metrics", {}),
|
||
"generated_by": board.get("generated_by", "rules"),
|
||
}
|
||
return json.dumps(_clean_for_json(payload), ensure_ascii=False, default=str)
|
||
|
||
|
||
async def _get_market_temperature() -> str:
|
||
from app.engine.recommender import get_latest_recommendations
|
||
result = await get_latest_recommendations()
|
||
mt = result.get("market_temp")
|
||
if not mt:
|
||
return json.dumps({"error": "暂无市场温度数据"}, ensure_ascii=False)
|
||
return json.dumps(mt.model_dump(), ensure_ascii=False, default=str)
|
||
|
||
|
||
async def _get_hot_sectors(limit: int) -> str:
|
||
from app.engine.recommender import get_latest_sectors
|
||
sectors = await get_latest_sectors()
|
||
data = [s.model_dump() for s in sectors[:limit]]
|
||
return json.dumps(data, ensure_ascii=False, default=str)
|
||
|
||
|
||
async def _get_latest_recommendations() -> str:
|
||
from app.engine.recommender import get_latest_recommendations
|
||
result = await get_latest_recommendations()
|
||
recs = result.get("recommendations", [])
|
||
data = []
|
||
for rec in recs:
|
||
item = rec.model_dump(exclude={"created_at"})
|
||
item["llm_analysis"] = ""
|
||
data.append(item)
|
||
return json.dumps(data, ensure_ascii=False, default=str)
|
||
|
||
|
||
async def _get_user_watchlist_snapshot() -> str:
|
||
from sqlalchemy import text
|
||
from app.db.database import get_db
|
||
|
||
user_id = (_chat_user_context or {}).get("id")
|
||
if not user_id:
|
||
return json.dumps({"error": "当前会话缺少用户上下文"}, ensure_ascii=False)
|
||
|
||
async with get_db() as db:
|
||
rows = (await db.execute(
|
||
text(
|
||
"SELECT w.id, w.ts_code, w.name, w.note, w.watch_group, w.cost_price, w.updated_at, "
|
||
"a.conclusion, a.advice, a.trigger_condition, a.risk_note, a.summary, "
|
||
"a.analysis_mode, a.created_at AS analysis_created_at "
|
||
"FROM user_watchlists w "
|
||
"LEFT JOIN watchlist_analyses a ON a.id = ("
|
||
" SELECT id FROM watchlist_analyses "
|
||
" WHERE watchlist_id = w.id ORDER BY created_at DESC, id DESC LIMIT 1"
|
||
") "
|
||
"WHERE w.user_id = :uid AND COALESCE(w.is_active, 1) = 1 "
|
||
"ORDER BY CASE w.watch_group "
|
||
" WHEN 'focus' THEN 1 "
|
||
" WHEN 'candidate' THEN 2 "
|
||
" WHEN 'holding' THEN 3 "
|
||
" ELSE 4 END, w.updated_at DESC, w.id DESC"
|
||
),
|
||
{"uid": user_id},
|
||
)).fetchall()
|
||
|
||
items = [dict(row._mapping) for row in rows]
|
||
grouped = {"focus": 0, "candidate": 0, "holding": 0, "observe": 0}
|
||
for item in items:
|
||
key = item.get("watch_group") or "observe"
|
||
if key in grouped:
|
||
grouped[key] += 1
|
||
|
||
actionable = [
|
||
item for item in items
|
||
if item.get("conclusion") in {"可操作", "重点关注"}
|
||
][:8]
|
||
|
||
payload = {
|
||
"count": len(items),
|
||
"group_counts": grouped,
|
||
"high_priority": actionable,
|
||
"items": items[:20],
|
||
}
|
||
return json.dumps(_clean_for_json(payload), ensure_ascii=False, default=str)
|
||
|
||
|
||
async def _get_stock_kline(ts_code: str, days: int) -> str:
|
||
from app.data.tushare_client import tushare_client
|
||
from app.analysis.technical import add_all_indicators
|
||
df = tushare_client.get_stock_daily(ts_code, days=days)
|
||
if df.empty:
|
||
return json.dumps({"error": f"未找到 {ts_code} 的K线数据"}, ensure_ascii=False)
|
||
df = df.sort_values("trade_date").reset_index(drop=True)
|
||
df = add_all_indicators(df)
|
||
# 只返回最近 20 条以节省 token
|
||
records = df.tail(20).to_dict(orient="records")
|
||
records = _clean_for_json(records)
|
||
return json.dumps(records, ensure_ascii=False, default=str)
|
||
|
||
|
||
async def _get_stock_capital_flow(ts_code: str, days: int) -> str:
|
||
from app.data.tushare_client import tushare_client
|
||
df = tushare_client.get_stock_moneyflow(ts_code, days=days)
|
||
if df.empty:
|
||
return json.dumps({"error": f"未找到 {ts_code} 的资金流向数据"}, ensure_ascii=False)
|
||
df = df.sort_values("trade_date")
|
||
records = []
|
||
for _, row in df.iterrows():
|
||
main_net = (
|
||
(row.get("buy_elg_amount", 0) or 0) - (row.get("sell_elg_amount", 0) or 0) +
|
||
(row.get("buy_lg_amount", 0) or 0) - (row.get("sell_lg_amount", 0) or 0)
|
||
)
|
||
records.append({
|
||
"trade_date": row["trade_date"],
|
||
"main_net_inflow": round(main_net, 2),
|
||
"net_mf_amount": round(float(row.get("net_mf_amount", 0) or 0), 2),
|
||
})
|
||
return json.dumps(records, ensure_ascii=False, default=str)
|
||
|
||
|
||
async def _search_stock(keyword: str) -> str:
|
||
from app.data.tushare_client import tushare_client
|
||
basic = tushare_client.get_stock_basic()
|
||
if basic.empty:
|
||
return json.dumps([], ensure_ascii=False)
|
||
matches = basic[
|
||
basic["name"].str.contains(keyword, na=False) |
|
||
basic["ts_code"].str.contains(keyword, na=False) |
|
||
basic["symbol"].str.contains(keyword, na=False)
|
||
].head(10)
|
||
data = matches[["ts_code", "name", "industry"]].to_dict(orient="records")
|
||
return json.dumps(data, ensure_ascii=False, default=str)
|
||
|
||
|
||
async def _get_stock_technical_signal(ts_code: str) -> str:
|
||
"""获取个股技术面信号详情"""
|
||
from app.analysis.signals import generate_signals
|
||
signal = generate_signals(ts_code)
|
||
data = _clean_for_json(signal.model_dump())
|
||
return json.dumps(data, ensure_ascii=False, default=str)
|
||
|
||
|
||
async def _diagnose_stock(ts_code: str, mode: str = "entry") -> str:
|
||
"""生成系统化个股会诊,供作战问答智能体调用。"""
|
||
from sqlalchemy import text
|
||
from app.db.database import get_db
|
||
from app.db import tables
|
||
from app.llm.client import chat_completion
|
||
|
||
mode_map = {
|
||
"entry": "建仓前诊断",
|
||
"holding": "持仓复核",
|
||
"review": "回撤复盘",
|
||
"tracking": "继续跟踪",
|
||
}
|
||
mode = mode if mode in mode_map else "entry"
|
||
|
||
strategy_board = await _get_strategy_board()
|
||
latest_recommendations = await _get_latest_recommendations()
|
||
hot_sectors = await _get_hot_sectors(8)
|
||
kline = await _get_stock_kline(ts_code, 80)
|
||
capital_flow = await _get_stock_capital_flow(ts_code, 15)
|
||
technical_signal = await _get_stock_technical_signal(ts_code)
|
||
|
||
latest_rec = None
|
||
stock_name = ts_code
|
||
try:
|
||
recs = json.loads(latest_recommendations)
|
||
latest_rec = next((item for item in recs if item.get("ts_code") == ts_code), None)
|
||
except Exception:
|
||
latest_rec = None
|
||
if latest_rec and latest_rec.get("name"):
|
||
stock_name = latest_rec["name"]
|
||
|
||
recent_diagnoses = []
|
||
try:
|
||
async with get_db() as db:
|
||
rows = (await db.execute(
|
||
text(
|
||
"SELECT name, diagnosis_mode, diagnosis, created_at "
|
||
"FROM stock_diagnoses WHERE ts_code = :ts_code "
|
||
"ORDER BY created_at DESC, id DESC LIMIT 3"
|
||
),
|
||
{"ts_code": ts_code},
|
||
)).fetchall()
|
||
recent_diagnoses = [dict(row._mapping) for row in rows]
|
||
if recent_diagnoses and recent_diagnoses[0].get("name"):
|
||
stock_name = recent_diagnoses[0]["name"]
|
||
except Exception:
|
||
recent_diagnoses = []
|
||
|
||
prompt = f"""请在 A 股作战台语境下,对 {ts_code} 做一次系统化个股会诊。
|
||
|
||
诊断模式: {mode_map[mode]}
|
||
|
||
今日作战结论:
|
||
{strategy_board}
|
||
|
||
最新推荐池中该股记录:
|
||
{json.dumps(latest_rec, ensure_ascii=False, default=str) if latest_rec else "不在最新推荐池"}
|
||
|
||
热门板块:
|
||
{hot_sectors}
|
||
|
||
资金流:
|
||
{capital_flow}
|
||
|
||
K线与价格行为:
|
||
{kline}
|
||
|
||
技术信号:
|
||
{technical_signal}
|
||
|
||
最近诊断:
|
||
{json.dumps(recent_diagnoses, ensure_ascii=False, default=str)}
|
||
|
||
输出要求:
|
||
- 先给明确结论,只能是「可操作 / 重点关注 / 观察 / 回避」
|
||
- 明确当前动作、触发条件、失效条件、仓位边界、下一步观察点
|
||
- 分析顺序必须是:资金与主线 > 量价与价格行为 > 位置与边界 > 技术指标备注
|
||
- A 股优先看资金顺势、主线板块、量价承接、价格行为和位置;技术指标只做节奏与风控确认
|
||
- RSI、MACD、KDJ 的超买超卖不能单独决定买卖,也不能放在核心依据第一位
|
||
- 输出必须包含以下小节:当前结论、资金与主线、量价与价格行为、位置与边界、执行动作、风险清单、技术指标备注
|
||
- 不写传统研报,不堆原始数据,不承诺收益
|
||
- 用 Markdown 输出,保持简洁"""
|
||
|
||
resp = await chat_completion([
|
||
{
|
||
"role": "system",
|
||
"content": (
|
||
"你是 A 股投研作战台的个股会诊智能体。"
|
||
"你必须优先分析资金流向、主线板块、量价关系、价格行为和位置边界,"
|
||
"技术指标只能作为最后的节奏与风控备注。"
|
||
"输出可执行但带风险边界的会诊结论。"
|
||
),
|
||
},
|
||
{"role": "user", "content": prompt},
|
||
])
|
||
if not resp or not resp.content:
|
||
return json.dumps({"error": "个股会诊生成失败,LLM 未返回内容"}, ensure_ascii=False)
|
||
|
||
diagnosis = resp.content.strip()
|
||
try:
|
||
async with get_db() as db:
|
||
await db.execute(
|
||
tables.stock_diagnoses_table.insert().values(
|
||
ts_code=ts_code,
|
||
name=stock_name,
|
||
diagnosis_mode=mode,
|
||
diagnosis=diagnosis,
|
||
)
|
||
)
|
||
await db.commit()
|
||
except Exception as e:
|
||
logger.warning(f"保存聊天会诊结果失败 {ts_code}: {e}")
|
||
|
||
return json.dumps({
|
||
"ts_code": ts_code,
|
||
"name": stock_name,
|
||
"mode": mode,
|
||
"diagnosis": diagnosis,
|
||
"saved": True,
|
||
}, ensure_ascii=False, default=str)
|
||
|
||
|
||
async def _get_sector_performance(sector_name: str) -> str:
|
||
"""获取板块表现数据"""
|
||
from app.engine.recommender import get_latest_sectors
|
||
sectors = await get_latest_sectors()
|
||
# 模糊匹配板块名称
|
||
matched = [s for s in sectors if sector_name in s.sector_name or s.sector_name in sector_name]
|
||
if not matched:
|
||
# 返回所有热门板块概览
|
||
data = [{"sector_name": s.sector_name, "pct_change": s.pct_change,
|
||
"capital_inflow": s.capital_inflow, "limit_up_count": s.limit_up_count,
|
||
"heat_score": s.heat_score, "stage": s.stage}
|
||
for s in sectors[:10]]
|
||
return json.dumps({"matched": False, "available_sectors": data}, ensure_ascii=False, default=str)
|
||
data = _clean_for_json([s.model_dump() for s in matched])
|
||
return json.dumps({"matched": True, "sectors": data}, ensure_ascii=False, default=str)
|
||
|
||
|
||
async def _get_realtime_indices() -> str:
|
||
"""获取指数实时行情数据(盘中用腾讯实时数据)"""
|
||
from app.data.tencent_client import get_index_realtime
|
||
from app.config import is_market_session
|
||
|
||
is_trading = is_market_session()
|
||
|
||
try:
|
||
index_data = await get_index_realtime()
|
||
if not index_data:
|
||
return json.dumps({"error": "获取指数数据失败"}, ensure_ascii=False)
|
||
|
||
# 格式化数据
|
||
results = []
|
||
for ts_code, data in index_data.items():
|
||
results.append({
|
||
"ts_code": ts_code,
|
||
"name": data.get("name", ts_code),
|
||
"price": round(data.get("price", 0), 2),
|
||
"pct_chg": round(data.get("pct_chg", 0), 2),
|
||
"volume": data.get("volume", 0),
|
||
"amount": data.get("amount", 0),
|
||
"high": round(data.get("high", 0), 2),
|
||
"low": round(data.get("low", 0), 2),
|
||
"open": round(data.get("open", 0), 2),
|
||
"pre_close": round(data.get("pre_close", 0), 2),
|
||
})
|
||
|
||
return json.dumps({
|
||
"is_realtime": is_trading,
|
||
"mode": "盘中实时" if is_trading else "盘后收盘",
|
||
"indices": results
|
||
}, ensure_ascii=False, default=str)
|
||
except Exception as e:
|
||
logger.error(f"获取实时指数失败: {e}")
|
||
await log_error("llm_tool", f"获取实时指数失败: {e}")
|
||
return json.dumps({"error": f"获取指数数据失败: {e}"}, ensure_ascii=False)
|