#!/usr/bin/env python3 """ SQLite到MySQL数据迁移脚本 将现有的SQLite数据库迁移到MySQL数据库 """ import sys import sqlite3 import pymysql import pandas as pd from pathlib import Path from datetime import datetime, date from loguru import logger # 添加项目根目录到路径 current_dir = Path(__file__).parent sys.path.insert(0, str(current_dir)) from config.mysql_config import MYSQL_CONFIG from src.database.mysql_database_manager import MySQLDatabaseManager class DataMigrator: """数据迁移器""" def __init__(self): self.sqlite_path = current_dir / "data" / "trading.db" self.mysql_config = MYSQL_CONFIG def migrate_all(self): """迁移所有数据""" logger.info("🚀 开始SQLite到MySQL数据迁移...") try: # 1. 初始化MySQL数据库 logger.info("📊 初始化MySQL数据库...") mysql_db = MySQLDatabaseManager() # 2. 迁移策略数据 self.migrate_strategies(mysql_db) # 3. 迁移扫描会话 self.migrate_scan_sessions(mysql_db) # 4. 迁移信号数据 self.migrate_signals(mysql_db) # 5. 迁移回踩提醒 self.migrate_pullback_alerts(mysql_db) # 6. 验证迁移结果 self.verify_migration(mysql_db) logger.info("🎉 数据迁移完成!") except Exception as e: logger.error(f"❌ 数据迁移失败: {e}") raise def migrate_strategies(self, mysql_db): """迁移策略数据""" logger.info("📋 迁移策略数据...") try: with sqlite3.connect(self.sqlite_path) as sqlite_conn: strategies_df = pd.read_sql_query("SELECT * FROM strategies", sqlite_conn) if strategies_df.empty: logger.info("无策略数据需要迁移") return with pymysql.connect(**mysql_db.connection_params) as mysql_conn: cursor = mysql_conn.cursor() for _, strategy in strategies_df.iterrows(): try: cursor.execute(""" INSERT IGNORE INTO strategies (strategy_name, strategy_type, description) VALUES (%s, %s, %s) """, ( strategy['strategy_name'], strategy['strategy_type'], strategy.get('description', '') )) except Exception as e: logger.warning(f"策略迁移警告: {e}") mysql_conn.commit() logger.info(f"✅ 迁移了 {len(strategies_df)} 个策略") except Exception as e: logger.error(f"策略迁移失败: {e}") raise def migrate_scan_sessions(self, mysql_db): """迁移扫描会话""" logger.info("📅 迁移扫描会话数据...") try: with sqlite3.connect(self.sqlite_path) as sqlite_conn: sessions_df = pd.read_sql_query(""" SELECT ss.*, s.strategy_name FROM scan_sessions ss JOIN strategies s ON ss.strategy_id = s.id """, sqlite_conn) if sessions_df.empty: logger.info("无扫描会话数据需要迁移") return with pymysql.connect(**mysql_db.connection_params) as mysql_conn: cursor = mysql_conn.cursor() # 获取MySQL中的策略ID映射 cursor.execute("SELECT id, strategy_name FROM strategies") strategy_mapping = {name: id for id, name in cursor.fetchall()} for _, session in sessions_df.iterrows(): try: mysql_strategy_id = strategy_mapping.get(session['strategy_name']) if mysql_strategy_id is None: logger.warning(f"未找到策略: {session['strategy_name']}") continue cursor.execute(""" INSERT INTO scan_sessions ( strategy_id, scan_date, total_scanned, total_signals, data_source, scan_config, status, created_at ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s) """, ( mysql_strategy_id, session['scan_date'], session.get('total_scanned', 0), session.get('total_signals', 0), session.get('data_source'), session.get('scan_config'), session.get('status', 'completed'), session.get('created_at', datetime.now()) )) except Exception as e: logger.warning(f"会话迁移警告: {e}") mysql_conn.commit() logger.info(f"✅ 迁移了 {len(sessions_df)} 个扫描会话") except Exception as e: logger.error(f"扫描会话迁移失败: {e}") raise def migrate_signals(self, mysql_db): """迁移信号数据""" logger.info("📈 迁移信号数据...") try: with sqlite3.connect(self.sqlite_path) as sqlite_conn: signals_df = pd.read_sql_query(""" SELECT ss.*, st.strategy_name FROM stock_signals ss JOIN strategies st ON ss.strategy_id = st.id ORDER BY ss.created_at DESC LIMIT 1000 """, sqlite_conn) if signals_df.empty: logger.info("无信号数据需要迁移") return with pymysql.connect(**mysql_db.connection_params) as mysql_conn: cursor = mysql_conn.cursor() # 获取MySQL中的映射 cursor.execute("SELECT id, strategy_name FROM strategies") strategy_mapping = {name: id for id, name in cursor.fetchall()} cursor.execute("SELECT id, strategy_id, created_at FROM scan_sessions ORDER BY created_at DESC") session_mapping = {} for session_id, strategy_id, created_at in cursor.fetchall(): session_mapping[(strategy_id, created_at.date())] = session_id migrated_count = 0 for _, signal in signals_df.iterrows(): try: mysql_strategy_id = strategy_mapping.get(signal['strategy_name']) if mysql_strategy_id is None: continue # 尝试找到对应的session_id signal_date = pd.to_datetime(signal['signal_date']).date() mysql_session_id = None # 查找最近的session for (sid, sdate), session_id in session_mapping.items(): if sid == mysql_strategy_id and abs((sdate - signal_date).days) <= 1: mysql_session_id = session_id break # 如果找不到session,创建一个 if mysql_session_id is None: cursor.execute(""" INSERT INTO scan_sessions (strategy_id, scan_date, total_scanned, total_signals, data_source) VALUES (%s, %s, %s, %s, %s) """, (mysql_strategy_id, signal_date, 1, 1, '迁移数据')) mysql_session_id = cursor.lastrowid # 处理NaN值的函数 def clean_value(value): if pd.isna(value): return None return value # 插入信号数据 cursor.execute(""" INSERT INTO stock_signals ( session_id, strategy_id, stock_code, stock_name, timeframe, signal_date, signal_type, breakout_price, yin_high, breakout_amount, breakout_pct, ema20_price, yang1_entity_ratio, yang2_entity_ratio, final_yang_entity_ratio, turnover_ratio, above_ema20, new_high_confirmed, new_high_price, new_high_date, confirmation_date, confirmation_days, pullback_distance, k1_data, k2_data, k3_data, k4_data, created_at ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, ( mysql_session_id, mysql_strategy_id, signal['stock_code'], signal['stock_name'], signal['timeframe'], signal['signal_date'], signal.get('signal_type', '两阳+阴+阳突破'), clean_value(signal.get('breakout_price')), clean_value(signal.get('yin_high')), clean_value(signal.get('breakout_amount')), clean_value(signal.get('breakout_pct')), clean_value(signal.get('ema20_price')), clean_value(signal.get('yang1_entity_ratio')), clean_value(signal.get('yang2_entity_ratio')), clean_value(signal.get('final_yang_entity_ratio')), clean_value(signal.get('turnover_ratio')), signal.get('above_ema20'), signal.get('new_high_confirmed', False), clean_value(signal.get('new_high_price')), signal.get('new_high_date'), signal.get('confirmation_date'), clean_value(signal.get('confirmation_days')), clean_value(signal.get('pullback_distance')), signal.get('k1_data'), signal.get('k2_data'), signal.get('k3_data'), signal.get('k4_data'), signal.get('created_at', datetime.now()) )) migrated_count += 1 except Exception as e: logger.warning(f"信号迁移警告: {signal['stock_code']} - {e}") mysql_conn.commit() logger.info(f"✅ 迁移了 {migrated_count} 条信号") except Exception as e: logger.error(f"信号迁移失败: {e}") raise def migrate_pullback_alerts(self, mysql_db): """迁移回踩提醒""" logger.info("⚠️ 迁移回踩提醒数据...") try: with sqlite3.connect(self.sqlite_path) as sqlite_conn: # 检查表是否存在 cursor = sqlite_conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='pullback_alerts'") if not cursor.fetchone(): logger.info("SQLite中无回踩提醒表,跳过迁移") return alerts_df = pd.read_sql_query("SELECT * FROM pullback_alerts", sqlite_conn) if alerts_df.empty: logger.info("无回踩提醒数据需要迁移") return with pymysql.connect(**mysql_db.connection_params) as mysql_conn: cursor = mysql_conn.cursor() migrated_count = 0 for _, alert in alerts_df.iterrows(): try: cursor.execute(""" INSERT INTO pullback_alerts ( signal_id, stock_code, stock_name, timeframe, original_signal_date, original_breakout_price, yin_high, pullback_date, current_price, current_low, pullback_pct, distance_to_yin_high, days_since_signal, alert_sent, alert_sent_time, created_at ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, ( alert.get('signal_id'), alert['stock_code'], alert['stock_name'], alert['timeframe'], alert.get('original_signal_date'), alert.get('original_breakout_price'), alert.get('yin_high'), alert['pullback_date'], alert.get('current_price'), alert.get('current_low'), alert.get('pullback_pct'), alert.get('distance_to_yin_high'), alert.get('days_since_signal'), alert.get('alert_sent', True), alert.get('alert_sent_time'), alert.get('created_at', datetime.now()) )) migrated_count += 1 except Exception as e: logger.warning(f"回踩提醒迁移警告: {alert['stock_code']} - {e}") mysql_conn.commit() logger.info(f"✅ 迁移了 {migrated_count} 条回踩提醒") except Exception as e: logger.error(f"回踩提醒迁移失败: {e}") raise def verify_migration(self, mysql_db): """验证迁移结果""" logger.info("🔍 验证迁移结果...") try: with pymysql.connect(**mysql_db.connection_params) as mysql_conn: cursor = mysql_conn.cursor() # 统计各表数据量 tables = ['strategies', 'scan_sessions', 'stock_signals', 'pullback_alerts'] for table in tables: cursor.execute(f"SELECT COUNT(*) FROM {table}") count = cursor.fetchone()[0] logger.info(f"📊 {table}: {count} 条记录") # 检查最新信号 cursor.execute("SELECT COUNT(*) FROM latest_signals_view WHERE new_high_confirmed = 1") confirmed_signals = cursor.fetchone()[0] logger.info(f"🎯 确认信号: {confirmed_signals} 条") # 检查视图 cursor.execute("SELECT COUNT(*) FROM strategy_stats_view") stats_count = cursor.fetchone()[0] logger.info(f"📈 策略统计: {stats_count} 条") logger.info("✅ 数据迁移验证完成") except Exception as e: logger.error(f"验证迁移结果失败: {e}") raise def main(): """主函数""" logger.info("🚀 开始SQLite到MySQL数据迁移...") try: # 检查依赖 try: import pymysql except ImportError: logger.error("❌ 请先安装pymysql: pip install pymysql") return # 执行迁移 migrator = DataMigrator() migrator.migrate_all() print("\n" + "="*70) print("🎉 MySQL数据库迁移完成!") print("="*70) print("\n✅ 迁移内容:") print(" - 策略配置") print(" - 扫描会话") print(" - 股票信号(包含创新高回踩确认字段)") print(" - 回踩提醒") print(" - 数据库视图") print("\n🌐 MySQL配置:") print(f" - 主机: {MYSQL_CONFIG.host}") print(f" - 端口: {MYSQL_CONFIG.port}") print(f" - 数据库: {MYSQL_CONFIG.database}") print("\n📝 下一步:") print(" 1. 更新系统配置使用MySQL数据库") print(" 2. 测试Web界面和API功能") print(" 3. 验证所有功能正常工作") except Exception as e: logger.error(f"❌ 迁移失败: {e}") sys.exit(1) if __name__ == "__main__": main()