"""Tushare Pro API 客户端封装 负责所有 Tushare 数据获取,内置缓存和限流。 """ import time import logging import tushare as ts import pandas as pd from datetime import datetime, timedelta from app.config import settings from app.data.cache import cache from app.db.error_logger import log_error_background logger = logging.getLogger(__name__) class TushareClient: def __init__(self): self.token = settings.tushare_token self.pro: ts.pro_api | None = None self._last_request_time: float = 0 self._initialized = False def _ensure_init(self): if not self._initialized: if not self.token: raise ValueError("Tushare token 未配置,请设置 ASTOCK_TUSHARE_TOKEN") # 直接把 token 传给 pro_api,避免 tushare 尝试在用户主目录写 tk.csv self.pro = ts.pro_api(self.token) self._initialized = True def _rate_limit(self): elapsed = time.time() - self._last_request_time if elapsed < settings.tushare_request_delay: time.sleep(settings.tushare_request_delay - elapsed) self._last_request_time = time.time() def _fetch_with_retry(self, func, *args, **kwargs) -> pd.DataFrame: self._ensure_init() for attempt in range(settings.tushare_max_retry): try: self._rate_limit() result = func(*args, **kwargs) if result is not None and not result.empty: return result return pd.DataFrame() except Exception as e: logger.warning(f"Tushare 请求失败 (attempt {attempt + 1}): {e}") if attempt < settings.tushare_max_retry - 1: time.sleep((2 ** attempt) * 1) else: logger.error(f"Tushare 请求最终失败: {e}") log_error_background( "tushare", f"Tushare 请求最终失败: {e}", ) return pd.DataFrame() return pd.DataFrame() def _cached_fetch(self, cache_key: str, ttl: int, func, *args, **kwargs) -> pd.DataFrame: cached = cache.get(cache_key) if cached is not None: return cached result = self._fetch_with_retry(func, *args, **kwargs) if not result.empty: # Tushare 返回降序(最新在前),排序为升序(最旧在前)以便分析 if "trade_date" in result.columns: result = result.sort_values("trade_date").reset_index(drop=True) cache.set(cache_key, result, ttl) return result # ── 交易日历 ── def get_trade_dates(self, start: str = None, end: str = None) -> list[str]: if not start: start = (datetime.now() - timedelta(days=60)).strftime("%Y%m%d") if not end: end = datetime.now().strftime("%Y%m%d") key = f"trade_cal:{start}:{end}" df = self._cached_fetch( key, settings.cache_ttl_static, lambda: self.pro.trade_cal( exchange="SSE", start_date=start, end_date=end, fields="cal_date,is_open" ) ) if df.empty: return [] return df[df["is_open"] == 1]["cal_date"].sort_values().tolist() def get_latest_trade_date(self) -> str: """获取有数据的最新交易日。 盘中时 daily 数据尚未更新,需要回退到上一个交易日。 策略:如果当天是交易日且已收盘(>15:30),用当天; 否则用前一个交易日。 """ dates = self.get_trade_dates() today = datetime.now().strftime("%Y%m%d") now_hour = datetime.now().hour now_minute = datetime.now().minute past_dates = [d for d in dates if d <= today] if not past_dates: return today # 如果当天是交易日但还没收盘(<15:30),daily 数据不可用,回退 if past_dates[-1] == today and (now_hour < 15 or (now_hour == 15 and now_minute < 30)): return past_dates[-2] if len(past_dates) >= 2 else today return past_dates[-1] # ── 全市场日线(用于计算涨跌家数)── def get_daily_all(self, trade_date: str) -> pd.DataFrame: key = f"daily_all:{trade_date}" return self._cached_fetch( key, settings.cache_ttl_daily, lambda: self.pro.daily( trade_date=trade_date, fields="ts_code,trade_date,open,high,low,close,pre_close,change,pct_chg,vol,amount" ) ) # ── 涨跌停数据 ── def get_limit_list(self, trade_date: str) -> pd.DataFrame: key = f"limit_list:{trade_date}" return self._cached_fetch( key, settings.cache_ttl_daily, lambda: self.pro.limit_list_d( trade_date=trade_date, fields="ts_code,trade_date,name,close,pct_chg,fd_amount,first_time,last_time,open_times,up_stat,limit_times,limit" ) ) # ── 指数日线 ── def get_index_daily(self, ts_code: str = "000001.SH", days: int = 30) -> pd.DataFrame: end = datetime.now().strftime("%Y%m%d") start = (datetime.now() - timedelta(days=days + 15)).strftime("%Y%m%d") key = f"index_daily:{ts_code}:{end}" return self._cached_fetch( key, settings.cache_ttl_daily, lambda: self.pro.index_daily( ts_code=ts_code, start_date=start, end_date=end, fields="ts_code,trade_date,close,open,high,low,vol,amount" ) ) # ── 板块数据 ── def get_ths_index_list(self, index_type: str = "N") -> pd.DataFrame: """获取同花顺板块列表,N=概念板块""" key = f"ths_index_list:{index_type}" return self._cached_fetch( key, settings.cache_ttl_static, lambda: self.pro.ths_index( exchange="A", type=index_type, fields="ts_code,name,count,list_date" ) ) def get_ths_daily(self, ts_code: str, days: int = 5) -> pd.DataFrame: end = datetime.now().strftime("%Y%m%d") start = (datetime.now() - timedelta(days=days + 10)).strftime("%Y%m%d") key = f"ths_daily:{ts_code}:{end}:{days}" return self._cached_fetch( key, settings.cache_ttl_sector, lambda: self.pro.ths_daily( ts_code=ts_code, start_date=start, end_date=end, fields="ts_code,trade_date,close,open,high,low,pct_change,vol,turnover_rate" ) ) def get_ths_members(self, ts_code: str) -> pd.DataFrame: """获取板块成分股 返回列: ts_code(板块代码), con_code(成分股代码), con_name(成分股名称) """ key = f"ths_member:{ts_code}" return self._cached_fetch( key, settings.cache_ttl_static, lambda: self.pro.ths_member( ts_code=ts_code, fields="ts_code,con_code,con_name" ) ) # ── 板块资金流向 ── def get_sector_moneyflow(self, trade_date: str) -> pd.DataFrame: key = f"sector_mf:{trade_date}" return self._cached_fetch( key, settings.cache_ttl_sector, lambda: self.pro.moneyflow_ind_ths( trade_date=trade_date, fields="ts_code,trade_date,industry,net_amount,buy_elg_amount,sell_elg_amount,buy_lg_amount,sell_lg_amount" ) ) # ── 个股资金流向 ── def get_stock_moneyflow(self, ts_code: str, days: int = 5) -> pd.DataFrame: end = self.get_latest_trade_date() start_dt = datetime.strptime(end, "%Y%m%d") - timedelta(days=days + 10) start = start_dt.strftime("%Y%m%d") key = f"stock_mf:{ts_code}:{end}:{days}" return self._cached_fetch( key, settings.cache_ttl_daily, lambda: self.pro.moneyflow( ts_code=ts_code, start_date=start, end_date=end, fields="ts_code,trade_date,buy_elg_amount,sell_elg_amount,buy_lg_amount,sell_lg_amount,buy_md_amount,sell_md_amount,buy_sm_amount,sell_sm_amount,net_mf_amount" ) ) def get_moneyflow_batch(self, trade_date: str) -> pd.DataFrame: """批量获取当日全市场资金流向""" key = f"mf_batch:{trade_date}" return self._cached_fetch( key, settings.cache_ttl_daily, lambda: self.pro.moneyflow( trade_date=trade_date, fields="ts_code,trade_date,buy_elg_amount,sell_elg_amount,buy_lg_amount,sell_lg_amount,buy_md_amount,sell_md_amount,buy_sm_amount,sell_sm_amount,net_mf_amount" ) ) # ── 日线基础指标(PE/PB/换手率/市值)── def get_daily_basic(self, trade_date: str) -> pd.DataFrame: key = f"daily_basic:{trade_date}" return self._cached_fetch( key, settings.cache_ttl_daily, lambda: self.pro.daily_basic( trade_date=trade_date, fields="ts_code,trade_date,turnover_rate,volume_ratio,pe,pb,total_mv,circ_mv" ) ) # ── 个股日线(历史K线)── def get_stock_daily(self, ts_code: str, days: int = 120) -> pd.DataFrame: end = datetime.now().strftime("%Y%m%d") start = (datetime.now() - timedelta(days=days + 30)).strftime("%Y%m%d") key = f"stock_daily:{ts_code}:{end}:{days}" return self._cached_fetch( key, settings.cache_ttl_daily, lambda: self.pro.daily( ts_code=ts_code, start_date=start, end_date=end, fields="ts_code,trade_date,open,high,low,close,pre_close,change,pct_chg,vol,amount" ) ) # ── 股票基本信息 ── def get_stock_basic(self) -> pd.DataFrame: key = "stock_basic" return self._cached_fetch( key, settings.cache_ttl_static, lambda: self.pro.stock_basic( exchange="", list_status="L", fields="ts_code,symbol,name,area,industry,market,list_date" ) ) # ── 新闻快讯 ── def get_news( self, source: str, start_time: datetime, end_time: datetime, limit: int = 30, ) -> pd.DataFrame: """获取 Tushare 新闻快讯。 Tushare news 接口不同账号权限不一,失败时返回空 DataFrame, 不阻断其他新闻源或推荐流程。该接口常见限频为 1 次/分钟或 2 次/小时,因此这里不复用通用重试逻辑,避免失败重试继续消耗配额。 """ cache_key = f"news:{source}:{start_time:%Y%m%d%H}:{end_time:%Y%m%d%H}:{limit}" cached = cache.get(cache_key) if cached is not None: return cached try: self._ensure_init() self._rate_limit() result = self.pro.news( src=source, start_date=start_time.strftime("%Y-%m-%d %H:%M:%S"), end_date=end_time.strftime("%Y-%m-%d %H:%M:%S"), limit=limit, ) if result is None or result.empty: return pd.DataFrame() cache.set(cache_key, result, 600) return result except Exception as e: logger.warning("Tushare 新闻请求失败 source=%s: %s", source, e) log_error_background("tushare_news", f"Tushare 新闻请求失败 source={source}: {e}") return pd.DataFrame() tushare_client = TushareClient()