586 lines
24 KiB
Python
586 lines
24 KiB
Python
"""推荐引擎
|
||
|
||
管理推荐状态,提供推荐查询接口。
|
||
将筛选结果持久化并管理历史推荐。
|
||
"""
|
||
|
||
import logging
|
||
import json
|
||
import asyncio
|
||
from datetime import datetime
|
||
from app.engine.screener import run_screening
|
||
from app.data.models import Recommendation, MarketTemperature, SectorInfo
|
||
from app.db.database import get_db
|
||
from app.db import tables
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 内存中的最新推荐结果
|
||
_latest_result: dict | None = None
|
||
|
||
# 扫描锁:防止同时触发两次扫描
|
||
_scan_lock = asyncio.Lock()
|
||
_scan_running = False
|
||
|
||
|
||
async def refresh_recommendations(trade_date: str = None, scan_session: str = "manual") -> dict:
|
||
"""刷新推荐列表(带扫描锁防止并发)"""
|
||
global _latest_result, _scan_running
|
||
|
||
if _scan_lock.locked():
|
||
logger.warning("扫描已在执行中,跳过本次触发")
|
||
return _latest_result or {"market_temp": None, "hot_sectors": [], "recommendations": []}
|
||
|
||
async with _scan_lock:
|
||
_scan_running = True
|
||
try:
|
||
result = await run_screening(trade_date)
|
||
|
||
# 给每条推荐添加 scan_session
|
||
for rec in result.get("recommendations", []):
|
||
rec.scan_session = scan_session
|
||
rec.created_at = datetime.now()
|
||
|
||
_latest_result = result
|
||
|
||
# 持久化到数据库
|
||
await _save_to_db(result)
|
||
|
||
# 更新历史推荐跟踪(检查之前推荐的后续表现)
|
||
await _update_tracking()
|
||
|
||
return result
|
||
finally:
|
||
_scan_running = False
|
||
|
||
|
||
async def _update_tracking():
|
||
"""更新历史推荐的跟踪数据"""
|
||
try:
|
||
from sqlalchemy import text
|
||
from app.data.tushare_client import tushare_client
|
||
|
||
trade_date = tushare_client.get_latest_trade_date()
|
||
|
||
async with get_db() as db:
|
||
# 查找所有活跃的推荐(有 entry_price 且未被标记为 closed)
|
||
result = await db.execute(
|
||
text(
|
||
"SELECT id, ts_code, entry_price, target_price, stop_loss "
|
||
"FROM recommendations "
|
||
"WHERE entry_price IS NOT NULL "
|
||
"AND entry_price > 0 "
|
||
"AND id NOT IN (SELECT DISTINCT recommendation_id FROM recommendation_tracking WHERE status = 'closed') "
|
||
"AND date(created_at) <= date(:today) "
|
||
"ORDER BY created_at DESC LIMIT 50"
|
||
),
|
||
{"today": datetime.now().strftime("%Y-%m-%d")},
|
||
)
|
||
rows = result.fetchall()
|
||
|
||
if not rows:
|
||
return
|
||
|
||
# 获取这些股票的今日收盘价
|
||
codes = [r[1] for r in rows]
|
||
daily_all = tushare_client.get_daily_all(trade_date)
|
||
price_map = {}
|
||
if not daily_all.empty:
|
||
for _, row in daily_all.iterrows():
|
||
if row["ts_code"] in codes:
|
||
price_map[row["ts_code"]] = row["close"]
|
||
|
||
tracked = 0
|
||
for r in rows:
|
||
rec_id, ts_code, entry_price, target_price, stop_loss = r
|
||
current_price = price_map.get(ts_code)
|
||
if current_price is None or entry_price is None or entry_price <= 0:
|
||
continue
|
||
|
||
pct = round((current_price - entry_price) / entry_price * 100, 2)
|
||
hit_target = target_price and current_price >= target_price
|
||
hit_stop = stop_loss and current_price <= stop_loss
|
||
status = "closed" if (hit_target or hit_stop) else "active"
|
||
|
||
# 检查今天是否已经跟踪过
|
||
exists = await db.execute(
|
||
text(
|
||
"SELECT id FROM recommendation_tracking "
|
||
"WHERE recommendation_id = :rid AND track_date = :td"
|
||
),
|
||
{"rid": rec_id, "td": trade_date},
|
||
)
|
||
if exists.fetchone():
|
||
continue
|
||
|
||
await db.execute(
|
||
tables.recommendation_tracking_table.insert().values(
|
||
recommendation_id=rec_id,
|
||
track_date=trade_date,
|
||
current_price=current_price,
|
||
pct_from_entry=pct,
|
||
hit_target=hit_target,
|
||
hit_stop_loss=hit_stop,
|
||
status=status,
|
||
)
|
||
)
|
||
tracked += 1
|
||
|
||
await db.commit()
|
||
if tracked > 0:
|
||
logger.info(f"已更新 {tracked} 条推荐跟踪记录")
|
||
except Exception as e:
|
||
logger.error(f"更新推荐跟踪失败: {e}")
|
||
|
||
|
||
async def get_performance_stats() -> dict:
|
||
"""获取推荐胜率统计"""
|
||
try:
|
||
from sqlalchemy import text
|
||
async with get_db() as db:
|
||
# 总推荐数
|
||
result = await db.execute(
|
||
text("SELECT COUNT(DISTINCT id) FROM recommendations")
|
||
)
|
||
total = result.scalar() or 0
|
||
|
||
# 有跟踪记录的推荐
|
||
result = await db.execute(
|
||
text(
|
||
"SELECT COUNT(DISTINCT r.id) FROM recommendations r "
|
||
"INNER JOIN recommendation_tracking t ON t.recommendation_id = r.id"
|
||
)
|
||
)
|
||
tracked = result.scalar() or 0
|
||
|
||
# 胜率基于最新跟踪日的最终 pct(正值=盈利,负值=亏损)
|
||
result = await db.execute(
|
||
text(
|
||
"SELECT COUNT(*) FROM ("
|
||
" SELECT t.recommendation_id, t.pct_from_entry as latest_pct "
|
||
" FROM recommendation_tracking t "
|
||
" INNER JOIN ("
|
||
" SELECT recommendation_id, MAX(id) as max_id "
|
||
" FROM recommendation_tracking GROUP BY recommendation_id"
|
||
" ) latest ON t.id = latest.max_id"
|
||
") WHERE latest_pct > 0"
|
||
)
|
||
)
|
||
winning = result.scalar() or 0
|
||
|
||
# 平均收益(基于最新跟踪日的 pct)
|
||
result = await db.execute(
|
||
text(
|
||
"SELECT AVG(latest_pct) FROM ("
|
||
" SELECT t.pct_from_entry as latest_pct "
|
||
" FROM recommendation_tracking t "
|
||
" INNER JOIN ("
|
||
" SELECT recommendation_id, MAX(id) as max_id "
|
||
" FROM recommendation_tracking GROUP BY recommendation_id"
|
||
" ) latest ON t.id = latest.max_id"
|
||
")"
|
||
)
|
||
)
|
||
avg_return = result.scalar()
|
||
avg_return = round(float(avg_return), 2) if avg_return else 0
|
||
|
||
# 达到目标价的推荐
|
||
result = await db.execute(
|
||
text(
|
||
"SELECT COUNT(DISTINCT recommendation_id) FROM recommendation_tracking "
|
||
"WHERE hit_target = 1"
|
||
)
|
||
)
|
||
hit_target_count = result.scalar() or 0
|
||
|
||
# 触发止损的推荐
|
||
result = await db.execute(
|
||
text(
|
||
"SELECT COUNT(DISTINCT recommendation_id) FROM recommendation_tracking "
|
||
"WHERE hit_stop_loss = 1"
|
||
)
|
||
)
|
||
hit_stop_count = result.scalar() or 0
|
||
|
||
# 最近跟踪的推荐详情
|
||
result = await db.execute(
|
||
text(
|
||
"SELECT r.ts_code, r.name, r.signal, r.entry_price, "
|
||
" r.target_price, r.stop_loss, r.entry_signal_type, r.score, "
|
||
" t.pct_from_entry, t.current_price, t.track_date, t.hit_target, t.hit_stop_loss, "
|
||
" r.created_at "
|
||
"FROM recommendations r "
|
||
"INNER JOIN recommendation_tracking t ON t.recommendation_id = r.id "
|
||
"INNER JOIN ("
|
||
" SELECT recommendation_id, MAX(id) as max_id "
|
||
" FROM recommendation_tracking GROUP BY recommendation_id"
|
||
") latest ON t.id = latest.max_id "
|
||
"ORDER BY r.created_at DESC LIMIT 20"
|
||
)
|
||
)
|
||
details = []
|
||
for row in result.fetchall():
|
||
r = row._mapping
|
||
details.append({
|
||
"ts_code": r["ts_code"],
|
||
"name": r["name"],
|
||
"signal": r["signal"],
|
||
"entry_signal_type": r["entry_signal_type"],
|
||
"score": r["score"],
|
||
"entry_price": r["entry_price"],
|
||
"target_price": r["target_price"],
|
||
"stop_loss": r["stop_loss"],
|
||
"current_price": r["current_price"],
|
||
"pct_from_entry": r["pct_from_entry"],
|
||
"track_date": r["track_date"],
|
||
"hit_target": bool(r["hit_target"]),
|
||
"hit_stop_loss": bool(r["hit_stop_loss"]),
|
||
"created_at": str(r["created_at"])[:10] if r["created_at"] else "",
|
||
})
|
||
|
||
win_rate = round(winning / tracked * 100, 1) if tracked > 0 else 0
|
||
|
||
return {
|
||
"total_recommendations": total,
|
||
"tracked": tracked,
|
||
"winning": winning,
|
||
"win_rate": win_rate,
|
||
"avg_return": avg_return,
|
||
"hit_target_count": hit_target_count,
|
||
"hit_stop_count": hit_stop_count,
|
||
"details": details,
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"获取胜率统计失败: {e}")
|
||
return {
|
||
"total_recommendations": 0, "tracked": 0, "winning": 0,
|
||
"win_rate": 0, "avg_return": 0, "hit_target_count": 0,
|
||
"hit_stop_count": 0, "details": [],
|
||
}
|
||
|
||
|
||
async def get_latest_recommendations() -> dict:
|
||
"""获取最新推荐结果"""
|
||
if _latest_result:
|
||
return _latest_result
|
||
# 如果内存中没有,从数据库加载今日的
|
||
return await _load_today_from_db()
|
||
|
||
|
||
async def get_latest_sectors() -> list[SectorInfo]:
|
||
"""获取最新的板块热度数据(只读缓存,不触发扫描)"""
|
||
if _latest_result and _latest_result.get("hot_sectors"):
|
||
return _latest_result["hot_sectors"]
|
||
# 内存中没有,从数据库加载
|
||
return await _load_sectors_from_db()
|
||
|
||
|
||
async def get_recommendation_history(days: int = 7) -> list[dict]:
|
||
"""获取历史推荐记录,按日期分组返回"""
|
||
from datetime import timedelta
|
||
import json
|
||
|
||
start = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
|
||
async with get_db() as db:
|
||
from sqlalchemy import text
|
||
|
||
# 查询所有历史推荐,按 ts_code 去重(每天取最新一条)
|
||
stmt = text(
|
||
"SELECT * FROM recommendations "
|
||
"WHERE created_at >= :start "
|
||
"AND id IN ("
|
||
" SELECT MAX(id) FROM recommendations "
|
||
" WHERE created_at >= :start "
|
||
" GROUP BY date(created_at), ts_code"
|
||
") "
|
||
"ORDER BY created_at DESC, score DESC"
|
||
)
|
||
result = await db.execute(stmt, {"start": start})
|
||
rows = result.fetchall()
|
||
|
||
# 按日期分组
|
||
grouped: dict[str, list[dict]] = {}
|
||
for row in rows:
|
||
r = row._mapping
|
||
# SQLite created_at 是字符串 "YYYY-MM-DD HH:MM:SS"
|
||
ca = r["created_at"]
|
||
if ca:
|
||
date_str = str(ca)[:10] # 取前10字符即日期部分
|
||
created_at_str = str(ca)
|
||
else:
|
||
date_str = "unknown"
|
||
created_at_str = None
|
||
|
||
rec_dict = {
|
||
"ts_code": r["ts_code"],
|
||
"name": r["name"],
|
||
"sector": r["sector"] or "",
|
||
"score": r["score"] or 0,
|
||
"level": _score_to_level_static(r["score"] or 0),
|
||
"signal": r["signal"] or "HOLD",
|
||
"market_temp_score": r["market_temp_score"] or 0,
|
||
"sector_score": r["sector_score"] or 0,
|
||
"capital_score": r["capital_score"] or 0,
|
||
"technical_score": r["technical_score"] or 0,
|
||
"supply_demand_score": r.get("supply_demand_score") or 0,
|
||
"price_action_score": r.get("price_action_score") or 0,
|
||
"position_score": r.get("position_score") or 50,
|
||
"valuation_score": r.get("valuation_score") or 50,
|
||
"entry_price": r["entry_price"],
|
||
"target_price": r["target_price"],
|
||
"stop_loss": r["stop_loss"],
|
||
"reasons": json.loads(r["reasons"]) if r["reasons"] else [],
|
||
"risk_note": "",
|
||
"strategy": r.get("strategy") or "trend_breakout",
|
||
"entry_signal_type": r.get("entry_signal_type") or "none",
|
||
"llm_analysis": r.get("llm_analysis") or "",
|
||
"llm_score": r.get("llm_score"),
|
||
"scan_session": r["scan_session"] or "",
|
||
"created_at": created_at_str,
|
||
}
|
||
|
||
if date_str not in grouped:
|
||
grouped[date_str] = []
|
||
grouped[date_str].append(rec_dict)
|
||
|
||
# 转为列表,按日期降序
|
||
result_list = []
|
||
for date_str in sorted(grouped.keys(), reverse=True):
|
||
recs = grouped[date_str]
|
||
buy_count = sum(1 for r in recs if r["signal"] == "BUY")
|
||
avg_score = round(sum(r["score"] for r in recs) / len(recs), 1) if recs else 0
|
||
result_list.append({
|
||
"date": date_str,
|
||
"count": len(recs),
|
||
"buy_count": buy_count,
|
||
"avg_score": avg_score,
|
||
"recommendations": recs,
|
||
})
|
||
|
||
return result_list
|
||
|
||
|
||
def _score_to_level_static(score: float) -> str:
|
||
"""根据评分确定推荐等级"""
|
||
if score >= 75:
|
||
return "强烈推荐"
|
||
elif score >= 60:
|
||
return "推荐"
|
||
elif score >= 45:
|
||
return "关注"
|
||
else:
|
||
return "观望"
|
||
|
||
|
||
async def _save_to_db(result: dict):
|
||
"""将推荐结果保存到数据库"""
|
||
try:
|
||
async with get_db() as db:
|
||
from sqlalchemy import text
|
||
# 保存市场温度
|
||
mt = result.get("market_temp")
|
||
if mt:
|
||
# 使用 INSERT OR REPLACE 确保重复扫描能更新数据
|
||
stmt = text(
|
||
"INSERT OR REPLACE INTO market_temperature "
|
||
"(trade_date, up_count, down_count, limit_up_count, limit_down_count, "
|
||
"max_streak, broken_rate, temperature) "
|
||
"VALUES (:td, :up, :down, :lu, :ld, :ms, :br, :temp)"
|
||
)
|
||
await db.execute(stmt, {
|
||
"td": mt.trade_date,
|
||
"up": mt.up_count,
|
||
"down": mt.down_count,
|
||
"lu": mt.limit_up_count,
|
||
"ld": mt.limit_down_count,
|
||
"ms": mt.max_streak,
|
||
"br": mt.broken_rate,
|
||
"temp": mt.temperature,
|
||
})
|
||
|
||
# 保存板块热度(先清除同一 trade_date 的旧数据,避免重复)
|
||
trade_date_val = mt.trade_date if mt else ""
|
||
if trade_date_val:
|
||
await db.execute(
|
||
text("DELETE FROM sector_heat WHERE trade_date = :td"),
|
||
{"td": trade_date_val},
|
||
)
|
||
for sector in result.get("hot_sectors", []):
|
||
stmt = tables.sector_heat_table.insert().values(
|
||
sector_code=sector.sector_code,
|
||
sector_name=sector.sector_name,
|
||
pct_change=sector.pct_change,
|
||
capital_inflow=sector.capital_inflow,
|
||
limit_up_count=sector.limit_up_count,
|
||
heat_score=sector.heat_score,
|
||
stage=sector.stage,
|
||
days_continuous=sector.days_continuous,
|
||
member_count=sector.member_count,
|
||
leading_stocks=json.dumps(sector.leading_stocks, ensure_ascii=False),
|
||
pct_trend=json.dumps(sector.pct_trend, ensure_ascii=False),
|
||
turnover_avg=sector.turnover_avg,
|
||
main_force_ratio=sector.main_force_ratio,
|
||
trade_date=trade_date_val,
|
||
)
|
||
await db.execute(stmt)
|
||
|
||
# 保存推荐(按 ts_code 清除当日旧记录,避免同一天多次扫描产生重复)
|
||
today_str = datetime.now().strftime("%Y-%m-%d")
|
||
for rec in result.get("recommendations", []):
|
||
await db.execute(
|
||
text("DELETE FROM recommendations WHERE date(created_at) = :today AND ts_code = :code"),
|
||
{"today": today_str, "code": rec.ts_code},
|
||
)
|
||
stmt = tables.recommendations_table.insert().values(
|
||
ts_code=rec.ts_code,
|
||
name=rec.name,
|
||
sector=rec.sector,
|
||
score=rec.score,
|
||
market_temp_score=rec.market_temp_score,
|
||
sector_score=rec.sector_score,
|
||
capital_score=rec.capital_score,
|
||
technical_score=rec.technical_score,
|
||
supply_demand_score=rec.supply_demand_score,
|
||
price_action_score=rec.price_action_score,
|
||
position_score=rec.position_score,
|
||
valuation_score=rec.valuation_score,
|
||
signal=rec.signal,
|
||
entry_price=rec.entry_price,
|
||
target_price=rec.target_price,
|
||
stop_loss=rec.stop_loss,
|
||
reasons=json.dumps(rec.reasons, ensure_ascii=False),
|
||
llm_analysis=rec.llm_analysis,
|
||
strategy=rec.strategy,
|
||
entry_signal_type=rec.entry_signal_type,
|
||
llm_score=rec.llm_score,
|
||
scan_session=rec.scan_session,
|
||
)
|
||
await db.execute(stmt)
|
||
|
||
await db.commit()
|
||
logger.info(f"已保存 {len(result.get('recommendations', []))} 条推荐到数据库")
|
||
except Exception as e:
|
||
logger.error(f"保存推荐到数据库失败: {e}")
|
||
|
||
|
||
async def _load_today_from_db() -> dict:
|
||
"""从数据库加载今日推荐"""
|
||
today = datetime.now().strftime("%Y-%m-%d")
|
||
try:
|
||
async with get_db() as db:
|
||
from sqlalchemy import text
|
||
import json
|
||
|
||
# 加载市场温度(按 trade_date 取最新交易日)
|
||
result = await db.execute(
|
||
text("SELECT * FROM market_temperature ORDER BY trade_date DESC LIMIT 1")
|
||
)
|
||
mt_row = result.fetchone()
|
||
market_temp = None
|
||
if mt_row:
|
||
m = mt_row._mapping
|
||
market_temp = MarketTemperature(
|
||
trade_date=m["trade_date"],
|
||
up_count=m["up_count"],
|
||
down_count=m["down_count"],
|
||
limit_up_count=m["limit_up_count"],
|
||
limit_down_count=m["limit_down_count"],
|
||
max_streak=m["max_streak"],
|
||
broken_rate=m["broken_rate"],
|
||
temperature=m["temperature"],
|
||
)
|
||
|
||
# 加载推荐(取最近一个有数据的日期,按 ts_code 去重,只取 >= 60 分)
|
||
result = await db.execute(
|
||
text("SELECT * FROM recommendations "
|
||
"WHERE date(created_at) = (SELECT date(created_at) FROM recommendations ORDER BY created_at DESC LIMIT 1) "
|
||
"AND score >= 60 "
|
||
"AND id IN (SELECT MAX(id) FROM recommendations "
|
||
" WHERE date(created_at) = (SELECT date(created_at) FROM recommendations ORDER BY created_at DESC LIMIT 1) "
|
||
" GROUP BY ts_code) "
|
||
"ORDER BY score DESC")
|
||
)
|
||
rows = result.fetchall()
|
||
recommendations = []
|
||
for row in rows:
|
||
r = row._mapping
|
||
recommendations.append(Recommendation(
|
||
ts_code=r["ts_code"],
|
||
name=r["name"],
|
||
sector=r["sector"] or "",
|
||
score=r["score"] or 0,
|
||
market_temp_score=r["market_temp_score"] or 0,
|
||
sector_score=r["sector_score"] or 0,
|
||
capital_score=r["capital_score"] or 0,
|
||
technical_score=r["technical_score"] or 0,
|
||
supply_demand_score=r.get("supply_demand_score") or 0,
|
||
price_action_score=r.get("price_action_score") or 0,
|
||
position_score=r.get("position_score") or 50,
|
||
valuation_score=r.get("valuation_score") or 50,
|
||
signal=r["signal"] or "HOLD",
|
||
entry_price=r["entry_price"],
|
||
target_price=r["target_price"],
|
||
stop_loss=r["stop_loss"],
|
||
reasons=json.loads(r["reasons"]) if r["reasons"] else [],
|
||
llm_analysis=r.get("llm_analysis") or "",
|
||
strategy=r.get("strategy") or "trend_breakout",
|
||
entry_signal_type=r.get("entry_signal_type") or "none",
|
||
llm_score=r.get("llm_score"),
|
||
scan_session=r["scan_session"] or "",
|
||
))
|
||
|
||
return {
|
||
"market_temp": market_temp,
|
||
"hot_sectors": [],
|
||
"capital_filtered": [],
|
||
"recommendations": recommendations,
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"从数据库加载推荐失败: {e}")
|
||
return {"market_temp": None, "hot_sectors": [], "capital_filtered": [], "recommendations": []}
|
||
|
||
|
||
async def _load_sectors_from_db() -> list[SectorInfo]:
|
||
"""从数据库加载最近的板块热度数据"""
|
||
try:
|
||
async with get_db() as db:
|
||
from sqlalchemy import text
|
||
result = await db.execute(
|
||
text(
|
||
"SELECT * FROM sector_heat "
|
||
"WHERE trade_date = (SELECT MAX(trade_date) FROM sector_heat) "
|
||
"AND id IN ("
|
||
" SELECT MAX(id) FROM sector_heat "
|
||
" WHERE trade_date = (SELECT MAX(trade_date) FROM sector_heat) "
|
||
" GROUP BY sector_code"
|
||
") "
|
||
"ORDER BY heat_score DESC"
|
||
)
|
||
)
|
||
rows = result.fetchall()
|
||
sectors = []
|
||
for row in rows:
|
||
r = row._mapping
|
||
# Parse JSON fields with fallback
|
||
leading_stocks = json.loads(r.get("leading_stocks") or "[]")
|
||
pct_trend = json.loads(r.get("pct_trend") or "[]")
|
||
sectors.append(SectorInfo(
|
||
sector_code=r["sector_code"],
|
||
sector_name=r["sector_name"],
|
||
pct_change=r["pct_change"] or 0,
|
||
capital_inflow=r["capital_inflow"] or 0,
|
||
limit_up_count=r["limit_up_count"] or 0,
|
||
days_continuous=r.get("days_continuous") or 0,
|
||
heat_score=r["heat_score"] or 0,
|
||
stage=r.get("stage") or "mid",
|
||
member_count=r.get("member_count") or 0,
|
||
leading_stocks=leading_stocks,
|
||
pct_trend=pct_trend,
|
||
turnover_avg=r.get("turnover_avg") or 0,
|
||
main_force_ratio=r.get("main_force_ratio") or 0,
|
||
))
|
||
return sectors
|
||
except Exception as e:
|
||
logger.error(f"从数据库加载板块数据失败: {e}")
|
||
return []
|