astock-agent/backend/app/data/market_breadth_client.py
2026-04-28 13:15:11 +08:00

138 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""市场广度客户端。
职责:
1. 上涨/下跌/平盘家数:优先用东方财富全市场分页行情聚合
2. 涨停/跌停家数:优先用东方财富专门池接口
3. 接口不可用时回退到实时行情阈值估算
该模块只负责“市场广度”数据,不掺杂温度分数计算。
"""
from __future__ import annotations
import logging
import httpx
from app.config import today_trade_date
from app.data.cache import cache
from app.data.eastmoney_client import SECTOR_HEADERS, SECTOR_LIST_URL, _parse_eastmoney_json, get_a_share_realtime_ranking
from app.data.models import MarketBreadth
from app.db.error_logger import log_error
logger = logging.getLogger(__name__)
ZT_POOL_URL = "https://push2ex.eastmoney.com/getTopicZTPool"
DT_POOL_URL = "https://push2ex.eastmoney.com/getTopicDTPool"
MIN_RELIABLE_SAMPLE_COUNT = 4500
async def get_market_breadth() -> MarketBreadth:
"""获取市场广度快照。"""
cache_key = f"market_breadth:{today_trade_date()}"
cached = cache.get(cache_key)
if cached is not None:
return cached
quotes = await get_a_share_realtime_ranking(page_size=6000)
if quotes and len(quotes) >= MIN_RELIABLE_SAMPLE_COUNT:
up_count = sum(1 for q in quotes if q.get("pct_chg", 0) > 0)
down_count = sum(1 for q in quotes if q.get("pct_chg", 0) < 0)
flat_count = sum(1 for q in quotes if q.get("pct_chg", 0) == 0)
limit_up_count, limit_down_count, limit_source = await _get_limit_counts(quotes)
breadth = MarketBreadth(
trade_date=today_trade_date(),
up_count=up_count,
down_count=down_count,
flat_count=flat_count,
limit_up_count=limit_up_count,
limit_down_count=limit_down_count,
total_count=len(quotes),
sample_count=len(quotes),
source=f"eastmoney_quotes+{limit_source}",
reliable=True,
limit_counts_reliable=(limit_source == "eastmoney_pool"),
)
cache.set(cache_key, breadth, ttl=60)
return breadth
logger.warning(
"市场广度实时样本不足quotes=%s,小于可靠阈值 %s,回退到基线口径",
len(quotes),
MIN_RELIABLE_SAMPLE_COUNT,
)
breadth = MarketBreadth(
trade_date=today_trade_date(),
total_count=len(quotes),
sample_count=len(quotes),
source="snapshot",
reliable=False,
limit_counts_reliable=False,
)
cache.set(cache_key, breadth, ttl=30)
return breadth
async def _get_limit_counts(quotes: list[dict]) -> tuple[int, int, str]:
"""优先走专门池接口,失败时回退到实时行情阈值估算。"""
pool = await _get_limit_counts_from_pool()
if pool is not None:
return pool[0], pool[1], "eastmoney_pool"
limit_up_count = sum(1 for q in quotes if q.get("pct_chg", 0) >= _limit_threshold(q.get("ts_code", "")))
limit_down_count = sum(1 for q in quotes if q.get("pct_chg", 0) <= -_limit_threshold(q.get("ts_code", "")))
return limit_up_count, limit_down_count, "eastmoney_quote_estimate"
async def _get_limit_counts_from_pool() -> tuple[int, int] | None:
cache_key = f"market_limit_pool:{today_trade_date()}"
cached = cache.get(cache_key)
if cached is not None:
return cached
params = {
"ut": "7eea3edcaed734bea9cbfc24409ed989",
"dpt": "wz.ztzt",
"Pageindex": "0",
"pagesize": "1000",
"date": today_trade_date(),
}
try:
async with httpx.AsyncClient(timeout=10, follow_redirects=True) as client:
zt_resp = await client.get(ZT_POOL_URL, params=params, headers=SECTOR_HEADERS)
zt_data = _parse_eastmoney_json(zt_resp, "涨停池")
dt_resp = await client.get(DT_POOL_URL, params=params, headers=SECTOR_HEADERS)
dt_data = _parse_eastmoney_json(dt_resp, "跌停池")
zt_items = _extract_pool_items(zt_data)
dt_items = _extract_pool_items(dt_data)
result = (len(zt_items), len(dt_items))
cache.set(cache_key, result, ttl=60)
return result
except Exception as e:
logger.warning("东方财富涨跌停池获取失败: %s", e)
await log_error(
"market_breadth",
f"东方财富涨跌停池获取失败: {e}",
detail=f"trade_date={today_trade_date()}",
notify=False,
)
return None
def _extract_pool_items(data: dict) -> list[dict]:
payload = data.get("data") or {}
if isinstance(payload, dict):
if isinstance(payload.get("pool"), list):
return payload["pool"]
if isinstance(payload.get("diff"), list):
return payload["diff"]
return []
def _limit_threshold(ts_code: str) -> float:
code = ts_code.split(".")[0] if ts_code else ""
if code.startswith(("300", "301", "688")):
return 19.8
return 9.8