""" A股数据获取模块 使用Tushare Pro API获取A股市场数据 """ import tushare as ts import pandas as pd import re from typing import List, Optional, Union from datetime import datetime, date, timedelta import time from loguru import logger from functools import wraps from src.utils.config_loader import config_loader def retry_on_failure(retries: int = 3, delay: float = 1.0): """ 重试装饰器,用于网络请求失败时自动重试 Args: retries: 重试次数 delay: 重试间隔(秒) """ def decorator(func): @wraps(func) def wrapper(*args, **kwargs): last_exception = None for attempt in range(retries + 1): try: return func(*args, **kwargs) except Exception as e: last_exception = e if attempt < retries: logger.warning(f"{func.__name__} 第{attempt + 1}次调用失败: {e}, {delay}秒后重试...") time.sleep(delay) else: logger.error(f"{func.__name__} 已重试{retries}次仍然失败: {e}") raise last_exception return wrapper return decorator class TushareFetcher: """Tushare数据获取器""" def __init__(self, token: str = None): """ 初始化数据获取器 Args: token: Tushare Pro token,如果为None则从配置文件读取 """ # 如果没有传入token,从配置文件读取 if token is None: token = config_loader.get_tushare_token() if token: logger.info("✅ 从配置文件读取TuShare token成功") else: logger.warning("⚠️ 配置文件中未找到TuShare token") self.token = token if token: try: ts.set_token(token) self.pro = ts.pro_api() # 验证token是否有效 test_data = self.pro.trade_cal(exchange='SSE', cal_date='20240101', limit=1) if not test_data.empty: logger.info("✅ Tushare Pro客户端初始化完成,token验证成功") else: logger.warning("⚠️ Tushare Pro token可能无效或权限不足") except Exception as e: logger.error(f"❌ Tushare Pro初始化失败: {e}") logger.warning("将回退到无Pro权限模式") self.pro = None else: logger.warning("未提供Tushare token,将使用免费接口") self.pro = None # 股票名称缓存机制 self._stock_name_cache = {} self._stock_list_cache = None self._cache_timestamp = None self._cache_duration = 3600 # 缓存1小时 # 分钟线接口调用频率控制(每分钟最多2次) self._minute_data_call_times = [] self._minute_data_max_calls_per_minute = 2 def clear_caches(self): """清除所有缓存""" self._stock_name_cache.clear() self._stock_list_cache = None self._cache_timestamp = None logger.info("🔄 已清除所有股票数据缓存") def get_stock_list(self, use_hot_stocks: bool = True, hot_limit: int = 100) -> pd.DataFrame: """ 获取股票列表 Args: use_hot_stocks: 是否优先使用热门股票,默认True hot_limit: 热门股票数量限制,默认100 Returns: 股票列表DataFrame """ import time current_time = time.time() # 检查缓存是否有效 if (self._stock_list_cache is not None and self._cache_timestamp is not None and current_time - self._cache_timestamp < self._cache_duration): logger.debug(f"🔄 使用缓存的股票列表数据 ({len(self._stock_list_cache)} 只股票)") return self._stock_list_cache.copy() logger.info("📊 重新获取股票列表数据...") # 优先使用热门股票 if use_hot_stocks: try: logger.info(f"🔥 优先获取同花顺热榜股票 (前{hot_limit}只)...") hot_stocks = self.get_combined_hot_stocks( limit_per_source=hot_limit, final_limit=hot_limit ) if not hot_stocks.empty and 'stock_code' in hot_stocks.columns: # 转换为标准格式 stock_list = hot_stocks.copy() # 确保有必要的列 if 'full_stock_code' not in stock_list.columns: stock_list['full_stock_code'] = stock_list['stock_code'] # 添加缺失的列 for col in ['area', 'industry', 'exchange', 'list_date']: if col not in stock_list.columns: stock_list[col] = '' logger.info(f"✅ 获取热门股票成功,共{len(stock_list)}只股票") if 'source' in stock_list.columns: source_counts = stock_list['source'].value_counts().to_dict() source_detail = " | ".join([f"{k}: {v}只" for k, v in source_counts.items()]) logger.info(f"📊 数据源分布: {source_detail}") # 更新缓存 self._stock_list_cache = stock_list.copy() self._cache_timestamp = current_time return stock_list else: logger.warning("热门股票数据为空,回退到全量股票列表") except Exception as e: logger.warning(f"获取热门股票失败: {e},回退到全量股票列表") try: # 回退方案:获取全量股票列表 logger.info("📊 获取全量A股股票列表...") if self.pro: # 使用Pro接口获取股票基本信息 stock_list = self.pro.stock_basic( exchange='', list_status='L', fields='ts_code,symbol,name,area,industry,market,list_date' ) # 统一列名 stock_list.rename(columns={ 'ts_code': 'full_stock_code', 'symbol': 'stock_code', 'name': 'short_name', 'market': 'exchange' }, inplace=True) else: # 需要Pro权限才能获取股票列表 logger.error("获取股票列表需要TuShare Pro权限,请提供有效token") return pd.DataFrame() logger.info(f"获取全量股票列表成功,共{len(stock_list)}只股票") # 更新缓存 self._stock_list_cache = stock_list.copy() self._cache_timestamp = current_time return stock_list except Exception as e: logger.error(f"获取股票列表失败: {e}") return pd.DataFrame() def get_filtered_a_share_list(self, exclude_st: bool = True, exclude_bj: bool = True, min_market_cap: float = 2000000000) -> pd.DataFrame: """ 获取过滤后的A股股票列表 Args: exclude_st: 是否排除ST股票 exclude_bj: 是否排除北交所股票 min_market_cap: 最小市值要求(元),默认20亿 Returns: 过滤后的股票列表DataFrame """ try: # 获取完整股票列表 all_stocks = self.get_stock_list() if all_stocks.empty: return pd.DataFrame() filtered_stocks = all_stocks.copy() original_count = len(filtered_stocks) # 排除北交所股票 if exclude_bj: before_count = len(filtered_stocks) if 'exchange' in filtered_stocks.columns: filtered_stocks = filtered_stocks[filtered_stocks['exchange'] != 'BJ'] else: # 根据股票代码判断 filtered_stocks = filtered_stocks[~filtered_stocks['stock_code'].str.startswith(('8', '43', '83'))] bj_excluded = before_count - len(filtered_stocks) logger.info(f"排除北交所股票: {bj_excluded}只") # 排除ST股票 if exclude_st: before_count = len(filtered_stocks) # 使用(?:...)非捕获组避免警告 st_pattern = r'(?:\*?ST|PT|退|暂停)' filtered_stocks = filtered_stocks[~filtered_stocks['short_name'].str.contains(st_pattern, na=False, case=False, regex=True)] st_excluded = before_count - len(filtered_stocks) logger.info(f"排除ST等风险股票: {st_excluded}只") # 基于市值筛选 if min_market_cap > 0 and self.pro: before_count = len(filtered_stocks) filtered_stocks = self._filter_by_market_cap(filtered_stocks, min_market_cap) cap_excluded = before_count - len(filtered_stocks) logger.info(f"排除小市值股票: {cap_excluded}只") final_count = len(filtered_stocks) excluded_count = original_count - final_count if not filtered_stocks.empty: logger.info(f"✅ 获取过滤后A股列表成功") logger.info(f"📊 原始股票: {original_count}只 | 过滤后: {final_count}只 | 排除: {excluded_count}只") return filtered_stocks except Exception as e: logger.error(f"获取过滤A股列表失败: {e}") return pd.DataFrame() def _filter_by_market_cap(self, stock_df: pd.DataFrame, min_market_cap: float) -> pd.DataFrame: """ 基于市值筛选股票 Args: stock_df: 股票列表DataFrame min_market_cap: 最小市值要求(元) Returns: 过滤后的股票DataFrame """ if stock_df.empty or not self.pro: return stock_df try: logger.info(f"开始基于市值筛选股票,阈值: {min_market_cap/100000000:.0f}亿元") # 获取股票基本信息包含市值 ts_codes = stock_df['full_stock_code'].tolist() if 'full_stock_code' in stock_df.columns else [] if not ts_codes: return stock_df # 分批获取市值数据(避免API限制) batch_size = 50 valid_stocks = [] for i in range(0, len(ts_codes), batch_size): batch_codes = ts_codes[i:i+batch_size] try: # 获取每日基本面数据(包含市值) trade_date = datetime.now().strftime('%Y%m%d') daily_basic = self.pro.daily_basic( ts_code=','.join(batch_codes), trade_date=trade_date, fields='ts_code,total_mv' ) if not daily_basic.empty: # 市值单位是万元,转换为元 daily_basic['market_cap'] = daily_basic['total_mv'] * 10000 # 筛选符合市值要求的股票 valid_codes = daily_basic[daily_basic['market_cap'] >= min_market_cap]['ts_code'].tolist() # 添加到结果中 batch_stocks = stock_df[stock_df['full_stock_code'].isin(valid_codes)] valid_stocks.append(batch_stocks) time.sleep(0.2) # API限制 except Exception as e: logger.debug(f"获取批次市值数据失败: {e}") # 如果获取失败,保留原数据 batch_stocks = stock_df[stock_df['full_stock_code'].isin(batch_codes)] valid_stocks.append(batch_stocks) if valid_stocks: result_df = pd.concat(valid_stocks, ignore_index=True) logger.info(f"✅ 市值筛选完成: {len(result_df)}/{len(stock_df)} 只股票符合要求") return result_df else: return stock_df except Exception as e: logger.error(f"市值筛选失败: {e}") return stock_df def get_realtime_data(self, stock_codes: Union[str, List[str]]) -> pd.DataFrame: """ 获取实时行情数据 Args: stock_codes: 股票代码或代码列表 Returns: 实时行情DataFrame """ try: if isinstance(stock_codes, str): stock_codes = [stock_codes] if self.pro: # 转换为tushare格式的代码 ts_codes = [] for code in stock_codes: if '.' in code: ts_codes.append(code) else: # 根据代码判断交易所 if code.startswith(('60', '68', '90')): ts_codes.append(f"{code}.SH") else: ts_codes.append(f"{code}.SZ") # 使用Pro接口获取最新行情数据 today = pd.Timestamp.now().strftime('%Y%m%d') realtime_data = self.pro.daily( ts_code=','.join(ts_codes), trade_date=today ) else: logger.error("获取实时行情需要TuShare Pro权限,请提供有效token") return pd.DataFrame() logger.info(f"获取实时数据成功,股票数量: {len(stock_codes)}") return realtime_data except Exception as e: logger.error(f"获取实时数据失败: {e}") return pd.DataFrame() def get_historical_data( self, stock_code: str, start_date: Union[str, date], end_date: Union[str, date], period: str = "daily" ) -> pd.DataFrame: """ 获取历史行情数据 Args: stock_code: 股票代码 start_date: 开始日期 end_date: 结束日期 period: 数据周期 ('daily', 'weekly', 'monthly', '60min', '30min', '15min', '5min', '1min') Returns: 历史行情DataFrame """ try: # 转换日期格式 if isinstance(start_date, date): start_date = start_date.strftime("%Y%m%d") else: start_date = start_date.replace('-', '') if isinstance(end_date, date): end_date = end_date.strftime("%Y%m%d") else: end_date = end_date.replace('-', '') # 转换为tushare格式的代码 if '.' not in stock_code: if stock_code.startswith(('60', '68', '90')): ts_code = f"{stock_code}.SH" else: ts_code = f"{stock_code}.SZ" else: ts_code = stock_code if self.pro: # 使用Pro接口 if period == 'daily': hist_data = self.pro.daily( ts_code=ts_code, start_date=start_date, end_date=end_date ) elif period == 'weekly': hist_data = self.pro.weekly( ts_code=ts_code, start_date=start_date, end_date=end_date ) elif period == 'monthly': hist_data = self.pro.monthly( ts_code=ts_code, start_date=start_date, end_date=end_date ) elif period in ('60min', '30min', '15min', '5min', '1min'): # 使用分钟线数据接口 # 注意:分钟线数据需要特殊权限,且有调用限制:每分钟最多2次 # 频率限制检查 current_time = time.time() # 清理60秒前的调用记录 self._minute_data_call_times = [t for t in self._minute_data_call_times if current_time - t < 60] # 检查是否超过限制 if len(self._minute_data_call_times) >= self._minute_data_max_calls_per_minute: wait_time = 60 - (current_time - self._minute_data_call_times[0]) logger.warning(f"分钟线接口频率限制:每分钟最多{self._minute_data_max_calls_per_minute}次,需等待 {wait_time:.1f} 秒") time.sleep(wait_time + 1) # 多等1秒确保安全 # 清理过期记录 current_time = time.time() self._minute_data_call_times = [t for t in self._minute_data_call_times if current_time - t < 60] # 调用接口 hist_data = self.pro.stk_mins( ts_code=ts_code, start_date=start_date, end_date=end_date, freq=period ) # 记录调用时间 self._minute_data_call_times.append(time.time()) else: hist_data = self.pro.daily( ts_code=ts_code, start_date=start_date, end_date=end_date ) if not hist_data.empty: # 统一列名映射 field_mapping = { 'vol': 'volume', # TuShare返回vol,策略期望volume 'ts_code': 'stock_code' # 统一股票代码字段名 } for old_name, new_name in field_mapping.items(): if old_name in hist_data.columns: hist_data.rename(columns={old_name: new_name}, inplace=True) # 转换日期格式并保持双格式兼容 if 'trade_date' in hist_data.columns: hist_data['trade_date'] = pd.to_datetime(hist_data['trade_date']).dt.strftime('%Y-%m-%d') hist_data['date'] = hist_data['trade_date'] # 同时提供date字段 elif 'date' in hist_data.columns: hist_data['date'] = pd.to_datetime(hist_data['date']).dt.strftime('%Y-%m-%d') hist_data['trade_date'] = hist_data['date'] # 同时提供trade_date字段 # 按日期升序排列 hist_data = hist_data.sort_values('date') else: # 需要Pro权限才能获取历史数据 logger.error("获取历史数据需要TuShare Pro权限,请提供有效token") return pd.DataFrame() if not hist_data.empty: logger.info(f"获取{stock_code}历史数据成功,数据量: {len(hist_data)}") else: logger.warning(f"获取{stock_code}历史数据为空") return hist_data except Exception as e: logger.error(f"获取{stock_code}历史数据失败: {e}") return pd.DataFrame() def get_index_data(self, index_code: str = "000001.SH") -> pd.DataFrame: """ 获取指数数据 Args: index_code: 指数代码 Returns: 指数数据DataFrame """ try: if self.pro: # 转换指数代码格式 if index_code == "000001.SH": ts_code = "000001.SH" # 上证指数 elif index_code == "399001.SZ": ts_code = "399001.SZ" # 深证成指 else: ts_code = index_code index_data = self.pro.index_daily( ts_code=ts_code, trade_date=datetime.now().strftime('%Y%m%d') ) else: # 需要Pro权限才能获取指数数据 logger.error("获取指数数据需要TuShare Pro权限,请提供有效token") return pd.DataFrame() logger.info(f"获取指数{index_code}数据成功") return index_data except Exception as e: logger.error(f"获取指数数据失败: {e}") return pd.DataFrame() def get_financial_data(self, stock_code: str) -> pd.DataFrame: """ 获取财务数据 Args: stock_code: 股票代码 Returns: 财务数据DataFrame """ try: # 转换为tushare格式的代码 if '.' not in stock_code: if stock_code.startswith(('60', '68', '90')): ts_code = f"{stock_code}.SH" else: ts_code = f"{stock_code}.SZ" else: ts_code = stock_code if self.pro: # 获取财务数据 financial_data = self.pro.income(ts_code=ts_code, period='20231231') else: # 免费接口的财务数据 financial_data = pd.DataFrame() logger.info(f"获取{stock_code}财务数据成功") return financial_data except Exception as e: logger.error(f"获取财务数据失败: {e}") return pd.DataFrame() def search_stocks(self, keyword: str) -> pd.DataFrame: """ 搜索股票 Args: keyword: 搜索关键词 Returns: 搜索结果DataFrame """ try: # 获取完整股票列表 all_stocks = self.get_stock_list() if all_stocks.empty: return pd.DataFrame() # 在股票代码和名称中搜索关键词 keyword = str(keyword).strip() if not keyword: return pd.DataFrame() # 支持按代码或名称模糊搜索 mask = ( all_stocks['stock_code'].str.contains(keyword, case=False, na=False) | all_stocks['short_name'].str.contains(keyword, case=False, na=False) ) results = all_stocks[mask].copy() logger.info(f"搜索股票'{keyword}'成功,找到{len(results)}个结果") return results except Exception as e: logger.error(f"搜索股票失败: {e}") return pd.DataFrame() @retry_on_failure(retries=2, delay=1.0) def get_hot_stocks_ths(self, limit: int = 100, trade_date: str = None) -> pd.DataFrame: """ 获取同花顺热门股票 Args: limit: 返回数量限制 trade_date: 交易日期,格式YYYYMMDD,默认为当日 Returns: 热门股票DataFrame """ try: if not self.pro: logger.error("需要Tushare Pro权限才能获取同花顺热股数据") return self._get_fallback_hot_stocks(limit) # 使用TuShare Pro的同花顺热榜接口 params = { 'market': '热股', 'is_new': 'Y' } if trade_date: params['trade_date'] = trade_date df = self.pro.ths_hot(**params) if df.empty: logger.warning("同花顺热股数据为空,使用备用方法") return self._get_fallback_hot_stocks(limit) # 数据清洗和去重 logger.info(f"原始数据: {len(df)} 条") # 1. 去除重复股票代码,保留第一个(排名最好的) if 'ts_code' in df.columns: df = df.drop_duplicates(subset=['ts_code'], keep='first') logger.info(f"去重后数据: {len(df)} 条") # 2. 过滤退市和ST股票 if 'ts_name' in df.columns: before_filter = len(df) # 过滤包含以下关键词的股票:退市、ST、*ST、PT # 使用(?:...)非捕获组避免警告 filter_pattern = r'(?:退市|^\*?ST|^ST|^PT|暂停)' df = df[~df['ts_name'].str.contains(filter_pattern, na=False, case=False, regex=True)] filtered_count = before_filter - len(df) if filtered_count > 0: logger.info(f"过滤退市/ST股票: {filtered_count} 只") # 3. 按rank排序,处理排名异常 if 'rank' in df.columns: df = df.sort_values('rank') # 重新分配连续排名 df['original_rank'] = df['rank'] # 保留原始排名 df['rank'] = range(1, len(df) + 1) # 重新分配连续排名 # 4. 统一列名格式(先改名,方便后续处理) if 'ts_code' in df.columns: df.rename(columns={'ts_code': 'stock_code', 'ts_name': 'short_name'}, inplace=True) # 5. 二次验证:通过TuShare API验证股票状态,过滤ST和退市股票 # 为了确保最终有足够数量,验证更多股票 if self.pro and 'stock_code' in df.columns: # 计算需要验证的数量:考虑约10%的过滤率,验证limit*1.15的股票 verify_count = min(int(limit * 1.15), len(df)) df_to_verify = df.head(verify_count) valid_stocks = [] verified_count = 0 for idx, row in df_to_verify.iterrows(): verified_count += 1 try: stock_code = row['stock_code'] stock_info = self.pro.stock_basic(ts_code=stock_code, fields='ts_code,name,list_status') if not stock_info.empty: info = stock_info.iloc[0] real_name = info['name'] list_status = info['list_status'] # 检查股票状态和名称 if list_status == 'L': # 只保留正常上市的股票 # 检查真实名称是否包含ST/退市等 st_pattern = r'(退市|^\*?ST|^ST|^PT|暂停)' if not re.search(st_pattern, real_name): valid_stocks.append(idx) else: logger.debug(f"二次过滤: {stock_code} - {real_name}") else: logger.debug(f"二次过滤: {stock_code} - 状态:{list_status}") except Exception as e: logger.debug(f"验证股票{row['stock_code']}失败: {e}") continue # 如果已经有足够的有效股票,可以提前结束 if len(valid_stocks) >= limit: break if valid_stocks: df = df.loc[valid_stocks] filtered_count = verified_count - len(valid_stocks) logger.info(f"二次验证: 验证{verified_count}只,过滤{filtered_count}只,剩余{len(df)}只") # 6. 最终确保返回limit数量的股票 if len(df) < limit: logger.warning(f"过滤后股票不足: 需要{limit}只,实际{len(df)}只") df = df.head(limit) # 8. 添加数据源标识和有用字段 df['source'] = '同花顺热股' # 保留有用的原始字段 useful_cols = ['stock_code', 'short_name', 'rank', 'source'] for col in ['pct_change', 'current_price', 'hot', 'concept', 'original_rank']: if col in df.columns: useful_cols.append(col) df = df[useful_cols] logger.info(f"获取同花顽热门股票成功,共{len(df)}只股票") return df except Exception as e: logger.error(f"获取同花顽热门股票失败: {e}") return self._get_fallback_hot_stocks(limit) @retry_on_failure(retries=2, delay=1.0) def get_popular_stocks_east(self, limit: int = 100, trade_date: str = None) -> pd.DataFrame: """ 获取东方财富人气股票 Args: limit: 返回数量限制 trade_date: 交易日期,格式YYYYMMDD,默认为当日 Returns: 人气股票DataFrame """ try: if not self.pro: logger.error("需要Tushare Pro权限才能获取东财人气股数据") return self._get_fallback_hot_stocks(limit) # 使用TuShare Pro的东方财富热榜接口 params = { 'market': 'A股市场', 'hot_type': '人气榜', 'is_new': 'Y' } if trade_date: params['trade_date'] = trade_date df = self.pro.dc_hot(**params) if df.empty: logger.warning("东财人气股数据为空,使用备用方法") return self._get_fallback_hot_stocks(limit) # 数据处理和标准化 # 1. 过滤退市和ST股票 if 'ts_name' in df.columns: before_filter = len(df) # 使用(?:...)非捕获组避免警告 filter_pattern = r'(?:退市|^\*?ST|^ST|^PT|暂停)' df = df[~df['ts_name'].str.contains(filter_pattern, na=False, case=False, regex=True)] filtered_count = before_filter - len(df) if filtered_count > 0: logger.info(f"过滤退市/ST股票: {filtered_count} 只") # 2. 限制数量 df = df.head(limit) # 3. 统一列名格式 if 'ts_code' in df.columns: df.rename(columns={'ts_code': 'stock_code', 'ts_name': 'short_name'}, inplace=True) # 4. 添加数据源标识 df['source'] = '东财人气榜' if 'rank' not in df.columns: df['rank'] = range(1, len(df) + 1) logger.info(f"获取东财人气股票成功,共{len(df)}只股票") return df except Exception as e: logger.error(f"获取东财人气股票失败: {e}") return self._get_fallback_hot_stocks(limit) def get_combined_hot_stocks(self, limit_per_source: int = 100, final_limit: int = 150) -> pd.DataFrame: """ 获取合并的热门股票(同花顽+东财) Args: limit_per_source: 每个数据源的股票数量 final_limit: 最终返回的股票数量 Returns: 合并去重后的热门股票DataFrame """ try: all_stocks = [] # 1. 获取同花顽热股 ths_stocks = self.get_hot_stocks_ths(limit_per_source) if not ths_stocks.empty: all_stocks.append(ths_stocks) logger.info(f"同花顽热股: {len(ths_stocks)}只") # 2. 获取东财人气股 east_stocks = self.get_popular_stocks_east(limit_per_source) if not east_stocks.empty: all_stocks.append(east_stocks) logger.info(f"东财人气股: {len(east_stocks)}只") if not all_stocks: logger.warning("所有热门股票数据源都失败,使用备用股票池") return self._get_fallback_hot_stocks(final_limit) # 3. 合并数据(过滤空数据框) non_empty_stocks = [df for df in all_stocks if not df.empty] if not non_empty_stocks: logger.warning("没有有效的热门股票数据,使用备用股票池") return self._get_fallback_hot_stocks(final_limit) combined_df = pd.concat(non_empty_stocks, ignore_index=True) # 4. 去重(优先保留排名靠前的) combined_df = combined_df.sort_values(['rank', 'source']) unique_stocks = combined_df.drop_duplicates(subset=['stock_code'], keep='first') # 5. 限制最终数量 result = unique_stocks.head(final_limit) # 6. 重新排序并添加最终排名 result = result.reset_index(drop=True) result['final_rank'] = range(1, len(result) + 1) logger.info(f"合并热门股票成功: 原始{len(combined_df)}只,去重后{len(unique_stocks)}只,最终{len(result)}只") return result except Exception as e: logger.error(f"获取合并热门股票失败: {e}") return self._get_fallback_hot_stocks(final_limit) def get_stock_name(self, stock_code: str) -> str: """ 获取股票中文名称 Args: stock_code: 股票代码 Returns: 股票中文名称 """ try: # 从缓存中查找 if stock_code in self._stock_name_cache: return self._stock_name_cache[stock_code] # 直接通过TuShare Pro查询单个股票信息 if self.pro: # 将股票代码转换为TuShare格式 ts_code = stock_code if '.' not in stock_code: # 如果没有后缀,根据代码前缀添加 if stock_code.startswith('6'): ts_code = f"{stock_code}.SH" elif stock_code.startswith(('0', '3')): ts_code = f"{stock_code}.SZ" # 查询股票基本信息 stock_info = self.pro.stock_basic(ts_code=ts_code, fields='ts_code,name') if not stock_info.empty and 'name' in stock_info.columns: stock_name = stock_info.iloc[0]['name'] # 添加到缓存 self._stock_name_cache[stock_code] = stock_name return stock_name # 如果Pro查询失败,返回股票代码本身 self._stock_name_cache[stock_code] = stock_code return stock_code except Exception as e: logger.debug(f"获取股票{stock_code}名称失败: {e}") return stock_code def _get_fallback_hot_stocks(self, limit: int = 100) -> pd.DataFrame: """ 备用热门股票获取方法(当主要接口失败时使用) Args: limit: 返回股票数量 Returns: 热门股票DataFrame """ try: logger.info("使用备用方法获取热门股票(基于成交量排序)") if self.pro: # 获取当日成交量排行 trade_date = datetime.now().strftime('%Y%m%d') daily_data = self.pro.daily(trade_date=trade_date) if not daily_data.empty: # 按成交量排序 hot_stocks = daily_data.sort_values('vol', ascending=False).head(limit) # 统一列名 hot_stocks.rename(columns={'ts_code': 'stock_code'}, inplace=True) # 添加股票名称 stock_list = self.get_stock_list() if not stock_list.empty: hot_stocks = hot_stocks.merge( stock_list[['stock_code', 'short_name']], left_on='stock_code', right_on='stock_code', how='left' ) # 添加必要字段 hot_stocks['source'] = '成交量排序' hot_stocks['rank'] = range(1, len(hot_stocks) + 1) logger.info(f"备用方法获取热门股票成功,共{len(hot_stocks)}只股票") return hot_stocks # 如果Pro接口也失败,返回预设的股票池 return self._get_default_stock_pool(limit) except Exception as e: logger.error(f"备用热门股票获取失败: {e}") return self._get_default_stock_pool(limit) def _get_default_stock_pool(self, limit: int = 100) -> pd.DataFrame: """ 默认股票池(当所有数据获取方法都失败时使用) Args: limit: 返回股票数量 Returns: 默认股票池DataFrame """ # 预设一些主要的大盘股和活跃股票 default_stocks = [ {'stock_code': '000001.SZ', 'short_name': '平安银行'}, {'stock_code': '000002.SZ', 'short_name': '万科A'}, {'stock_code': '000858.SZ', 'short_name': '五粮液'}, {'stock_code': '600000.SH', 'short_name': '浦发银行'}, {'stock_code': '600036.SH', 'short_name': '招商银行'}, {'stock_code': '600519.SH', 'short_name': '贵州茅台'}, {'stock_code': '600887.SH', 'short_name': '伊利股份'}, {'stock_code': '002415.SZ', 'short_name': '海康威视'}, {'stock_code': '300014.SZ', 'short_name': '亿纬锂能'}, {'stock_code': '300059.SZ', 'short_name': '东方财富'}, {'stock_code': '300750.SZ', 'short_name': '宁德时代'}, {'stock_code': '000876.SZ', 'short_name': '新希望'}, {'stock_code': '002594.SZ', 'short_name': 'BYD'}, {'stock_code': '000895.SZ', 'short_name': '双汇发展'}, {'stock_code': '600031.SH', 'short_name': '三一重工'}, {'stock_code': '601318.SH', 'short_name': '中国平安'}, {'stock_code': '601166.SH', 'short_name': '兴业银行'}, {'stock_code': '600009.SH', 'short_name': '上海机场'}, {'stock_code': '600276.SH', 'short_name': '恒瑞医药'}, {'stock_code': '000063.SZ', 'short_name': '中兴通讯'}, ] df = pd.DataFrame(default_stocks[:limit]) df['source'] = '默认股票池' df['rank'] = range(1, len(df) + 1) logger.warning(f"使用默认股票池,包含{len(df)}只股票") return df def get_market_overview(self) -> dict: """ 获取市场概况 Returns: 市场概况字典 """ try: # 获取主要指数数据 sh_index = self.get_index_data("000001.SH") # 上证指数 sz_index = self.get_index_data("399001.SZ") # 深证成指 cyb_index = self.get_index_data("399006.SZ") # 创业板指 overview = { "update_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "shanghai": sh_index.iloc[0].to_dict() if not sh_index.empty else {}, "shenzhen": sz_index.iloc[0].to_dict() if not sz_index.empty else {}, "chinext": cyb_index.iloc[0].to_dict() if not cyb_index.empty else {} } logger.info("获取市场概况成功") return overview except Exception as e: logger.error(f"获取市场概况失败: {e}") return {} def get_sector_money_flow(self, trade_date: str = None) -> pd.DataFrame: """ 获取板块资金流向数据 使用同花顺行业资金流向接口 Args: trade_date: 交易日期,格式YYYYMMDD,默认为当日 Returns: 板块资金流向DataFrame,包含净流入金额等信息 """ try: if not self.pro: logger.error("需要Tushare Pro权限才能获取板块资金流向数据") return pd.DataFrame() # 如果未指定日期,使用最近交易日 if trade_date is None: trade_date = datetime.now().strftime('%Y%m%d') # 获取同花顺行业资金流向数据 df = self.pro.moneyflow_ind_ths(trade_date=trade_date) if df.empty: # 如果当日无数据,尝试获取前一交易日 prev_date = (datetime.strptime(trade_date, '%Y%m%d') - timedelta(days=1)).strftime('%Y%m%d') df = self.pro.moneyflow_ind_ths(trade_date=prev_date) if not df.empty: # 重命名列以保持兼容性 if 'industry' in df.columns: df['name'] = df['industry'] # 使用正确的净流入字段 if 'net_amount' in df.columns: df['net_amount'] = df['net_amount'] else: logger.warning("数据中未找到net_amount字段") return pd.DataFrame() # 按净流入金额排序(从大到小) df = df.sort_values('net_amount', ascending=False) # 添加排名 df['rank'] = range(1, len(df) + 1) # 确保有涨跌幅字段 if 'pct_change' not in df.columns: df['pct_change'] = 0 logger.info(f"获取板块资金流向成功,共{len(df)}个板块") else: logger.warning(f"未获取到{trade_date}的板块资金流向数据") return df except Exception as e: logger.error(f"获取板块资金流向失败: {e}") return pd.DataFrame() def get_concept_money_flow(self, trade_date: str = None) -> pd.DataFrame: """ 获取概念板块资金流向数据 使用同花顺概念板块资金流向接口 Args: trade_date: 交易日期,格式YYYYMMDD,默认为当日 Returns: 概念板块资金流向DataFrame """ try: if not self.pro: logger.error("需要Tushare Pro权限才能获取概念资金流向数据") return pd.DataFrame() if trade_date is None: trade_date = datetime.now().strftime('%Y%m%d') # 获取同花顺概念板块资金流向数据 df = self.pro.moneyflow_cnt_ths(trade_date=trade_date) if df.empty: # 如果当日无数据,尝试获取前一交易日 prev_date = (datetime.strptime(trade_date, '%Y%m%d') - timedelta(days=1)).strftime('%Y%m%d') df = self.pro.moneyflow_cnt_ths(trade_date=prev_date) if not df.empty: # 按净流入金额排序(从大到小) df = df.sort_values('net_amount', ascending=False) # 添加排名 df['rank'] = range(1, len(df) + 1) logger.info(f"获取概念资金流向成功,共{len(df)}个概念") else: logger.warning(f"未获取到{trade_date}的概念资金流向数据") return df except Exception as e: logger.error(f"获取概念资金流向失败: {e}") return pd.DataFrame() def get_strongest_concept_boards(self, trade_date: str = None, ts_code: str = None) -> pd.DataFrame: """ 获取最强板块统计(涨停股票最多的概念板块) Args: trade_date: 交易日期,格式YYYYMMDD,默认为当日 ts_code: 板块代码,可选 Returns: 最强板块统计DataFrame,包含以下字段: - ts_code: 板块代码 - name: 板块名称 - trade_date: 交易日期 - days: 上榜天数 - up_stat: 连板高度 - cons_nums: 连板家数 - up_nums: 涨停家数 - pct_chg: 涨跌幅% - rank: 板块热点排名 """ try: if not self.pro: logger.error("需要Tushare Pro权限才能获取最强板块数据") return pd.DataFrame() # 设置查询参数 params = {} if trade_date: params['trade_date'] = trade_date if ts_code: params['ts_code'] = ts_code # 调用Tushare接口 df = self.pro.limit_cpt_list(**params) if df.empty: logger.warning(f"未获取到最强板块数据: {trade_date or '当日'}") return pd.DataFrame() # 按照涨停家数和涨跌幅排序 df = df.sort_values(['up_nums', 'pct_chg'], ascending=[False, False]) logger.info(f"获取最强板块数据成功: {len(df)}个板块,日期: {trade_date or '当日'}") return df except Exception as e: logger.error(f"获取最强板块统计失败: {e}") return pd.DataFrame() def get_concept_constituent_stocks(self, ts_code: str) -> pd.DataFrame: """ 获取概念板块的成分股票 Args: ts_code: 概念板块代码 Returns: 成分股票DataFrame,包含股票代码、名称等信息 """ try: if not self.pro: logger.error("需要Tushare Pro权限才能获取概念成分股") return pd.DataFrame() # 获取概念成分股 df = self.pro.concept_detail(id=ts_code) if df.empty: logger.warning(f"未获取到概念板块 {ts_code} 的成分股") return pd.DataFrame() logger.info(f"获取概念板块 {ts_code} 成分股成功: {len(df)}只股票") return df except Exception as e: logger.error(f"获取概念成分股失败: {e}") return pd.DataFrame() def get_strongest_concept_stocks(self, trade_date: str = None, top_boards: int = 5) -> dict: """ 获取最强板块中的股票(综合方法) Args: trade_date: 交易日期,格式YYYYMMDD,默认为当日 top_boards: 选择前N个最强板块,默认5个 Returns: 包含最强板块及其成分股的字典 """ try: result = { 'trade_date': trade_date or datetime.now().strftime('%Y%m%d'), 'strongest_boards': pd.DataFrame(), 'stocks_by_board': {}, 'all_stocks': pd.DataFrame(), 'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S') } # 1. 获取最强板块列表 strongest_boards = self.get_strongest_concept_boards(trade_date) if strongest_boards.empty: logger.warning("未获取到最强板块数据") return result # 取前N个最强板块 top_boards_data = strongest_boards.head(top_boards) result['strongest_boards'] = top_boards_data # 2. 获取每个强势板块的成分股 all_stocks_list = [] for _, board in top_boards_data.iterrows(): board_code = board['ts_code'] board_name = board['name'] # 获取该板块的成分股 stocks = self.get_concept_constituent_stocks(board_code) if not stocks.empty: # 添加板块信息到股票数据 stocks['board_code'] = board_code stocks['board_name'] = board_name stocks['board_up_nums'] = board['up_nums'] stocks['board_pct_chg'] = board['pct_chg'] result['stocks_by_board'][board_name] = stocks all_stocks_list.append(stocks) logger.info(f"板块 {board_name} 包含 {len(stocks)} 只股票") # 避免频繁调用API time.sleep(0.1) # 3. 合并所有股票数据 if all_stocks_list: result['all_stocks'] = pd.concat(all_stocks_list, ignore_index=True) # 去重(一只股票可能属于多个概念) unique_stocks = result['all_stocks'].drop_duplicates(subset=['ts_code']) logger.info(f"最强板块股票获取完成:") logger.info(f" - 强势板块数量: {len(top_boards_data)}") logger.info(f" - 包含股票总数: {len(result['all_stocks'])}") logger.info(f" - 去重后股票数: {len(unique_stocks)}") return result except Exception as e: logger.error(f"获取最强板块股票失败: {e}") return { 'trade_date': trade_date or datetime.now().strftime('%Y%m%d'), 'strongest_boards': pd.DataFrame(), 'stocks_by_board': {}, 'all_stocks': pd.DataFrame(), 'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'error': str(e) } def get_leading_stocks_from_board(self, board_code: str, board_name: str = None, top_n: int = 3, min_price: float = 5.0, max_price: float = 300.0) -> pd.DataFrame: """ 从单个板块中筛选龙头股 Args: board_code: 板块代码 board_name: 板块名称 top_n: 返回龙头股数量 min_price: 最低股价过滤 max_price: 最高股价过滤 Returns: 龙头股DataFrame,包含评分和排名 """ try: if not self.pro: logger.error("需要Tushare Pro权限才能筛选龙头股") return pd.DataFrame() # 1. 获取板块成分股 constituent_stocks = self.get_concept_constituent_stocks(board_code) if constituent_stocks.empty: logger.warning(f"板块 {board_code} 没有成分股数据") return pd.DataFrame() # 2. 获取股票基本信息和当日行情 trade_date = datetime.now().strftime('%Y%m%d') stock_codes = constituent_stocks['ts_code'].tolist() leading_candidates = [] for stock_code in stock_codes[:50]: # 限制查询数量避免超限 try: # 获取基本信息 basic_info = self.pro.stock_basic(ts_code=stock_code) if basic_info.empty: continue # 获取当日行情 daily_data = self.pro.daily(ts_code=stock_code, trade_date=trade_date) if daily_data.empty: # 尝试获取最近一个交易日数据 recent_data = self.pro.daily(ts_code=stock_code, limit=1) if not recent_data.empty: daily_data = recent_data else: continue stock_info = daily_data.iloc[0] basic = basic_info.iloc[0] # 价格过滤 current_price = stock_info['close'] if current_price < min_price or current_price > max_price: continue # 计算评分指标 candidate = { 'ts_code': stock_code, 'stock_code': stock_code, 'name': basic['name'], 'board_code': board_code, 'board_name': board_name or board_code, 'close': current_price, 'pct_chg': stock_info.get('pct_chg', 0), 'vol': stock_info.get('vol', 0), 'amount': stock_info.get('amount', 0), 'turnover_rate': stock_info.get('turnover_rate', 0), 'total_mv': stock_info.get('total_mv', 0), # 总市值 'circ_mv': stock_info.get('circ_mv', 0), # 流通市值 } # 获取近5日数据计算连涨天数 recent_5d = self.pro.daily(ts_code=stock_code, limit=5) if not recent_5d.empty: # 计算连续上涨天数 consecutive_up = 0 for _, row in recent_5d.iterrows(): if row['pct_chg'] > 0: consecutive_up += 1 else: break candidate['consecutive_up_days'] = consecutive_up # 计算5日平均成交额 candidate['avg_amount_5d'] = recent_5d['amount'].mean() else: candidate['consecutive_up_days'] = 0 candidate['avg_amount_5d'] = candidate['amount'] leading_candidates.append(candidate) time.sleep(0.1) # 避免调用过于频繁 except Exception as e: logger.debug(f"处理股票 {stock_code} 时出错: {e}") continue if not leading_candidates: logger.warning(f"板块 {board_code} 没有找到符合条件的龙头股") return pd.DataFrame() # 3. 计算龙头股评分 df = pd.DataFrame(leading_candidates) df = self._calculate_leading_score(df) # 4. 排序并返回前N个 df = df.sort_values('leading_score', ascending=False).head(top_n) df['rank'] = range(1, len(df) + 1) logger.info(f"板块 {board_name or board_code} 筛选出 {len(df)} 只龙头股") return df except Exception as e: logger.error(f"从板块 {board_code} 筛选龙头股失败: {e}") return pd.DataFrame() def _calculate_leading_score(self, df: pd.DataFrame) -> pd.DataFrame: """ 计算龙头股评分 Args: df: 候选股票DataFrame Returns: 包含评分的DataFrame """ try: if df.empty: return df df = df.copy() # 标准化各项指标到0-100分 # 1. 涨幅得分 (30%) df['pct_chg_score'] = self._normalize_score(df['pct_chg'], weight=30) # 2. 成交额得分 (25%) df['amount_score'] = self._normalize_score(df['avg_amount_5d'], weight=25) # 3. 连续上涨天数得分 (20%) df['consecutive_score'] = df['consecutive_up_days'] * 4 # 每天4分,最高20分 df['consecutive_score'] = df['consecutive_score'].clip(upper=20) # 4. 换手率得分 (15%) - 适中的换手率更好 optimal_turnover = 8 # 最优换手率8% df['turnover_score'] = 15 - abs(df['turnover_rate'] - optimal_turnover) df['turnover_score'] = df['turnover_score'].clip(lower=0, upper=15) # 5. 市值得分 (10%) - 流通市值适中更好 # 50-500亿为最佳区间 df['mv_score'] = df['circ_mv'].apply(lambda x: self._get_mv_score(x)) # 综合评分 df['leading_score'] = ( df['pct_chg_score'] + df['amount_score'] + df['consecutive_score'] + df['turnover_score'] + df['mv_score'] ) # 添加评级 df['leading_grade'] = df['leading_score'].apply(self._get_leading_grade) return df except Exception as e: logger.error(f"计算龙头股评分失败: {e}") return df def _normalize_score(self, series: pd.Series, weight: int = 100) -> pd.Series: """ 将数据标准化为指定权重的得分 Args: series: 原始数据序列 weight: 权重分数 Returns: 标准化后的得分序列 """ if series.empty or series.max() == series.min(): return pd.Series([0] * len(series), index=series.index) # Min-Max标准化到0-1,然后乘以权重 normalized = (series - series.min()) / (series.max() - series.min()) return normalized * weight def _get_mv_score(self, market_value: float) -> float: """ 根据流通市值计算得分 Args: market_value: 流通市值(万元) Returns: 市值得分 """ if market_value <= 0: return 0 # 转换为亿元 mv_billion = market_value / 10000 if 50 <= mv_billion <= 500: # 最佳区间:50-500亿 return 10 elif 20 <= mv_billion < 50: # 较小但可接受:20-50亿 return 8 elif 500 < mv_billion <= 1000: # 较大但可接受:500-1000亿 return 6 elif 10 <= mv_billion < 20: # 偏小:10-20亿 return 4 elif mv_billion > 1000: # 太大:>1000亿 return 2 else: # 太小:<10亿 return 1 def _get_leading_grade(self, score: float) -> str: """ 根据评分获取评级 Args: score: 综合评分 Returns: 评级字符串 """ if score >= 80: return "A+ 超级龙头" elif score >= 70: return "A 优质龙头" elif score >= 60: return "B+ 潜力龙头" elif score >= 50: return "B 一般龙头" else: return "C 弱势股票" def get_leading_stocks_from_hot_boards(self, top_boards: int = 10, stocks_per_board: int = 2, min_score: float = 50.0) -> dict: """ 从热门板块中筛选龙头牛股(主要接口) Args: top_boards: 分析前N个热门板块 stocks_per_board: 每个板块选择的龙头股数量 min_score: 最低评分要求 Returns: 包含所有龙头股信息的字典 """ try: result = { 'scan_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'boards_analyzed': 0, 'total_leading_stocks': 0, 'leading_stocks_by_board': {}, 'all_leading_stocks': pd.DataFrame(), 'top_leading_stocks': pd.DataFrame() } # 1. 获取最强板块 logger.info(f"正在获取前 {top_boards} 个最强板块...") strongest_boards = self.get_strongest_concept_boards() if strongest_boards.empty: logger.warning("未能获取到强势板块数据") return result # 取前N个板块 target_boards = strongest_boards.head(top_boards) result['boards_analyzed'] = len(target_boards) # 2. 逐个分析每个板块的龙头股 all_leading_stocks = [] for idx, (_, board) in enumerate(target_boards.iterrows(), 1): board_code = board['ts_code'] board_name = board['name'] board_up_nums = board.get('up_nums', 0) board_pct_chg = board.get('pct_chg', 0) logger.info(f"[{idx}/{len(target_boards)}] 分析板块: {board_name} (涨停{board_up_nums}只, 涨幅{board_pct_chg:.2f}%)") # 从该板块筛选龙头股 board_leaders = self.get_leading_stocks_from_board( board_code, board_name, top_n=stocks_per_board ) if not board_leaders.empty: # 过滤低分股票 qualified_leaders = board_leaders[board_leaders['leading_score'] >= min_score] if not qualified_leaders.empty: # 添加板块信息 qualified_leaders['board_up_nums'] = board_up_nums qualified_leaders['board_pct_chg'] = board_pct_chg qualified_leaders['board_rank'] = idx result['leading_stocks_by_board'][board_name] = qualified_leaders all_leading_stocks.append(qualified_leaders) logger.info(f" ✅ 找到 {len(qualified_leaders)} 只龙头股") else: logger.info(f" ❌ 无符合评分要求的龙头股") else: logger.info(f" ❌ 板块数据获取失败") # 避免API限制 time.sleep(0.5) # 3. 汇总所有龙头股 if all_leading_stocks: all_df = pd.concat(all_leading_stocks, ignore_index=True) result['all_leading_stocks'] = all_df result['total_leading_stocks'] = len(all_df) # 4. 获取综合排名前N的超级龙头 top_leaders = all_df.nlargest(20, 'leading_score') top_leaders['overall_rank'] = range(1, len(top_leaders) + 1) result['top_leading_stocks'] = top_leaders logger.info(f"🎯 筛选完成! 共分析 {result['boards_analyzed']} 个板块,发现 {result['total_leading_stocks']} 只龙头股") logger.info(f"📈 TOP10 超级龙头:") for _, stock in top_leaders.head(10).iterrows(): logger.info(f" {stock['overall_rank']}. {stock['stock_code']} {stock['name']} | {stock['board_name']} | 评分:{stock['leading_score']:.1f} | {stock['leading_grade']}") return result except Exception as e: logger.error(f"筛选热门板块龙头股失败: {e}") result['error'] = str(e) return result def get_top_money_flow_sectors(self, trade_date: str = None, top_n: int = 10) -> dict: """ 获取当日资金净流入最多的板块 Args: trade_date: 交易日期,格式YYYYMMDD,默认为当日 top_n: 返回前N个板块,默认10个 Returns: 包含行业板块和概念板块的字典 """ try: result = { 'trade_date': trade_date or datetime.now().strftime('%Y%m%d'), 'sectors': pd.DataFrame(), 'concepts': pd.DataFrame(), 'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S') } # 获取行业板块资金流向 sectors_df = self.get_sector_money_flow(trade_date) if not sectors_df.empty: result['sectors'] = sectors_df.head(top_n) logger.info(f"获取行业板块TOP{top_n}成功") # 获取概念板块资金流向 concepts_df = self.get_concept_money_flow(trade_date) if not concepts_df.empty: result['concepts'] = concepts_df.head(top_n) logger.info(f"获取概念板块TOP{top_n}成功") return result except Exception as e: logger.error(f"获取TOP资金流向板块失败: {e}") return { 'trade_date': trade_date or datetime.now().strftime('%Y%m%d'), 'sectors': pd.DataFrame(), 'concepts': pd.DataFrame(), 'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'error': str(e) } # ADataFetcher别名已移除,请直接使用TushareFetcher if __name__ == "__main__": # 测试代码 token = "0ed6419a00d8923dc19c0b58fc92d94c9a0696949ab91a13aa58a0cc" fetcher = TushareFetcher(token) # 测试获取股票列表 print("测试获取股票列表...") stock_list = fetcher.get_stock_list() print(f"股票数量: {len(stock_list)}") if not stock_list.empty: print(stock_list.head()) # 测试搜索功能 print("\n测试搜索功能...") search_results = fetcher.search_stocks("平安") if not search_results.empty: print(search_results.head())