387 lines
16 KiB
Python
387 lines
16 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
SQLite到MySQL数据迁移脚本
|
||
将现有的SQLite数据库迁移到MySQL数据库
|
||
"""
|
||
|
||
import sys
|
||
import sqlite3
|
||
import pymysql
|
||
import pandas as pd
|
||
from pathlib import Path
|
||
from datetime import datetime, date
|
||
from loguru import logger
|
||
|
||
# 添加项目根目录到路径
|
||
current_dir = Path(__file__).parent
|
||
sys.path.insert(0, str(current_dir))
|
||
|
||
from config.mysql_config import MYSQL_CONFIG
|
||
from src.database.mysql_database_manager import MySQLDatabaseManager
|
||
|
||
|
||
class DataMigrator:
|
||
"""数据迁移器"""
|
||
|
||
def __init__(self):
|
||
self.sqlite_path = current_dir / "data" / "trading.db"
|
||
self.mysql_config = MYSQL_CONFIG
|
||
|
||
def migrate_all(self):
|
||
"""迁移所有数据"""
|
||
logger.info("🚀 开始SQLite到MySQL数据迁移...")
|
||
|
||
try:
|
||
# 1. 初始化MySQL数据库
|
||
logger.info("📊 初始化MySQL数据库...")
|
||
mysql_db = MySQLDatabaseManager()
|
||
|
||
# 2. 迁移策略数据
|
||
self.migrate_strategies(mysql_db)
|
||
|
||
# 3. 迁移扫描会话
|
||
self.migrate_scan_sessions(mysql_db)
|
||
|
||
# 4. 迁移信号数据
|
||
self.migrate_signals(mysql_db)
|
||
|
||
# 5. 迁移回踩提醒
|
||
self.migrate_pullback_alerts(mysql_db)
|
||
|
||
# 6. 验证迁移结果
|
||
self.verify_migration(mysql_db)
|
||
|
||
logger.info("🎉 数据迁移完成!")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 数据迁移失败: {e}")
|
||
raise
|
||
|
||
def migrate_strategies(self, mysql_db):
|
||
"""迁移策略数据"""
|
||
logger.info("📋 迁移策略数据...")
|
||
|
||
try:
|
||
with sqlite3.connect(self.sqlite_path) as sqlite_conn:
|
||
strategies_df = pd.read_sql_query("SELECT * FROM strategies", sqlite_conn)
|
||
|
||
if strategies_df.empty:
|
||
logger.info("无策略数据需要迁移")
|
||
return
|
||
|
||
with pymysql.connect(**mysql_db.connection_params) as mysql_conn:
|
||
cursor = mysql_conn.cursor()
|
||
|
||
for _, strategy in strategies_df.iterrows():
|
||
try:
|
||
cursor.execute("""
|
||
INSERT IGNORE INTO strategies (strategy_name, strategy_type, description)
|
||
VALUES (%s, %s, %s)
|
||
""", (
|
||
strategy['strategy_name'],
|
||
strategy['strategy_type'],
|
||
strategy.get('description', '')
|
||
))
|
||
except Exception as e:
|
||
logger.warning(f"策略迁移警告: {e}")
|
||
|
||
mysql_conn.commit()
|
||
|
||
logger.info(f"✅ 迁移了 {len(strategies_df)} 个策略")
|
||
|
||
except Exception as e:
|
||
logger.error(f"策略迁移失败: {e}")
|
||
raise
|
||
|
||
def migrate_scan_sessions(self, mysql_db):
|
||
"""迁移扫描会话"""
|
||
logger.info("📅 迁移扫描会话数据...")
|
||
|
||
try:
|
||
with sqlite3.connect(self.sqlite_path) as sqlite_conn:
|
||
sessions_df = pd.read_sql_query("""
|
||
SELECT ss.*, s.strategy_name
|
||
FROM scan_sessions ss
|
||
JOIN strategies s ON ss.strategy_id = s.id
|
||
""", sqlite_conn)
|
||
|
||
if sessions_df.empty:
|
||
logger.info("无扫描会话数据需要迁移")
|
||
return
|
||
|
||
with pymysql.connect(**mysql_db.connection_params) as mysql_conn:
|
||
cursor = mysql_conn.cursor()
|
||
|
||
# 获取MySQL中的策略ID映射
|
||
cursor.execute("SELECT id, strategy_name FROM strategies")
|
||
strategy_mapping = {name: id for id, name in cursor.fetchall()}
|
||
|
||
for _, session in sessions_df.iterrows():
|
||
try:
|
||
mysql_strategy_id = strategy_mapping.get(session['strategy_name'])
|
||
if mysql_strategy_id is None:
|
||
logger.warning(f"未找到策略: {session['strategy_name']}")
|
||
continue
|
||
|
||
cursor.execute("""
|
||
INSERT INTO scan_sessions (
|
||
strategy_id, scan_date, total_scanned, total_signals,
|
||
data_source, scan_config, status, created_at
|
||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
||
""", (
|
||
mysql_strategy_id,
|
||
session['scan_date'],
|
||
session.get('total_scanned', 0),
|
||
session.get('total_signals', 0),
|
||
session.get('data_source'),
|
||
session.get('scan_config'),
|
||
session.get('status', 'completed'),
|
||
session.get('created_at', datetime.now())
|
||
))
|
||
|
||
except Exception as e:
|
||
logger.warning(f"会话迁移警告: {e}")
|
||
|
||
mysql_conn.commit()
|
||
|
||
logger.info(f"✅ 迁移了 {len(sessions_df)} 个扫描会话")
|
||
|
||
except Exception as e:
|
||
logger.error(f"扫描会话迁移失败: {e}")
|
||
raise
|
||
|
||
def migrate_signals(self, mysql_db):
|
||
"""迁移信号数据"""
|
||
logger.info("📈 迁移信号数据...")
|
||
|
||
try:
|
||
with sqlite3.connect(self.sqlite_path) as sqlite_conn:
|
||
signals_df = pd.read_sql_query("""
|
||
SELECT ss.*, st.strategy_name
|
||
FROM stock_signals ss
|
||
JOIN strategies st ON ss.strategy_id = st.id
|
||
ORDER BY ss.created_at DESC
|
||
LIMIT 1000
|
||
""", sqlite_conn)
|
||
|
||
if signals_df.empty:
|
||
logger.info("无信号数据需要迁移")
|
||
return
|
||
|
||
with pymysql.connect(**mysql_db.connection_params) as mysql_conn:
|
||
cursor = mysql_conn.cursor()
|
||
|
||
# 获取MySQL中的映射
|
||
cursor.execute("SELECT id, strategy_name FROM strategies")
|
||
strategy_mapping = {name: id for id, name in cursor.fetchall()}
|
||
|
||
cursor.execute("SELECT id, strategy_id, created_at FROM scan_sessions ORDER BY created_at DESC")
|
||
session_mapping = {}
|
||
for session_id, strategy_id, created_at in cursor.fetchall():
|
||
session_mapping[(strategy_id, created_at.date())] = session_id
|
||
|
||
migrated_count = 0
|
||
|
||
for _, signal in signals_df.iterrows():
|
||
try:
|
||
mysql_strategy_id = strategy_mapping.get(signal['strategy_name'])
|
||
if mysql_strategy_id is None:
|
||
continue
|
||
|
||
# 尝试找到对应的session_id
|
||
signal_date = pd.to_datetime(signal['signal_date']).date()
|
||
mysql_session_id = None
|
||
|
||
# 查找最近的session
|
||
for (sid, sdate), session_id in session_mapping.items():
|
||
if sid == mysql_strategy_id and abs((sdate - signal_date).days) <= 1:
|
||
mysql_session_id = session_id
|
||
break
|
||
|
||
# 如果找不到session,创建一个
|
||
if mysql_session_id is None:
|
||
cursor.execute("""
|
||
INSERT INTO scan_sessions (strategy_id, scan_date, total_scanned, total_signals, data_source)
|
||
VALUES (%s, %s, %s, %s, %s)
|
||
""", (mysql_strategy_id, signal_date, 1, 1, '迁移数据'))
|
||
mysql_session_id = cursor.lastrowid
|
||
|
||
# 处理NaN值的函数
|
||
def clean_value(value):
|
||
if pd.isna(value):
|
||
return None
|
||
return value
|
||
|
||
# 插入信号数据
|
||
cursor.execute("""
|
||
INSERT INTO stock_signals (
|
||
session_id, strategy_id, stock_code, stock_name, timeframe,
|
||
signal_date, signal_type, breakout_price, yin_high, breakout_amount,
|
||
breakout_pct, ema20_price, yang1_entity_ratio, yang2_entity_ratio,
|
||
final_yang_entity_ratio, turnover_ratio, above_ema20,
|
||
new_high_confirmed, new_high_price, new_high_date, confirmation_date,
|
||
confirmation_days, pullback_distance,
|
||
k1_data, k2_data, k3_data, k4_data, created_at
|
||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||
""", (
|
||
mysql_session_id, mysql_strategy_id,
|
||
signal['stock_code'], signal['stock_name'], signal['timeframe'],
|
||
signal['signal_date'], signal.get('signal_type', '两阳+阴+阳突破'),
|
||
clean_value(signal.get('breakout_price')), clean_value(signal.get('yin_high')),
|
||
clean_value(signal.get('breakout_amount')),
|
||
clean_value(signal.get('breakout_pct')), clean_value(signal.get('ema20_price')),
|
||
clean_value(signal.get('yang1_entity_ratio')), clean_value(signal.get('yang2_entity_ratio')),
|
||
clean_value(signal.get('final_yang_entity_ratio')), clean_value(signal.get('turnover_ratio')),
|
||
signal.get('above_ema20'),
|
||
signal.get('new_high_confirmed', False), clean_value(signal.get('new_high_price')),
|
||
signal.get('new_high_date'), signal.get('confirmation_date'),
|
||
clean_value(signal.get('confirmation_days')), clean_value(signal.get('pullback_distance')),
|
||
signal.get('k1_data'), signal.get('k2_data'), signal.get('k3_data'), signal.get('k4_data'),
|
||
signal.get('created_at', datetime.now())
|
||
))
|
||
|
||
migrated_count += 1
|
||
|
||
except Exception as e:
|
||
logger.warning(f"信号迁移警告: {signal['stock_code']} - {e}")
|
||
|
||
mysql_conn.commit()
|
||
|
||
logger.info(f"✅ 迁移了 {migrated_count} 条信号")
|
||
|
||
except Exception as e:
|
||
logger.error(f"信号迁移失败: {e}")
|
||
raise
|
||
|
||
def migrate_pullback_alerts(self, mysql_db):
|
||
"""迁移回踩提醒"""
|
||
logger.info("⚠️ 迁移回踩提醒数据...")
|
||
|
||
try:
|
||
with sqlite3.connect(self.sqlite_path) as sqlite_conn:
|
||
# 检查表是否存在
|
||
cursor = sqlite_conn.cursor()
|
||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='pullback_alerts'")
|
||
if not cursor.fetchone():
|
||
logger.info("SQLite中无回踩提醒表,跳过迁移")
|
||
return
|
||
|
||
alerts_df = pd.read_sql_query("SELECT * FROM pullback_alerts", sqlite_conn)
|
||
|
||
if alerts_df.empty:
|
||
logger.info("无回踩提醒数据需要迁移")
|
||
return
|
||
|
||
with pymysql.connect(**mysql_db.connection_params) as mysql_conn:
|
||
cursor = mysql_conn.cursor()
|
||
|
||
migrated_count = 0
|
||
|
||
for _, alert in alerts_df.iterrows():
|
||
try:
|
||
cursor.execute("""
|
||
INSERT INTO pullback_alerts (
|
||
signal_id, stock_code, stock_name, timeframe,
|
||
original_signal_date, original_breakout_price, yin_high,
|
||
pullback_date, current_price, current_low,
|
||
pullback_pct, distance_to_yin_high, days_since_signal,
|
||
alert_sent, alert_sent_time, created_at
|
||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||
""", (
|
||
alert.get('signal_id'), alert['stock_code'], alert['stock_name'],
|
||
alert['timeframe'], alert.get('original_signal_date'),
|
||
alert.get('original_breakout_price'), alert.get('yin_high'),
|
||
alert['pullback_date'], alert.get('current_price'), alert.get('current_low'),
|
||
alert.get('pullback_pct'), alert.get('distance_to_yin_high'),
|
||
alert.get('days_since_signal'), alert.get('alert_sent', True),
|
||
alert.get('alert_sent_time'), alert.get('created_at', datetime.now())
|
||
))
|
||
|
||
migrated_count += 1
|
||
|
||
except Exception as e:
|
||
logger.warning(f"回踩提醒迁移警告: {alert['stock_code']} - {e}")
|
||
|
||
mysql_conn.commit()
|
||
|
||
logger.info(f"✅ 迁移了 {migrated_count} 条回踩提醒")
|
||
|
||
except Exception as e:
|
||
logger.error(f"回踩提醒迁移失败: {e}")
|
||
raise
|
||
|
||
def verify_migration(self, mysql_db):
|
||
"""验证迁移结果"""
|
||
logger.info("🔍 验证迁移结果...")
|
||
|
||
try:
|
||
with pymysql.connect(**mysql_db.connection_params) as mysql_conn:
|
||
cursor = mysql_conn.cursor()
|
||
|
||
# 统计各表数据量
|
||
tables = ['strategies', 'scan_sessions', 'stock_signals', 'pullback_alerts']
|
||
for table in tables:
|
||
cursor.execute(f"SELECT COUNT(*) FROM {table}")
|
||
count = cursor.fetchone()[0]
|
||
logger.info(f"📊 {table}: {count} 条记录")
|
||
|
||
# 检查最新信号
|
||
cursor.execute("SELECT COUNT(*) FROM latest_signals_view WHERE new_high_confirmed = 1")
|
||
confirmed_signals = cursor.fetchone()[0]
|
||
logger.info(f"🎯 确认信号: {confirmed_signals} 条")
|
||
|
||
# 检查视图
|
||
cursor.execute("SELECT COUNT(*) FROM strategy_stats_view")
|
||
stats_count = cursor.fetchone()[0]
|
||
logger.info(f"📈 策略统计: {stats_count} 条")
|
||
|
||
logger.info("✅ 数据迁移验证完成")
|
||
|
||
except Exception as e:
|
||
logger.error(f"验证迁移结果失败: {e}")
|
||
raise
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
logger.info("🚀 开始SQLite到MySQL数据迁移...")
|
||
|
||
try:
|
||
# 检查依赖
|
||
try:
|
||
import pymysql
|
||
except ImportError:
|
||
logger.error("❌ 请先安装pymysql: pip install pymysql")
|
||
return
|
||
|
||
# 执行迁移
|
||
migrator = DataMigrator()
|
||
migrator.migrate_all()
|
||
|
||
print("\n" + "="*70)
|
||
print("🎉 MySQL数据库迁移完成!")
|
||
print("="*70)
|
||
print("\n✅ 迁移内容:")
|
||
print(" - 策略配置")
|
||
print(" - 扫描会话")
|
||
print(" - 股票信号(包含创新高回踩确认字段)")
|
||
print(" - 回踩提醒")
|
||
print(" - 数据库视图")
|
||
|
||
print("\n🌐 MySQL配置:")
|
||
print(f" - 主机: {MYSQL_CONFIG.host}")
|
||
print(f" - 端口: {MYSQL_CONFIG.port}")
|
||
print(f" - 数据库: {MYSQL_CONFIG.database}")
|
||
|
||
print("\n📝 下一步:")
|
||
print(" 1. 更新系统配置使用MySQL数据库")
|
||
print(" 2. 测试Web界面和API功能")
|
||
print(" 3. 验证所有功能正常工作")
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 迁移失败: {e}")
|
||
sys.exit(1)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |