150 lines
5.6 KiB
Python
150 lines
5.6 KiB
Python
"""板块分析 API"""
|
||
|
||
from fastapi import APIRouter
|
||
|
||
from app.analysis.sector_realtime import enrich_sectors_with_realtime, get_today_realtime_sector_board
|
||
from app.config import should_prefer_realtime_today, today_trade_date
|
||
from app.data.tushare_client import tushare_client
|
||
from app.data.cache import cache
|
||
from app.engine.recommender import get_latest_sectors
|
||
|
||
router = APIRouter(prefix="/api/sectors", tags=["sectors"])
|
||
|
||
|
||
@router.get("/hot")
|
||
async def get_hot_sectors(limit: int = 10):
|
||
"""获取热门板块排名(盘中自动补充实时数据)"""
|
||
sectors = await get_latest_sectors()
|
||
snapshot_trade_date = sectors[0].trade_date if sectors else ""
|
||
if should_prefer_realtime_today(snapshot_trade_date) or snapshot_trade_date != today_trade_date():
|
||
realtime_sectors = await get_today_realtime_sector_board(limit=max(limit, 20))
|
||
if realtime_sectors:
|
||
sectors = realtime_sectors
|
||
else:
|
||
sectors = await enrich_sectors_with_realtime(sectors)
|
||
else:
|
||
sectors = await enrich_sectors_with_realtime(sectors)
|
||
trade_date = sectors[0].trade_date if sectors else ""
|
||
|
||
sectors_data = [
|
||
{
|
||
"sector_code": s.sector_code,
|
||
"sector_name": s.sector_name,
|
||
"pct_change": s.pct_change,
|
||
"capital_inflow": s.capital_inflow,
|
||
"limit_up_count": s.limit_up_count,
|
||
"days_continuous": s.days_continuous,
|
||
"heat_score": s.heat_score,
|
||
"stage": s.stage,
|
||
# 增强分析字段
|
||
"member_count": s.member_count,
|
||
"leading_stocks": s.leading_stocks,
|
||
"pct_trend": s.pct_trend,
|
||
"turnover_avg": s.turnover_avg,
|
||
"main_force_ratio": s.main_force_ratio,
|
||
"trade_date": trade_date,
|
||
"realtime_pct_change": s.realtime_pct_change,
|
||
"realtime_limit_up_count": s.realtime_limit_up_count,
|
||
"realtime_amount": s.realtime_amount,
|
||
"realtime_turnover_rate": s.realtime_turnover_rate,
|
||
"realtime_up_count": s.realtime_up_count,
|
||
"realtime_down_count": s.realtime_down_count,
|
||
"leading_stocks_realtime": s.leading_stocks_realtime,
|
||
"is_realtime": s.is_realtime,
|
||
"data_mode": s.data_mode,
|
||
}
|
||
for s in sectors[:limit]
|
||
]
|
||
|
||
realtime_enabled = any(s.get("is_realtime") for s in sectors_data)
|
||
mode = sectors[0].data_mode if realtime_enabled and sectors else "daily_snapshot"
|
||
for s in sectors_data:
|
||
s["data_mode"] = mode
|
||
s["structure_trade_date"] = trade_date
|
||
return sectors_data
|
||
|
||
|
||
@router.get("/rotation")
|
||
async def get_sector_rotation(days: int = 5):
|
||
"""获取近N日板块轮动数据(用于热力图)"""
|
||
# 检查缓存
|
||
cache_key = f"sector_rotation:{days}"
|
||
cached = cache.get(cache_key)
|
||
if cached is not None:
|
||
return cached
|
||
|
||
trade_date = tushare_client.get_latest_trade_date()
|
||
|
||
# 获取交易日历
|
||
trade_dates_df = tushare_client.get_trade_dates()
|
||
today = trade_date
|
||
past_dates = [d for d in trade_dates_df if d <= today]
|
||
# 取最近 N 天
|
||
recent_dates = past_dates[-days:] if len(past_dates) >= days else past_dates
|
||
|
||
# 获取板块指数列表用于名字映射
|
||
index_list = tushare_client.get_ths_index_list()
|
||
name_map = {}
|
||
if not index_list.empty:
|
||
for _, row in index_list.iterrows():
|
||
name_map[row["ts_code"]] = row["name"]
|
||
|
||
all_sectors = []
|
||
for td in recent_dates:
|
||
df = tushare_client.get_sector_moneyflow(td)
|
||
if df.empty:
|
||
continue
|
||
for _, row in df.iterrows():
|
||
code = row.get("ts_code", "")
|
||
# Use industry field from moneyflow data, fallback to name_map
|
||
industry_name = row.get("industry", "") or name_map.get(code, code)
|
||
all_sectors.append({
|
||
"sector_code": code,
|
||
"sector_name": industry_name,
|
||
"trade_date": td,
|
||
"net_amount": round(float(row.get("net_amount", 0) or 0), 2),
|
||
})
|
||
|
||
# 获取板块日线来补充涨跌幅
|
||
sector_codes = list(set(s["sector_code"] for s in all_sectors))
|
||
sector_pct_map: dict[str, dict[str, float]] = {}
|
||
for code in sector_codes:
|
||
df_daily = tushare_client.get_ths_daily(code, days=days + 10)
|
||
if not df_daily.empty:
|
||
for _, r in df_daily.iterrows():
|
||
if r["trade_date"] in recent_dates:
|
||
if code not in sector_pct_map:
|
||
sector_pct_map[code] = {}
|
||
sector_pct_map[code][r["trade_date"]] = float(r.get("pct_change", 0) or 0)
|
||
|
||
# 按板块分组
|
||
sector_map: dict[str, dict] = {}
|
||
for s in all_sectors:
|
||
code = s["sector_code"]
|
||
if code not in sector_map:
|
||
sector_map[code] = {
|
||
"sector_code": code,
|
||
"sector_name": s["sector_name"],
|
||
"daily_data": [],
|
||
}
|
||
pct = sector_pct_map.get(code, {}).get(s["trade_date"], 0)
|
||
sector_map[code]["daily_data"].append({
|
||
"trade_date": s["trade_date"],
|
||
"pct_change": round(pct, 2),
|
||
"net_amount": s["net_amount"],
|
||
})
|
||
|
||
# 按最近一天涨幅排序,取 top 20
|
||
sorted_sectors = sorted(
|
||
sector_map.values(),
|
||
key=lambda x: max((d["pct_change"] for d in x["daily_data"]), default=0),
|
||
reverse=True,
|
||
)[:20]
|
||
|
||
result = {"trade_date": trade_date, "dates": recent_dates, "sectors": sorted_sectors}
|
||
|
||
# 写入缓存,TTL 300秒(5分钟)
|
||
cache.set(cache_key, result, ttl=300)
|
||
|
||
return result
|