"""推荐引擎 管理推荐状态,提供推荐查询接口。 将筛选结果持久化并管理历史推荐。 """ import logging import json import asyncio from datetime import datetime from app.engine.screener import run_screening from app.data.models import Recommendation, MarketTemperature, SectorInfo from app.db.database import get_db from app.db import tables logger = logging.getLogger(__name__) # 内存中的最新推荐结果 _latest_result: dict | None = None # 扫描锁:防止同时触发两次扫描 _scan_lock = asyncio.Lock() _scan_running = False async def refresh_recommendations(trade_date: str = None, scan_session: str = "manual") -> dict: """刷新推荐列表(带扫描锁防止并发)""" global _latest_result, _scan_running if _scan_lock.locked(): logger.warning("扫描已在执行中,跳过本次触发") return _latest_result or {"market_temp": None, "hot_sectors": [], "recommendations": []} async with _scan_lock: _scan_running = True try: result = await run_screening(trade_date) # 给每条推荐添加 scan_session for rec in result.get("recommendations", []): rec.scan_session = scan_session rec.created_at = datetime.now() _latest_result = result # 持久化到数据库 await _save_to_db(result) # 更新历史推荐跟踪(检查之前推荐的后续表现) await _update_tracking() return result finally: _scan_running = False async def _update_tracking(): """更新历史推荐的跟踪数据""" try: from sqlalchemy import text from app.data.tushare_client import tushare_client trade_date = tushare_client.get_latest_trade_date() async with get_db() as db: # 查找所有活跃的推荐(有 entry_price 且未被标记为 closed) result = await db.execute( text( "SELECT id, ts_code, entry_price, target_price, stop_loss " "FROM recommendations " "WHERE entry_price IS NOT NULL " "AND entry_price > 0 " "AND id NOT IN (SELECT DISTINCT recommendation_id FROM recommendation_tracking WHERE status = 'closed') " "AND date(created_at) <= date(:today) " "ORDER BY created_at DESC LIMIT 50" ), {"today": datetime.now().strftime("%Y-%m-%d")}, ) rows = result.fetchall() if not rows: return # 获取这些股票的今日收盘价 codes = [r[1] for r in rows] daily_all = tushare_client.get_daily_all(trade_date) price_map = {} if not daily_all.empty: for _, row in daily_all.iterrows(): if row["ts_code"] in codes: price_map[row["ts_code"]] = row["close"] tracked = 0 for r in rows: rec_id, ts_code, entry_price, target_price, stop_loss = r current_price = price_map.get(ts_code) if current_price is None or entry_price is None or entry_price <= 0: continue pct = round((current_price - entry_price) / entry_price * 100, 2) hit_target = target_price and current_price >= target_price hit_stop = stop_loss and current_price <= stop_loss status = "closed" if (hit_target or hit_stop) else "active" # 检查今天是否已经跟踪过 exists = await db.execute( text( "SELECT id FROM recommendation_tracking " "WHERE recommendation_id = :rid AND track_date = :td" ), {"rid": rec_id, "td": trade_date}, ) if exists.fetchone(): continue await db.execute( tables.recommendation_tracking_table.insert().values( recommendation_id=rec_id, track_date=trade_date, current_price=current_price, pct_from_entry=pct, hit_target=hit_target, hit_stop_loss=hit_stop, status=status, ) ) tracked += 1 await db.commit() if tracked > 0: logger.info(f"已更新 {tracked} 条推荐跟踪记录") except Exception as e: logger.error(f"更新推荐跟踪失败: {e}") async def get_performance_stats() -> dict: """获取推荐胜率统计""" try: from sqlalchemy import text async with get_db() as db: # 总推荐数 result = await db.execute( text("SELECT COUNT(DISTINCT id) FROM recommendations") ) total = result.scalar() or 0 # 有跟踪记录的推荐 result = await db.execute( text( "SELECT COUNT(DISTINCT r.id) FROM recommendations r " "INNER JOIN recommendation_tracking t ON t.recommendation_id = r.id" ) ) tracked = result.scalar() or 0 # 胜率基于最新跟踪日的最终 pct(正值=盈利,负值=亏损) result = await db.execute( text( "SELECT COUNT(*) FROM (" " SELECT t.recommendation_id, t.pct_from_entry as latest_pct " " FROM recommendation_tracking t " " INNER JOIN (" " SELECT recommendation_id, MAX(id) as max_id " " FROM recommendation_tracking GROUP BY recommendation_id" " ) latest ON t.id = latest.max_id" ") WHERE latest_pct > 0" ) ) winning = result.scalar() or 0 # 平均收益(基于最新跟踪日的 pct) result = await db.execute( text( "SELECT AVG(latest_pct) FROM (" " SELECT t.pct_from_entry as latest_pct " " FROM recommendation_tracking t " " INNER JOIN (" " SELECT recommendation_id, MAX(id) as max_id " " FROM recommendation_tracking GROUP BY recommendation_id" " ) latest ON t.id = latest.max_id" ")" ) ) avg_return = result.scalar() avg_return = round(float(avg_return), 2) if avg_return else 0 # 达到目标价的推荐 result = await db.execute( text( "SELECT COUNT(DISTINCT recommendation_id) FROM recommendation_tracking " "WHERE hit_target = 1" ) ) hit_target_count = result.scalar() or 0 # 触发止损的推荐 result = await db.execute( text( "SELECT COUNT(DISTINCT recommendation_id) FROM recommendation_tracking " "WHERE hit_stop_loss = 1" ) ) hit_stop_count = result.scalar() or 0 # 最近跟踪的推荐详情 result = await db.execute( text( "SELECT r.ts_code, r.name, r.signal, r.entry_price, " " r.target_price, r.stop_loss, r.entry_signal_type, r.score, " " t.pct_from_entry, t.current_price, t.track_date, t.hit_target, t.hit_stop_loss, " " r.created_at " "FROM recommendations r " "INNER JOIN recommendation_tracking t ON t.recommendation_id = r.id " "INNER JOIN (" " SELECT recommendation_id, MAX(id) as max_id " " FROM recommendation_tracking GROUP BY recommendation_id" ") latest ON t.id = latest.max_id " "ORDER BY r.created_at DESC LIMIT 20" ) ) details = [] for row in result.fetchall(): r = row._mapping details.append({ "ts_code": r["ts_code"], "name": r["name"], "signal": r["signal"], "entry_signal_type": r["entry_signal_type"], "score": r["score"], "entry_price": r["entry_price"], "target_price": r["target_price"], "stop_loss": r["stop_loss"], "current_price": r["current_price"], "pct_from_entry": r["pct_from_entry"], "track_date": r["track_date"], "hit_target": bool(r["hit_target"]), "hit_stop_loss": bool(r["hit_stop_loss"]), "created_at": str(r["created_at"])[:10] if r["created_at"] else "", }) win_rate = round(winning / tracked * 100, 1) if tracked > 0 else 0 return { "total_recommendations": total, "tracked": tracked, "winning": winning, "win_rate": win_rate, "avg_return": avg_return, "hit_target_count": hit_target_count, "hit_stop_count": hit_stop_count, "details": details, } except Exception as e: logger.error(f"获取胜率统计失败: {e}") return { "total_recommendations": 0, "tracked": 0, "winning": 0, "win_rate": 0, "avg_return": 0, "hit_target_count": 0, "hit_stop_count": 0, "details": [], } async def get_latest_recommendations() -> dict: """获取最新推荐结果""" if _latest_result: return _latest_result # 如果内存中没有,从数据库加载今日的 return await _load_today_from_db() async def get_latest_sectors() -> list[SectorInfo]: """获取最新的板块热度数据(只读缓存,不触发扫描)""" if _latest_result and _latest_result.get("hot_sectors"): return _latest_result["hot_sectors"] # 内存中没有,从数据库加载 return await _load_sectors_from_db() async def get_recommendation_history(days: int = 7) -> list[dict]: """获取历史推荐记录,按日期分组返回""" from datetime import timedelta import json start = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d") async with get_db() as db: from sqlalchemy import text # 查询所有历史推荐,按 ts_code 去重(每天取最新一条) stmt = text( "SELECT * FROM recommendations " "WHERE created_at >= :start " "AND id IN (" " SELECT MAX(id) FROM recommendations " " WHERE created_at >= :start " " GROUP BY date(created_at), ts_code" ") " "ORDER BY created_at DESC, score DESC" ) result = await db.execute(stmt, {"start": start}) rows = result.fetchall() # 按日期分组 grouped: dict[str, list[dict]] = {} for row in rows: r = row._mapping # SQLite created_at 是字符串 "YYYY-MM-DD HH:MM:SS" ca = r["created_at"] if ca: date_str = str(ca)[:10] # 取前10字符即日期部分 created_at_str = str(ca) else: date_str = "unknown" created_at_str = None rec_dict = { "ts_code": r["ts_code"], "name": r["name"], "sector": r["sector"] or "", "score": r["score"] or 0, "level": _score_to_level_static(r["score"] or 0), "signal": r["signal"] or "HOLD", "market_temp_score": r["market_temp_score"] or 0, "sector_score": r["sector_score"] or 0, "capital_score": r["capital_score"] or 0, "technical_score": r["technical_score"] or 0, "supply_demand_score": r.get("supply_demand_score") or 0, "price_action_score": r.get("price_action_score") or 0, "position_score": r.get("position_score") or 50, "valuation_score": r.get("valuation_score") or 50, "entry_price": r["entry_price"], "target_price": r["target_price"], "stop_loss": r["stop_loss"], "reasons": json.loads(r["reasons"]) if r["reasons"] else [], "risk_note": "", "strategy": r.get("strategy") or "trend_breakout", "entry_signal_type": r.get("entry_signal_type") or "none", "llm_analysis": r.get("llm_analysis") or "", "llm_score": r.get("llm_score"), "scan_session": r["scan_session"] or "", "created_at": created_at_str, } if date_str not in grouped: grouped[date_str] = [] grouped[date_str].append(rec_dict) # 转为列表,按日期降序 result_list = [] for date_str in sorted(grouped.keys(), reverse=True): recs = grouped[date_str] buy_count = sum(1 for r in recs if r["signal"] == "BUY") avg_score = round(sum(r["score"] for r in recs) / len(recs), 1) if recs else 0 result_list.append({ "date": date_str, "count": len(recs), "buy_count": buy_count, "avg_score": avg_score, "recommendations": recs, }) return result_list def _score_to_level_static(score: float) -> str: """根据评分确定推荐等级""" if score >= 75: return "强烈推荐" elif score >= 60: return "推荐" elif score >= 45: return "关注" else: return "观望" async def _save_to_db(result: dict): """将推荐结果保存到数据库""" try: async with get_db() as db: from sqlalchemy import text # 保存市场温度 mt = result.get("market_temp") if mt: # 使用 INSERT OR REPLACE 确保重复扫描能更新数据 stmt = text( "INSERT OR REPLACE INTO market_temperature " "(trade_date, up_count, down_count, limit_up_count, limit_down_count, " "max_streak, broken_rate, temperature) " "VALUES (:td, :up, :down, :lu, :ld, :ms, :br, :temp)" ) await db.execute(stmt, { "td": mt.trade_date, "up": mt.up_count, "down": mt.down_count, "lu": mt.limit_up_count, "ld": mt.limit_down_count, "ms": mt.max_streak, "br": mt.broken_rate, "temp": mt.temperature, }) # 保存板块热度(先清除同一 trade_date 的旧数据,避免重复) trade_date_val = mt.trade_date if mt else "" if trade_date_val: await db.execute( text("DELETE FROM sector_heat WHERE trade_date = :td"), {"td": trade_date_val}, ) for sector in result.get("hot_sectors", []): stmt = tables.sector_heat_table.insert().values( sector_code=sector.sector_code, sector_name=sector.sector_name, pct_change=sector.pct_change, capital_inflow=sector.capital_inflow, limit_up_count=sector.limit_up_count, heat_score=sector.heat_score, stage=sector.stage, days_continuous=sector.days_continuous, member_count=sector.member_count, leading_stocks=json.dumps(sector.leading_stocks, ensure_ascii=False), pct_trend=json.dumps(sector.pct_trend, ensure_ascii=False), turnover_avg=sector.turnover_avg, main_force_ratio=sector.main_force_ratio, trade_date=trade_date_val, ) await db.execute(stmt) # 保存推荐(按 ts_code 清除当日旧记录,避免同一天多次扫描产生重复) today_str = datetime.now().strftime("%Y-%m-%d") for rec in result.get("recommendations", []): await db.execute( text("DELETE FROM recommendations WHERE date(created_at) = :today AND ts_code = :code"), {"today": today_str, "code": rec.ts_code}, ) stmt = tables.recommendations_table.insert().values( ts_code=rec.ts_code, name=rec.name, sector=rec.sector, score=rec.score, market_temp_score=rec.market_temp_score, sector_score=rec.sector_score, capital_score=rec.capital_score, technical_score=rec.technical_score, supply_demand_score=rec.supply_demand_score, price_action_score=rec.price_action_score, position_score=rec.position_score, valuation_score=rec.valuation_score, signal=rec.signal, entry_price=rec.entry_price, target_price=rec.target_price, stop_loss=rec.stop_loss, reasons=json.dumps(rec.reasons, ensure_ascii=False), llm_analysis=rec.llm_analysis, strategy=rec.strategy, entry_signal_type=rec.entry_signal_type, llm_score=rec.llm_score, scan_session=rec.scan_session, ) await db.execute(stmt) await db.commit() logger.info(f"已保存 {len(result.get('recommendations', []))} 条推荐到数据库") except Exception as e: logger.error(f"保存推荐到数据库失败: {e}") async def _load_today_from_db() -> dict: """从数据库加载今日推荐""" today = datetime.now().strftime("%Y-%m-%d") try: async with get_db() as db: from sqlalchemy import text import json # 加载市场温度(按 trade_date 取最新交易日) result = await db.execute( text("SELECT * FROM market_temperature ORDER BY trade_date DESC LIMIT 1") ) mt_row = result.fetchone() market_temp = None if mt_row: m = mt_row._mapping market_temp = MarketTemperature( trade_date=m["trade_date"], up_count=m["up_count"], down_count=m["down_count"], limit_up_count=m["limit_up_count"], limit_down_count=m["limit_down_count"], max_streak=m["max_streak"], broken_rate=m["broken_rate"], temperature=m["temperature"], ) # 加载推荐(取最近一个有数据的日期,按 ts_code 去重,只取 >= 60 分) result = await db.execute( text("SELECT * FROM recommendations " "WHERE date(created_at) = (SELECT date(created_at) FROM recommendations ORDER BY created_at DESC LIMIT 1) " "AND score >= 60 " "AND id IN (SELECT MAX(id) FROM recommendations " " WHERE date(created_at) = (SELECT date(created_at) FROM recommendations ORDER BY created_at DESC LIMIT 1) " " GROUP BY ts_code) " "ORDER BY score DESC") ) rows = result.fetchall() recommendations = [] for row in rows: r = row._mapping recommendations.append(Recommendation( ts_code=r["ts_code"], name=r["name"], sector=r["sector"] or "", score=r["score"] or 0, market_temp_score=r["market_temp_score"] or 0, sector_score=r["sector_score"] or 0, capital_score=r["capital_score"] or 0, technical_score=r["technical_score"] or 0, supply_demand_score=r.get("supply_demand_score") or 0, price_action_score=r.get("price_action_score") or 0, position_score=r.get("position_score") or 50, valuation_score=r.get("valuation_score") or 50, signal=r["signal"] or "HOLD", entry_price=r["entry_price"], target_price=r["target_price"], stop_loss=r["stop_loss"], reasons=json.loads(r["reasons"]) if r["reasons"] else [], llm_analysis=r.get("llm_analysis") or "", strategy=r.get("strategy") or "trend_breakout", entry_signal_type=r.get("entry_signal_type") or "none", llm_score=r.get("llm_score"), scan_session=r["scan_session"] or "", )) return { "market_temp": market_temp, "hot_sectors": [], "capital_filtered": [], "recommendations": recommendations, } except Exception as e: logger.error(f"从数据库加载推荐失败: {e}") return {"market_temp": None, "hot_sectors": [], "capital_filtered": [], "recommendations": []} async def _load_sectors_from_db() -> list[SectorInfo]: """从数据库加载最近的板块热度数据""" try: async with get_db() as db: from sqlalchemy import text result = await db.execute( text( "SELECT * FROM sector_heat " "WHERE trade_date = (SELECT MAX(trade_date) FROM sector_heat) " "AND id IN (" " SELECT MAX(id) FROM sector_heat " " WHERE trade_date = (SELECT MAX(trade_date) FROM sector_heat) " " GROUP BY sector_code" ") " "ORDER BY heat_score DESC" ) ) rows = result.fetchall() sectors = [] for row in rows: r = row._mapping # Parse JSON fields with fallback leading_stocks = json.loads(r.get("leading_stocks") or "[]") pct_trend = json.loads(r.get("pct_trend") or "[]") sectors.append(SectorInfo( sector_code=r["sector_code"], sector_name=r["sector_name"], pct_change=r["pct_change"] or 0, capital_inflow=r["capital_inflow"] or 0, limit_up_count=r["limit_up_count"] or 0, days_continuous=r.get("days_continuous") or 0, heat_score=r["heat_score"] or 0, stage=r.get("stage") or "mid", member_count=r.get("member_count") or 0, leading_stocks=leading_stocks, pct_trend=pct_trend, turnover_avg=r.get("turnover_avg") or 0, main_force_ratio=r.get("main_force_ratio") or 0, )) return sectors except Exception as e: logger.error(f"从数据库加载板块数据失败: {e}") return []