astock-agent/backend/app/api/sectors.py
2026-04-15 08:58:21 +08:00

194 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""板块分析 API"""
import logging
from fastapi import APIRouter
from app.config import is_market_session
from app.data.tushare_client import tushare_client
from app.data.tencent_client import get_realtime_quotes_batch
from app.engine.recommender import get_latest_sectors
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/sectors", tags=["sectors"])
async def _enrich_sectors_realtime(sectors_data: list[dict]) -> list[dict]:
"""盘中时,用腾讯实时行情补充板块涨幅和涨停数"""
if not is_market_session():
for s in sectors_data:
s["realtime_pct_change"] = None
s["realtime_limit_up_count"] = None
s["is_realtime"] = False
return sectors_data
# 收集所有板块的成分股代码
sector_members: dict[str, list[str]] = {}
all_codes: list[str] = []
for s in sectors_data:
code = s["sector_code"]
try:
df = tushare_client.get_ths_members(code)
members = df["con_code"].tolist() if not df.empty else []
except Exception:
members = []
sector_members[code] = members
all_codes.extend(members)
if not all_codes:
for s in sectors_data:
s["realtime_pct_change"] = None
s["realtime_limit_up_count"] = None
s["is_realtime"] = True
return sectors_data
# 批量获取实时报价
try:
quotes = await get_realtime_quotes_batch(all_codes)
except Exception:
logger.warning("获取板块实时行情失败,回退到日级数据")
for s in sectors_data:
s["realtime_pct_change"] = None
s["realtime_limit_up_count"] = None
s["is_realtime"] = False
return sectors_data
# 为每个板块计算实时指标
for s in sectors_data:
members = sector_members.get(s["sector_code"], [])
member_quotes = [quotes[c] for c in members if c in quotes]
if member_quotes:
pct_changes = [q.pct_chg for q in member_quotes]
s["realtime_pct_change"] = round(sum(pct_changes) / len(pct_changes), 2)
s["realtime_limit_up_count"] = sum(
1 for q in member_quotes
if q.limit_up and q.price >= q.limit_up * 0.995
)
# 盘中更新领涨股
sorted_quotes = sorted(member_quotes, key=lambda q: q.pct_chg, reverse=True)
s["leading_stocks_realtime"] = [
{
"ts_code": q.ts_code,
"name": q.name or q.ts_code,
"pct_chg": round(q.pct_chg, 2),
"amount": round(q.amount, 0),
}
for q in sorted_quotes[:3]
]
else:
s["realtime_pct_change"] = None
s["realtime_limit_up_count"] = None
s["leading_stocks_realtime"] = None
s["is_realtime"] = True
# 盘中按实时涨幅重新排序
sectors_data.sort(key=lambda s: s.get("realtime_pct_change") or 0, reverse=True)
return sectors_data
@router.get("/hot")
async def get_hot_sectors(limit: int = 10):
"""获取热门板块排名(盘中自动补充实时数据)"""
sectors = await get_latest_sectors()
sectors_data = [
{
"sector_code": s.sector_code,
"sector_name": s.sector_name,
"pct_change": s.pct_change,
"capital_inflow": s.capital_inflow,
"limit_up_count": s.limit_up_count,
"days_continuous": s.days_continuous,
"heat_score": s.heat_score,
"stage": s.stage,
# 增强分析字段
"member_count": s.member_count,
"leading_stocks": s.leading_stocks,
"pct_trend": s.pct_trend,
"turnover_avg": s.turnover_avg,
"main_force_ratio": s.main_force_ratio,
}
for s in sectors[:limit]
]
sectors_data = await _enrich_sectors_realtime(sectors_data)
return sectors_data
@router.get("/rotation")
async def get_sector_rotation(days: int = 5):
"""获取近N日板块轮动数据用于热力图"""
trade_date = tushare_client.get_latest_trade_date()
# 获取交易日历
trade_dates_df = tushare_client.get_trade_dates()
today = trade_date
past_dates = [d for d in trade_dates_df if d <= today]
# 取最近 N 天
recent_dates = past_dates[-days:] if len(past_dates) >= days else past_dates
# 获取板块指数列表用于名字映射
index_list = tushare_client.get_ths_index_list()
name_map = {}
if not index_list.empty:
for _, row in index_list.iterrows():
name_map[row["ts_code"]] = row["name"]
all_sectors = []
for td in recent_dates:
df = tushare_client.get_sector_moneyflow(td)
if df.empty:
continue
for _, row in df.iterrows():
code = row.get("ts_code", "")
# Use industry field from moneyflow data, fallback to name_map
industry_name = row.get("industry", "") or name_map.get(code, code)
all_sectors.append({
"sector_code": code,
"sector_name": industry_name,
"trade_date": td,
"net_amount": round(float(row.get("net_amount", 0) or 0), 2),
})
# 获取板块日线来补充涨跌幅
sector_codes = list(set(s["sector_code"] for s in all_sectors))
sector_pct_map: dict[str, dict[str, float]] = {}
for code in sector_codes:
df_daily = tushare_client.get_ths_daily(code, days=days + 10)
if not df_daily.empty:
for _, r in df_daily.iterrows():
if r["trade_date"] in recent_dates:
if code not in sector_pct_map:
sector_pct_map[code] = {}
sector_pct_map[code][r["trade_date"]] = float(r.get("pct_change", 0) or 0)
# 按板块分组
sector_map: dict[str, dict] = {}
for s in all_sectors:
code = s["sector_code"]
if code not in sector_map:
sector_map[code] = {
"sector_code": code,
"sector_name": s["sector_name"],
"daily_data": [],
}
pct = sector_pct_map.get(code, {}).get(s["trade_date"], 0)
sector_map[code]["daily_data"].append({
"trade_date": s["trade_date"],
"pct_change": round(pct, 2),
"net_amount": s["net_amount"],
})
# 按最近一天涨幅排序,取 top 20
sorted_sectors = sorted(
sector_map.values(),
key=lambda x: max((d["pct_change"] for d in x["daily_data"]), default=0),
reverse=True,
)[:20]
return {"trade_date": trade_date, "dates": recent_dates, "sectors": sorted_sectors}