"""推荐引擎 管理推荐状态,提供推荐查询接口。 将筛选结果持久化并管理历史推荐。 """ import logging from datetime import datetime from app.engine.screener import run_screening from app.data.models import Recommendation, MarketTemperature, SectorInfo from app.db.database import get_db from app.db import tables from app.config import settings logger = logging.getLogger(__name__) # 内存中的最新推荐结果 _latest_result: dict | None = None async def refresh_recommendations(trade_date: str = None, scan_session: str = "manual") -> dict: """刷新推荐列表""" global _latest_result result = await run_screening(trade_date) # 给每条推荐添加 scan_session for rec in result.get("recommendations", []): rec.scan_session = scan_session rec.created_at = datetime.now() _latest_result = result # 持久化到数据库 await _save_to_db(result) # 异步 AI 深度分析(不阻塞返回) if settings.deepseek_api_key: import asyncio from app.llm.analysis_agent import analyze_recommendations asyncio.create_task(analyze_recommendations(result)) return result async def get_latest_recommendations() -> dict: """获取最新推荐结果""" if _latest_result: return _latest_result # 如果内存中没有,从数据库加载今日的 return await _load_today_from_db() async def get_latest_sectors() -> list[SectorInfo]: """获取最新的板块热度数据(只读缓存,不触发扫描)""" if _latest_result and _latest_result.get("hot_sectors"): return _latest_result["hot_sectors"] # 内存中没有,从数据库加载 return await _load_sectors_from_db() async def get_recommendation_history(days: int = 7) -> list[dict]: """获取历史推荐记录,按日期分组返回""" from datetime import timedelta import json start = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") async with get_db() as db: from sqlalchemy import text # 查询所有历史推荐,按 ts_code 去重(每天取最新一条) stmt = text( "SELECT * FROM recommendations " "WHERE created_at >= :start " "AND id IN (" " SELECT MAX(id) FROM recommendations " " WHERE created_at >= :start " " GROUP BY date(created_at), ts_code" ") " "ORDER BY created_at DESC, score DESC" ) result = await db.execute(stmt, {"start": start}) rows = result.fetchall() # 按日期分组 grouped: dict[str, list[dict]] = {} for row in rows: r = row._mapping # SQLite created_at 是字符串 "YYYY-MM-DD HH:MM:SS" ca = r["created_at"] if ca: date_str = str(ca)[:10] # 取前10字符即日期部分 created_at_str = str(ca) else: date_str = "unknown" created_at_str = None rec_dict = { "ts_code": r["ts_code"], "name": r["name"], "sector": r["sector"] or "", "score": r["score"] or 0, "level": _score_to_level_static(r["score"] or 0), "signal": r["signal"] or "HOLD", "market_temp_score": r["market_temp_score"] or 0, "sector_score": r["sector_score"] or 0, "capital_score": r["capital_score"] or 0, "technical_score": r["technical_score"] or 0, "position_score": r.get("position_score") or 50, "valuation_score": r.get("valuation_score") or 50, "entry_price": r["entry_price"], "target_price": r["target_price"], "stop_loss": r["stop_loss"], "reasons": json.loads(r["reasons"]) if r["reasons"] else [], "risk_note": "", "strategy": r.get("strategy") or "trend_breakout", "entry_signal_type": r.get("entry_signal_type") or "none", "llm_analysis": r.get("llm_analysis") or "", "llm_score": r.get("llm_score"), "scan_session": r["scan_session"] or "", "created_at": created_at_str, } if date_str not in grouped: grouped[date_str] = [] grouped[date_str].append(rec_dict) # 转为列表,按日期降序 result_list = [] for date_str in sorted(grouped.keys(), reverse=True): recs = grouped[date_str] buy_count = sum(1 for r in recs if r["signal"] == "BUY") avg_score = round(sum(r["score"] for r in recs) / len(recs), 1) if recs else 0 result_list.append({ "date": date_str, "count": len(recs), "buy_count": buy_count, "avg_score": avg_score, "recommendations": recs, }) return result_list def _score_to_level_static(score: float) -> str: """根据评分确定推荐等级""" if score >= 75: return "强烈推荐" elif score >= 60: return "推荐" elif score >= 45: return "关注" else: return "观望" async def _save_to_db(result: dict): """将推荐结果保存到数据库""" try: async with get_db() as db: from sqlalchemy import text # 保存市场温度 mt = result.get("market_temp") if mt: # 使用 INSERT OR REPLACE 确保重复扫描能更新数据 stmt = text( "INSERT OR REPLACE INTO market_temperature " "(trade_date, up_count, down_count, limit_up_count, limit_down_count, " "max_streak, broken_rate, temperature) " "VALUES (:td, :up, :down, :lu, :ld, :ms, :br, :temp)" ) await db.execute(stmt, { "td": mt.trade_date, "up": mt.up_count, "down": mt.down_count, "lu": mt.limit_up_count, "ld": mt.limit_down_count, "ms": mt.max_streak, "br": mt.broken_rate, "temp": mt.temperature, }) # 保存板块热度(先清除同一 trade_date 的旧数据,避免重复) trade_date_val = mt.trade_date if mt else "" if trade_date_val: await db.execute( text("DELETE FROM sector_heat WHERE trade_date = :td"), {"td": trade_date_val}, ) for sector in result.get("hot_sectors", []): stmt = tables.sector_heat_table.insert().values( sector_code=sector.sector_code, sector_name=sector.sector_name, pct_change=sector.pct_change, capital_inflow=sector.capital_inflow, limit_up_count=sector.limit_up_count, heat_score=sector.heat_score, trade_date=trade_date_val, ) await db.execute(stmt) # 保存推荐(先清除今日旧推荐,避免重复) today_str = datetime.now().strftime("%Y-%m-%d") await db.execute( text("DELETE FROM recommendations WHERE date(created_at) = :today"), {"today": today_str}, ) import json for rec in result.get("recommendations", []): stmt = tables.recommendations_table.insert().values( ts_code=rec.ts_code, name=rec.name, sector=rec.sector, score=rec.score, market_temp_score=rec.market_temp_score, sector_score=rec.sector_score, capital_score=rec.capital_score, technical_score=rec.technical_score, position_score=rec.position_score, valuation_score=rec.valuation_score, signal=rec.signal, entry_price=rec.entry_price, target_price=rec.target_price, stop_loss=rec.stop_loss, reasons=json.dumps(rec.reasons, ensure_ascii=False), llm_analysis=rec.llm_analysis, strategy=rec.strategy, entry_signal_type=rec.entry_signal_type, llm_score=rec.llm_score, scan_session=rec.scan_session, ) await db.execute(stmt) await db.commit() logger.info(f"已保存 {len(result.get('recommendations', []))} 条推荐到数据库") except Exception as e: logger.error(f"保存推荐到数据库失败: {e}") async def _load_today_from_db() -> dict: """从数据库加载今日推荐""" today = datetime.now().strftime("%Y-%m-%d") try: async with get_db() as db: from sqlalchemy import text import json # 加载市场温度 result = await db.execute( text("SELECT * FROM market_temperature ORDER BY created_at DESC LIMIT 1") ) mt_row = result.fetchone() market_temp = None if mt_row: m = mt_row._mapping market_temp = MarketTemperature( trade_date=m["trade_date"], up_count=m["up_count"], down_count=m["down_count"], limit_up_count=m["limit_up_count"], limit_down_count=m["limit_down_count"], max_streak=m["max_streak"], broken_rate=m["broken_rate"], temperature=m["temperature"], ) # 加载推荐(按 ts_code 去重,取最新一条) result = await db.execute( text("SELECT * FROM recommendations WHERE date(created_at) = :today " "AND id IN (SELECT MAX(id) FROM recommendations WHERE date(created_at) = :today GROUP BY ts_code) " "ORDER BY score DESC"), {"today": today} ) rows = result.fetchall() recommendations = [] for row in rows: r = row._mapping recommendations.append(Recommendation( ts_code=r["ts_code"], name=r["name"], sector=r["sector"] or "", score=r["score"] or 0, market_temp_score=r["market_temp_score"] or 0, sector_score=r["sector_score"] or 0, capital_score=r["capital_score"] or 0, technical_score=r["technical_score"] or 0, position_score=r.get("position_score") or 50, valuation_score=r.get("valuation_score") or 50, signal=r["signal"] or "HOLD", entry_price=r["entry_price"], target_price=r["target_price"], stop_loss=r["stop_loss"], reasons=json.loads(r["reasons"]) if r["reasons"] else [], llm_analysis=r.get("llm_analysis") or "", strategy=r.get("strategy") or "trend_breakout", entry_signal_type=r.get("entry_signal_type") or "none", llm_score=r.get("llm_score"), scan_session=r["scan_session"] or "", )) return { "market_temp": market_temp, "hot_sectors": [], "capital_filtered": [], "recommendations": recommendations, } except Exception as e: logger.error(f"从数据库加载推荐失败: {e}") return {"market_temp": None, "hot_sectors": [], "capital_filtered": [], "recommendations": []} async def _load_sectors_from_db() -> list[SectorInfo]: """从数据库加载最近的板块热度数据""" try: async with get_db() as db: from sqlalchemy import text result = await db.execute( text( "SELECT * FROM sector_heat " "WHERE trade_date = (SELECT MAX(trade_date) FROM sector_heat) " "AND id IN (" " SELECT MAX(id) FROM sector_heat " " WHERE trade_date = (SELECT MAX(trade_date) FROM sector_heat) " " GROUP BY sector_code" ") " "ORDER BY heat_score DESC" ) ) rows = result.fetchall() sectors = [] for row in rows: r = row._mapping sectors.append(SectorInfo( sector_code=r["sector_code"], sector_name=r["sector_name"], pct_change=r["pct_change"] or 0, capital_inflow=r["capital_inflow"] or 0, limit_up_count=r["limit_up_count"] or 0, days_continuous=0, heat_score=r["heat_score"] or 0, )) return sectors except Exception as e: logger.error(f"从数据库加载板块数据失败: {e}") return []