astock-agent/backend/app/engine/recommender.py
2026-04-16 23:49:04 +08:00

592 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""推荐引擎
管理推荐状态,提供推荐查询接口。
将筛选结果持久化并管理历史推荐。
"""
import logging
import json
import asyncio
from datetime import datetime, timedelta
from app.engine.screener import run_screening
from app.data.models import Recommendation, MarketTemperature, SectorInfo
from app.db.database import get_db
from app.db import tables
logger = logging.getLogger(__name__)
# 内存中的最新推荐结果
_latest_result: dict | None = None
# 扫描锁:防止同时触发两次扫描
_scan_lock = asyncio.Lock()
_scan_running = False
async def refresh_recommendations(trade_date: str = None, scan_session: str = "manual") -> dict:
"""刷新推荐列表(带扫描锁防止并发)"""
global _latest_result, _scan_running
if _scan_lock.locked():
logger.warning("扫描已在执行中,跳过本次触发")
return _latest_result or {"market_temp": None, "hot_sectors": [], "recommendations": []}
async with _scan_lock:
_scan_running = True
try:
result = await run_screening(trade_date)
# 给每条推荐添加 scan_session
for rec in result.get("recommendations", []):
rec.scan_session = scan_session
rec.created_at = datetime.now()
_latest_result = result
# 持久化到数据库
await _save_to_db(result)
# 更新历史推荐跟踪(检查之前推荐的后续表现)
await _update_tracking()
return result
finally:
_scan_running = False
async def _update_tracking():
"""更新历史推荐的跟踪数据"""
try:
from sqlalchemy import text
from app.data.tushare_client import tushare_client
trade_date = tushare_client.get_latest_trade_date()
async with get_db() as db:
# 查找所有活跃的推荐(有 entry_price 且未被标记为 closed
result = await db.execute(
text(
"SELECT id, ts_code, entry_price, target_price, stop_loss "
"FROM recommendations "
"WHERE entry_price IS NOT NULL "
"AND entry_price > 0 "
"AND id NOT IN (SELECT DISTINCT recommendation_id FROM recommendation_tracking WHERE status = 'closed') "
"AND date(created_at) <= date(:today) "
"ORDER BY created_at DESC LIMIT 50"
),
{"today": datetime.now().strftime("%Y-%m-%d")},
)
rows = result.fetchall()
if not rows:
return
# 获取这些股票的今日收盘价
codes = [r[1] for r in rows]
daily_all = tushare_client.get_daily_all(trade_date)
price_map = {}
if not daily_all.empty:
for _, row in daily_all.iterrows():
if row["ts_code"] in codes:
price_map[row["ts_code"]] = row["close"]
tracked = 0
for r in rows:
rec_id, ts_code, entry_price, target_price, stop_loss = r
current_price = price_map.get(ts_code)
if current_price is None or entry_price is None or entry_price <= 0:
continue
pct = round((current_price - entry_price) / entry_price * 100, 2)
hit_target = target_price and current_price >= target_price
hit_stop = stop_loss and current_price <= stop_loss
status = "closed" if (hit_target or hit_stop) else "active"
# 检查今天是否已经跟踪过
exists = await db.execute(
text(
"SELECT id FROM recommendation_tracking "
"WHERE recommendation_id = :rid AND track_date = :td"
),
{"rid": rec_id, "td": trade_date},
)
if exists.fetchone():
continue
await db.execute(
tables.recommendation_tracking_table.insert().values(
recommendation_id=rec_id,
track_date=trade_date,
current_price=current_price,
pct_from_entry=pct,
hit_target=hit_target,
hit_stop_loss=hit_stop,
status=status,
)
)
tracked += 1
await db.commit()
if tracked > 0:
logger.info(f"已更新 {tracked} 条推荐跟踪记录")
except Exception as e:
logger.error(f"更新推荐跟踪失败: {e}")
async def get_performance_stats() -> dict:
"""获取推荐胜率统计"""
try:
from sqlalchemy import text
async with get_db() as db:
# 总推荐数
result = await db.execute(
text("SELECT COUNT(DISTINCT id) FROM recommendations")
)
total = result.scalar() or 0
# 有跟踪记录的推荐
result = await db.execute(
text(
"SELECT COUNT(DISTINCT r.id) FROM recommendations r "
"INNER JOIN recommendation_tracking t ON t.recommendation_id = r.id"
)
)
tracked = result.scalar() or 0
# 胜率基于最新跟踪日的最终 pct正值=盈利,负值=亏损)
result = await db.execute(
text(
"SELECT COUNT(*) FROM ("
" SELECT t.recommendation_id, t.pct_from_entry as latest_pct "
" FROM recommendation_tracking t "
" INNER JOIN ("
" SELECT recommendation_id, MAX(id) as max_id "
" FROM recommendation_tracking GROUP BY recommendation_id"
" ) latest ON t.id = latest.max_id"
") WHERE latest_pct > 0"
)
)
winning = result.scalar() or 0
# 平均收益(基于最新跟踪日的 pct
result = await db.execute(
text(
"SELECT AVG(latest_pct) FROM ("
" SELECT t.pct_from_entry as latest_pct "
" FROM recommendation_tracking t "
" INNER JOIN ("
" SELECT recommendation_id, MAX(id) as max_id "
" FROM recommendation_tracking GROUP BY recommendation_id"
" ) latest ON t.id = latest.max_id"
")"
)
)
avg_return = result.scalar()
avg_return = round(float(avg_return), 2) if avg_return else 0
# 达到目标价的推荐
result = await db.execute(
text(
"SELECT COUNT(DISTINCT recommendation_id) FROM recommendation_tracking "
"WHERE hit_target = 1"
)
)
hit_target_count = result.scalar() or 0
# 触发止损的推荐
result = await db.execute(
text(
"SELECT COUNT(DISTINCT recommendation_id) FROM recommendation_tracking "
"WHERE hit_stop_loss = 1"
)
)
hit_stop_count = result.scalar() or 0
# 最近跟踪的推荐详情
result = await db.execute(
text(
"SELECT r.ts_code, r.name, r.signal, r.entry_price, "
" r.target_price, r.stop_loss, r.entry_signal_type, r.score, "
" t.pct_from_entry, t.current_price, t.track_date, t.hit_target, t.hit_stop_loss, "
" r.created_at "
"FROM recommendations r "
"INNER JOIN recommendation_tracking t ON t.recommendation_id = r.id "
"INNER JOIN ("
" SELECT recommendation_id, MAX(id) as max_id "
" FROM recommendation_tracking GROUP BY recommendation_id"
") latest ON t.id = latest.max_id "
"ORDER BY r.created_at DESC LIMIT 20"
)
)
details = []
for row in result.fetchall():
r = row._mapping
details.append({
"ts_code": r["ts_code"],
"name": r["name"],
"signal": r["signal"],
"entry_signal_type": r["entry_signal_type"],
"score": r["score"],
"entry_price": r["entry_price"],
"target_price": r["target_price"],
"stop_loss": r["stop_loss"],
"current_price": r["current_price"],
"pct_from_entry": r["pct_from_entry"],
"track_date": r["track_date"],
"hit_target": bool(r["hit_target"]),
"hit_stop_loss": bool(r["hit_stop_loss"]),
"created_at": str(r["created_at"])[:10] if r["created_at"] else "",
})
win_rate = round(winning / tracked * 100, 1) if tracked > 0 else 0
return {
"total_recommendations": total,
"tracked": tracked,
"winning": winning,
"win_rate": win_rate,
"avg_return": avg_return,
"hit_target_count": hit_target_count,
"hit_stop_count": hit_stop_count,
"details": details,
}
except Exception as e:
logger.error(f"获取胜率统计失败: {e}")
return {
"total_recommendations": 0, "tracked": 0, "winning": 0,
"win_rate": 0, "avg_return": 0, "hit_target_count": 0,
"hit_stop_count": 0, "details": [],
}
async def get_latest_recommendations() -> dict:
"""获取最新推荐结果"""
if _latest_result:
return _latest_result
# 如果内存中没有,从数据库加载今日的
return await _load_today_from_db()
async def get_latest_sectors() -> list[SectorInfo]:
"""获取最新的板块热度数据(只读缓存,不触发扫描)"""
if _latest_result and _latest_result.get("hot_sectors"):
return _latest_result["hot_sectors"]
# 内存中没有,从数据库加载
return await _load_sectors_from_db()
async def get_recommendation_history(days: int = 7) -> list[dict]:
"""获取历史推荐记录,按日期分组返回"""
import json
start = (datetime.now() - timedelta(days=days)).strftime("%Y-%m-%d")
async with get_db() as db:
from sqlalchemy import text
# 查询所有历史推荐,按 ts_code 去重(每天取最新一条)
stmt = text(
"SELECT * FROM recommendations "
"WHERE created_at >= :start "
"AND score >= 60 "
"AND id IN ("
" SELECT MAX(id) FROM recommendations "
" WHERE created_at >= :start "
" GROUP BY date(created_at), ts_code"
") "
"ORDER BY created_at DESC, score DESC"
)
result = await db.execute(stmt, {"start": start})
rows = result.fetchall()
# 按日期分组
grouped: dict[str, list[dict]] = {}
for row in rows:
r = row._mapping
# SQLite created_at 是字符串 "YYYY-MM-DD HH:MM:SS"
ca = r["created_at"]
if ca:
date_str = str(ca)[:10] # 取前10字符即日期部分
created_at_str = str(ca)
else:
date_str = "unknown"
created_at_str = None
rec_dict = {
"ts_code": r["ts_code"],
"name": r["name"],
"sector": r["sector"] or "",
"score": r["score"] or 0,
"level": _score_to_level_static(r["score"] or 0),
"signal": r["signal"] or "HOLD",
"market_temp_score": r["market_temp_score"] or 0,
"sector_score": r["sector_score"] or 0,
"capital_score": r["capital_score"] or 0,
"technical_score": r["technical_score"] or 0,
"supply_demand_score": r.get("supply_demand_score") or 0,
"price_action_score": r.get("price_action_score") or 0,
"position_score": r.get("position_score") or 50,
"valuation_score": r.get("valuation_score") or 50,
"entry_price": r["entry_price"],
"target_price": r["target_price"],
"stop_loss": r["stop_loss"],
"reasons": json.loads(r["reasons"]) if r["reasons"] else [],
"risk_note": "",
"strategy": r.get("strategy") or "trend_breakout",
"entry_signal_type": r.get("entry_signal_type") or "none",
"llm_analysis": r.get("llm_analysis") or "",
"llm_score": r.get("llm_score"),
"scan_session": r["scan_session"] or "",
"created_at": created_at_str,
}
if date_str not in grouped:
grouped[date_str] = []
grouped[date_str].append(rec_dict)
# 转为列表,按日期降序
result_list = []
for date_str in sorted(grouped.keys(), reverse=True):
recs = grouped[date_str]
buy_count = sum(1 for r in recs if r["signal"] == "BUY")
avg_score = round(sum(r["score"] for r in recs) / len(recs), 1) if recs else 0
result_list.append({
"date": date_str,
"count": len(recs),
"buy_count": buy_count,
"avg_score": avg_score,
"recommendations": recs,
})
return result_list
def _score_to_level_static(score: float) -> str:
"""根据评分确定推荐等级"""
if score >= 75:
return "强烈推荐"
elif score >= 60:
return "推荐"
elif score >= 45:
return "关注"
else:
return "观望"
async def _save_to_db(result: dict):
"""将推荐结果保存到数据库"""
try:
async with get_db() as db:
from sqlalchemy import text
# 保存市场温度
mt = result.get("market_temp")
if mt:
# 使用 INSERT OR REPLACE 确保重复扫描能更新数据
stmt = text(
"INSERT OR REPLACE INTO market_temperature "
"(trade_date, up_count, down_count, limit_up_count, limit_down_count, "
"max_streak, broken_rate, temperature) "
"VALUES (:td, :up, :down, :lu, :ld, :ms, :br, :temp)"
)
await db.execute(stmt, {
"td": mt.trade_date,
"up": mt.up_count,
"down": mt.down_count,
"lu": mt.limit_up_count,
"ld": mt.limit_down_count,
"ms": mt.max_streak,
"br": mt.broken_rate,
"temp": mt.temperature,
})
# 保存板块热度(先清除同一 trade_date 的旧数据,避免重复)
trade_date_val = mt.trade_date if mt else ""
if trade_date_val:
await db.execute(
text("DELETE FROM sector_heat WHERE trade_date = :td"),
{"td": trade_date_val},
)
for sector in result.get("hot_sectors", []):
stmt = tables.sector_heat_table.insert().values(
sector_code=sector.sector_code,
sector_name=sector.sector_name,
pct_change=sector.pct_change,
capital_inflow=sector.capital_inflow,
limit_up_count=sector.limit_up_count,
heat_score=sector.heat_score,
stage=sector.stage,
days_continuous=sector.days_continuous,
member_count=sector.member_count,
leading_stocks=json.dumps(sector.leading_stocks, ensure_ascii=False),
pct_trend=json.dumps(sector.pct_trend, ensure_ascii=False),
turnover_avg=sector.turnover_avg,
main_force_ratio=sector.main_force_ratio,
trade_date=trade_date_val,
)
await db.execute(stmt)
# 保存推荐(按 ts_code 清除当日旧记录,避免同一天多次扫描产生重复)
today_str = datetime.now().strftime("%Y-%m-%d")
now_dt = datetime.now()
saved_count = 0
for rec in result.get("recommendations", []):
if rec.score < 60:
continue
await db.execute(
text("DELETE FROM recommendations WHERE date(created_at) = :today AND ts_code = :code"),
{"today": today_str, "code": rec.ts_code},
)
stmt = tables.recommendations_table.insert().values(
ts_code=rec.ts_code,
name=rec.name,
sector=rec.sector,
score=rec.score,
market_temp_score=rec.market_temp_score,
sector_score=rec.sector_score,
capital_score=rec.capital_score,
technical_score=rec.technical_score,
supply_demand_score=rec.supply_demand_score,
price_action_score=rec.price_action_score,
position_score=rec.position_score,
valuation_score=rec.valuation_score,
signal=rec.signal,
entry_price=rec.entry_price,
target_price=rec.target_price,
stop_loss=rec.stop_loss,
reasons=json.dumps(rec.reasons, ensure_ascii=False),
llm_analysis=rec.llm_analysis,
strategy=rec.strategy,
entry_signal_type=rec.entry_signal_type,
llm_score=rec.llm_score,
scan_session=rec.scan_session,
created_at=now_dt,
)
await db.execute(stmt)
saved_count += 1
await db.commit()
logger.info(f"已保存 {saved_count} 条推荐到数据库(共 {len(result.get('recommendations', []))} 条,过滤掉 <60 分)")
except Exception as e:
logger.error(f"保存推荐到数据库失败: {e}")
async def _load_today_from_db() -> dict:
"""从数据库加载今日推荐"""
today = datetime.now().strftime("%Y-%m-%d")
try:
async with get_db() as db:
from sqlalchemy import text
import json
# 加载市场温度(按 trade_date 取最新交易日)
result = await db.execute(
text("SELECT * FROM market_temperature ORDER BY trade_date DESC LIMIT 1")
)
mt_row = result.fetchone()
market_temp = None
if mt_row:
m = mt_row._mapping
market_temp = MarketTemperature(
trade_date=m["trade_date"],
up_count=m["up_count"],
down_count=m["down_count"],
limit_up_count=m["limit_up_count"],
limit_down_count=m["limit_down_count"],
max_streak=m["max_streak"],
broken_rate=m["broken_rate"],
temperature=m["temperature"],
)
# 加载推荐(取最近一个有数据的日期,按 ts_code 去重,只取 >= 60 分)
result = await db.execute(
text("SELECT * FROM recommendations "
"WHERE date(created_at) = (SELECT date(created_at) FROM recommendations ORDER BY created_at DESC LIMIT 1) "
"AND score >= 60 "
"AND id IN (SELECT MAX(id) FROM recommendations "
" WHERE date(created_at) = (SELECT date(created_at) FROM recommendations ORDER BY created_at DESC LIMIT 1) "
" GROUP BY ts_code) "
"ORDER BY score DESC")
)
rows = result.fetchall()
recommendations = []
for row in rows:
r = row._mapping
recommendations.append(Recommendation(
ts_code=r["ts_code"],
name=r["name"],
sector=r["sector"] or "",
score=r["score"] or 0,
market_temp_score=r["market_temp_score"] or 0,
sector_score=r["sector_score"] or 0,
capital_score=r["capital_score"] or 0,
technical_score=r["technical_score"] or 0,
supply_demand_score=r.get("supply_demand_score") or 0,
price_action_score=r.get("price_action_score") or 0,
position_score=r.get("position_score") or 50,
valuation_score=r.get("valuation_score") or 50,
signal=r["signal"] or "HOLD",
entry_price=r["entry_price"],
target_price=r["target_price"],
stop_loss=r["stop_loss"],
reasons=json.loads(r["reasons"]) if r["reasons"] else [],
llm_analysis=r.get("llm_analysis") or "",
strategy=r.get("strategy") or "trend_breakout",
entry_signal_type=r.get("entry_signal_type") or "none",
llm_score=r.get("llm_score"),
scan_session=r["scan_session"] or "",
))
return {
"market_temp": market_temp,
"hot_sectors": [],
"capital_filtered": [],
"recommendations": recommendations,
}
except Exception as e:
logger.error(f"从数据库加载推荐失败: {e}")
return {"market_temp": None, "hot_sectors": [], "capital_filtered": [], "recommendations": []}
async def _load_sectors_from_db() -> list[SectorInfo]:
"""从数据库加载最近的板块热度数据"""
try:
async with get_db() as db:
from sqlalchemy import text
result = await db.execute(
text(
"SELECT * FROM sector_heat "
"WHERE trade_date = (SELECT MAX(trade_date) FROM sector_heat) "
"AND id IN ("
" SELECT MAX(id) FROM sector_heat "
" WHERE trade_date = (SELECT MAX(trade_date) FROM sector_heat) "
" GROUP BY sector_code"
") "
"ORDER BY heat_score DESC"
)
)
rows = result.fetchall()
sectors = []
for row in rows:
r = row._mapping
# Parse JSON fields with fallback
leading_stocks = json.loads(r.get("leading_stocks") or "[]")
pct_trend = json.loads(r.get("pct_trend") or "[]")
sectors.append(SectorInfo(
sector_code=r["sector_code"],
sector_name=r["sector_name"],
pct_change=r["pct_change"] or 0,
capital_inflow=r["capital_inflow"] or 0,
limit_up_count=r["limit_up_count"] or 0,
days_continuous=r.get("days_continuous") or 0,
heat_score=r["heat_score"] or 0,
stage=r.get("stage") or "mid",
member_count=r.get("member_count") or 0,
leading_stocks=leading_stocks,
pct_trend=pct_trend,
turnover_avg=r.get("turnover_avg") or 0,
main_force_ratio=r.get("main_force_ratio") or 0,
))
return sectors
except Exception as e:
logger.error(f"从数据库加载板块数据失败: {e}")
return []