astock-agent/backend/app/llm/tool_executor.py
2026-04-08 22:39:51 +08:00

149 lines
6.0 KiB
Python

"""LLM 工具执行器
根据工具名调用现有数据层,返回 JSON 字符串供 LLM 使用。
"""
import json
import logging
import math
logger = logging.getLogger(__name__)
async def execute_tool(name: str, arguments: dict) -> str:
"""执行工具调用,返回 JSON 字符串"""
try:
if name == "get_market_temperature":
return await _get_market_temperature()
elif name == "get_hot_sectors":
return await _get_hot_sectors(arguments.get("limit", 10))
elif name == "get_latest_recommendations":
return await _get_latest_recommendations()
elif name == "get_stock_kline":
return await _get_stock_kline(
arguments["ts_code"], arguments.get("days", 60)
)
elif name == "get_stock_capital_flow":
return await _get_stock_capital_flow(
arguments["ts_code"], arguments.get("days", 10)
)
elif name == "search_stock":
return await _search_stock(arguments["keyword"])
elif name == "get_stock_technical_signal":
return await _get_stock_technical_signal(arguments["ts_code"])
elif name == "get_sector_performance":
return await _get_sector_performance(arguments["sector_name"])
else:
return json.dumps({"error": f"未知工具: {name}"}, ensure_ascii=False)
except Exception as e:
logger.error(f"工具执行失败 {name}: {e}")
return json.dumps({"error": str(e)}, ensure_ascii=False)
def _clean_for_json(obj):
"""清理 float NaN/Inf 为 None"""
if isinstance(obj, dict):
return {k: _clean_for_json(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_clean_for_json(v) for v in obj]
if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
return None
return obj
async def _get_market_temperature() -> str:
from app.engine.recommender import get_latest_recommendations
result = await get_latest_recommendations()
mt = result.get("market_temp")
if not mt:
return json.dumps({"error": "暂无市场温度数据"}, ensure_ascii=False)
return json.dumps(mt.model_dump(), ensure_ascii=False, default=str)
async def _get_hot_sectors(limit: int) -> str:
from app.engine.recommender import get_latest_sectors
sectors = await get_latest_sectors()
data = [s.model_dump() for s in sectors[:limit]]
return json.dumps(data, ensure_ascii=False, default=str)
async def _get_latest_recommendations() -> str:
from app.engine.recommender import get_latest_recommendations
result = await get_latest_recommendations()
recs = result.get("recommendations", [])
data = [r.model_dump(exclude={"created_at"}) for r in recs]
return json.dumps(data, ensure_ascii=False, default=str)
async def _get_stock_kline(ts_code: str, days: int) -> str:
from app.data.tushare_client import tushare_client
from app.analysis.technical import add_all_indicators
df = tushare_client.get_stock_daily(ts_code, days=days)
if df.empty:
return json.dumps({"error": f"未找到 {ts_code} 的K线数据"}, ensure_ascii=False)
df = df.sort_values("trade_date").reset_index(drop=True)
df = add_all_indicators(df)
# 只返回最近 20 条以节省 token
records = df.tail(20).to_dict(orient="records")
records = _clean_for_json(records)
return json.dumps(records, ensure_ascii=False, default=str)
async def _get_stock_capital_flow(ts_code: str, days: int) -> str:
from app.data.tushare_client import tushare_client
df = tushare_client.get_stock_moneyflow(ts_code, days=days)
if df.empty:
return json.dumps({"error": f"未找到 {ts_code} 的资金流向数据"}, ensure_ascii=False)
df = df.sort_values("trade_date")
records = []
for _, row in df.iterrows():
main_net = (
(row.get("buy_elg_amount", 0) or 0) - (row.get("sell_elg_amount", 0) or 0) +
(row.get("buy_lg_amount", 0) or 0) - (row.get("sell_lg_amount", 0) or 0)
)
records.append({
"trade_date": row["trade_date"],
"main_net_inflow": round(main_net, 2),
"net_mf_amount": round(float(row.get("net_mf_amount", 0) or 0), 2),
})
return json.dumps(records, ensure_ascii=False, default=str)
async def _search_stock(keyword: str) -> str:
from app.data.tushare_client import tushare_client
basic = tushare_client.get_stock_basic()
if basic.empty:
return json.dumps([], ensure_ascii=False)
matches = basic[
basic["name"].str.contains(keyword, na=False) |
basic["ts_code"].str.contains(keyword, na=False) |
basic["symbol"].str.contains(keyword, na=False)
].head(10)
data = matches[["ts_code", "name", "industry"]].to_dict(orient="records")
return json.dumps(data, ensure_ascii=False, default=str)
async def _get_stock_technical_signal(ts_code: str) -> str:
"""获取个股技术面信号详情"""
from app.analysis.signals import generate_signals
signal = generate_signals(ts_code)
data = _clean_for_json(signal.model_dump())
return json.dumps(data, ensure_ascii=False, default=str)
async def _get_sector_performance(sector_name: str) -> str:
"""获取板块表现数据"""
from app.engine.recommender import get_latest_sectors
sectors = await get_latest_sectors()
# 模糊匹配板块名称
matched = [s for s in sectors if sector_name in s.sector_name or s.sector_name in sector_name]
if not matched:
# 返回所有热门板块概览
data = [{"sector_name": s.sector_name, "pct_change": s.pct_change,
"capital_inflow": s.capital_inflow, "limit_up_count": s.limit_up_count,
"heat_score": s.heat_score, "stage": s.stage}
for s in sectors[:10]]
return json.dumps({"matched": False, "available_sectors": data}, ensure_ascii=False, default=str)
data = _clean_for_json([s.model_dump() for s in matched])
return json.dumps({"matched": True, "sectors": data}, ensure_ascii=False, default=str)