astock-agent/backend/app/engine/recommender.py
2026-04-07 20:51:00 +08:00

249 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""推荐引擎
管理推荐状态,提供推荐查询接口。
将筛选结果持久化并管理历史推荐。
"""
import logging
from datetime import datetime
from app.engine.screener import run_screening
from app.data.models import Recommendation, MarketTemperature, SectorInfo
from app.db.database import get_db
from app.db import tables
from app.config import settings
logger = logging.getLogger(__name__)
# 内存中的最新推荐结果
_latest_result: dict | None = None
async def refresh_recommendations(trade_date: str = None, scan_session: str = "manual") -> dict:
"""刷新推荐列表"""
global _latest_result
result = await run_screening(trade_date)
# 给每条推荐添加 scan_session
for rec in result.get("recommendations", []):
rec.scan_session = scan_session
rec.created_at = datetime.now()
_latest_result = result
# 持久化到数据库
await _save_to_db(result)
# 异步 LLM 增强(不阻塞返回)
if settings.deepseek_api_key:
import asyncio
from app.llm.enhancer import enhance_recommendations
asyncio.create_task(enhance_recommendations(result))
return result
async def get_latest_recommendations() -> dict:
"""获取最新推荐结果"""
if _latest_result:
return _latest_result
# 如果内存中没有,从数据库加载今日的
return await _load_today_from_db()
async def get_latest_sectors() -> list[SectorInfo]:
"""获取最新的板块热度数据(只读缓存,不触发扫描)"""
if _latest_result and _latest_result.get("hot_sectors"):
return _latest_result["hot_sectors"]
# 内存中没有,从数据库加载
return await _load_sectors_from_db()
async def get_recommendation_history(days: int = 7) -> list[dict]:
"""获取历史推荐记录"""
from datetime import timedelta
start = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
async with get_db() as db:
from sqlalchemy import select, text
stmt = text(
"SELECT * FROM recommendations WHERE created_at >= :start ORDER BY created_at DESC"
)
result = await db.execute(stmt, {"start": start})
rows = result.fetchall()
return [dict(row._mapping) for row in rows]
async def _save_to_db(result: dict):
"""将推荐结果保存到数据库"""
try:
async with get_db() as db:
from sqlalchemy import text
# 保存市场温度
mt = result.get("market_temp")
if mt:
stmt = tables.market_temperature_table.insert().values(
trade_date=mt.trade_date,
up_count=mt.up_count,
down_count=mt.down_count,
limit_up_count=mt.limit_up_count,
limit_down_count=mt.limit_down_count,
max_streak=mt.max_streak,
broken_rate=mt.broken_rate,
temperature=mt.temperature,
)
try:
await db.execute(stmt)
except Exception:
pass # 可能已存在UNIQUE 约束)
# 保存板块热度(先清除同一 trade_date 的旧数据,避免重复)
trade_date_val = mt.trade_date if mt else ""
if trade_date_val:
await db.execute(
text("DELETE FROM sector_heat WHERE trade_date = :td"),
{"td": trade_date_val},
)
for sector in result.get("hot_sectors", []):
stmt = tables.sector_heat_table.insert().values(
sector_code=sector.sector_code,
sector_name=sector.sector_name,
pct_change=sector.pct_change,
capital_inflow=sector.capital_inflow,
limit_up_count=sector.limit_up_count,
heat_score=sector.heat_score,
trade_date=trade_date_val,
)
await db.execute(stmt)
# 保存推荐
import json
for rec in result.get("recommendations", []):
stmt = tables.recommendations_table.insert().values(
ts_code=rec.ts_code,
name=rec.name,
sector=rec.sector,
score=rec.score,
market_temp_score=rec.market_temp_score,
sector_score=rec.sector_score,
capital_score=rec.capital_score,
technical_score=rec.technical_score,
position_score=rec.position_score,
valuation_score=rec.valuation_score,
signal=rec.signal,
entry_price=rec.entry_price,
target_price=rec.target_price,
stop_loss=rec.stop_loss,
reasons=json.dumps(rec.reasons, ensure_ascii=False),
llm_analysis=rec.llm_analysis,
scan_session=rec.scan_session,
)
await db.execute(stmt)
await db.commit()
logger.info(f"已保存 {len(result.get('recommendations', []))} 条推荐到数据库")
except Exception as e:
logger.error(f"保存推荐到数据库失败: {e}")
async def _load_today_from_db() -> dict:
"""从数据库加载今日推荐"""
today = datetime.now().strftime("%Y-%m-%d")
try:
async with get_db() as db:
from sqlalchemy import text
import json
# 加载市场温度
result = await db.execute(
text("SELECT * FROM market_temperature ORDER BY created_at DESC LIMIT 1")
)
mt_row = result.fetchone()
market_temp = None
if mt_row:
m = mt_row._mapping
market_temp = MarketTemperature(
trade_date=m["trade_date"],
up_count=m["up_count"],
down_count=m["down_count"],
limit_up_count=m["limit_up_count"],
limit_down_count=m["limit_down_count"],
max_streak=m["max_streak"],
broken_rate=m["broken_rate"],
temperature=m["temperature"],
)
# 加载推荐
result = await db.execute(
text("SELECT * FROM recommendations WHERE date(created_at) = :today ORDER BY score DESC"),
{"today": today}
)
rows = result.fetchall()
recommendations = []
for row in rows:
r = row._mapping
recommendations.append(Recommendation(
ts_code=r["ts_code"],
name=r["name"],
sector=r["sector"] or "",
score=r["score"] or 0,
market_temp_score=r["market_temp_score"] or 0,
sector_score=r["sector_score"] or 0,
capital_score=r["capital_score"] or 0,
technical_score=r["technical_score"] or 0,
position_score=r.get("position_score") or 50,
valuation_score=r.get("valuation_score") or 50,
signal=r["signal"] or "HOLD",
entry_price=r["entry_price"],
target_price=r["target_price"],
stop_loss=r["stop_loss"],
reasons=json.loads(r["reasons"]) if r["reasons"] else [],
llm_analysis=r.get("llm_analysis") or "",
scan_session=r["scan_session"] or "",
))
return {
"market_temp": market_temp,
"hot_sectors": [],
"capital_filtered": [],
"recommendations": recommendations,
}
except Exception as e:
logger.error(f"从数据库加载推荐失败: {e}")
return {"market_temp": None, "hot_sectors": [], "capital_filtered": [], "recommendations": []}
async def _load_sectors_from_db() -> list[SectorInfo]:
"""从数据库加载最近的板块热度数据"""
try:
async with get_db() as db:
from sqlalchemy import text
result = await db.execute(
text(
"SELECT * FROM sector_heat "
"WHERE trade_date = (SELECT MAX(trade_date) FROM sector_heat) "
"AND id IN ("
" SELECT MAX(id) FROM sector_heat "
" WHERE trade_date = (SELECT MAX(trade_date) FROM sector_heat) "
" GROUP BY sector_code"
") "
"ORDER BY heat_score DESC"
)
)
rows = result.fetchall()
sectors = []
for row in rows:
r = row._mapping
sectors.append(SectorInfo(
sector_code=r["sector_code"],
sector_name=r["sector_name"],
pct_change=r["pct_change"] or 0,
capital_inflow=r["capital_inflow"] or 0,
limit_up_count=r["limit_up_count"] or 0,
days_continuous=0,
heat_score=r["heat_score"] or 0,
))
return sectors
except Exception as e:
logger.error(f"从数据库加载板块数据失败: {e}")
return []