256 lines
8.4 KiB
Python
256 lines
8.4 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
测试数据库集成和策略存储功能
|
||
"""
|
||
|
||
import sys
|
||
from pathlib import Path
|
||
from datetime import datetime, date
|
||
|
||
# 添加src目录到路径
|
||
current_dir = Path(__file__).parent
|
||
src_dir = current_dir / "src"
|
||
sys.path.insert(0, str(src_dir))
|
||
|
||
from loguru import logger
|
||
from src.database.database_manager import DatabaseManager
|
||
from src.utils.config_loader import ConfigLoader
|
||
from src.data.data_fetcher import ADataFetcher
|
||
from src.utils.notification import NotificationManager
|
||
from src.strategy.kline_pattern_strategy import KLinePatternStrategy
|
||
|
||
|
||
def test_database_operations():
|
||
"""测试数据库基本操作"""
|
||
logger.info("🗄️ 测试数据库基本操作...")
|
||
|
||
try:
|
||
# 初始化数据库管理器
|
||
db_manager = DatabaseManager()
|
||
|
||
# 测试策略统计
|
||
strategy_stats = db_manager.get_strategy_stats()
|
||
logger.info(f"📊 策略统计记录数: {len(strategy_stats)}")
|
||
|
||
# 测试最新信号
|
||
latest_signals = db_manager.get_latest_signals(limit=10)
|
||
logger.info(f"📈 最新信号记录数: {len(latest_signals)}")
|
||
|
||
# 测试日期范围查询
|
||
start_date = date.today()
|
||
signals_by_date = db_manager.get_signals_by_date_range(start_date)
|
||
logger.info(f"🗓️ 今日信号记录数: {len(signals_by_date)}")
|
||
|
||
# 测试回踩提醒
|
||
pullback_alerts = db_manager.get_pullback_alerts(days=7)
|
||
logger.info(f"⚠️ 最近7天回踩提醒: {len(pullback_alerts)}")
|
||
|
||
logger.info("✅ 数据库基本操作测试完成")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 数据库操作测试失败: {e}")
|
||
return False
|
||
|
||
|
||
def test_strategy_integration():
|
||
"""测试策略与数据库集成"""
|
||
logger.info("🔄 测试策略与数据库集成...")
|
||
|
||
try:
|
||
# 初始化组件
|
||
config_loader = ConfigLoader()
|
||
config = config_loader.load_config()
|
||
|
||
data_fetcher = ADataFetcher()
|
||
notification_manager = NotificationManager(config.get('notification', {}))
|
||
db_manager = DatabaseManager()
|
||
|
||
# 初始化策略(自动创建数据库记录)
|
||
kline_config = config.get('strategy', {}).get('kline_pattern', {})
|
||
strategy = KLinePatternStrategy(
|
||
data_fetcher=data_fetcher,
|
||
notification_manager=notification_manager,
|
||
config=kline_config,
|
||
db_manager=db_manager
|
||
)
|
||
|
||
logger.info(f"📋 策略ID: {strategy.strategy_id}")
|
||
logger.info(f"📝 策略名称: {strategy.strategy_name}")
|
||
|
||
# 测试分析单只股票(会自动保存到数据库)
|
||
test_stock = "000001.SZ"
|
||
logger.info(f"🔍 测试分析股票: {test_stock}")
|
||
|
||
stock_results = strategy.analyze_stock(test_stock, days=30)
|
||
total_signals = sum(len(signals) for signals in stock_results.values())
|
||
|
||
logger.info(f"📊 分析结果: {total_signals} 个信号")
|
||
|
||
# 验证数据库中的记录
|
||
latest_signals = db_manager.get_latest_signals(strategy_name=strategy.strategy_name, limit=10)
|
||
logger.info(f"💾 数据库中最新信号数: {len(latest_signals)}")
|
||
|
||
logger.info("✅ 策略与数据库集成测试完成")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 策略集成测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
|
||
def test_scan_market_with_database():
|
||
"""测试市场扫描与数据库存储"""
|
||
logger.info("🌍 测试市场扫描与数据库存储...")
|
||
|
||
try:
|
||
# 初始化组件
|
||
config_loader = ConfigLoader()
|
||
config = config_loader.load_config()
|
||
|
||
data_fetcher = ADataFetcher()
|
||
notification_manager = NotificationManager(config.get('notification', {}))
|
||
db_manager = DatabaseManager()
|
||
|
||
# 初始化策略
|
||
kline_config = config.get('strategy', {}).get('kline_pattern', {})
|
||
strategy = KLinePatternStrategy(
|
||
data_fetcher=data_fetcher,
|
||
notification_manager=notification_manager,
|
||
config=kline_config,
|
||
db_manager=db_manager
|
||
)
|
||
|
||
# 小规模市场扫描测试(限制5只股票)
|
||
logger.info("🔍 开始小规模市场扫描测试...")
|
||
test_stocks = ["000001.SZ", "000002.SZ", "600000.SH", "600036.SH", "000858.SZ"]
|
||
|
||
results = strategy.scan_market(
|
||
stock_list=test_stocks,
|
||
max_stocks=5,
|
||
use_hot_stocks=False
|
||
)
|
||
|
||
total_signals = sum(
|
||
sum(len(signals) for signals in stock_results.values())
|
||
for stock_results in results.values()
|
||
)
|
||
|
||
logger.info(f"📊 扫描完成: 发现 {total_signals} 个信号")
|
||
|
||
# 验证数据库存储
|
||
recent_signals = db_manager.get_latest_signals(
|
||
strategy_name=strategy.strategy_name,
|
||
limit=50
|
||
)
|
||
logger.info(f"💾 数据库中存储的信号数: {len(recent_signals)}")
|
||
|
||
# 显示最新的几个信号
|
||
if not recent_signals.empty:
|
||
logger.info("📋 最新信号示例:")
|
||
for i, signal in recent_signals.head(3).iterrows():
|
||
logger.info(f" {signal['stock_code']}({signal['stock_name']}) - {signal['breakout_price']:.2f}元")
|
||
|
||
logger.info("✅ 市场扫描与数据库存储测试完成")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 市场扫描测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
|
||
def test_database_queries():
|
||
"""测试数据库查询功能"""
|
||
logger.info("🔍 测试数据库查询功能...")
|
||
|
||
try:
|
||
db_manager = DatabaseManager()
|
||
|
||
# 测试策略统计
|
||
strategy_stats = db_manager.get_strategy_stats()
|
||
if not strategy_stats.empty:
|
||
logger.info("📊 策略统计:")
|
||
for _, stat in strategy_stats.iterrows():
|
||
logger.info(f" {stat['strategy_name']}: {stat['total_signals']}个信号, {stat['unique_stocks']}只股票")
|
||
|
||
# 测试按日期查询
|
||
today = date.today()
|
||
today_signals = db_manager.get_signals_by_date_range(today, today)
|
||
logger.info(f"📅 今日信号数: {len(today_signals)}")
|
||
|
||
# 测试获取策略ID
|
||
strategy_id = db_manager.get_strategy_id("K线形态策略")
|
||
logger.info(f"🆔 K线形态策略ID: {strategy_id}")
|
||
|
||
logger.info("✅ 数据库查询功能测试完成")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 数据库查询测试失败: {e}")
|
||
return False
|
||
|
||
|
||
def main():
|
||
"""主测试函数"""
|
||
logger.remove()
|
||
logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}")
|
||
|
||
print("=" * 70)
|
||
print("🧪 A股量化交易系统 - 数据库集成测试")
|
||
print("=" * 70)
|
||
|
||
test_results = []
|
||
|
||
# 运行测试
|
||
tests = [
|
||
("数据库基本操作", test_database_operations),
|
||
("策略与数据库集成", test_strategy_integration),
|
||
("数据库查询功能", test_database_queries),
|
||
("市场扫描与存储", test_scan_market_with_database),
|
||
]
|
||
|
||
for test_name, test_func in tests:
|
||
logger.info(f"\n🚀 开始测试: {test_name}")
|
||
try:
|
||
result = test_func()
|
||
test_results.append((test_name, result))
|
||
if result:
|
||
logger.info(f"✅ {test_name} 测试通过")
|
||
else:
|
||
logger.error(f"❌ {test_name} 测试失败")
|
||
except Exception as e:
|
||
logger.error(f"❌ {test_name} 测试异常: {e}")
|
||
test_results.append((test_name, False))
|
||
|
||
# 输出测试结果
|
||
print("\n" + "=" * 70)
|
||
print("📊 测试结果汇总:")
|
||
print("=" * 70)
|
||
|
||
passed = 0
|
||
total = len(test_results)
|
||
|
||
for test_name, result in test_results:
|
||
status = "✅ 通过" if result else "❌ 失败"
|
||
print(f" {test_name}: {status}")
|
||
if result:
|
||
passed += 1
|
||
|
||
print(f"\n🎯 总计: {passed}/{total} 个测试通过")
|
||
|
||
if passed == total:
|
||
print("🎉 所有测试都通过了!数据库集成功能正常工作。")
|
||
print("🌐 现在可以启动Web界面查看数据:")
|
||
print(" cd web && python app.py")
|
||
else:
|
||
print("⚠️ 部分测试失败,请检查错误信息并修复问题。")
|
||
|
||
print("=" * 70)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |