astock-agent/backend/app/api/sectors.py
2026-05-14 11:10:17 +08:00

167 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""板块分析 API"""
from fastapi import APIRouter
from app.data.tushare_client import tushare_client
from app.data.cache import cache
from app.engine.recommender import get_latest_sectors
router = APIRouter(prefix="/api/sectors", tags=["sectors"])
@router.get("/hot")
async def get_hot_sectors(limit: int = 10):
"""获取最新主线主题排名。
页面访问只读数据库里的扫描结论,不在 GET 请求中拉取外部实时行情。
"""
sectors = await get_latest_sectors()
trade_date = sectors[0].trade_date if sectors else ""
sectors_data = [
{
"sector_code": s.sector_code,
"sector_name": s.sector_name,
"board_type": s.board_type,
"theme_id": s.theme_id,
"theme_name": s.theme_name,
"theme_aliases": s.theme_aliases,
"pct_change": s.pct_change,
"capital_inflow": s.capital_inflow,
"limit_up_count": s.limit_up_count,
"days_continuous": s.days_continuous,
"heat_score": s.heat_score,
"stage": s.stage,
# 增强分析字段
"member_count": s.member_count,
"leading_stocks": s.leading_stocks,
"pct_trend": s.pct_trend,
"turnover_avg": s.turnover_avg,
"main_force_ratio": s.main_force_ratio,
"trade_date": trade_date,
"realtime_pct_change": s.realtime_pct_change,
"realtime_limit_up_count": s.realtime_limit_up_count,
"realtime_amount": s.realtime_amount,
"realtime_turnover_rate": s.realtime_turnover_rate,
"realtime_up_count": s.realtime_up_count,
"realtime_down_count": s.realtime_down_count,
"leading_stocks_realtime": s.leading_stocks_realtime,
"is_realtime": s.is_realtime,
"data_mode": s.data_mode,
"source": s.source,
"data_status": s.data_status,
"source_detail": s.source_detail,
"catalyst_score": s.catalyst_score,
"catalyst_count": s.catalyst_count,
"catalyst_reasons": s.catalyst_reasons,
}
for s in sectors[:limit]
]
realtime_enabled = any(s.get("is_realtime") for s in sectors_data)
mode = sectors[0].data_mode if realtime_enabled and sectors else "daily_snapshot"
status = _derive_status(sectors_data)
for s in sectors_data:
s["data_mode"] = mode
s["data_status"] = status
s["structure_trade_date"] = trade_date
return sectors_data
def _derive_status(sectors: list[dict]) -> str:
statuses = {str(s.get("data_status") or "fresh") for s in sectors}
if not statuses:
return "snapshot"
if "fresh" in statuses:
return "fresh" if len(statuses) == 1 else "mixed"
if "stale" in statuses:
return "stale"
if "fallback" in statuses:
return "fallback"
return next(iter(statuses))
@router.get("/rotation")
async def get_sector_rotation(days: int = 5):
"""获取近N日板块轮动数据用于热力图"""
# 检查缓存
cache_key = f"sector_rotation:{days}"
cached = cache.get(cache_key)
if cached is not None:
return cached
trade_date = tushare_client.get_latest_trade_date()
# 获取交易日历
trade_dates_df = tushare_client.get_trade_dates()
today = trade_date
past_dates = [d for d in trade_dates_df if d <= today]
# 取最近 N 天
recent_dates = past_dates[-days:] if len(past_dates) >= days else past_dates
# 获取板块指数列表用于名字映射
index_list = tushare_client.get_ths_index_list()
name_map = {}
if not index_list.empty:
for _, row in index_list.iterrows():
name_map[row["ts_code"]] = row["name"]
all_sectors = []
for td in recent_dates:
df = tushare_client.get_sector_moneyflow(td)
if df.empty:
continue
for _, row in df.iterrows():
code = row.get("ts_code", "")
# Use industry field from moneyflow data, fallback to name_map
industry_name = row.get("industry", "") or name_map.get(code, code)
all_sectors.append({
"sector_code": code,
"sector_name": industry_name,
"trade_date": td,
"net_amount": round(float(row.get("net_amount", 0) or 0), 2),
})
# 获取板块日线来补充涨跌幅
sector_codes = list(set(s["sector_code"] for s in all_sectors))
sector_pct_map: dict[str, dict[str, float]] = {}
for code in sector_codes:
df_daily = tushare_client.get_ths_daily(code, days=days + 10)
if not df_daily.empty:
for _, r in df_daily.iterrows():
if r["trade_date"] in recent_dates:
if code not in sector_pct_map:
sector_pct_map[code] = {}
sector_pct_map[code][r["trade_date"]] = float(r.get("pct_change", 0) or 0)
# 按板块分组
sector_map: dict[str, dict] = {}
for s in all_sectors:
code = s["sector_code"]
if code not in sector_map:
sector_map[code] = {
"sector_code": code,
"sector_name": s["sector_name"],
"daily_data": [],
}
pct = sector_pct_map.get(code, {}).get(s["trade_date"], 0)
sector_map[code]["daily_data"].append({
"trade_date": s["trade_date"],
"pct_change": round(pct, 2),
"net_amount": s["net_amount"],
})
# 按最近一天涨幅排序,取 top 20
sorted_sectors = sorted(
sector_map.values(),
key=lambda x: max((d["pct_change"] for d in x["daily_data"]), default=0),
reverse=True,
)[:20]
result = {"trade_date": trade_date, "dates": recent_dates, "sectors": sorted_sectors}
# 写入缓存TTL 300秒5分钟
cache.set(cache_key, result, ttl=300)
return result