"""市场广度客户端。 职责: 1. 上涨/下跌/平盘家数:优先用东方财富全市场分页行情聚合 2. 涨停/跌停家数:优先用东方财富专门池接口 3. 接口不可用时回退到实时行情阈值估算 该模块只负责“市场广度”数据,不掺杂温度分数计算。 """ from __future__ import annotations import logging import asyncio import httpx from app.config import today_trade_date from app.data.cache import cache from app.data.eastmoney_client import SECTOR_HEADERS, SECTOR_LIST_URL, _parse_eastmoney_json, get_a_share_realtime_ranking from app.data.models import MarketBreadth from app.db.error_logger import log_error logger = logging.getLogger(__name__) ZT_POOL_URL = "https://push2ex.eastmoney.com/getTopicZTPool" DT_POOL_URL = "https://push2ex.eastmoney.com/getTopicDTPool" MIN_RELIABLE_SAMPLE_COUNT = 4500 _market_breadth_task: asyncio.Task | None = None async def get_market_breadth() -> MarketBreadth: """获取市场广度快照。""" global _market_breadth_task cache_key = f"market_breadth:{today_trade_date()}" cached = cache.get(cache_key) if cached is not None: return cached if _market_breadth_task and not _market_breadth_task.done(): return await _market_breadth_task _market_breadth_task = asyncio.create_task(_load_market_breadth(cache_key)) try: return await _market_breadth_task finally: if _market_breadth_task.done(): _market_breadth_task = None async def _load_market_breadth(cache_key: str) -> MarketBreadth: quotes = await get_a_share_realtime_ranking(page_size=6000) if quotes and len(quotes) >= MIN_RELIABLE_SAMPLE_COUNT: up_count = sum(1 for q in quotes if q.get("pct_chg", 0) > 0) down_count = sum(1 for q in quotes if q.get("pct_chg", 0) < 0) flat_count = sum(1 for q in quotes if q.get("pct_chg", 0) == 0) limit_up_count, limit_down_count, limit_source = await _get_limit_counts(quotes) breadth = MarketBreadth( trade_date=today_trade_date(), up_count=up_count, down_count=down_count, flat_count=flat_count, limit_up_count=limit_up_count, limit_down_count=limit_down_count, total_count=len(quotes), sample_count=len(quotes), source=f"eastmoney_quotes+{limit_source}", reliable=True, limit_counts_reliable=(limit_source == "eastmoney_pool"), ) cache.set(cache_key, breadth, ttl=60) return breadth logger.warning( "市场广度实时样本不足,quotes=%s,小于可靠阈值 %s,回退到基线口径", len(quotes), MIN_RELIABLE_SAMPLE_COUNT, ) breadth = MarketBreadth( trade_date=today_trade_date(), total_count=len(quotes), sample_count=len(quotes), source="snapshot", reliable=False, limit_counts_reliable=False, ) cache.set(cache_key, breadth, ttl=30) return breadth async def _get_limit_counts(quotes: list[dict]) -> tuple[int, int, str]: """优先走专门池接口,失败时回退到实时行情阈值估算。""" pool = await _get_limit_counts_from_pool() if pool is not None: return pool[0], pool[1], "eastmoney_pool" limit_up_count = sum(1 for q in quotes if q.get("pct_chg", 0) >= _limit_threshold(q.get("ts_code", ""))) limit_down_count = sum(1 for q in quotes if q.get("pct_chg", 0) <= -_limit_threshold(q.get("ts_code", ""))) return limit_up_count, limit_down_count, "eastmoney_quote_estimate" async def _get_limit_counts_from_pool() -> tuple[int, int] | None: cache_key = f"market_limit_pool:{today_trade_date()}" cached = cache.get(cache_key) if cached is not None: return cached params = { "ut": "7eea3edcaed734bea9cbfc24409ed989", "dpt": "wz.ztzt", "Pageindex": "0", "pagesize": "1000", "date": today_trade_date(), } try: async with httpx.AsyncClient(timeout=10, follow_redirects=True) as client: zt_resp = await client.get(ZT_POOL_URL, params=params, headers=SECTOR_HEADERS) zt_data = _parse_eastmoney_json(zt_resp, "涨停池") dt_resp = await client.get(DT_POOL_URL, params=params, headers=SECTOR_HEADERS) dt_data = _parse_eastmoney_json(dt_resp, "跌停池") zt_items = _extract_pool_items(zt_data) dt_items = _extract_pool_items(dt_data) result = (len(zt_items), len(dt_items)) cache.set(cache_key, result, ttl=60) return result except Exception as e: logger.warning("东方财富涨跌停池获取失败: %s", e) await log_error( "market_breadth", f"东方财富涨跌停池获取失败: {e}", detail=f"trade_date={today_trade_date()}", notify=False, ) return None def _extract_pool_items(data: dict) -> list[dict]: payload = data.get("data") or {} if isinstance(payload, dict): if isinstance(payload.get("pool"), list): return payload["pool"] if isinstance(payload.get("diff"), list): return payload["diff"] return [] def _limit_threshold(ts_code: str) -> float: code = ts_code.split(".")[0] if ts_code else "" if code.startswith(("300", "301", "688")): return 19.8 return 9.8