266 lines
8.4 KiB
Python
266 lines
8.4 KiB
Python
"""
|
||
股票池管理器
|
||
负责根据不同规则获取和管理股票池
|
||
"""
|
||
|
||
from typing import List, Dict, Any, Optional
|
||
import pandas as pd
|
||
from loguru import logger
|
||
from abc import ABC, abstractmethod
|
||
from src.data.tushare_fetcher import TushareFetcher
|
||
|
||
|
||
class StockPoolRule(ABC):
|
||
"""股票池规则抽象基类"""
|
||
|
||
@abstractmethod
|
||
def get_stocks(self, fetcher: TushareFetcher, **kwargs) -> List[str]:
|
||
"""
|
||
获取股票列表
|
||
|
||
Args:
|
||
fetcher: 数据获取器
|
||
**kwargs: 规则参数
|
||
|
||
Returns:
|
||
股票代码列表
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def get_rule_name(self) -> str:
|
||
"""获取规则名称"""
|
||
pass
|
||
|
||
|
||
class TushareHotStocksRule(StockPoolRule):
|
||
"""同花顺热榜股票池规则"""
|
||
|
||
def get_stocks(self, fetcher: TushareFetcher, limit: int = None, **kwargs) -> List[str]:
|
||
"""获取同花顺热榜股票"""
|
||
try:
|
||
# 如果没有指定limit,获取所有可用的热榜股票(默认最大2000)
|
||
if limit is None:
|
||
limit = 2000
|
||
hot_stocks = fetcher.get_hot_stocks_ths(limit=limit)
|
||
if not hot_stocks.empty and 'stock_code' in hot_stocks.columns:
|
||
stocks = hot_stocks['stock_code'].tolist()
|
||
logger.info(f"✅ 同花顺热榜获取成功: {len(stocks)}只股票")
|
||
return stocks
|
||
else:
|
||
logger.warning("同花顺热榜数据为空")
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"获取同花顺热榜失败: {e}")
|
||
return []
|
||
|
||
def get_rule_name(self) -> str:
|
||
return "同花顺热榜"
|
||
|
||
|
||
class CombinedHotStocksRule(StockPoolRule):
|
||
"""合并热门股票池规则(同花顺+东财)"""
|
||
|
||
def get_stocks(self, fetcher: TushareFetcher, limit_per_source: int = 30, final_limit: int = 50, **kwargs) -> List[str]:
|
||
"""获取合并热门股票"""
|
||
try:
|
||
combined_stocks = fetcher.get_combined_hot_stocks(
|
||
limit_per_source=limit_per_source,
|
||
final_limit=final_limit
|
||
)
|
||
if not combined_stocks.empty and 'stock_code' in combined_stocks.columns:
|
||
stocks = combined_stocks['stock_code'].tolist()
|
||
logger.info(f"✅ 合并热门股票获取成功: {len(stocks)}只股票")
|
||
return stocks
|
||
else:
|
||
logger.warning("合并热门股票数据为空")
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"获取合并热门股票失败: {e}")
|
||
return []
|
||
|
||
def get_rule_name(self) -> str:
|
||
return "合并热门股票"
|
||
|
||
|
||
class LeadingStocksRule(StockPoolRule):
|
||
"""龙头牛股股票池规则"""
|
||
|
||
def get_stocks(self, fetcher: TushareFetcher, top_boards: int = 8, stocks_per_board: int = 3, min_score: float = 60.0, **kwargs) -> List[str]:
|
||
"""获取龙头牛股"""
|
||
try:
|
||
result = fetcher.get_leading_stocks_from_hot_boards(
|
||
top_boards=top_boards,
|
||
stocks_per_board=stocks_per_board,
|
||
min_score=min_score
|
||
)
|
||
|
||
if 'error' not in result and not result['top_leading_stocks'].empty:
|
||
stocks = result['top_leading_stocks']['stock_code'].tolist()
|
||
logger.info(f"✅ 龙头牛股获取成功: {len(stocks)}只股票")
|
||
return stocks
|
||
else:
|
||
logger.warning("龙头牛股数据为空")
|
||
return []
|
||
except Exception as e:
|
||
logger.error(f"获取龙头牛股失败: {e}")
|
||
return []
|
||
|
||
def get_rule_name(self) -> str:
|
||
return "龙头牛股"
|
||
|
||
|
||
class CustomStockListRule(StockPoolRule):
|
||
"""自定义股票列表规则"""
|
||
|
||
def __init__(self, stock_list: List[str]):
|
||
self.stock_list = stock_list
|
||
|
||
def get_stocks(self, fetcher: TushareFetcher, **kwargs) -> List[str]:
|
||
"""返回自定义股票列表"""
|
||
logger.info(f"✅ 使用自定义股票列表: {len(self.stock_list)}只股票")
|
||
return self.stock_list.copy()
|
||
|
||
def get_rule_name(self) -> str:
|
||
return "自定义股票列表"
|
||
|
||
|
||
class StockPoolManager:
|
||
"""股票池管理器"""
|
||
|
||
def __init__(self, fetcher: TushareFetcher):
|
||
"""
|
||
初始化股票池管理器
|
||
|
||
Args:
|
||
fetcher: TuShare数据获取器
|
||
"""
|
||
self.fetcher = fetcher
|
||
self.rules: Dict[str, StockPoolRule] = {}
|
||
self._register_default_rules()
|
||
|
||
def _register_default_rules(self):
|
||
"""注册默认规则"""
|
||
self.register_rule("tushare_hot", TushareHotStocksRule())
|
||
self.register_rule("combined_hot", CombinedHotStocksRule())
|
||
self.register_rule("leading_stocks", LeadingStocksRule())
|
||
|
||
def register_rule(self, rule_name: str, rule: StockPoolRule):
|
||
"""
|
||
注册股票池规则
|
||
|
||
Args:
|
||
rule_name: 规则名称
|
||
rule: 规则实例
|
||
"""
|
||
self.rules[rule_name] = rule
|
||
logger.info(f"注册股票池规则: {rule_name} - {rule.get_rule_name()}")
|
||
|
||
def get_stock_pool(self, rule_name: str, **kwargs) -> Dict[str, Any]:
|
||
"""
|
||
根据规则获取股票池
|
||
|
||
Args:
|
||
rule_name: 规则名称
|
||
**kwargs: 规则参数
|
||
|
||
Returns:
|
||
包含股票列表和元信息的字典
|
||
"""
|
||
if rule_name not in self.rules:
|
||
logger.error(f"未找到股票池规则: {rule_name}")
|
||
return {
|
||
'stocks': [],
|
||
'rule_name': rule_name,
|
||
'rule_display_name': '未知规则',
|
||
'total_count': 0,
|
||
'success': False,
|
||
'error': f'未找到规则: {rule_name}'
|
||
}
|
||
|
||
rule = self.rules[rule_name]
|
||
|
||
try:
|
||
logger.info(f"🔍 执行股票池规则: {rule.get_rule_name()}")
|
||
stocks = rule.get_stocks(self.fetcher, **kwargs)
|
||
|
||
return {
|
||
'stocks': stocks,
|
||
'rule_name': rule_name,
|
||
'rule_display_name': rule.get_rule_name(),
|
||
'total_count': len(stocks),
|
||
'success': True,
|
||
'parameters': kwargs
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"执行股票池规则失败 {rule.get_rule_name()}: {e}")
|
||
return {
|
||
'stocks': [],
|
||
'rule_name': rule_name,
|
||
'rule_display_name': rule.get_rule_name(),
|
||
'total_count': 0,
|
||
'success': False,
|
||
'error': str(e)
|
||
}
|
||
|
||
def get_available_rules(self) -> Dict[str, str]:
|
||
"""
|
||
获取可用的规则列表
|
||
|
||
Returns:
|
||
规则名称到显示名称的映射
|
||
"""
|
||
return {name: rule.get_rule_name() for name, rule in self.rules.items()}
|
||
|
||
def create_custom_rule(self, rule_name: str, stock_list: List[str]):
|
||
"""
|
||
创建自定义股票列表规则
|
||
|
||
Args:
|
||
rule_name: 规则名称
|
||
stock_list: 股票代码列表
|
||
"""
|
||
custom_rule = CustomStockListRule(stock_list)
|
||
self.register_rule(rule_name, custom_rule)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 测试股票池管理器
|
||
from loguru import logger
|
||
import sys
|
||
|
||
logger.remove()
|
||
logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}")
|
||
|
||
# 初始化
|
||
fetcher = TushareFetcher()
|
||
pool_manager = StockPoolManager(fetcher)
|
||
|
||
print("=" * 60)
|
||
print("📊 股票池管理器测试")
|
||
print("=" * 60)
|
||
|
||
# 显示可用规则
|
||
print("可用规则:")
|
||
for rule_id, rule_name in pool_manager.get_available_rules().items():
|
||
print(f" {rule_id}: {rule_name}")
|
||
|
||
# 测试同花顺热榜
|
||
print(f"\n测试同花顺热榜:")
|
||
result = pool_manager.get_stock_pool("tushare_hot", limit=10)
|
||
if result['success']:
|
||
print(f"✅ 获取成功: {result['total_count']}只股票")
|
||
print(f"前5只: {result['stocks'][:5]}")
|
||
else:
|
||
print(f"❌ 获取失败: {result['error']}")
|
||
|
||
# 测试自定义规则
|
||
print(f"\n测试自定义规则:")
|
||
custom_stocks = ["000001.SZ", "000002.SZ", "600000.SH"]
|
||
pool_manager.create_custom_rule("my_custom", custom_stocks)
|
||
|
||
result = pool_manager.get_stock_pool("my_custom")
|
||
if result['success']:
|
||
print(f"✅ 自定义规则: {result['total_count']}只股票")
|
||
print(f"股票: {result['stocks']}") |