astock-agent/backend/app/api/sectors.py
2026-06-02 11:25:08 +08:00

337 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""板块分析 API"""
import json
from fastapi import APIRouter
from sqlalchemy import text
from app.data.tushare_client import tushare_client
from app.data.cache import cache
from app.engine.recommender import get_latest_sectors
from app.db.database import get_db
router = APIRouter(prefix="/api/sectors", tags=["sectors"])
@router.get("/hot")
async def get_hot_sectors(limit: int = 10):
"""获取最新主线主题排名。
页面访问只读数据库里的扫描结论,不在 GET 请求中拉取外部实时行情。
"""
sectors = await get_latest_sectors()
trade_date = sectors[0].trade_date if sectors else ""
sectors_data = [
{
"sector_code": s.sector_code,
"sector_name": s.sector_name,
"board_type": s.board_type,
"theme_id": s.theme_id,
"theme_name": s.theme_name,
"theme_aliases": s.theme_aliases,
"pct_change": s.pct_change,
"capital_inflow": s.capital_inflow,
"limit_up_count": s.limit_up_count,
"days_continuous": s.days_continuous,
"heat_score": s.heat_score,
"stage": s.stage,
# 增强分析字段
"member_count": s.member_count,
"leading_stocks": s.leading_stocks,
"pct_trend": s.pct_trend,
"turnover_avg": s.turnover_avg,
"main_force_ratio": s.main_force_ratio,
"trade_date": trade_date,
"realtime_pct_change": s.realtime_pct_change,
"realtime_limit_up_count": s.realtime_limit_up_count,
"realtime_amount": s.realtime_amount,
"realtime_turnover_rate": s.realtime_turnover_rate,
"realtime_up_count": s.realtime_up_count,
"realtime_down_count": s.realtime_down_count,
"leading_stocks_realtime": s.leading_stocks_realtime,
"is_realtime": s.is_realtime,
"data_mode": s.data_mode,
"source": s.source,
"data_status": s.data_status,
"source_detail": s.source_detail,
"catalyst_score": s.catalyst_score,
"catalyst_count": s.catalyst_count,
"catalyst_reasons": s.catalyst_reasons,
}
for s in sectors[:limit]
]
realtime_enabled = any(s.get("is_realtime") for s in sectors_data)
mode = sectors[0].data_mode if realtime_enabled and sectors else "daily_snapshot"
status = _derive_status(sectors_data)
for s in sectors_data:
s["data_mode"] = mode
s["data_status"] = status
s["structure_trade_date"] = trade_date
return sectors_data
def _derive_status(sectors: list[dict]) -> str:
statuses = {str(s.get("data_status") or "fresh") for s in sectors}
if not statuses:
return "snapshot"
if "fresh" in statuses:
return "fresh" if len(statuses) == 1 else "mixed"
if "stale" in statuses:
return "stale"
if "fallback" in statuses:
return "fallback"
return next(iter(statuses))
def _safe_json(raw: str | None):
if not raw:
return {}
try:
return json.loads(raw)
except Exception:
return {}
def _member_theme_matches(value: str, target: str, aliases: list[str]) -> bool:
value_norm = (value or "").strip().lower()
target_norm = (target or "").strip().lower()
alias_norms = {(alias or "").strip().lower() for alias in aliases if alias}
if not value_norm or not target_norm:
return False
return value_norm == target_norm or value_norm in alias_norms or target_norm in value_norm or value_norm in target_norm
def _sector_identity_matches(value: str, target: str, aliases: list[str]) -> bool:
value_norm = (value or "").strip().lower()
target_norm = (target or "").strip().lower()
alias_norms = {(alias or "").strip().lower() for alias in aliases if alias}
if not value_norm or not target_norm:
return False
return value_norm == target_norm or target_norm in alias_norms or target_norm in value_norm or value_norm in target_norm
@router.get("/{sector_name}/detail")
async def get_sector_detail(sector_name: str):
"""获取板块详情与候选股投研观察。
只读取最近一次扫描沉淀的数据,不在页面访问时拉外部行情。
"""
sectors = await get_latest_sectors()
matched = next(
(
s for s in sectors
if _sector_identity_matches(s.sector_name, sector_name, [s.theme_name, *(s.theme_aliases or [])])
),
None,
)
if not matched and sectors:
matched = next((s for s in sectors if _sector_identity_matches(sector_name, s.sector_name, [])), None)
aliases = [matched.theme_name, *(matched.theme_aliases or [])] if matched else []
target_name = matched.sector_name if matched else sector_name
async with get_db() as db:
latest_session = (
await db.execute(
text("SELECT scan_session FROM research_observations ORDER BY created_at DESC LIMIT 1")
)
).scalar()
rows = []
if latest_session:
result = await db.execute(
text(
"SELECT id, scan_session, scan_mode, ts_code, name, theme_name, stock_role, "
"action_plan, final_score, catalyst_score, theme_money_score, stock_money_score, "
"emotion_role_score, timing_score, entry_signal_type, elimination_reason, detail_json, created_at "
"FROM research_observations "
"WHERE scan_session = :session "
"ORDER BY final_score DESC LIMIT 300"
),
{"session": latest_session},
)
rows = result.fetchall()
members = []
for row in rows:
r = row._mapping
theme = r["theme_name"] or ""
if not _member_theme_matches(theme, target_name, aliases):
continue
detail = _safe_json(r["detail_json"])
stock_detail = detail.get("stock", {}) if isinstance(detail, dict) else {}
decision_detail = detail.get("decision", {}) if isinstance(detail, dict) else {}
members.append({
"ts_code": r["ts_code"],
"name": r["name"],
"theme_name": theme,
"stock_role": r["stock_role"] or "",
"action_plan": r["action_plan"] or "观察",
"final_score": r["final_score"] or 0,
"catalyst_score": r["catalyst_score"] or 0,
"theme_money_score": r["theme_money_score"] or 0,
"stock_money_score": r["stock_money_score"] or 0,
"emotion_role_score": r["emotion_role_score"] or 0,
"timing_score": r["timing_score"] or 0,
"entry_signal_type": r["entry_signal_type"] or "none",
"elimination_reason": r["elimination_reason"] or "",
"recall_tags": stock_detail.get("recall_tags", []),
"main_net_inflow": stock_detail.get("main_net_inflow", 0),
"inflow_ratio": stock_detail.get("inflow_ratio", 0),
"turnover_rate": stock_detail.get("turnover_rate", 0),
"volume_ratio": stock_detail.get("volume_ratio"),
"trigger_condition": decision_detail.get("trigger_condition", ""),
"invalidation_condition": decision_detail.get("invalidation_condition", ""),
"created_at": str(r["created_at"] or ""),
})
leader_codes = {
str(item.get("ts_code") or "")
for item in ((matched.leading_stocks_realtime or matched.leading_stocks) if matched else [])
}
for item in members:
if item["ts_code"] in leader_codes and "龙头" not in item["stock_role"]:
item["stock_role"] = f"代表股/{item['stock_role'] or '候选'}"
action_counts = {"可操作": 0, "重点关注": 0, "观察": 0}
for item in members:
action_counts[item["action_plan"]] = action_counts.get(item["action_plan"], 0) + 1
sector_payload = None
if matched:
sector_payload = {
"sector_code": matched.sector_code,
"sector_name": matched.sector_name,
"board_type": matched.board_type,
"theme_id": matched.theme_id,
"theme_name": matched.theme_name,
"theme_aliases": matched.theme_aliases,
"pct_change": matched.pct_change,
"capital_inflow": matched.capital_inflow,
"limit_up_count": matched.limit_up_count,
"days_continuous": matched.days_continuous,
"heat_score": matched.heat_score,
"stage": matched.stage,
"member_count": matched.member_count,
"leading_stocks": matched.leading_stocks,
"leading_stocks_realtime": matched.leading_stocks_realtime,
"turnover_avg": matched.turnover_avg,
"main_force_ratio": matched.main_force_ratio,
"trade_date": matched.trade_date,
"realtime_pct_change": matched.realtime_pct_change,
"realtime_limit_up_count": matched.realtime_limit_up_count,
"realtime_amount": matched.realtime_amount,
"realtime_turnover_rate": matched.realtime_turnover_rate,
"realtime_up_count": matched.realtime_up_count,
"realtime_down_count": matched.realtime_down_count,
"is_realtime": matched.is_realtime,
"data_mode": matched.data_mode,
"source": matched.source,
"data_status": matched.data_status,
"source_detail": matched.source_detail,
"catalyst_score": matched.catalyst_score,
"catalyst_count": matched.catalyst_count,
"catalyst_reasons": matched.catalyst_reasons,
}
return {
"sector": sector_payload,
"scan_session": latest_session or "",
"members": members,
"action_counts": action_counts,
"top_candidates": members[:12],
"summary": {
"member_count": len(members),
"actionable_count": action_counts.get("可操作", 0),
"watch_count": action_counts.get("重点关注", 0),
"observe_count": action_counts.get("观察", 0),
"avg_score": round(sum(item["final_score"] for item in members) / len(members), 1) if members else 0,
},
}
@router.get("/rotation")
async def get_sector_rotation(days: int = 5):
"""获取近N日板块轮动数据用于热力图"""
# 检查缓存
cache_key = f"sector_rotation:{days}"
cached = cache.get(cache_key)
if cached is not None:
return cached
trade_date = tushare_client.get_latest_trade_date()
# 获取交易日历
trade_dates_df = tushare_client.get_trade_dates()
today = trade_date
past_dates = [d for d in trade_dates_df if d <= today]
# 取最近 N 天
recent_dates = past_dates[-days:] if len(past_dates) >= days else past_dates
# 获取板块指数列表用于名字映射
index_list = tushare_client.get_ths_index_list()
name_map = {}
if not index_list.empty:
for _, row in index_list.iterrows():
name_map[row["ts_code"]] = row["name"]
all_sectors = []
for td in recent_dates:
df = tushare_client.get_sector_moneyflow(td)
if df.empty:
continue
for _, row in df.iterrows():
code = row.get("ts_code", "")
# Use industry field from moneyflow data, fallback to name_map
industry_name = row.get("industry", "") or name_map.get(code, code)
all_sectors.append({
"sector_code": code,
"sector_name": industry_name,
"trade_date": td,
"net_amount": round(float(row.get("net_amount", 0) or 0), 2),
})
# 获取板块日线来补充涨跌幅
sector_codes = list(set(s["sector_code"] for s in all_sectors))
sector_pct_map: dict[str, dict[str, float]] = {}
for code in sector_codes:
df_daily = tushare_client.get_ths_daily(code, days=days + 10)
if not df_daily.empty:
for _, r in df_daily.iterrows():
if r["trade_date"] in recent_dates:
if code not in sector_pct_map:
sector_pct_map[code] = {}
sector_pct_map[code][r["trade_date"]] = float(r.get("pct_change", 0) or 0)
# 按板块分组
sector_map: dict[str, dict] = {}
for s in all_sectors:
code = s["sector_code"]
if code not in sector_map:
sector_map[code] = {
"sector_code": code,
"sector_name": s["sector_name"],
"daily_data": [],
}
pct = sector_pct_map.get(code, {}).get(s["trade_date"], 0)
sector_map[code]["daily_data"].append({
"trade_date": s["trade_date"],
"pct_change": round(pct, 2),
"net_amount": s["net_amount"],
})
# 按最近一天涨幅排序,取 top 20
sorted_sectors = sorted(
sector_map.values(),
key=lambda x: max((d["pct_change"] for d in x["daily_data"]), default=0),
reverse=True,
)[:20]
result = {"trade_date": trade_date, "dates": recent_dates, "sectors": sorted_sectors}
# 写入缓存TTL 300秒5分钟
cache.set(cache_key, result, ttl=300)
return result