astock-agent/backend/app/llm/tool_executor.py
2026-06-01 21:29:26 +08:00

414 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM 工具执行器
根据工具名调用现有数据层,返回 JSON 字符串供 LLM 使用。
"""
import json
import logging
import math
from app.db.error_logger import log_error
logger = logging.getLogger(__name__)
_chat_user_context: dict | None = None
def set_chat_user_context(user: dict | None) -> None:
global _chat_user_context
_chat_user_context = user
async def execute_tool(name: str, arguments: dict) -> str:
"""执行工具调用,返回 JSON 字符串"""
try:
if name == "get_strategy_board":
return await _get_strategy_board()
elif name == "get_market_temperature":
return await _get_market_temperature()
elif name == "get_hot_sectors":
return await _get_hot_sectors(arguments.get("limit", 10))
elif name == "get_latest_recommendations":
return await _get_latest_recommendations()
elif name == "get_user_watchlist_snapshot":
return await _get_user_watchlist_snapshot()
elif name == "get_stock_kline":
return await _get_stock_kline(
arguments["ts_code"], arguments.get("days", 60)
)
elif name == "get_stock_capital_flow":
return await _get_stock_capital_flow(
arguments["ts_code"], arguments.get("days", 10)
)
elif name == "search_stock":
return await _search_stock(arguments["keyword"])
elif name == "get_stock_technical_signal":
return await _get_stock_technical_signal(arguments["ts_code"])
elif name == "diagnose_stock":
return await _diagnose_stock(arguments["ts_code"], arguments.get("mode", "entry"))
elif name == "get_sector_performance":
return await _get_sector_performance(arguments["sector_name"])
elif name == "get_realtime_indices":
return await _get_realtime_indices()
else:
return json.dumps({"error": f"未知工具: {name}"}, ensure_ascii=False)
except Exception as e:
logger.error(f"工具执行失败 {name}: {e}")
await log_error(
"llm_tool",
f"工具执行失败 {name}: {e}",
detail=f"arguments={json.dumps(arguments, ensure_ascii=False, default=str)}",
)
return json.dumps({"error": str(e)}, ensure_ascii=False)
def _clean_for_json(obj):
"""清理 float NaN/Inf 为 None"""
if isinstance(obj, dict):
return {k: _clean_for_json(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_clean_for_json(v) for v in obj]
if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
return None
return obj
async def _get_strategy_board() -> str:
"""返回当日市场概况 + 策略状态,替代已删除的 strategy_board 模块。"""
from app.engine.recommender import get_latest_recommendations
result = await get_latest_recommendations()
mt = result.get("market_temp")
recs = result.get("recommendations", [])
strategy = result.get("strategy_profile") or {}
actionable = [r for r in recs if r.action_plan == "可操作"]
watch = [r for r in recs if r.action_plan == "重点关注"]
payload = {
"market_temperature": mt.temperature if mt else 0,
"up_count": mt.up_count if mt else 0,
"down_count": mt.down_count if mt else 0,
"limit_up_count": mt.limit_up_count if mt else 0,
"strategy_name": strategy.get("name", "未知"),
"market_stance": strategy.get("market_stance", ""),
"decision_note": strategy.get("decision_note", ""),
"actionable_count": len(actionable),
"watch_count": len(watch),
"total_recommendations": len(recs),
}
return json.dumps(_clean_for_json(payload), ensure_ascii=False, default=str)
async def _get_market_temperature() -> str:
from app.engine.recommender import get_latest_recommendations
result = await get_latest_recommendations()
mt = result.get("market_temp")
if not mt:
return json.dumps({"error": "暂无市场温度数据"}, ensure_ascii=False)
return json.dumps(mt.model_dump(), ensure_ascii=False, default=str)
async def _get_hot_sectors(limit: int) -> str:
from app.engine.recommender import get_latest_sectors
sectors = await get_latest_sectors()
data = [s.model_dump() for s in sectors[:limit]]
return json.dumps(data, ensure_ascii=False, default=str)
async def _get_latest_recommendations() -> str:
from app.engine.recommender import get_latest_recommendations
result = await get_latest_recommendations()
recs = result.get("recommendations", [])
data = []
for rec in recs:
item = rec.model_dump(exclude={"created_at"})
item["llm_analysis"] = ""
data.append(item)
return json.dumps(data, ensure_ascii=False, default=str)
async def _get_user_watchlist_snapshot() -> str:
from sqlalchemy import text
from app.db.database import get_db
user_id = (_chat_user_context or {}).get("id")
if not user_id:
return json.dumps({"error": "当前会话缺少用户上下文"}, ensure_ascii=False)
async with get_db() as db:
rows = (await db.execute(
text(
"SELECT w.id, w.ts_code, w.name, w.note, w.watch_group, w.cost_price, w.updated_at, "
"a.conclusion, a.advice, a.trigger_condition, a.risk_note, a.summary, "
"a.analysis_mode, a.created_at AS analysis_created_at "
"FROM user_watchlists w "
"LEFT JOIN watchlist_analyses a ON a.id = ("
" SELECT id FROM watchlist_analyses "
" WHERE watchlist_id = w.id ORDER BY created_at DESC, id DESC LIMIT 1"
") "
"WHERE w.user_id = :uid AND COALESCE(w.is_active, 1) = 1 "
"ORDER BY CASE w.watch_group "
" WHEN 'focus' THEN 1 "
" WHEN 'candidate' THEN 2 "
" WHEN 'holding' THEN 3 "
" ELSE 4 END, w.updated_at DESC, w.id DESC"
),
{"uid": user_id},
)).fetchall()
items = [dict(row._mapping) for row in rows]
grouped = {"focus": 0, "candidate": 0, "holding": 0, "observe": 0}
for item in items:
key = item.get("watch_group") or "observe"
if key in grouped:
grouped[key] += 1
actionable = [
item for item in items
if item.get("conclusion") in {"可操作", "重点关注"}
][:8]
payload = {
"count": len(items),
"group_counts": grouped,
"high_priority": actionable,
"items": items[:20],
}
return json.dumps(_clean_for_json(payload), ensure_ascii=False, default=str)
async def _get_stock_kline(ts_code: str, days: int) -> str:
from app.data.tushare_client import tushare_client
from app.analysis.technical import add_all_indicators
df = tushare_client.get_stock_daily(ts_code, days=days)
if df.empty:
return json.dumps({"error": f"未找到 {ts_code} 的K线数据"}, ensure_ascii=False)
df = df.sort_values("trade_date").reset_index(drop=True)
df = add_all_indicators(df)
# 只返回最近 20 条以节省 token
records = df.tail(20).to_dict(orient="records")
records = _clean_for_json(records)
return json.dumps(records, ensure_ascii=False, default=str)
async def _get_stock_capital_flow(ts_code: str, days: int) -> str:
from app.data.tushare_client import tushare_client
df = tushare_client.get_stock_moneyflow(ts_code, days=days)
if df.empty:
return json.dumps({"error": f"未找到 {ts_code} 的资金流向数据"}, ensure_ascii=False)
df = df.sort_values("trade_date")
records = []
for _, row in df.iterrows():
main_net = (
(row.get("buy_elg_amount", 0) or 0) - (row.get("sell_elg_amount", 0) or 0) +
(row.get("buy_lg_amount", 0) or 0) - (row.get("sell_lg_amount", 0) or 0)
)
records.append({
"trade_date": row["trade_date"],
"main_net_inflow": round(main_net, 2),
"net_mf_amount": round(float(row.get("net_mf_amount", 0) or 0), 2),
})
return json.dumps(records, ensure_ascii=False, default=str)
async def _search_stock(keyword: str) -> str:
from app.data.tushare_client import tushare_client
basic = tushare_client.get_stock_basic()
if basic.empty:
return json.dumps([], ensure_ascii=False)
matches = basic[
basic["name"].str.contains(keyword, na=False) |
basic["ts_code"].str.contains(keyword, na=False) |
basic["symbol"].str.contains(keyword, na=False)
].head(10)
data = matches[["ts_code", "name", "industry"]].to_dict(orient="records")
return json.dumps(data, ensure_ascii=False, default=str)
async def _get_stock_technical_signal(ts_code: str) -> str:
"""获取个股技术面信号详情"""
from app.analysis.signals import generate_signals
signal = generate_signals(ts_code)
data = _clean_for_json(signal.model_dump())
return json.dumps(data, ensure_ascii=False, default=str)
async def _diagnose_stock(ts_code: str, mode: str = "entry") -> str:
"""生成系统化个股会诊,供作战问答智能体调用。"""
from sqlalchemy import text
from app.db.database import get_db
from app.db import tables
from app.llm.client import chat_completion
mode_map = {
"entry": "建仓前诊断",
"holding": "持仓复核",
"review": "回撤复盘",
"tracking": "继续跟踪",
}
mode = mode if mode in mode_map else "entry"
strategy_board = await _get_strategy_board()
latest_recommendations = await _get_latest_recommendations()
hot_sectors = await _get_hot_sectors(8)
kline = await _get_stock_kline(ts_code, 80)
capital_flow = await _get_stock_capital_flow(ts_code, 15)
technical_signal = await _get_stock_technical_signal(ts_code)
latest_rec = None
stock_name = ts_code
try:
recs = json.loads(latest_recommendations)
latest_rec = next((item for item in recs if item.get("ts_code") == ts_code), None)
except Exception:
latest_rec = None
if latest_rec and latest_rec.get("name"):
stock_name = latest_rec["name"]
recent_diagnoses = []
try:
async with get_db() as db:
rows = (await db.execute(
text(
"SELECT name, diagnosis_mode, diagnosis, created_at "
"FROM stock_diagnoses WHERE ts_code = :ts_code "
"ORDER BY created_at DESC, id DESC LIMIT 3"
),
{"ts_code": ts_code},
)).fetchall()
recent_diagnoses = [dict(row._mapping) for row in rows]
if recent_diagnoses and recent_diagnoses[0].get("name"):
stock_name = recent_diagnoses[0]["name"]
except Exception:
recent_diagnoses = []
prompt = f"""请在 A 股作战台语境下,对 {ts_code} 做一次系统化个股会诊。
诊断模式: {mode_map[mode]}
今日作战结论:
{strategy_board}
最新推荐池中该股记录:
{json.dumps(latest_rec, ensure_ascii=False, default=str) if latest_rec else "不在最新推荐池"}
热门板块:
{hot_sectors}
资金流:
{capital_flow}
K线与价格行为:
{kline}
技术信号:
{technical_signal}
最近诊断:
{json.dumps(recent_diagnoses, ensure_ascii=False, default=str)}
输出要求:
- 先给明确结论,只能是「可操作 / 重点关注 / 观察 / 回避」
- 明确当前动作、触发条件、失效条件、仓位边界、下一步观察点
- 分析顺序必须是:资金与主线 > 量价与价格行为 > 位置与边界 > 技术指标备注
- A 股优先看资金顺势、主线板块、量价承接、价格行为和位置;技术指标只做节奏与风控确认
- RSI、MACD、KDJ 的超买超卖不能单独决定买卖,也不能放在核心依据第一位
- 输出必须包含以下小节:当前结论、资金与主线、量价与价格行为、位置与边界、执行动作、风险清单、技术指标备注
- 不写传统研报,不堆原始数据,不承诺收益
- 用 Markdown 输出,保持简洁"""
resp = await chat_completion([
{
"role": "system",
"content": (
"你是 A 股投研作战台的个股会诊智能体。"
"你必须优先分析资金流向、主线板块、量价关系、价格行为和位置边界,"
"技术指标只能作为最后的节奏与风控备注。"
"输出可执行但带风险边界的会诊结论。"
),
},
{"role": "user", "content": prompt},
])
if not resp or not resp.content:
return json.dumps({"error": "个股会诊生成失败LLM 未返回内容"}, ensure_ascii=False)
diagnosis = resp.content.strip()
try:
async with get_db() as db:
await db.execute(
tables.stock_diagnoses_table.insert().values(
ts_code=ts_code,
name=stock_name,
diagnosis_mode=mode,
diagnosis=diagnosis,
)
)
await db.commit()
except Exception as e:
logger.warning(f"保存聊天会诊结果失败 {ts_code}: {e}")
return json.dumps({
"ts_code": ts_code,
"name": stock_name,
"mode": mode,
"diagnosis": diagnosis,
"saved": True,
}, ensure_ascii=False, default=str)
async def _get_sector_performance(sector_name: str) -> str:
"""获取板块表现数据"""
from app.engine.recommender import get_latest_sectors
sectors = await get_latest_sectors()
# 模糊匹配板块名称
matched = [s for s in sectors if sector_name in s.sector_name or s.sector_name in sector_name]
if not matched:
# 返回所有热门板块概览
data = [{"sector_name": s.sector_name, "pct_change": s.pct_change,
"capital_inflow": s.capital_inflow, "limit_up_count": s.limit_up_count,
"heat_score": s.heat_score, "stage": s.stage}
for s in sectors[:10]]
return json.dumps({"matched": False, "available_sectors": data}, ensure_ascii=False, default=str)
data = _clean_for_json([s.model_dump() for s in matched])
return json.dumps({"matched": True, "sectors": data}, ensure_ascii=False, default=str)
async def _get_realtime_indices() -> str:
"""获取指数实时行情数据(盘中用腾讯实时数据)"""
from app.data.tencent_client import get_index_realtime
from app.config import is_market_session
is_trading = is_market_session()
try:
index_data = await get_index_realtime()
if not index_data:
return json.dumps({"error": "获取指数数据失败"}, ensure_ascii=False)
# 格式化数据
results = []
for ts_code, data in index_data.items():
results.append({
"ts_code": ts_code,
"name": data.get("name", ts_code),
"price": round(data.get("price", 0), 2),
"pct_chg": round(data.get("pct_chg", 0), 2),
"volume": data.get("volume", 0),
"amount": data.get("amount", 0),
"high": round(data.get("high", 0), 2),
"low": round(data.get("low", 0), 2),
"open": round(data.get("open", 0), 2),
"pre_close": round(data.get("pre_close", 0), 2),
})
return json.dumps({
"is_realtime": is_trading,
"mode": "盘中实时" if is_trading else "盘后收盘",
"indices": results
}, ensure_ascii=False, default=str)
except Exception as e:
logger.error(f"获取实时指数失败: {e}")
await log_error("llm_tool", f"获取实时指数失败: {e}")
return json.dumps({"error": f"获取指数数据失败: {e}"}, ensure_ascii=False)