167 lines
5.9 KiB
Python
167 lines
5.9 KiB
Python
"""板块分析 API"""
|
||
|
||
from fastapi import APIRouter
|
||
|
||
from app.data.tushare_client import tushare_client
|
||
from app.data.cache import cache
|
||
from app.engine.recommender import get_latest_sectors
|
||
|
||
router = APIRouter(prefix="/api/sectors", tags=["sectors"])
|
||
|
||
|
||
@router.get("/hot")
|
||
async def get_hot_sectors(limit: int = 10):
|
||
"""获取最新主线主题排名。
|
||
|
||
页面访问只读数据库里的扫描结论,不在 GET 请求中拉取外部实时行情。
|
||
"""
|
||
sectors = await get_latest_sectors()
|
||
trade_date = sectors[0].trade_date if sectors else ""
|
||
|
||
sectors_data = [
|
||
{
|
||
"sector_code": s.sector_code,
|
||
"sector_name": s.sector_name,
|
||
"board_type": s.board_type,
|
||
"theme_id": s.theme_id,
|
||
"theme_name": s.theme_name,
|
||
"theme_aliases": s.theme_aliases,
|
||
"pct_change": s.pct_change,
|
||
"capital_inflow": s.capital_inflow,
|
||
"limit_up_count": s.limit_up_count,
|
||
"days_continuous": s.days_continuous,
|
||
"heat_score": s.heat_score,
|
||
"stage": s.stage,
|
||
# 增强分析字段
|
||
"member_count": s.member_count,
|
||
"leading_stocks": s.leading_stocks,
|
||
"pct_trend": s.pct_trend,
|
||
"turnover_avg": s.turnover_avg,
|
||
"main_force_ratio": s.main_force_ratio,
|
||
"trade_date": trade_date,
|
||
"realtime_pct_change": s.realtime_pct_change,
|
||
"realtime_limit_up_count": s.realtime_limit_up_count,
|
||
"realtime_amount": s.realtime_amount,
|
||
"realtime_turnover_rate": s.realtime_turnover_rate,
|
||
"realtime_up_count": s.realtime_up_count,
|
||
"realtime_down_count": s.realtime_down_count,
|
||
"leading_stocks_realtime": s.leading_stocks_realtime,
|
||
"is_realtime": s.is_realtime,
|
||
"data_mode": s.data_mode,
|
||
"source": s.source,
|
||
"data_status": s.data_status,
|
||
"source_detail": s.source_detail,
|
||
"catalyst_score": s.catalyst_score,
|
||
"catalyst_count": s.catalyst_count,
|
||
"catalyst_reasons": s.catalyst_reasons,
|
||
}
|
||
for s in sectors[:limit]
|
||
]
|
||
|
||
realtime_enabled = any(s.get("is_realtime") for s in sectors_data)
|
||
mode = sectors[0].data_mode if realtime_enabled and sectors else "daily_snapshot"
|
||
status = _derive_status(sectors_data)
|
||
for s in sectors_data:
|
||
s["data_mode"] = mode
|
||
s["data_status"] = status
|
||
s["structure_trade_date"] = trade_date
|
||
return sectors_data
|
||
|
||
|
||
def _derive_status(sectors: list[dict]) -> str:
|
||
statuses = {str(s.get("data_status") or "fresh") for s in sectors}
|
||
if not statuses:
|
||
return "snapshot"
|
||
if "fresh" in statuses:
|
||
return "fresh" if len(statuses) == 1 else "mixed"
|
||
if "stale" in statuses:
|
||
return "stale"
|
||
if "fallback" in statuses:
|
||
return "fallback"
|
||
return next(iter(statuses))
|
||
|
||
|
||
@router.get("/rotation")
|
||
async def get_sector_rotation(days: int = 5):
|
||
"""获取近N日板块轮动数据(用于热力图)"""
|
||
# 检查缓存
|
||
cache_key = f"sector_rotation:{days}"
|
||
cached = cache.get(cache_key)
|
||
if cached is not None:
|
||
return cached
|
||
|
||
trade_date = tushare_client.get_latest_trade_date()
|
||
|
||
# 获取交易日历
|
||
trade_dates_df = tushare_client.get_trade_dates()
|
||
today = trade_date
|
||
past_dates = [d for d in trade_dates_df if d <= today]
|
||
# 取最近 N 天
|
||
recent_dates = past_dates[-days:] if len(past_dates) >= days else past_dates
|
||
|
||
# 获取板块指数列表用于名字映射
|
||
index_list = tushare_client.get_ths_index_list()
|
||
name_map = {}
|
||
if not index_list.empty:
|
||
for _, row in index_list.iterrows():
|
||
name_map[row["ts_code"]] = row["name"]
|
||
|
||
all_sectors = []
|
||
for td in recent_dates:
|
||
df = tushare_client.get_sector_moneyflow(td)
|
||
if df.empty:
|
||
continue
|
||
for _, row in df.iterrows():
|
||
code = row.get("ts_code", "")
|
||
# Use industry field from moneyflow data, fallback to name_map
|
||
industry_name = row.get("industry", "") or name_map.get(code, code)
|
||
all_sectors.append({
|
||
"sector_code": code,
|
||
"sector_name": industry_name,
|
||
"trade_date": td,
|
||
"net_amount": round(float(row.get("net_amount", 0) or 0), 2),
|
||
})
|
||
|
||
# 获取板块日线来补充涨跌幅
|
||
sector_codes = list(set(s["sector_code"] for s in all_sectors))
|
||
sector_pct_map: dict[str, dict[str, float]] = {}
|
||
for code in sector_codes:
|
||
df_daily = tushare_client.get_ths_daily(code, days=days + 10)
|
||
if not df_daily.empty:
|
||
for _, r in df_daily.iterrows():
|
||
if r["trade_date"] in recent_dates:
|
||
if code not in sector_pct_map:
|
||
sector_pct_map[code] = {}
|
||
sector_pct_map[code][r["trade_date"]] = float(r.get("pct_change", 0) or 0)
|
||
|
||
# 按板块分组
|
||
sector_map: dict[str, dict] = {}
|
||
for s in all_sectors:
|
||
code = s["sector_code"]
|
||
if code not in sector_map:
|
||
sector_map[code] = {
|
||
"sector_code": code,
|
||
"sector_name": s["sector_name"],
|
||
"daily_data": [],
|
||
}
|
||
pct = sector_pct_map.get(code, {}).get(s["trade_date"], 0)
|
||
sector_map[code]["daily_data"].append({
|
||
"trade_date": s["trade_date"],
|
||
"pct_change": round(pct, 2),
|
||
"net_amount": s["net_amount"],
|
||
})
|
||
|
||
# 按最近一天涨幅排序,取 top 20
|
||
sorted_sectors = sorted(
|
||
sector_map.values(),
|
||
key=lambda x: max((d["pct_change"] for d in x["daily_data"]), default=0),
|
||
reverse=True,
|
||
)[:20]
|
||
|
||
result = {"trade_date": trade_date, "dates": recent_dates, "sectors": sorted_sectors}
|
||
|
||
# 写入缓存,TTL 300秒(5分钟)
|
||
cache.set(cache_key, result, ttl=300)
|
||
|
||
return result
|