astock-agent/backend/app/api/sectors.py
2026-04-16 16:40:56 +08:00

219 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""板块分析 API"""
import logging
from fastapi import APIRouter
from app.config import is_market_session
from app.data.tushare_client import tushare_client
from app.data.eastmoney_client import get_sector_realtime_ranking
from app.data.cache import cache
from app.engine.recommender import get_latest_sectors
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/sectors", tags=["sectors"])
def _match_sector_name(em_name: str, ts_name: str) -> bool:
"""东方财富板块名与 Tushare 板块名模糊匹配
东方财富用"酿酒行业"Tushare 可能叫"白酒"
东方财富用"汽车整车"Tushare 可能叫"汽车"
用包含匹配(短名在长名中)或尾部去掉"行业"后完全匹配。
"""
# 去掉常见后缀再做比较
em_clean = em_name.rstrip("行业").rstrip("板块").rstrip("概念").strip()
ts_clean = ts_name.rstrip("行业").rstrip("板块").rstrip("概念").strip()
if em_clean == ts_clean:
return True
# 短名包含在长名中
short, long = (em_clean, ts_clean) if len(em_clean) <= len(ts_clean) else (ts_clean, em_clean)
return short in long
async def _enrich_sectors_realtime(sectors_data: list[dict]) -> list[dict]:
"""盘中时,用东方财富实时板块数据补充涨幅和涨停数
一次请求替代之前腾讯批量获取数千只成分股的方式。
"""
if not is_market_session():
for s in sectors_data:
s["realtime_pct_change"] = None
s["realtime_limit_up_count"] = None
s["is_realtime"] = False
return sectors_data
# 从东方财富获取实时板块排名1次 HTTP 请求)
try:
em_sectors = await get_sector_realtime_ranking()
except Exception:
logger.warning("东方财富板块实时数据获取失败,回退到日级数据")
for s in sectors_data:
s["realtime_pct_change"] = None
s["realtime_limit_up_count"] = None
s["is_realtime"] = False
return sectors_data
if not em_sectors:
for s in sectors_data:
s["realtime_pct_change"] = None
s["realtime_limit_up_count"] = None
s["is_realtime"] = False
return sectors_data
# 构建东方财富板块名查找表(用于匹配)
em_name_map = {s["sector_name"]: s for s in em_sectors}
matched = 0
for s in sectors_data:
ts_name = s["sector_name"]
# 尝试匹配:先精确,再模糊
em_data = em_name_map.get(ts_name)
if not em_data:
# 模糊匹配
for em_s in em_sectors:
if _match_sector_name(em_s["sector_name"], ts_name):
em_data = em_s
break
if em_data:
matched += 1
s["realtime_pct_change"] = em_data["pct_change"]
s["is_realtime"] = True
# 涨停家数仍保留 Tushare 数据(东方财富此字段不可用)
s["realtime_limit_up_count"] = None
# 更新领涨股(东方财富直接提供)
if em_data.get("leading_stock_name"):
s["leading_stocks_realtime"] = [
{
"ts_code": em_data.get("leading_stock_code", ""),
"name": em_data.get("leading_stock_name", ""),
"pct_chg": em_data.get("leading_stock_pct", 0),
"amount": 0,
}
]
else:
s["realtime_pct_change"] = None
s["realtime_limit_up_count"] = None
s["is_realtime"] = False
logger.info(f"板块实时数据: {matched}/{len(sectors_data)} 匹配成功")
# 盘中按实时涨幅重新排序
sectors_data.sort(key=lambda s: s.get("realtime_pct_change") or s.get("pct_change") or 0, reverse=True)
return sectors_data
@router.get("/hot")
async def get_hot_sectors(limit: int = 10):
"""获取热门板块排名(盘中自动补充实时数据)"""
sectors = await get_latest_sectors()
sectors_data = [
{
"sector_code": s.sector_code,
"sector_name": s.sector_name,
"pct_change": s.pct_change,
"capital_inflow": s.capital_inflow,
"limit_up_count": s.limit_up_count,
"days_continuous": s.days_continuous,
"heat_score": s.heat_score,
"stage": s.stage,
# 增强分析字段
"member_count": s.member_count,
"leading_stocks": s.leading_stocks,
"pct_trend": s.pct_trend,
"turnover_avg": s.turnover_avg,
"main_force_ratio": s.main_force_ratio,
}
for s in sectors[:limit]
]
sectors_data = await _enrich_sectors_realtime(sectors_data)
return sectors_data
@router.get("/rotation")
async def get_sector_rotation(days: int = 5):
"""获取近N日板块轮动数据用于热力图"""
# 检查缓存
cache_key = f"sector_rotation:{days}"
cached = cache.get(cache_key)
if cached is not None:
return cached
trade_date = tushare_client.get_latest_trade_date()
# 获取交易日历
trade_dates_df = tushare_client.get_trade_dates()
today = trade_date
past_dates = [d for d in trade_dates_df if d <= today]
# 取最近 N 天
recent_dates = past_dates[-days:] if len(past_dates) >= days else past_dates
# 获取板块指数列表用于名字映射
index_list = tushare_client.get_ths_index_list()
name_map = {}
if not index_list.empty:
for _, row in index_list.iterrows():
name_map[row["ts_code"]] = row["name"]
all_sectors = []
for td in recent_dates:
df = tushare_client.get_sector_moneyflow(td)
if df.empty:
continue
for _, row in df.iterrows():
code = row.get("ts_code", "")
# Use industry field from moneyflow data, fallback to name_map
industry_name = row.get("industry", "") or name_map.get(code, code)
all_sectors.append({
"sector_code": code,
"sector_name": industry_name,
"trade_date": td,
"net_amount": round(float(row.get("net_amount", 0) or 0), 2),
})
# 获取板块日线来补充涨跌幅
sector_codes = list(set(s["sector_code"] for s in all_sectors))
sector_pct_map: dict[str, dict[str, float]] = {}
for code in sector_codes:
df_daily = tushare_client.get_ths_daily(code, days=days + 10)
if not df_daily.empty:
for _, r in df_daily.iterrows():
if r["trade_date"] in recent_dates:
if code not in sector_pct_map:
sector_pct_map[code] = {}
sector_pct_map[code][r["trade_date"]] = float(r.get("pct_change", 0) or 0)
# 按板块分组
sector_map: dict[str, dict] = {}
for s in all_sectors:
code = s["sector_code"]
if code not in sector_map:
sector_map[code] = {
"sector_code": code,
"sector_name": s["sector_name"],
"daily_data": [],
}
pct = sector_pct_map.get(code, {}).get(s["trade_date"], 0)
sector_map[code]["daily_data"].append({
"trade_date": s["trade_date"],
"pct_change": round(pct, 2),
"net_amount": s["net_amount"],
})
# 按最近一天涨幅排序,取 top 20
sorted_sectors = sorted(
sector_map.values(),
key=lambda x: max((d["pct_change"] for d in x["daily_data"]), default=0),
reverse=True,
)[:20]
result = {"trade_date": trade_date, "dates": recent_dates, "sectors": sorted_sectors}
# 写入缓存TTL 300秒5分钟
cache.set(cache_key, result, ttl=300)
return result