trading.ai/test_database_integration.py
2025-09-18 20:45:01 +08:00

256 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
测试数据库集成和策略存储功能
"""
import sys
from pathlib import Path
from datetime import datetime, date
# 添加src目录到路径
current_dir = Path(__file__).parent
src_dir = current_dir / "src"
sys.path.insert(0, str(src_dir))
from loguru import logger
from src.database.database_manager import DatabaseManager
from src.utils.config_loader import ConfigLoader
from src.data.data_fetcher import ADataFetcher
from src.utils.notification import NotificationManager
from src.strategy.kline_pattern_strategy import KLinePatternStrategy
def test_database_operations():
"""测试数据库基本操作"""
logger.info("🗄️ 测试数据库基本操作...")
try:
# 初始化数据库管理器
db_manager = DatabaseManager()
# 测试策略统计
strategy_stats = db_manager.get_strategy_stats()
logger.info(f"📊 策略统计记录数: {len(strategy_stats)}")
# 测试最新信号
latest_signals = db_manager.get_latest_signals(limit=10)
logger.info(f"📈 最新信号记录数: {len(latest_signals)}")
# 测试日期范围查询
start_date = date.today()
signals_by_date = db_manager.get_signals_by_date_range(start_date)
logger.info(f"🗓️ 今日信号记录数: {len(signals_by_date)}")
# 测试回踩提醒
pullback_alerts = db_manager.get_pullback_alerts(days=7)
logger.info(f"⚠️ 最近7天回踩提醒: {len(pullback_alerts)}")
logger.info("✅ 数据库基本操作测试完成")
return True
except Exception as e:
logger.error(f"❌ 数据库操作测试失败: {e}")
return False
def test_strategy_integration():
"""测试策略与数据库集成"""
logger.info("🔄 测试策略与数据库集成...")
try:
# 初始化组件
config_loader = ConfigLoader()
config = config_loader.load_config()
data_fetcher = ADataFetcher()
notification_manager = NotificationManager(config.get('notification', {}))
db_manager = DatabaseManager()
# 初始化策略(自动创建数据库记录)
kline_config = config.get('strategy', {}).get('kline_pattern', {})
strategy = KLinePatternStrategy(
data_fetcher=data_fetcher,
notification_manager=notification_manager,
config=kline_config,
db_manager=db_manager
)
logger.info(f"📋 策略ID: {strategy.strategy_id}")
logger.info(f"📝 策略名称: {strategy.strategy_name}")
# 测试分析单只股票(会自动保存到数据库)
test_stock = "000001.SZ"
logger.info(f"🔍 测试分析股票: {test_stock}")
stock_results = strategy.analyze_stock(test_stock, days=30)
total_signals = sum(len(signals) for signals in stock_results.values())
logger.info(f"📊 分析结果: {total_signals} 个信号")
# 验证数据库中的记录
latest_signals = db_manager.get_latest_signals(strategy_name=strategy.strategy_name, limit=10)
logger.info(f"💾 数据库中最新信号数: {len(latest_signals)}")
logger.info("✅ 策略与数据库集成测试完成")
return True
except Exception as e:
logger.error(f"❌ 策略集成测试失败: {e}")
import traceback
traceback.print_exc()
return False
def test_scan_market_with_database():
"""测试市场扫描与数据库存储"""
logger.info("🌍 测试市场扫描与数据库存储...")
try:
# 初始化组件
config_loader = ConfigLoader()
config = config_loader.load_config()
data_fetcher = ADataFetcher()
notification_manager = NotificationManager(config.get('notification', {}))
db_manager = DatabaseManager()
# 初始化策略
kline_config = config.get('strategy', {}).get('kline_pattern', {})
strategy = KLinePatternStrategy(
data_fetcher=data_fetcher,
notification_manager=notification_manager,
config=kline_config,
db_manager=db_manager
)
# 小规模市场扫描测试限制5只股票
logger.info("🔍 开始小规模市场扫描测试...")
test_stocks = ["000001.SZ", "000002.SZ", "600000.SH", "600036.SH", "000858.SZ"]
results = strategy.scan_market(
stock_list=test_stocks,
max_stocks=5,
use_hot_stocks=False
)
total_signals = sum(
sum(len(signals) for signals in stock_results.values())
for stock_results in results.values()
)
logger.info(f"📊 扫描完成: 发现 {total_signals} 个信号")
# 验证数据库存储
recent_signals = db_manager.get_latest_signals(
strategy_name=strategy.strategy_name,
limit=50
)
logger.info(f"💾 数据库中存储的信号数: {len(recent_signals)}")
# 显示最新的几个信号
if not recent_signals.empty:
logger.info("📋 最新信号示例:")
for i, signal in recent_signals.head(3).iterrows():
logger.info(f" {signal['stock_code']}({signal['stock_name']}) - {signal['breakout_price']:.2f}")
logger.info("✅ 市场扫描与数据库存储测试完成")
return True
except Exception as e:
logger.error(f"❌ 市场扫描测试失败: {e}")
import traceback
traceback.print_exc()
return False
def test_database_queries():
"""测试数据库查询功能"""
logger.info("🔍 测试数据库查询功能...")
try:
db_manager = DatabaseManager()
# 测试策略统计
strategy_stats = db_manager.get_strategy_stats()
if not strategy_stats.empty:
logger.info("📊 策略统计:")
for _, stat in strategy_stats.iterrows():
logger.info(f" {stat['strategy_name']}: {stat['total_signals']}个信号, {stat['unique_stocks']}只股票")
# 测试按日期查询
today = date.today()
today_signals = db_manager.get_signals_by_date_range(today, today)
logger.info(f"📅 今日信号数: {len(today_signals)}")
# 测试获取策略ID
strategy_id = db_manager.get_strategy_id("K线形态策略")
logger.info(f"🆔 K线形态策略ID: {strategy_id}")
logger.info("✅ 数据库查询功能测试完成")
return True
except Exception as e:
logger.error(f"❌ 数据库查询测试失败: {e}")
return False
def main():
"""主测试函数"""
logger.remove()
logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}")
print("=" * 70)
print("🧪 A股量化交易系统 - 数据库集成测试")
print("=" * 70)
test_results = []
# 运行测试
tests = [
("数据库基本操作", test_database_operations),
("策略与数据库集成", test_strategy_integration),
("数据库查询功能", test_database_queries),
("市场扫描与存储", test_scan_market_with_database),
]
for test_name, test_func in tests:
logger.info(f"\n🚀 开始测试: {test_name}")
try:
result = test_func()
test_results.append((test_name, result))
if result:
logger.info(f"{test_name} 测试通过")
else:
logger.error(f"{test_name} 测试失败")
except Exception as e:
logger.error(f"{test_name} 测试异常: {e}")
test_results.append((test_name, False))
# 输出测试结果
print("\n" + "=" * 70)
print("📊 测试结果汇总:")
print("=" * 70)
passed = 0
total = len(test_results)
for test_name, result in test_results:
status = "✅ 通过" if result else "❌ 失败"
print(f" {test_name}: {status}")
if result:
passed += 1
print(f"\n🎯 总计: {passed}/{total} 个测试通过")
if passed == total:
print("🎉 所有测试都通过了!数据库集成功能正常工作。")
print("🌐 现在可以启动Web界面查看数据:")
print(" cd web && python app.py")
else:
print("⚠️ 部分测试失败,请检查错误信息并修复问题。")
print("=" * 70)
if __name__ == "__main__":
main()