348 lines
14 KiB
Python
348 lines
14 KiB
Python
"""推荐引擎
|
|
|
|
管理推荐状态,提供推荐查询接口。
|
|
将筛选结果持久化并管理历史推荐。
|
|
"""
|
|
|
|
import logging
|
|
from datetime import datetime
|
|
from app.engine.screener import run_screening
|
|
from app.data.models import Recommendation, MarketTemperature, SectorInfo
|
|
from app.db.database import get_db
|
|
from app.db import tables
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# 内存中的最新推荐结果
|
|
_latest_result: dict | None = None
|
|
|
|
|
|
async def refresh_recommendations(trade_date: str = None, scan_session: str = "manual") -> dict:
|
|
"""刷新推荐列表"""
|
|
global _latest_result
|
|
|
|
result = await run_screening(trade_date)
|
|
|
|
# 给每条推荐添加 scan_session
|
|
for rec in result.get("recommendations", []):
|
|
rec.scan_session = scan_session
|
|
rec.created_at = datetime.now()
|
|
|
|
_latest_result = result
|
|
|
|
# 持久化到数据库
|
|
await _save_to_db(result)
|
|
|
|
# 异步 AI 深度分析(不阻塞返回)
|
|
if settings.deepseek_api_key:
|
|
import asyncio
|
|
from app.llm.analysis_agent import analyze_recommendations
|
|
asyncio.create_task(analyze_recommendations(result))
|
|
|
|
return result
|
|
|
|
|
|
async def get_latest_recommendations() -> dict:
|
|
"""获取最新推荐结果"""
|
|
if _latest_result:
|
|
return _latest_result
|
|
# 如果内存中没有,从数据库加载今日的
|
|
return await _load_today_from_db()
|
|
|
|
|
|
async def get_latest_sectors() -> list[SectorInfo]:
|
|
"""获取最新的板块热度数据(只读缓存,不触发扫描)"""
|
|
if _latest_result and _latest_result.get("hot_sectors"):
|
|
return _latest_result["hot_sectors"]
|
|
# 内存中没有,从数据库加载
|
|
return await _load_sectors_from_db()
|
|
|
|
|
|
async def get_recommendation_history(days: int = 7) -> list[dict]:
|
|
"""获取历史推荐记录,按日期分组返回"""
|
|
from datetime import timedelta
|
|
import json
|
|
|
|
start = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
|
|
async with get_db() as db:
|
|
from sqlalchemy import text
|
|
|
|
# 查询所有历史推荐,按 ts_code 去重(每天取最新一条)
|
|
stmt = text(
|
|
"SELECT * FROM recommendations "
|
|
"WHERE created_at >= :start "
|
|
"AND id IN ("
|
|
" SELECT MAX(id) FROM recommendations "
|
|
" WHERE created_at >= :start "
|
|
" GROUP BY date(created_at), ts_code"
|
|
") "
|
|
"ORDER BY created_at DESC, score DESC"
|
|
)
|
|
result = await db.execute(stmt, {"start": start})
|
|
rows = result.fetchall()
|
|
|
|
# 按日期分组
|
|
grouped: dict[str, list[dict]] = {}
|
|
for row in rows:
|
|
r = row._mapping
|
|
# SQLite created_at 是字符串 "YYYY-MM-DD HH:MM:SS"
|
|
ca = r["created_at"]
|
|
if ca:
|
|
date_str = str(ca)[:10] # 取前10字符即日期部分
|
|
created_at_str = str(ca)
|
|
else:
|
|
date_str = "unknown"
|
|
created_at_str = None
|
|
|
|
rec_dict = {
|
|
"ts_code": r["ts_code"],
|
|
"name": r["name"],
|
|
"sector": r["sector"] or "",
|
|
"score": r["score"] or 0,
|
|
"level": _score_to_level_static(r["score"] or 0),
|
|
"signal": r["signal"] or "HOLD",
|
|
"market_temp_score": r["market_temp_score"] or 0,
|
|
"sector_score": r["sector_score"] or 0,
|
|
"capital_score": r["capital_score"] or 0,
|
|
"technical_score": r["technical_score"] or 0,
|
|
"position_score": r.get("position_score") or 50,
|
|
"valuation_score": r.get("valuation_score") or 50,
|
|
"entry_price": r["entry_price"],
|
|
"target_price": r["target_price"],
|
|
"stop_loss": r["stop_loss"],
|
|
"reasons": json.loads(r["reasons"]) if r["reasons"] else [],
|
|
"risk_note": "",
|
|
"strategy": r.get("strategy") or "trend_breakout",
|
|
"entry_signal_type": r.get("entry_signal_type") or "none",
|
|
"llm_analysis": r.get("llm_analysis") or "",
|
|
"llm_score": r.get("llm_score"),
|
|
"scan_session": r["scan_session"] or "",
|
|
"created_at": created_at_str,
|
|
}
|
|
|
|
if date_str not in grouped:
|
|
grouped[date_str] = []
|
|
grouped[date_str].append(rec_dict)
|
|
|
|
# 转为列表,按日期降序
|
|
result_list = []
|
|
for date_str in sorted(grouped.keys(), reverse=True):
|
|
recs = grouped[date_str]
|
|
buy_count = sum(1 for r in recs if r["signal"] == "BUY")
|
|
avg_score = round(sum(r["score"] for r in recs) / len(recs), 1) if recs else 0
|
|
result_list.append({
|
|
"date": date_str,
|
|
"count": len(recs),
|
|
"buy_count": buy_count,
|
|
"avg_score": avg_score,
|
|
"recommendations": recs,
|
|
})
|
|
|
|
return result_list
|
|
|
|
|
|
def _score_to_level_static(score: float) -> str:
|
|
"""根据评分确定推荐等级"""
|
|
if score >= 75:
|
|
return "强烈推荐"
|
|
elif score >= 60:
|
|
return "推荐"
|
|
elif score >= 45:
|
|
return "关注"
|
|
else:
|
|
return "观望"
|
|
|
|
|
|
async def _save_to_db(result: dict):
|
|
"""将推荐结果保存到数据库"""
|
|
try:
|
|
async with get_db() as db:
|
|
from sqlalchemy import text
|
|
# 保存市场温度
|
|
mt = result.get("market_temp")
|
|
if mt:
|
|
# 使用 INSERT OR REPLACE 确保重复扫描能更新数据
|
|
stmt = text(
|
|
"INSERT OR REPLACE INTO market_temperature "
|
|
"(trade_date, up_count, down_count, limit_up_count, limit_down_count, "
|
|
"max_streak, broken_rate, temperature) "
|
|
"VALUES (:td, :up, :down, :lu, :ld, :ms, :br, :temp)"
|
|
)
|
|
await db.execute(stmt, {
|
|
"td": mt.trade_date,
|
|
"up": mt.up_count,
|
|
"down": mt.down_count,
|
|
"lu": mt.limit_up_count,
|
|
"ld": mt.limit_down_count,
|
|
"ms": mt.max_streak,
|
|
"br": mt.broken_rate,
|
|
"temp": mt.temperature,
|
|
})
|
|
|
|
# 保存板块热度(先清除同一 trade_date 的旧数据,避免重复)
|
|
trade_date_val = mt.trade_date if mt else ""
|
|
if trade_date_val:
|
|
await db.execute(
|
|
text("DELETE FROM sector_heat WHERE trade_date = :td"),
|
|
{"td": trade_date_val},
|
|
)
|
|
for sector in result.get("hot_sectors", []):
|
|
stmt = tables.sector_heat_table.insert().values(
|
|
sector_code=sector.sector_code,
|
|
sector_name=sector.sector_name,
|
|
pct_change=sector.pct_change,
|
|
capital_inflow=sector.capital_inflow,
|
|
limit_up_count=sector.limit_up_count,
|
|
heat_score=sector.heat_score,
|
|
trade_date=trade_date_val,
|
|
)
|
|
await db.execute(stmt)
|
|
|
|
# 保存推荐(先清除今日旧推荐,避免重复)
|
|
today_str = datetime.now().strftime("%Y-%m-%d")
|
|
await db.execute(
|
|
text("DELETE FROM recommendations WHERE date(created_at) = :today"),
|
|
{"today": today_str},
|
|
)
|
|
import json
|
|
for rec in result.get("recommendations", []):
|
|
stmt = tables.recommendations_table.insert().values(
|
|
ts_code=rec.ts_code,
|
|
name=rec.name,
|
|
sector=rec.sector,
|
|
score=rec.score,
|
|
market_temp_score=rec.market_temp_score,
|
|
sector_score=rec.sector_score,
|
|
capital_score=rec.capital_score,
|
|
technical_score=rec.technical_score,
|
|
position_score=rec.position_score,
|
|
valuation_score=rec.valuation_score,
|
|
signal=rec.signal,
|
|
entry_price=rec.entry_price,
|
|
target_price=rec.target_price,
|
|
stop_loss=rec.stop_loss,
|
|
reasons=json.dumps(rec.reasons, ensure_ascii=False),
|
|
llm_analysis=rec.llm_analysis,
|
|
strategy=rec.strategy,
|
|
entry_signal_type=rec.entry_signal_type,
|
|
llm_score=rec.llm_score,
|
|
scan_session=rec.scan_session,
|
|
)
|
|
await db.execute(stmt)
|
|
|
|
await db.commit()
|
|
logger.info(f"已保存 {len(result.get('recommendations', []))} 条推荐到数据库")
|
|
except Exception as e:
|
|
logger.error(f"保存推荐到数据库失败: {e}")
|
|
|
|
|
|
async def _load_today_from_db() -> dict:
|
|
"""从数据库加载今日推荐"""
|
|
today = datetime.now().strftime("%Y-%m-%d")
|
|
try:
|
|
async with get_db() as db:
|
|
from sqlalchemy import text
|
|
import json
|
|
|
|
# 加载市场温度(按 trade_date 取最新交易日)
|
|
result = await db.execute(
|
|
text("SELECT * FROM market_temperature ORDER BY trade_date DESC LIMIT 1")
|
|
)
|
|
mt_row = result.fetchone()
|
|
market_temp = None
|
|
if mt_row:
|
|
m = mt_row._mapping
|
|
market_temp = MarketTemperature(
|
|
trade_date=m["trade_date"],
|
|
up_count=m["up_count"],
|
|
down_count=m["down_count"],
|
|
limit_up_count=m["limit_up_count"],
|
|
limit_down_count=m["limit_down_count"],
|
|
max_streak=m["max_streak"],
|
|
broken_rate=m["broken_rate"],
|
|
temperature=m["temperature"],
|
|
)
|
|
|
|
# 加载推荐(取最近一个有数据的日期,按 ts_code 去重)
|
|
result = await db.execute(
|
|
text("SELECT * FROM recommendations "
|
|
"WHERE date(created_at) = (SELECT date(created_at) FROM recommendations ORDER BY created_at DESC LIMIT 1) "
|
|
"AND id IN (SELECT MAX(id) FROM recommendations "
|
|
" WHERE date(created_at) = (SELECT date(created_at) FROM recommendations ORDER BY created_at DESC LIMIT 1) "
|
|
" GROUP BY ts_code) "
|
|
"ORDER BY score DESC")
|
|
)
|
|
rows = result.fetchall()
|
|
recommendations = []
|
|
for row in rows:
|
|
r = row._mapping
|
|
recommendations.append(Recommendation(
|
|
ts_code=r["ts_code"],
|
|
name=r["name"],
|
|
sector=r["sector"] or "",
|
|
score=r["score"] or 0,
|
|
market_temp_score=r["market_temp_score"] or 0,
|
|
sector_score=r["sector_score"] or 0,
|
|
capital_score=r["capital_score"] or 0,
|
|
technical_score=r["technical_score"] or 0,
|
|
position_score=r.get("position_score") or 50,
|
|
valuation_score=r.get("valuation_score") or 50,
|
|
signal=r["signal"] or "HOLD",
|
|
entry_price=r["entry_price"],
|
|
target_price=r["target_price"],
|
|
stop_loss=r["stop_loss"],
|
|
reasons=json.loads(r["reasons"]) if r["reasons"] else [],
|
|
llm_analysis=r.get("llm_analysis") or "",
|
|
strategy=r.get("strategy") or "trend_breakout",
|
|
entry_signal_type=r.get("entry_signal_type") or "none",
|
|
llm_score=r.get("llm_score"),
|
|
scan_session=r["scan_session"] or "",
|
|
))
|
|
|
|
return {
|
|
"market_temp": market_temp,
|
|
"hot_sectors": [],
|
|
"capital_filtered": [],
|
|
"recommendations": recommendations,
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"从数据库加载推荐失败: {e}")
|
|
return {"market_temp": None, "hot_sectors": [], "capital_filtered": [], "recommendations": []}
|
|
|
|
|
|
async def _load_sectors_from_db() -> list[SectorInfo]:
|
|
"""从数据库加载最近的板块热度数据"""
|
|
try:
|
|
async with get_db() as db:
|
|
from sqlalchemy import text
|
|
result = await db.execute(
|
|
text(
|
|
"SELECT * FROM sector_heat "
|
|
"WHERE trade_date = (SELECT MAX(trade_date) FROM sector_heat) "
|
|
"AND id IN ("
|
|
" SELECT MAX(id) FROM sector_heat "
|
|
" WHERE trade_date = (SELECT MAX(trade_date) FROM sector_heat) "
|
|
" GROUP BY sector_code"
|
|
") "
|
|
"ORDER BY heat_score DESC"
|
|
)
|
|
)
|
|
rows = result.fetchall()
|
|
sectors = []
|
|
for row in rows:
|
|
r = row._mapping
|
|
sectors.append(SectorInfo(
|
|
sector_code=r["sector_code"],
|
|
sector_name=r["sector_name"],
|
|
pct_change=r["pct_change"] or 0,
|
|
capital_inflow=r["capital_inflow"] or 0,
|
|
limit_up_count=r["limit_up_count"] or 0,
|
|
days_continuous=0,
|
|
heat_score=r["heat_score"] or 0,
|
|
))
|
|
return sectors
|
|
except Exception as e:
|
|
logger.error(f"从数据库加载板块数据失败: {e}")
|
|
return []
|