149 lines
6.0 KiB
Python
149 lines
6.0 KiB
Python
"""LLM 工具执行器
|
|
|
|
根据工具名调用现有数据层,返回 JSON 字符串供 LLM 使用。
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import math
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def execute_tool(name: str, arguments: dict) -> str:
|
|
"""执行工具调用,返回 JSON 字符串"""
|
|
try:
|
|
if name == "get_market_temperature":
|
|
return await _get_market_temperature()
|
|
elif name == "get_hot_sectors":
|
|
return await _get_hot_sectors(arguments.get("limit", 10))
|
|
elif name == "get_latest_recommendations":
|
|
return await _get_latest_recommendations()
|
|
elif name == "get_stock_kline":
|
|
return await _get_stock_kline(
|
|
arguments["ts_code"], arguments.get("days", 60)
|
|
)
|
|
elif name == "get_stock_capital_flow":
|
|
return await _get_stock_capital_flow(
|
|
arguments["ts_code"], arguments.get("days", 10)
|
|
)
|
|
elif name == "search_stock":
|
|
return await _search_stock(arguments["keyword"])
|
|
elif name == "get_stock_technical_signal":
|
|
return await _get_stock_technical_signal(arguments["ts_code"])
|
|
elif name == "get_sector_performance":
|
|
return await _get_sector_performance(arguments["sector_name"])
|
|
else:
|
|
return json.dumps({"error": f"未知工具: {name}"}, ensure_ascii=False)
|
|
except Exception as e:
|
|
logger.error(f"工具执行失败 {name}: {e}")
|
|
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
|
|
|
|
|
def _clean_for_json(obj):
|
|
"""清理 float NaN/Inf 为 None"""
|
|
if isinstance(obj, dict):
|
|
return {k: _clean_for_json(v) for k, v in obj.items()}
|
|
if isinstance(obj, list):
|
|
return [_clean_for_json(v) for v in obj]
|
|
if isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)):
|
|
return None
|
|
return obj
|
|
|
|
|
|
async def _get_market_temperature() -> str:
|
|
from app.engine.recommender import get_latest_recommendations
|
|
result = await get_latest_recommendations()
|
|
mt = result.get("market_temp")
|
|
if not mt:
|
|
return json.dumps({"error": "暂无市场温度数据"}, ensure_ascii=False)
|
|
return json.dumps(mt.model_dump(), ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _get_hot_sectors(limit: int) -> str:
|
|
from app.engine.recommender import get_latest_sectors
|
|
sectors = await get_latest_sectors()
|
|
data = [s.model_dump() for s in sectors[:limit]]
|
|
return json.dumps(data, ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _get_latest_recommendations() -> str:
|
|
from app.engine.recommender import get_latest_recommendations
|
|
result = await get_latest_recommendations()
|
|
recs = result.get("recommendations", [])
|
|
data = [r.model_dump(exclude={"created_at"}) for r in recs]
|
|
return json.dumps(data, ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _get_stock_kline(ts_code: str, days: int) -> str:
|
|
from app.data.tushare_client import tushare_client
|
|
from app.analysis.technical import add_all_indicators
|
|
df = tushare_client.get_stock_daily(ts_code, days=days)
|
|
if df.empty:
|
|
return json.dumps({"error": f"未找到 {ts_code} 的K线数据"}, ensure_ascii=False)
|
|
df = df.sort_values("trade_date").reset_index(drop=True)
|
|
df = add_all_indicators(df)
|
|
# 只返回最近 20 条以节省 token
|
|
records = df.tail(20).to_dict(orient="records")
|
|
records = _clean_for_json(records)
|
|
return json.dumps(records, ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _get_stock_capital_flow(ts_code: str, days: int) -> str:
|
|
from app.data.tushare_client import tushare_client
|
|
df = tushare_client.get_stock_moneyflow(ts_code, days=days)
|
|
if df.empty:
|
|
return json.dumps({"error": f"未找到 {ts_code} 的资金流向数据"}, ensure_ascii=False)
|
|
df = df.sort_values("trade_date")
|
|
records = []
|
|
for _, row in df.iterrows():
|
|
main_net = (
|
|
(row.get("buy_elg_amount", 0) or 0) - (row.get("sell_elg_amount", 0) or 0) +
|
|
(row.get("buy_lg_amount", 0) or 0) - (row.get("sell_lg_amount", 0) or 0)
|
|
)
|
|
records.append({
|
|
"trade_date": row["trade_date"],
|
|
"main_net_inflow": round(main_net, 2),
|
|
"net_mf_amount": round(float(row.get("net_mf_amount", 0) or 0), 2),
|
|
})
|
|
return json.dumps(records, ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _search_stock(keyword: str) -> str:
|
|
from app.data.tushare_client import tushare_client
|
|
basic = tushare_client.get_stock_basic()
|
|
if basic.empty:
|
|
return json.dumps([], ensure_ascii=False)
|
|
matches = basic[
|
|
basic["name"].str.contains(keyword, na=False) |
|
|
basic["ts_code"].str.contains(keyword, na=False) |
|
|
basic["symbol"].str.contains(keyword, na=False)
|
|
].head(10)
|
|
data = matches[["ts_code", "name", "industry"]].to_dict(orient="records")
|
|
return json.dumps(data, ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _get_stock_technical_signal(ts_code: str) -> str:
|
|
"""获取个股技术面信号详情"""
|
|
from app.analysis.signals import generate_signals
|
|
signal = generate_signals(ts_code)
|
|
data = _clean_for_json(signal.model_dump())
|
|
return json.dumps(data, ensure_ascii=False, default=str)
|
|
|
|
|
|
async def _get_sector_performance(sector_name: str) -> str:
|
|
"""获取板块表现数据"""
|
|
from app.engine.recommender import get_latest_sectors
|
|
sectors = await get_latest_sectors()
|
|
# 模糊匹配板块名称
|
|
matched = [s for s in sectors if sector_name in s.sector_name or s.sector_name in sector_name]
|
|
if not matched:
|
|
# 返回所有热门板块概览
|
|
data = [{"sector_name": s.sector_name, "pct_change": s.pct_change,
|
|
"capital_inflow": s.capital_inflow, "limit_up_count": s.limit_up_count,
|
|
"heat_score": s.heat_score, "stage": s.stage}
|
|
for s in sectors[:10]]
|
|
return json.dumps({"matched": False, "available_sectors": data}, ensure_ascii=False, default=str)
|
|
data = _clean_for_json([s.model_dump() for s in matched])
|
|
return json.dumps({"matched": True, "sectors": data}, ensure_ascii=False, default=str)
|