trading.ai/migrate_to_mysql.py
2025-09-23 16:12:18 +08:00

387 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
SQLite到MySQL数据迁移脚本
将现有的SQLite数据库迁移到MySQL数据库
"""
import sys
import sqlite3
import pymysql
import pandas as pd
from pathlib import Path
from datetime import datetime, date
from loguru import logger
# 添加项目根目录到路径
current_dir = Path(__file__).parent
sys.path.insert(0, str(current_dir))
from config.mysql_config import MYSQL_CONFIG
from src.database.mysql_database_manager import MySQLDatabaseManager
class DataMigrator:
"""数据迁移器"""
def __init__(self):
self.sqlite_path = current_dir / "data" / "trading.db"
self.mysql_config = MYSQL_CONFIG
def migrate_all(self):
"""迁移所有数据"""
logger.info("🚀 开始SQLite到MySQL数据迁移...")
try:
# 1. 初始化MySQL数据库
logger.info("📊 初始化MySQL数据库...")
mysql_db = MySQLDatabaseManager()
# 2. 迁移策略数据
self.migrate_strategies(mysql_db)
# 3. 迁移扫描会话
self.migrate_scan_sessions(mysql_db)
# 4. 迁移信号数据
self.migrate_signals(mysql_db)
# 5. 迁移回踩提醒
self.migrate_pullback_alerts(mysql_db)
# 6. 验证迁移结果
self.verify_migration(mysql_db)
logger.info("🎉 数据迁移完成!")
except Exception as e:
logger.error(f"❌ 数据迁移失败: {e}")
raise
def migrate_strategies(self, mysql_db):
"""迁移策略数据"""
logger.info("📋 迁移策略数据...")
try:
with sqlite3.connect(self.sqlite_path) as sqlite_conn:
strategies_df = pd.read_sql_query("SELECT * FROM strategies", sqlite_conn)
if strategies_df.empty:
logger.info("无策略数据需要迁移")
return
with pymysql.connect(**mysql_db.connection_params) as mysql_conn:
cursor = mysql_conn.cursor()
for _, strategy in strategies_df.iterrows():
try:
cursor.execute("""
INSERT IGNORE INTO strategies (strategy_name, strategy_type, description)
VALUES (%s, %s, %s)
""", (
strategy['strategy_name'],
strategy['strategy_type'],
strategy.get('description', '')
))
except Exception as e:
logger.warning(f"策略迁移警告: {e}")
mysql_conn.commit()
logger.info(f"✅ 迁移了 {len(strategies_df)} 个策略")
except Exception as e:
logger.error(f"策略迁移失败: {e}")
raise
def migrate_scan_sessions(self, mysql_db):
"""迁移扫描会话"""
logger.info("📅 迁移扫描会话数据...")
try:
with sqlite3.connect(self.sqlite_path) as sqlite_conn:
sessions_df = pd.read_sql_query("""
SELECT ss.*, s.strategy_name
FROM scan_sessions ss
JOIN strategies s ON ss.strategy_id = s.id
""", sqlite_conn)
if sessions_df.empty:
logger.info("无扫描会话数据需要迁移")
return
with pymysql.connect(**mysql_db.connection_params) as mysql_conn:
cursor = mysql_conn.cursor()
# 获取MySQL中的策略ID映射
cursor.execute("SELECT id, strategy_name FROM strategies")
strategy_mapping = {name: id for id, name in cursor.fetchall()}
for _, session in sessions_df.iterrows():
try:
mysql_strategy_id = strategy_mapping.get(session['strategy_name'])
if mysql_strategy_id is None:
logger.warning(f"未找到策略: {session['strategy_name']}")
continue
cursor.execute("""
INSERT INTO scan_sessions (
strategy_id, scan_date, total_scanned, total_signals,
data_source, scan_config, status, created_at
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
""", (
mysql_strategy_id,
session['scan_date'],
session.get('total_scanned', 0),
session.get('total_signals', 0),
session.get('data_source'),
session.get('scan_config'),
session.get('status', 'completed'),
session.get('created_at', datetime.now())
))
except Exception as e:
logger.warning(f"会话迁移警告: {e}")
mysql_conn.commit()
logger.info(f"✅ 迁移了 {len(sessions_df)} 个扫描会话")
except Exception as e:
logger.error(f"扫描会话迁移失败: {e}")
raise
def migrate_signals(self, mysql_db):
"""迁移信号数据"""
logger.info("📈 迁移信号数据...")
try:
with sqlite3.connect(self.sqlite_path) as sqlite_conn:
signals_df = pd.read_sql_query("""
SELECT ss.*, st.strategy_name
FROM stock_signals ss
JOIN strategies st ON ss.strategy_id = st.id
ORDER BY ss.created_at DESC
LIMIT 1000
""", sqlite_conn)
if signals_df.empty:
logger.info("无信号数据需要迁移")
return
with pymysql.connect(**mysql_db.connection_params) as mysql_conn:
cursor = mysql_conn.cursor()
# 获取MySQL中的映射
cursor.execute("SELECT id, strategy_name FROM strategies")
strategy_mapping = {name: id for id, name in cursor.fetchall()}
cursor.execute("SELECT id, strategy_id, created_at FROM scan_sessions ORDER BY created_at DESC")
session_mapping = {}
for session_id, strategy_id, created_at in cursor.fetchall():
session_mapping[(strategy_id, created_at.date())] = session_id
migrated_count = 0
for _, signal in signals_df.iterrows():
try:
mysql_strategy_id = strategy_mapping.get(signal['strategy_name'])
if mysql_strategy_id is None:
continue
# 尝试找到对应的session_id
signal_date = pd.to_datetime(signal['signal_date']).date()
mysql_session_id = None
# 查找最近的session
for (sid, sdate), session_id in session_mapping.items():
if sid == mysql_strategy_id and abs((sdate - signal_date).days) <= 1:
mysql_session_id = session_id
break
# 如果找不到session创建一个
if mysql_session_id is None:
cursor.execute("""
INSERT INTO scan_sessions (strategy_id, scan_date, total_scanned, total_signals, data_source)
VALUES (%s, %s, %s, %s, %s)
""", (mysql_strategy_id, signal_date, 1, 1, '迁移数据'))
mysql_session_id = cursor.lastrowid
# 处理NaN值的函数
def clean_value(value):
if pd.isna(value):
return None
return value
# 插入信号数据
cursor.execute("""
INSERT INTO stock_signals (
session_id, strategy_id, stock_code, stock_name, timeframe,
signal_date, signal_type, breakout_price, yin_high, breakout_amount,
breakout_pct, ema20_price, yang1_entity_ratio, yang2_entity_ratio,
final_yang_entity_ratio, turnover_ratio, above_ema20,
new_high_confirmed, new_high_price, new_high_date, confirmation_date,
confirmation_days, pullback_distance,
k1_data, k2_data, k3_data, k4_data, created_at
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", (
mysql_session_id, mysql_strategy_id,
signal['stock_code'], signal['stock_name'], signal['timeframe'],
signal['signal_date'], signal.get('signal_type', '两阳+阴+阳突破'),
clean_value(signal.get('breakout_price')), clean_value(signal.get('yin_high')),
clean_value(signal.get('breakout_amount')),
clean_value(signal.get('breakout_pct')), clean_value(signal.get('ema20_price')),
clean_value(signal.get('yang1_entity_ratio')), clean_value(signal.get('yang2_entity_ratio')),
clean_value(signal.get('final_yang_entity_ratio')), clean_value(signal.get('turnover_ratio')),
signal.get('above_ema20'),
signal.get('new_high_confirmed', False), clean_value(signal.get('new_high_price')),
signal.get('new_high_date'), signal.get('confirmation_date'),
clean_value(signal.get('confirmation_days')), clean_value(signal.get('pullback_distance')),
signal.get('k1_data'), signal.get('k2_data'), signal.get('k3_data'), signal.get('k4_data'),
signal.get('created_at', datetime.now())
))
migrated_count += 1
except Exception as e:
logger.warning(f"信号迁移警告: {signal['stock_code']} - {e}")
mysql_conn.commit()
logger.info(f"✅ 迁移了 {migrated_count} 条信号")
except Exception as e:
logger.error(f"信号迁移失败: {e}")
raise
def migrate_pullback_alerts(self, mysql_db):
"""迁移回踩提醒"""
logger.info("⚠️ 迁移回踩提醒数据...")
try:
with sqlite3.connect(self.sqlite_path) as sqlite_conn:
# 检查表是否存在
cursor = sqlite_conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='pullback_alerts'")
if not cursor.fetchone():
logger.info("SQLite中无回踩提醒表跳过迁移")
return
alerts_df = pd.read_sql_query("SELECT * FROM pullback_alerts", sqlite_conn)
if alerts_df.empty:
logger.info("无回踩提醒数据需要迁移")
return
with pymysql.connect(**mysql_db.connection_params) as mysql_conn:
cursor = mysql_conn.cursor()
migrated_count = 0
for _, alert in alerts_df.iterrows():
try:
cursor.execute("""
INSERT INTO pullback_alerts (
signal_id, stock_code, stock_name, timeframe,
original_signal_date, original_breakout_price, yin_high,
pullback_date, current_price, current_low,
pullback_pct, distance_to_yin_high, days_since_signal,
alert_sent, alert_sent_time, created_at
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", (
alert.get('signal_id'), alert['stock_code'], alert['stock_name'],
alert['timeframe'], alert.get('original_signal_date'),
alert.get('original_breakout_price'), alert.get('yin_high'),
alert['pullback_date'], alert.get('current_price'), alert.get('current_low'),
alert.get('pullback_pct'), alert.get('distance_to_yin_high'),
alert.get('days_since_signal'), alert.get('alert_sent', True),
alert.get('alert_sent_time'), alert.get('created_at', datetime.now())
))
migrated_count += 1
except Exception as e:
logger.warning(f"回踩提醒迁移警告: {alert['stock_code']} - {e}")
mysql_conn.commit()
logger.info(f"✅ 迁移了 {migrated_count} 条回踩提醒")
except Exception as e:
logger.error(f"回踩提醒迁移失败: {e}")
raise
def verify_migration(self, mysql_db):
"""验证迁移结果"""
logger.info("🔍 验证迁移结果...")
try:
with pymysql.connect(**mysql_db.connection_params) as mysql_conn:
cursor = mysql_conn.cursor()
# 统计各表数据量
tables = ['strategies', 'scan_sessions', 'stock_signals', 'pullback_alerts']
for table in tables:
cursor.execute(f"SELECT COUNT(*) FROM {table}")
count = cursor.fetchone()[0]
logger.info(f"📊 {table}: {count} 条记录")
# 检查最新信号
cursor.execute("SELECT COUNT(*) FROM latest_signals_view WHERE new_high_confirmed = 1")
confirmed_signals = cursor.fetchone()[0]
logger.info(f"🎯 确认信号: {confirmed_signals}")
# 检查视图
cursor.execute("SELECT COUNT(*) FROM strategy_stats_view")
stats_count = cursor.fetchone()[0]
logger.info(f"📈 策略统计: {stats_count}")
logger.info("✅ 数据迁移验证完成")
except Exception as e:
logger.error(f"验证迁移结果失败: {e}")
raise
def main():
"""主函数"""
logger.info("🚀 开始SQLite到MySQL数据迁移...")
try:
# 检查依赖
try:
import pymysql
except ImportError:
logger.error("❌ 请先安装pymysql: pip install pymysql")
return
# 执行迁移
migrator = DataMigrator()
migrator.migrate_all()
print("\n" + "="*70)
print("🎉 MySQL数据库迁移完成!")
print("="*70)
print("\n✅ 迁移内容:")
print(" - 策略配置")
print(" - 扫描会话")
print(" - 股票信号(包含创新高回踩确认字段)")
print(" - 回踩提醒")
print(" - 数据库视图")
print("\n🌐 MySQL配置:")
print(f" - 主机: {MYSQL_CONFIG.host}")
print(f" - 端口: {MYSQL_CONFIG.port}")
print(f" - 数据库: {MYSQL_CONFIG.database}")
print("\n📝 下一步:")
print(" 1. 更新系统配置使用MySQL数据库")
print(" 2. 测试Web界面和API功能")
print(" 3. 验证所有功能正常工作")
except Exception as e:
logger.error(f"❌ 迁移失败: {e}")
sys.exit(1)
if __name__ == "__main__":
main()