trading.ai/generate_test_data.py
2025-09-18 20:45:01 +08:00

104 lines
3.4 KiB
Python

#!/usr/bin/env python3
"""
生成测试数据用于Web界面展示
"""
import sys
from pathlib import Path
from datetime import datetime, date, timedelta
import random
# 添加项目根目录到路径
current_dir = Path(__file__).parent
sys.path.insert(0, str(current_dir))
from loguru import logger
from src.database.database_manager import DatabaseManager
from src.utils.config_loader import ConfigLoader
from src.data.data_fetcher import ADataFetcher
from src.utils.notification import NotificationManager
from src.strategy.kline_pattern_strategy import KLinePatternStrategy
def generate_test_data():
"""生成测试数据"""
logger.remove()
logger.add(sys.stdout, level="INFO", format="{time:HH:mm:ss} | {level} | {message}")
print("🧪 生成测试数据")
print("=" * 40)
try:
# 初始化组件
logger.info("初始化组件...")
config_loader = ConfigLoader()
config = config_loader.load_config()
data_fetcher = ADataFetcher()
notification_manager = NotificationManager(config.get('notification', {}))
db_manager = DatabaseManager()
# 初始化策略
kline_config = config.get('strategy', {}).get('kline_pattern', {})
strategy = KLinePatternStrategy(
data_fetcher=data_fetcher,
notification_manager=notification_manager,
config=kline_config,
db_manager=db_manager
)
# 测试股票列表
test_stocks = ["000001.SZ", "000002.SZ", "600000.SH"]
logger.info(f"开始分析 {len(test_stocks)} 只股票...")
total_signals = 0
for i, stock_code in enumerate(test_stocks, 1):
logger.info(f"[{i}/{len(test_stocks)}] 分析 {stock_code}...")
try:
# 创建会话
session_id = db_manager.create_scan_session(
strategy_id=strategy.strategy_id,
data_source="测试数据生成"
)
# 分析股票
stock_results = strategy.analyze_stock(stock_code, session_id=session_id, days=60)
# 统计信号
stock_signals = sum(len(signals) for signals in stock_results.values())
total_signals += stock_signals
# 更新会话统计
db_manager.update_scan_session_stats(session_id, 1, stock_signals)
logger.info(f" 发现 {stock_signals} 个信号")
except Exception as e:
logger.error(f" 分析失败: {e}")
logger.info(f"✅ 数据生成完成!总共生成 {total_signals} 个信号")
# 验证数据
latest_signals = db_manager.get_latest_signals(limit=10)
logger.info(f"📊 数据库中共有 {len(latest_signals)} 条最新信号")
if not latest_signals.empty:
logger.info("📋 信号示例:")
for _, signal in latest_signals.head(3).iterrows():
logger.info(f" {signal['stock_code']}({signal['stock_name']}) - {signal['breakout_price']:.2f}")
print("\n" + "=" * 40)
print("🌐 现在可以访问Web界面查看数据:")
print(" http://localhost:8080")
print("=" * 40)
except Exception as e:
logger.error(f"❌ 生成测试数据失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
generate_test_data()