astock-agent/backend/app/engine/recommender.py
2026-04-14 20:51:38 +08:00

348 lines
14 KiB
Python

"""推荐引擎
管理推荐状态,提供推荐查询接口。
将筛选结果持久化并管理历史推荐。
"""
import logging
from datetime import datetime
from app.engine.screener import run_screening
from app.data.models import Recommendation, MarketTemperature, SectorInfo
from app.db.database import get_db
from app.db import tables
from app.config import settings
logger = logging.getLogger(__name__)
# 内存中的最新推荐结果
_latest_result: dict | None = None
async def refresh_recommendations(trade_date: str = None, scan_session: str = "manual") -> dict:
"""刷新推荐列表"""
global _latest_result
result = await run_screening(trade_date)
# 给每条推荐添加 scan_session
for rec in result.get("recommendations", []):
rec.scan_session = scan_session
rec.created_at = datetime.now()
_latest_result = result
# 持久化到数据库
await _save_to_db(result)
# 异步 AI 深度分析(不阻塞返回)
if settings.deepseek_api_key:
import asyncio
from app.llm.analysis_agent import analyze_recommendations
asyncio.create_task(analyze_recommendations(result))
return result
async def get_latest_recommendations() -> dict:
"""获取最新推荐结果"""
if _latest_result:
return _latest_result
# 如果内存中没有,从数据库加载今日的
return await _load_today_from_db()
async def get_latest_sectors() -> list[SectorInfo]:
"""获取最新的板块热度数据(只读缓存,不触发扫描)"""
if _latest_result and _latest_result.get("hot_sectors"):
return _latest_result["hot_sectors"]
# 内存中没有,从数据库加载
return await _load_sectors_from_db()
async def get_recommendation_history(days: int = 7) -> list[dict]:
"""获取历史推荐记录,按日期分组返回"""
from datetime import timedelta
import json
start = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
async with get_db() as db:
from sqlalchemy import text
# 查询所有历史推荐,按 ts_code 去重(每天取最新一条)
stmt = text(
"SELECT * FROM recommendations "
"WHERE created_at >= :start "
"AND id IN ("
" SELECT MAX(id) FROM recommendations "
" WHERE created_at >= :start "
" GROUP BY date(created_at), ts_code"
") "
"ORDER BY created_at DESC, score DESC"
)
result = await db.execute(stmt, {"start": start})
rows = result.fetchall()
# 按日期分组
grouped: dict[str, list[dict]] = {}
for row in rows:
r = row._mapping
# SQLite created_at 是字符串 "YYYY-MM-DD HH:MM:SS"
ca = r["created_at"]
if ca:
date_str = str(ca)[:10] # 取前10字符即日期部分
created_at_str = str(ca)
else:
date_str = "unknown"
created_at_str = None
rec_dict = {
"ts_code": r["ts_code"],
"name": r["name"],
"sector": r["sector"] or "",
"score": r["score"] or 0,
"level": _score_to_level_static(r["score"] or 0),
"signal": r["signal"] or "HOLD",
"market_temp_score": r["market_temp_score"] or 0,
"sector_score": r["sector_score"] or 0,
"capital_score": r["capital_score"] or 0,
"technical_score": r["technical_score"] or 0,
"position_score": r.get("position_score") or 50,
"valuation_score": r.get("valuation_score") or 50,
"entry_price": r["entry_price"],
"target_price": r["target_price"],
"stop_loss": r["stop_loss"],
"reasons": json.loads(r["reasons"]) if r["reasons"] else [],
"risk_note": "",
"strategy": r.get("strategy") or "trend_breakout",
"entry_signal_type": r.get("entry_signal_type") or "none",
"llm_analysis": r.get("llm_analysis") or "",
"llm_score": r.get("llm_score"),
"scan_session": r["scan_session"] or "",
"created_at": created_at_str,
}
if date_str not in grouped:
grouped[date_str] = []
grouped[date_str].append(rec_dict)
# 转为列表,按日期降序
result_list = []
for date_str in sorted(grouped.keys(), reverse=True):
recs = grouped[date_str]
buy_count = sum(1 for r in recs if r["signal"] == "BUY")
avg_score = round(sum(r["score"] for r in recs) / len(recs), 1) if recs else 0
result_list.append({
"date": date_str,
"count": len(recs),
"buy_count": buy_count,
"avg_score": avg_score,
"recommendations": recs,
})
return result_list
def _score_to_level_static(score: float) -> str:
"""根据评分确定推荐等级"""
if score >= 75:
return "强烈推荐"
elif score >= 60:
return "推荐"
elif score >= 45:
return "关注"
else:
return "观望"
async def _save_to_db(result: dict):
"""将推荐结果保存到数据库"""
try:
async with get_db() as db:
from sqlalchemy import text
# 保存市场温度
mt = result.get("market_temp")
if mt:
# 使用 INSERT OR REPLACE 确保重复扫描能更新数据
stmt = text(
"INSERT OR REPLACE INTO market_temperature "
"(trade_date, up_count, down_count, limit_up_count, limit_down_count, "
"max_streak, broken_rate, temperature) "
"VALUES (:td, :up, :down, :lu, :ld, :ms, :br, :temp)"
)
await db.execute(stmt, {
"td": mt.trade_date,
"up": mt.up_count,
"down": mt.down_count,
"lu": mt.limit_up_count,
"ld": mt.limit_down_count,
"ms": mt.max_streak,
"br": mt.broken_rate,
"temp": mt.temperature,
})
# 保存板块热度(先清除同一 trade_date 的旧数据,避免重复)
trade_date_val = mt.trade_date if mt else ""
if trade_date_val:
await db.execute(
text("DELETE FROM sector_heat WHERE trade_date = :td"),
{"td": trade_date_val},
)
for sector in result.get("hot_sectors", []):
stmt = tables.sector_heat_table.insert().values(
sector_code=sector.sector_code,
sector_name=sector.sector_name,
pct_change=sector.pct_change,
capital_inflow=sector.capital_inflow,
limit_up_count=sector.limit_up_count,
heat_score=sector.heat_score,
trade_date=trade_date_val,
)
await db.execute(stmt)
# 保存推荐(先清除今日旧推荐,避免重复)
today_str = datetime.now().strftime("%Y-%m-%d")
await db.execute(
text("DELETE FROM recommendations WHERE date(created_at) = :today"),
{"today": today_str},
)
import json
for rec in result.get("recommendations", []):
stmt = tables.recommendations_table.insert().values(
ts_code=rec.ts_code,
name=rec.name,
sector=rec.sector,
score=rec.score,
market_temp_score=rec.market_temp_score,
sector_score=rec.sector_score,
capital_score=rec.capital_score,
technical_score=rec.technical_score,
position_score=rec.position_score,
valuation_score=rec.valuation_score,
signal=rec.signal,
entry_price=rec.entry_price,
target_price=rec.target_price,
stop_loss=rec.stop_loss,
reasons=json.dumps(rec.reasons, ensure_ascii=False),
llm_analysis=rec.llm_analysis,
strategy=rec.strategy,
entry_signal_type=rec.entry_signal_type,
llm_score=rec.llm_score,
scan_session=rec.scan_session,
)
await db.execute(stmt)
await db.commit()
logger.info(f"已保存 {len(result.get('recommendations', []))} 条推荐到数据库")
except Exception as e:
logger.error(f"保存推荐到数据库失败: {e}")
async def _load_today_from_db() -> dict:
"""从数据库加载今日推荐"""
today = datetime.now().strftime("%Y-%m-%d")
try:
async with get_db() as db:
from sqlalchemy import text
import json
# 加载市场温度(按 trade_date 取最新交易日)
result = await db.execute(
text("SELECT * FROM market_temperature ORDER BY trade_date DESC LIMIT 1")
)
mt_row = result.fetchone()
market_temp = None
if mt_row:
m = mt_row._mapping
market_temp = MarketTemperature(
trade_date=m["trade_date"],
up_count=m["up_count"],
down_count=m["down_count"],
limit_up_count=m["limit_up_count"],
limit_down_count=m["limit_down_count"],
max_streak=m["max_streak"],
broken_rate=m["broken_rate"],
temperature=m["temperature"],
)
# 加载推荐(取最近一个有数据的日期,按 ts_code 去重)
result = await db.execute(
text("SELECT * FROM recommendations "
"WHERE date(created_at) = (SELECT date(created_at) FROM recommendations ORDER BY created_at DESC LIMIT 1) "
"AND id IN (SELECT MAX(id) FROM recommendations "
" WHERE date(created_at) = (SELECT date(created_at) FROM recommendations ORDER BY created_at DESC LIMIT 1) "
" GROUP BY ts_code) "
"ORDER BY score DESC")
)
rows = result.fetchall()
recommendations = []
for row in rows:
r = row._mapping
recommendations.append(Recommendation(
ts_code=r["ts_code"],
name=r["name"],
sector=r["sector"] or "",
score=r["score"] or 0,
market_temp_score=r["market_temp_score"] or 0,
sector_score=r["sector_score"] or 0,
capital_score=r["capital_score"] or 0,
technical_score=r["technical_score"] or 0,
position_score=r.get("position_score") or 50,
valuation_score=r.get("valuation_score") or 50,
signal=r["signal"] or "HOLD",
entry_price=r["entry_price"],
target_price=r["target_price"],
stop_loss=r["stop_loss"],
reasons=json.loads(r["reasons"]) if r["reasons"] else [],
llm_analysis=r.get("llm_analysis") or "",
strategy=r.get("strategy") or "trend_breakout",
entry_signal_type=r.get("entry_signal_type") or "none",
llm_score=r.get("llm_score"),
scan_session=r["scan_session"] or "",
))
return {
"market_temp": market_temp,
"hot_sectors": [],
"capital_filtered": [],
"recommendations": recommendations,
}
except Exception as e:
logger.error(f"从数据库加载推荐失败: {e}")
return {"market_temp": None, "hot_sectors": [], "capital_filtered": [], "recommendations": []}
async def _load_sectors_from_db() -> list[SectorInfo]:
"""从数据库加载最近的板块热度数据"""
try:
async with get_db() as db:
from sqlalchemy import text
result = await db.execute(
text(
"SELECT * FROM sector_heat "
"WHERE trade_date = (SELECT MAX(trade_date) FROM sector_heat) "
"AND id IN ("
" SELECT MAX(id) FROM sector_heat "
" WHERE trade_date = (SELECT MAX(trade_date) FROM sector_heat) "
" GROUP BY sector_code"
") "
"ORDER BY heat_score DESC"
)
)
rows = result.fetchall()
sectors = []
for row in rows:
r = row._mapping
sectors.append(SectorInfo(
sector_code=r["sector_code"],
sector_name=r["sector_name"],
pct_change=r["pct_change"] or 0,
capital_inflow=r["capital_inflow"] or 0,
limit_up_count=r["limit_up_count"] or 0,
days_continuous=0,
heat_score=r["heat_score"] or 0,
))
return sectors
except Exception as e:
logger.error(f"从数据库加载板块数据失败: {e}")
return []