astock-agent/backend/app/api/sectors.py
2026-04-23 23:24:54 +08:00

155 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""板块分析 API"""
from fastapi import APIRouter
from app.analysis.sector_realtime import enrich_sectors_with_realtime, get_today_realtime_sector_board
from app.config import should_prefer_realtime_today, today_trade_date
from app.data.tushare_client import tushare_client
from app.data.cache import cache
from app.engine.recommender import get_latest_sectors
router = APIRouter(prefix="/api/sectors", tags=["sectors"])
@router.get("/hot")
async def get_hot_sectors(limit: int = 10):
"""获取今日主线主题排名(盘中自动补充实时数据并统一归一)"""
sectors = await get_latest_sectors()
snapshot_trade_date = sectors[0].trade_date if sectors else ""
if should_prefer_realtime_today(snapshot_trade_date) or snapshot_trade_date != today_trade_date():
realtime_sectors = await get_today_realtime_sector_board(limit=max(limit, 20))
if realtime_sectors:
sectors = realtime_sectors
else:
sectors = await enrich_sectors_with_realtime(sectors)
else:
sectors = await enrich_sectors_with_realtime(sectors)
trade_date = sectors[0].trade_date if sectors else ""
sectors_data = [
{
"sector_code": s.sector_code,
"sector_name": s.sector_name,
"board_type": s.board_type,
"theme_id": s.theme_id,
"theme_name": s.theme_name,
"theme_aliases": s.theme_aliases,
"pct_change": s.pct_change,
"capital_inflow": s.capital_inflow,
"limit_up_count": s.limit_up_count,
"days_continuous": s.days_continuous,
"heat_score": s.heat_score,
"stage": s.stage,
# 增强分析字段
"member_count": s.member_count,
"leading_stocks": s.leading_stocks,
"pct_trend": s.pct_trend,
"turnover_avg": s.turnover_avg,
"main_force_ratio": s.main_force_ratio,
"trade_date": trade_date,
"realtime_pct_change": s.realtime_pct_change,
"realtime_limit_up_count": s.realtime_limit_up_count,
"realtime_amount": s.realtime_amount,
"realtime_turnover_rate": s.realtime_turnover_rate,
"realtime_up_count": s.realtime_up_count,
"realtime_down_count": s.realtime_down_count,
"leading_stocks_realtime": s.leading_stocks_realtime,
"is_realtime": s.is_realtime,
"data_mode": s.data_mode,
"source": s.source,
}
for s in sectors[:limit]
]
realtime_enabled = any(s.get("is_realtime") for s in sectors_data)
mode = sectors[0].data_mode if realtime_enabled and sectors else "daily_snapshot"
for s in sectors_data:
s["data_mode"] = mode
s["structure_trade_date"] = trade_date
return sectors_data
@router.get("/rotation")
async def get_sector_rotation(days: int = 5):
"""获取近N日板块轮动数据用于热力图"""
# 检查缓存
cache_key = f"sector_rotation:{days}"
cached = cache.get(cache_key)
if cached is not None:
return cached
trade_date = tushare_client.get_latest_trade_date()
# 获取交易日历
trade_dates_df = tushare_client.get_trade_dates()
today = trade_date
past_dates = [d for d in trade_dates_df if d <= today]
# 取最近 N 天
recent_dates = past_dates[-days:] if len(past_dates) >= days else past_dates
# 获取板块指数列表用于名字映射
index_list = tushare_client.get_ths_index_list()
name_map = {}
if not index_list.empty:
for _, row in index_list.iterrows():
name_map[row["ts_code"]] = row["name"]
all_sectors = []
for td in recent_dates:
df = tushare_client.get_sector_moneyflow(td)
if df.empty:
continue
for _, row in df.iterrows():
code = row.get("ts_code", "")
# Use industry field from moneyflow data, fallback to name_map
industry_name = row.get("industry", "") or name_map.get(code, code)
all_sectors.append({
"sector_code": code,
"sector_name": industry_name,
"trade_date": td,
"net_amount": round(float(row.get("net_amount", 0) or 0), 2),
})
# 获取板块日线来补充涨跌幅
sector_codes = list(set(s["sector_code"] for s in all_sectors))
sector_pct_map: dict[str, dict[str, float]] = {}
for code in sector_codes:
df_daily = tushare_client.get_ths_daily(code, days=days + 10)
if not df_daily.empty:
for _, r in df_daily.iterrows():
if r["trade_date"] in recent_dates:
if code not in sector_pct_map:
sector_pct_map[code] = {}
sector_pct_map[code][r["trade_date"]] = float(r.get("pct_change", 0) or 0)
# 按板块分组
sector_map: dict[str, dict] = {}
for s in all_sectors:
code = s["sector_code"]
if code not in sector_map:
sector_map[code] = {
"sector_code": code,
"sector_name": s["sector_name"],
"daily_data": [],
}
pct = sector_pct_map.get(code, {}).get(s["trade_date"], 0)
sector_map[code]["daily_data"].append({
"trade_date": s["trade_date"],
"pct_change": round(pct, 2),
"net_amount": s["net_amount"],
})
# 按最近一天涨幅排序,取 top 20
sorted_sectors = sorted(
sector_map.values(),
key=lambda x: max((d["pct_change"] for d in x["daily_data"]), default=0),
reverse=True,
)[:20]
result = {"trade_date": trade_date, "dates": recent_dates, "sectors": sorted_sectors}
# 写入缓存TTL 300秒5分钟
cache.set(cache_key, result, ttl=300)
return result