""" Tushare 数据封装 提供 A 股板块、个股行情数据获取接口(使用同花顺系列接口) """ import time import tushare as ts import pandas as pd from typing import Dict, List, Optional from datetime import datetime, timedelta from app.utils.logger import logger class TushareClient: """Tushare 数据客户端(同花顺系列接口)""" # 缓存数据,避免频繁请求 _cache = {} _cache_time = {} _last_request_time = 0 def __init__(self, token: str): """ 初始化客户端 Args: token: Tushare token """ self.token = token ts.set_token(token) self.pro = ts.pro_api() self.cache_ttl = 300 # 缓存5分钟 self.request_delay = 0.5 # 请求间隔(秒)- tushare 有频率限制 def _get_cached(self, key: str, fetch_func) -> pd.DataFrame: """获取缓存数据,支持重试""" now = datetime.now() # 检查缓存 if key in self._cache: cache_time = self._cache_time.get(key) if cache_time and (now - cache_time).seconds < self.cache_ttl: logger.debug(f"使用缓存数据: {key}") return self._cache[key] # 请求限流 elapsed = now.timestamp() - self._last_request_time if elapsed < self.request_delay: time.sleep(self.request_delay - elapsed) # 重试逻辑 max_retries = 3 for attempt in range(max_retries): try: self._last_request_time = time.time() df = fetch_func() if df is not None and not df.empty: self._cache[key] = df self._cache_time[key] = now logger.debug(f"获取数据成功: {key}") return df except Exception as e: error_msg = str(e) # 指数退避重试 if attempt < max_retries - 1: wait_time = (2 ** attempt) * 2 logger.warning( f"获取数据失败 {key} (尝试 {attempt + 1}/{max_retries}): {e}," f"等待 {wait_time}秒后重试..." ) time.sleep(wait_time) continue logger.error(f"获取数据失败 {key}: {e}") break return pd.DataFrame() def get_concept_sectors(self) -> pd.DataFrame: """ 获取概念板块列表 使用 ths_index 接口,type="N" 代表概念板块 Returns: 概念板块列表 """ def fetch(): # ths_index - 获取同花顺概念指数列表 return self.pro.ths_index(type='N') return self._get_cached('concept_sectors', fetch) def get_sector_daily(self, ts_code: str, start_date: str = None, end_date: str = None) -> pd.DataFrame: """ 获取板块日线行情 Args: ts_code: 板块指数代码(如 885823.TI) start_date: 开始日期 (YYYYMMDD) end_date: 结束日期 (YYYYMMDD) Returns: 板块日线数据 """ if not start_date: start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d') if not end_date: end_date = datetime.now().strftime('%Y%m%d') def fetch(): # ths_daily - 获取板块指数历史行情 return self.pro.ths_daily( ts_code=ts_code, start_date=start_date, end_date=end_date ) return self._get_cached(f'sector_daily_{ts_code}_{end_date}', fetch) def get_sector_members(self, ts_code: str) -> pd.DataFrame: """ 获取板块成分股 Args: ts_code: 板块指数代码(如 885823.TI) Returns: 成分股列表 """ def fetch(): # ths_member - 获取板块成分股 return self.pro.ths_member(ts_code=ts_code) return self._get_cached(f'sector_members_{ts_code}', fetch) def get_stock_daily(self, ts_code: str, start_date: str = None, end_date: str = None) -> pd.DataFrame: """ 获取个股日线行情 Args: ts_code: 股票代码(如 000001.SZ) start_date: 开始日期 (YYYYMMDD) end_date: 结束日期 (YYYYMMDD) Returns: 日线数据 """ if not start_date: start_date = (datetime.now() - timedelta(days=30)).strftime('%Y%m%d') if not end_date: end_date = datetime.now().strftime('%Y%m%d') def fetch(): # daily - 获取日线行情 return self.pro.daily( ts_code=ts_code, start_date=start_date, end_date=end_date ) return self._get_cached(f'stock_daily_{ts_code}_{end_date}', fetch) def get_stock_daily_basic(self, ts_codes: List[str], trade_date: str = None) -> pd.DataFrame: """ 获取个股每日指标(包含换手率、量比等) Args: ts_codes: 股票代码列表 trade_date: 交易日期 (YYYYMMDD) Returns: 每日指标数据 """ if not ts_codes: return pd.DataFrame() from datetime import datetime, timedelta if not trade_date: trade_date = datetime.now().strftime('%Y%m%d') def fetch(): # daily_basic - 获取每日指标 codes_str = ','.join(ts_codes[:300]) # 限制单次查询数量 # 尝试获取最近3天的数据(以防当天数据未更新) all_data = [] for i in range(3): try_date = (datetime.now() - timedelta(days=i)).strftime('%Y%m%d') df = self.pro.daily_basic( ts_code=codes_str, trade_date=try_date, fields='ts_code,trade_date,turnover_rate,volume_ratio,pe,pb' ) if not df.empty: all_data.append(df) # 如果找到数据就不再尝试更早的日期 break if all_data: return pd.concat(all_data, ignore_index=True) return pd.DataFrame() return self._get_cached(f'stock_daily_basic_{trade_date}', fetch) def get_stock_basic(self) -> pd.DataFrame: """ 获取股票基本信息列表 Returns: 股票基本信息 """ def fetch(): # stock_basic - 获取股票基本信息 return self.pro.stock_basic( exchange='', list_status='L', fields='ts_code,symbol,name,area,industry,list_date' ) return self._get_cached('stock_basic', fetch) def get_realtime_data(self, ts_codes: List[str]) -> pd.DataFrame: """ 获取实时行情数据(使用最新的日线数据) 注意:tushare 不提供真正的实时数据,这里返回最新的日线数据 注意:amount 字段单位是千元,需要 * 1000 转换为元 Args: ts_codes: 股票代码列表 Returns: 实时行情数据(amount 单位为千元) """ if not ts_codes: return pd.DataFrame() # 获取今天的日期 today = datetime.now().strftime('%Y%m%d') yesterday = (datetime.now() - timedelta(days=10)).strftime('%Y%m%d') def fetch(): # 使用 daily 接口获取最近数据 codes_str = ','.join(ts_codes[:100]) # 限制单次查询数量 df = self.pro.daily( ts_code=codes_str, start_date=yesterday, end_date=today ) # 只返回每个股票的最新一天数据 if not df.empty: df = df.sort_values('trade_date').groupby('ts_code').tail(1) return df return self._get_cached(f'realtime_{today}', fetch) def get_hot_sectors(self, threshold: float = 2.0) -> pd.DataFrame: """ 获取异动板块(一次性获取所有板块的最新行情) Args: threshold: 涨跌幅阈值(%) Returns: 异动板块数据 """ try: # 1. 获取所有概念板块 sectors_df = self.get_concept_sectors() if sectors_df.empty: logger.warning("获取概念板块列表失败") return pd.DataFrame() logger.info(f"获取到 {len(sectors_df)} 个概念板块") # 2. 获取今天的日期 today = datetime.now().strftime('%Y%m%d') yesterday = (datetime.now() - timedelta(days=10)).strftime('%Y%m%d') # 3. 批量获取板块行情(为了效率,限制数量) hot_sectors = [] max_sectors = 100 # 最多检查100个板块 for idx, row in sectors_df.head(max_sectors).iterrows(): ts_code = row['ts_code'] name = row.get('name', '') try: # 获取板块最新行情 daily_df = self.pro.ths_daily( ts_code=ts_code, start_date=yesterday, end_date=today ) if daily_df.empty: continue # 获取最新一天的数据 latest = daily_df.sort_values('trade_date').iloc[-1] # 检查涨跌幅 - 注意列名是 pct_change 不是 pct_chg change_pct = float(latest.get('pct_change', 0)) if change_pct >= threshold: hot_sectors.append({ 'ts_code': ts_code, 'name': name, 'change_pct': change_pct, 'change': float(latest.get('change', 0)), # 涨跌额 'close': float(latest.get('close', 0)), 'amount': float(latest.get('amount', 0)), # 成交额(元) 'volume': float(latest.get('vol', 0)), # 成交量(手) 'turnover_rate': float(latest.get('turnover_rate', 0)), # 换手率 'trade_date': str(latest.get('trade_date', '')) }) except Exception as e: logger.debug(f"获取板块 {name} 行情失败: {e}") continue result_df = pd.DataFrame(hot_sectors) if not result_df.empty: result_df = result_df.sort_values('change_pct', ascending=False) return result_df except Exception as e: logger.error(f"获取异动板块失败: {e}") return pd.DataFrame() # 全局单例 _tushare_client: Optional[TushareClient] = None def get_tushare_client(token: str = None) -> Optional[TushareClient]: """获取 Tushare 客户端单例""" global _tushare_client if _tushare_client is None: if not token: return None _tushare_client = TushareClient(token) return _tushare_client