150 lines
4.0 KiB
Python
150 lines
4.0 KiB
Python
"""
|
|
数据库迁移脚本 - 添加移动止损字段
|
|
用于为已有的 paper_trading 表添加新字段
|
|
"""
|
|
import sqlite3
|
|
import os
|
|
from pathlib import Path
|
|
|
|
def migrate_database():
|
|
"""执行数据库迁移"""
|
|
|
|
# 查找数据库文件
|
|
db_paths = [
|
|
"stock_agent.db",
|
|
"backend/stock_agent.db",
|
|
"../stock_agent.db"
|
|
]
|
|
|
|
db_path = None
|
|
for path in db_paths:
|
|
if os.path.exists(path):
|
|
db_path = path
|
|
break
|
|
|
|
if not db_path:
|
|
print("❌ 未找到数据库文件 stock_agent.db")
|
|
print("请确保在项目根目录或 backend 目录下运行此脚本")
|
|
return False
|
|
|
|
print(f"📁 找到数据库文件: {db_path}")
|
|
|
|
# 连接数据库
|
|
try:
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# 检查表是否存在
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='paper_orders'")
|
|
if not cursor.fetchone():
|
|
print("⚠️ paper_orders 表不存在,将在首次启动时自动创建")
|
|
return False
|
|
|
|
# 检查字段是否已存在
|
|
cursor.execute("PRAGMA table_info(paper_orders)")
|
|
columns = [column[1] for column in cursor.fetchall()]
|
|
|
|
# 需要添加的新字段
|
|
new_columns = {
|
|
'trailing_stop_triggered': 'INTEGER DEFAULT 0',
|
|
'trailing_stop_base_profit': 'REAL DEFAULT 0'
|
|
}
|
|
|
|
columns_to_add = []
|
|
for col_name, col_type in new_columns.items():
|
|
if col_name not in columns:
|
|
columns_to_add.append((col_name, col_type))
|
|
|
|
if not columns_to_add:
|
|
print("✅ 数据库已是最新版本,无需迁移")
|
|
return True
|
|
|
|
# 执行迁移
|
|
print(f"📝 开始迁移,将添加 {len(columns_to_add)} 个新字段...")
|
|
|
|
for col_name, col_type in columns_to_add:
|
|
try:
|
|
sql = f"ALTER TABLE paper_orders ADD COLUMN {col_name} {col_type}"
|
|
cursor.execute(sql)
|
|
print(f" ✅ 添加字段: {col_name}")
|
|
except sqlite3.OperationalError as e:
|
|
print(f" ⚠️ 添加字段 {col_name} 失败: {e}")
|
|
|
|
# 提交更改
|
|
conn.commit()
|
|
print("✅ 数据库迁移完成!")
|
|
return True
|
|
|
|
except sqlite3.Error as e:
|
|
print(f"❌ 数据库错误: {e}")
|
|
return False
|
|
finally:
|
|
if conn:
|
|
conn.close()
|
|
|
|
|
|
def verify_migration():
|
|
"""验证迁移结果"""
|
|
db_paths = [
|
|
"stock_agent.db",
|
|
"backend/stock_agent.db",
|
|
"../stock_agent.db"
|
|
]
|
|
|
|
db_path = None
|
|
for path in db_paths:
|
|
if os.path.exists(path):
|
|
db_path = path
|
|
break
|
|
|
|
if not db_path:
|
|
return
|
|
|
|
try:
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("PRAGMA table_info(paper_orders)")
|
|
columns = {column[1]: column[2] for column in cursor.fetchall()}
|
|
|
|
print("\n📋 paper_orders 表字段列表:")
|
|
print("-" * 60)
|
|
|
|
required_fields = ['trailing_stop_triggered', 'trailing_stop_base_profit']
|
|
all_present = True
|
|
|
|
for field in required_fields:
|
|
if field in columns:
|
|
print(f" ✅ {field}: {columns[field]}")
|
|
else:
|
|
print(f" ❌ {field}: 缺失")
|
|
all_present = False
|
|
|
|
if all_present:
|
|
print("\n✅ 所有必需字段都已存在!")
|
|
else:
|
|
print("\n⚠️ 部分字段缺失,请检查迁移结果")
|
|
|
|
except sqlite3.Error as e:
|
|
print(f"❌ 验证失败: {e}")
|
|
finally:
|
|
if conn:
|
|
conn.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("=" * 60)
|
|
print("🔄 Stock Agent 数据库迁移工具")
|
|
print("=" * 60)
|
|
print()
|
|
|
|
success = migrate_database()
|
|
|
|
if success:
|
|
verify_migration()
|
|
else:
|
|
print("\n💡 如果数据库文件不存在,启动服务时会自动创建")
|
|
|
|
print()
|
|
print("=" * 60)
|