stock-ai-agent/backend/migrate_db.py
2026-02-19 19:32:46 +08:00

150 lines
4.0 KiB
Python

"""
数据库迁移脚本 - 添加移动止损字段
用于为已有的 paper_trading 表添加新字段
"""
import sqlite3
import os
from pathlib import Path
def migrate_database():
"""执行数据库迁移"""
# 查找数据库文件
db_paths = [
"stock_agent.db",
"backend/stock_agent.db",
"../stock_agent.db"
]
db_path = None
for path in db_paths:
if os.path.exists(path):
db_path = path
break
if not db_path:
print("❌ 未找到数据库文件 stock_agent.db")
print("请确保在项目根目录或 backend 目录下运行此脚本")
return False
print(f"📁 找到数据库文件: {db_path}")
# 连接数据库
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# 检查表是否存在
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='paper_orders'")
if not cursor.fetchone():
print("⚠️ paper_orders 表不存在,将在首次启动时自动创建")
return False
# 检查字段是否已存在
cursor.execute("PRAGMA table_info(paper_orders)")
columns = [column[1] for column in cursor.fetchall()]
# 需要添加的新字段
new_columns = {
'trailing_stop_triggered': 'INTEGER DEFAULT 0',
'trailing_stop_base_profit': 'REAL DEFAULT 0'
}
columns_to_add = []
for col_name, col_type in new_columns.items():
if col_name not in columns:
columns_to_add.append((col_name, col_type))
if not columns_to_add:
print("✅ 数据库已是最新版本,无需迁移")
return True
# 执行迁移
print(f"📝 开始迁移,将添加 {len(columns_to_add)} 个新字段...")
for col_name, col_type in columns_to_add:
try:
sql = f"ALTER TABLE paper_orders ADD COLUMN {col_name} {col_type}"
cursor.execute(sql)
print(f" ✅ 添加字段: {col_name}")
except sqlite3.OperationalError as e:
print(f" ⚠️ 添加字段 {col_name} 失败: {e}")
# 提交更改
conn.commit()
print("✅ 数据库迁移完成!")
return True
except sqlite3.Error as e:
print(f"❌ 数据库错误: {e}")
return False
finally:
if conn:
conn.close()
def verify_migration():
"""验证迁移结果"""
db_paths = [
"stock_agent.db",
"backend/stock_agent.db",
"../stock_agent.db"
]
db_path = None
for path in db_paths:
if os.path.exists(path):
db_path = path
break
if not db_path:
return
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(paper_orders)")
columns = {column[1]: column[2] for column in cursor.fetchall()}
print("\n📋 paper_orders 表字段列表:")
print("-" * 60)
required_fields = ['trailing_stop_triggered', 'trailing_stop_base_profit']
all_present = True
for field in required_fields:
if field in columns:
print(f"{field}: {columns[field]}")
else:
print(f"{field}: 缺失")
all_present = False
if all_present:
print("\n✅ 所有必需字段都已存在!")
else:
print("\n⚠️ 部分字段缺失,请检查迁移结果")
except sqlite3.Error as e:
print(f"❌ 验证失败: {e}")
finally:
if conn:
conn.close()
if __name__ == "__main__":
print("=" * 60)
print("🔄 Stock Agent 数据库迁移工具")
print("=" * 60)
print()
success = migrate_database()
if success:
verify_migration()
else:
print("\n💡 如果数据库文件不存在,启动服务时会自动创建")
print()
print("=" * 60)