trading.ai/src/data/stock_pool_manager.py
2025-11-02 10:41:17 +08:00

266 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
股票池管理器
负责根据不同规则获取和管理股票池
"""
from typing import List, Dict, Any, Optional
import pandas as pd
from loguru import logger
from abc import ABC, abstractmethod
from src.data.tushare_fetcher import TushareFetcher
class StockPoolRule(ABC):
"""股票池规则抽象基类"""
@abstractmethod
def get_stocks(self, fetcher: TushareFetcher, **kwargs) -> List[str]:
"""
获取股票列表
Args:
fetcher: 数据获取器
**kwargs: 规则参数
Returns:
股票代码列表
"""
pass
@abstractmethod
def get_rule_name(self) -> str:
"""获取规则名称"""
pass
class TushareHotStocksRule(StockPoolRule):
"""同花顺热榜股票池规则"""
def get_stocks(self, fetcher: TushareFetcher, limit: int = None, **kwargs) -> List[str]:
"""获取同花顺热榜股票"""
try:
# 如果没有指定limit获取所有可用的热榜股票默认最大2000
if limit is None:
limit = 2000
hot_stocks = fetcher.get_hot_stocks_ths(limit=limit)
if not hot_stocks.empty and 'stock_code' in hot_stocks.columns:
stocks = hot_stocks['stock_code'].tolist()
logger.info(f"✅ 同花顺热榜获取成功: {len(stocks)}只股票")
return stocks
else:
logger.warning("同花顺热榜数据为空")
return []
except Exception as e:
logger.error(f"获取同花顺热榜失败: {e}")
return []
def get_rule_name(self) -> str:
return "同花顺热榜"
class CombinedHotStocksRule(StockPoolRule):
"""合并热门股票池规则(同花顺+东财)"""
def get_stocks(self, fetcher: TushareFetcher, limit_per_source: int = 30, final_limit: int = 50, **kwargs) -> List[str]:
"""获取合并热门股票"""
try:
combined_stocks = fetcher.get_combined_hot_stocks(
limit_per_source=limit_per_source,
final_limit=final_limit
)
if not combined_stocks.empty and 'stock_code' in combined_stocks.columns:
stocks = combined_stocks['stock_code'].tolist()
logger.info(f"✅ 合并热门股票获取成功: {len(stocks)}只股票")
return stocks
else:
logger.warning("合并热门股票数据为空")
return []
except Exception as e:
logger.error(f"获取合并热门股票失败: {e}")
return []
def get_rule_name(self) -> str:
return "合并热门股票"
class LeadingStocksRule(StockPoolRule):
"""龙头牛股股票池规则"""
def get_stocks(self, fetcher: TushareFetcher, top_boards: int = 8, stocks_per_board: int = 3, min_score: float = 60.0, **kwargs) -> List[str]:
"""获取龙头牛股"""
try:
result = fetcher.get_leading_stocks_from_hot_boards(
top_boards=top_boards,
stocks_per_board=stocks_per_board,
min_score=min_score
)
if 'error' not in result and not result['top_leading_stocks'].empty:
stocks = result['top_leading_stocks']['stock_code'].tolist()
logger.info(f"✅ 龙头牛股获取成功: {len(stocks)}只股票")
return stocks
else:
logger.warning("龙头牛股数据为空")
return []
except Exception as e:
logger.error(f"获取龙头牛股失败: {e}")
return []
def get_rule_name(self) -> str:
return "龙头牛股"
class CustomStockListRule(StockPoolRule):
"""自定义股票列表规则"""
def __init__(self, stock_list: List[str]):
self.stock_list = stock_list
def get_stocks(self, fetcher: TushareFetcher, **kwargs) -> List[str]:
"""返回自定义股票列表"""
logger.info(f"✅ 使用自定义股票列表: {len(self.stock_list)}只股票")
return self.stock_list.copy()
def get_rule_name(self) -> str:
return "自定义股票列表"
class StockPoolManager:
"""股票池管理器"""
def __init__(self, fetcher: TushareFetcher):
"""
初始化股票池管理器
Args:
fetcher: TuShare数据获取器
"""
self.fetcher = fetcher
self.rules: Dict[str, StockPoolRule] = {}
self._register_default_rules()
def _register_default_rules(self):
"""注册默认规则"""
self.register_rule("tushare_hot", TushareHotStocksRule())
self.register_rule("combined_hot", CombinedHotStocksRule())
self.register_rule("leading_stocks", LeadingStocksRule())
def register_rule(self, rule_name: str, rule: StockPoolRule):
"""
注册股票池规则
Args:
rule_name: 规则名称
rule: 规则实例
"""
self.rules[rule_name] = rule
logger.info(f"注册股票池规则: {rule_name} - {rule.get_rule_name()}")
def get_stock_pool(self, rule_name: str, **kwargs) -> Dict[str, Any]:
"""
根据规则获取股票池
Args:
rule_name: 规则名称
**kwargs: 规则参数
Returns:
包含股票列表和元信息的字典
"""
if rule_name not in self.rules:
logger.error(f"未找到股票池规则: {rule_name}")
return {
'stocks': [],
'rule_name': rule_name,
'rule_display_name': '未知规则',
'total_count': 0,
'success': False,
'error': f'未找到规则: {rule_name}'
}
rule = self.rules[rule_name]
try:
logger.info(f"🔍 执行股票池规则: {rule.get_rule_name()}")
stocks = rule.get_stocks(self.fetcher, **kwargs)
return {
'stocks': stocks,
'rule_name': rule_name,
'rule_display_name': rule.get_rule_name(),
'total_count': len(stocks),
'success': True,
'parameters': kwargs
}
except Exception as e:
logger.error(f"执行股票池规则失败 {rule.get_rule_name()}: {e}")
return {
'stocks': [],
'rule_name': rule_name,
'rule_display_name': rule.get_rule_name(),
'total_count': 0,
'success': False,
'error': str(e)
}
def get_available_rules(self) -> Dict[str, str]:
"""
获取可用的规则列表
Returns:
规则名称到显示名称的映射
"""
return {name: rule.get_rule_name() for name, rule in self.rules.items()}
def create_custom_rule(self, rule_name: str, stock_list: List[str]):
"""
创建自定义股票列表规则
Args:
rule_name: 规则名称
stock_list: 股票代码列表
"""
custom_rule = CustomStockListRule(stock_list)
self.register_rule(rule_name, custom_rule)
if __name__ == "__main__":
# 测试股票池管理器
from loguru import logger
import sys
logger.remove()
logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}")
# 初始化
fetcher = TushareFetcher()
pool_manager = StockPoolManager(fetcher)
print("=" * 60)
print("📊 股票池管理器测试")
print("=" * 60)
# 显示可用规则
print("可用规则:")
for rule_id, rule_name in pool_manager.get_available_rules().items():
print(f" {rule_id}: {rule_name}")
# 测试同花顺热榜
print(f"\n测试同花顺热榜:")
result = pool_manager.get_stock_pool("tushare_hot", limit=10)
if result['success']:
print(f"✅ 获取成功: {result['total_count']}只股票")
print(f"前5只: {result['stocks'][:5]}")
else:
print(f"❌ 获取失败: {result['error']}")
# 测试自定义规则
print(f"\n测试自定义规则:")
custom_stocks = ["000001.SZ", "000002.SZ", "600000.SH"]
pool_manager.create_custom_rule("my_custom", custom_stocks)
result = pool_manager.get_stock_pool("my_custom")
if result['success']:
print(f"✅ 自定义规则: {result['total_count']}只股票")
print(f"股票: {result['stocks']}")