""" 股票池管理器 负责根据不同规则获取和管理股票池 """ from typing import List, Dict, Any, Optional import pandas as pd from loguru import logger from abc import ABC, abstractmethod from src.data.tushare_fetcher import TushareFetcher class StockPoolRule(ABC): """股票池规则抽象基类""" @abstractmethod def get_stocks(self, fetcher: TushareFetcher, **kwargs) -> List[str]: """ 获取股票列表 Args: fetcher: 数据获取器 **kwargs: 规则参数 Returns: 股票代码列表 """ pass @abstractmethod def get_rule_name(self) -> str: """获取规则名称""" pass class TushareHotStocksRule(StockPoolRule): """同花顺热榜股票池规则""" def get_stocks(self, fetcher: TushareFetcher, limit: int = None, **kwargs) -> List[str]: """获取同花顺热榜股票""" try: # 如果没有指定limit,获取所有可用的热榜股票(默认最大2000) if limit is None: limit = 2000 hot_stocks = fetcher.get_hot_stocks_ths(limit=limit) if not hot_stocks.empty and 'stock_code' in hot_stocks.columns: stocks = hot_stocks['stock_code'].tolist() logger.info(f"✅ 同花顺热榜获取成功: {len(stocks)}只股票") return stocks else: logger.warning("同花顺热榜数据为空") return [] except Exception as e: logger.error(f"获取同花顺热榜失败: {e}") return [] def get_rule_name(self) -> str: return "同花顺热榜" class CombinedHotStocksRule(StockPoolRule): """合并热门股票池规则(同花顺+东财)""" def get_stocks(self, fetcher: TushareFetcher, limit_per_source: int = 30, final_limit: int = 50, **kwargs) -> List[str]: """获取合并热门股票""" try: combined_stocks = fetcher.get_combined_hot_stocks( limit_per_source=limit_per_source, final_limit=final_limit ) if not combined_stocks.empty and 'stock_code' in combined_stocks.columns: stocks = combined_stocks['stock_code'].tolist() logger.info(f"✅ 合并热门股票获取成功: {len(stocks)}只股票") return stocks else: logger.warning("合并热门股票数据为空") return [] except Exception as e: logger.error(f"获取合并热门股票失败: {e}") return [] def get_rule_name(self) -> str: return "合并热门股票" class LeadingStocksRule(StockPoolRule): """龙头牛股股票池规则""" def get_stocks(self, fetcher: TushareFetcher, top_boards: int = 8, stocks_per_board: int = 3, min_score: float = 60.0, **kwargs) -> List[str]: """获取龙头牛股""" try: result = fetcher.get_leading_stocks_from_hot_boards( top_boards=top_boards, stocks_per_board=stocks_per_board, min_score=min_score ) if 'error' not in result and not result['top_leading_stocks'].empty: stocks = result['top_leading_stocks']['stock_code'].tolist() logger.info(f"✅ 龙头牛股获取成功: {len(stocks)}只股票") return stocks else: logger.warning("龙头牛股数据为空") return [] except Exception as e: logger.error(f"获取龙头牛股失败: {e}") return [] def get_rule_name(self) -> str: return "龙头牛股" class CustomStockListRule(StockPoolRule): """自定义股票列表规则""" def __init__(self, stock_list: List[str]): self.stock_list = stock_list def get_stocks(self, fetcher: TushareFetcher, **kwargs) -> List[str]: """返回自定义股票列表""" logger.info(f"✅ 使用自定义股票列表: {len(self.stock_list)}只股票") return self.stock_list.copy() def get_rule_name(self) -> str: return "自定义股票列表" class StockPoolManager: """股票池管理器""" def __init__(self, fetcher: TushareFetcher): """ 初始化股票池管理器 Args: fetcher: TuShare数据获取器 """ self.fetcher = fetcher self.rules: Dict[str, StockPoolRule] = {} self._register_default_rules() def _register_default_rules(self): """注册默认规则""" self.register_rule("tushare_hot", TushareHotStocksRule()) self.register_rule("combined_hot", CombinedHotStocksRule()) self.register_rule("leading_stocks", LeadingStocksRule()) def register_rule(self, rule_name: str, rule: StockPoolRule): """ 注册股票池规则 Args: rule_name: 规则名称 rule: 规则实例 """ self.rules[rule_name] = rule logger.info(f"注册股票池规则: {rule_name} - {rule.get_rule_name()}") def get_stock_pool(self, rule_name: str, **kwargs) -> Dict[str, Any]: """ 根据规则获取股票池 Args: rule_name: 规则名称 **kwargs: 规则参数 Returns: 包含股票列表和元信息的字典 """ if rule_name not in self.rules: logger.error(f"未找到股票池规则: {rule_name}") return { 'stocks': [], 'rule_name': rule_name, 'rule_display_name': '未知规则', 'total_count': 0, 'success': False, 'error': f'未找到规则: {rule_name}' } rule = self.rules[rule_name] try: logger.info(f"🔍 执行股票池规则: {rule.get_rule_name()}") stocks = rule.get_stocks(self.fetcher, **kwargs) return { 'stocks': stocks, 'rule_name': rule_name, 'rule_display_name': rule.get_rule_name(), 'total_count': len(stocks), 'success': True, 'parameters': kwargs } except Exception as e: logger.error(f"执行股票池规则失败 {rule.get_rule_name()}: {e}") return { 'stocks': [], 'rule_name': rule_name, 'rule_display_name': rule.get_rule_name(), 'total_count': 0, 'success': False, 'error': str(e) } def get_available_rules(self) -> Dict[str, str]: """ 获取可用的规则列表 Returns: 规则名称到显示名称的映射 """ return {name: rule.get_rule_name() for name, rule in self.rules.items()} def create_custom_rule(self, rule_name: str, stock_list: List[str]): """ 创建自定义股票列表规则 Args: rule_name: 规则名称 stock_list: 股票代码列表 """ custom_rule = CustomStockListRule(stock_list) self.register_rule(rule_name, custom_rule) if __name__ == "__main__": # 测试股票池管理器 from loguru import logger import sys logger.remove() logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}") # 初始化 fetcher = TushareFetcher() pool_manager = StockPoolManager(fetcher) print("=" * 60) print("📊 股票池管理器测试") print("=" * 60) # 显示可用规则 print("可用规则:") for rule_id, rule_name in pool_manager.get_available_rules().items(): print(f" {rule_id}: {rule_name}") # 测试同花顺热榜 print(f"\n测试同花顺热榜:") result = pool_manager.get_stock_pool("tushare_hot", limit=10) if result['success']: print(f"✅ 获取成功: {result['total_count']}只股票") print(f"前5只: {result['stocks'][:5]}") else: print(f"❌ 获取失败: {result['error']}") # 测试自定义规则 print(f"\n测试自定义规则:") custom_stocks = ["000001.SZ", "000002.SZ", "600000.SH"] pool_manager.create_custom_rule("my_custom", custom_stocks) result = pool_manager.get_stock_pool("my_custom") if result['success']: print(f"✅ 自定义规则: {result['total_count']}只股票") print(f"股票: {result['stocks']}")