trading.ai/test_all_a_shares.py
2025-09-18 20:45:01 +08:00

93 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
测试所有A股股票扫描功能
"""
import sys
from pathlib import Path
# 添加项目根目录到路径
current_dir = Path(__file__).parent
sys.path.insert(0, str(current_dir))
from loguru import logger
from src.strategy.kline_pattern_strategy import KLinePatternStrategy
from src.data.data_fetcher import ADataFetcher
from src.utils.notification import NotificationManager
from src.database.database_manager import DatabaseManager
from src.utils.config_loader import ConfigLoader
def test_all_a_shares_scan():
"""测试全A股扫描功能"""
logger.remove()
logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}")
print("🧪 测试全A股扫描功能")
print("=" * 50)
try:
# 初始化组件
logger.info("初始化组件...")
config_loader = ConfigLoader()
config = config_loader.load_config()
data_fetcher = ADataFetcher()
notification_manager = NotificationManager(config.get('notification', {}))
db_manager = DatabaseManager()
# 初始化策略
kline_config = config.get('strategy', {}).get('kline_pattern', {})
strategy = KLinePatternStrategy(
data_fetcher=data_fetcher,
notification_manager=notification_manager,
config=kline_config,
db_manager=db_manager
)
# 测试过滤后的A股列表
logger.info("测试获取过滤后的A股列表...")
filtered_stocks = data_fetcher.get_filtered_a_share_list()
logger.info(f"过滤后的A股数量: {len(filtered_stocks)}")
if not filtered_stocks.empty:
# 显示样例
sample_stocks = filtered_stocks.head(5)
logger.info("A股样例:")
for _, stock in sample_stocks.iterrows():
logger.info(f" {stock['full_stock_code']} - {stock['short_name']} ({stock['exchange']})")
# 测试扫描全A股限制5只股票进行测试
logger.info("开始测试全A股扫描限制5只...")
results = strategy.scan_market(
max_stocks=5, # 限制5只股票进行测试
use_all_a_shares=True # 使用全A股模式
)
# 统计结果
total_signals = 0
for stock_code, stock_results in results.items():
stock_signals = sum(len(signals) for signals in stock_results.values())
total_signals += stock_signals
logger.info(f"股票 {stock_code}: {stock_signals}个信号")
logger.info(f"✅ 全A股扫描测试完成")
logger.info(f"📊 扫描股票数: {len(results)}")
logger.info(f"📈 发现信号数: {total_signals}")
print("\n" + "=" * 50)
print("🎯 测试结果:")
print(f" - 可用A股数量: {len(filtered_stocks)}")
print(f" - 扫描股票数量: 5只测试限制")
print(f" - 发现信号数量: {total_signals}")
print(" - 功能状态: ✅ 正常工作")
print("=" * 50)
except Exception as e:
logger.error(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
test_all_a_shares_scan()