125 lines
4.2 KiB
Python
125 lines
4.2 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
添加 margin 和 leverage 字段到 paper_orders 表
|
||
|
||
运行方式:
|
||
cd backend && python scripts/migrate_add_margin_leverage.py
|
||
"""
|
||
import sys
|
||
import os
|
||
|
||
# 添加父目录到路径
|
||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from app.services.db_service import db_service
|
||
from sqlalchemy import text, inspect
|
||
import logging
|
||
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def check_column_exists(table_name, column_name):
|
||
"""检查列是否存在"""
|
||
inspector = inspect(db_service.engine)
|
||
columns = [col['name'] for col in inspector.get_columns(table_name)]
|
||
return column_name in columns
|
||
|
||
|
||
def create_tables_if_not_exist():
|
||
"""创建表(如果不存在)"""
|
||
from app.models.paper_trading import PaperOrder as POTable
|
||
from app.models.database import Base
|
||
|
||
# 创建所有表
|
||
Base.metadata.create_all(bind=db_service.engine)
|
||
logger.info("✅ 数据库表已创建")
|
||
|
||
|
||
def migrate():
|
||
"""执行迁移"""
|
||
try:
|
||
# 首先确保表存在
|
||
inspector = inspect(db_service.engine)
|
||
if 'paper_orders' not in inspector.get_table_names():
|
||
logger.info("⚠️ paper_orders 表不存在,先创建表...")
|
||
create_tables_if_not_exist()
|
||
# 重新检查
|
||
inspector = inspect(db_service.engine)
|
||
if 'paper_orders' not in inspector.get_table_names():
|
||
logger.error("❌ 创建表失败")
|
||
return False
|
||
|
||
with db_service.engine.connect() as conn:
|
||
|
||
logger.info("开始迁移 paper_orders 表...")
|
||
|
||
# 检查并添加 margin 列
|
||
if not check_column_exists('paper_orders', 'margin'):
|
||
logger.info("➕ 添加 margin 列...")
|
||
conn.execute(text("""
|
||
ALTER TABLE paper_orders
|
||
ADD COLUMN margin FLOAT DEFAULT 50
|
||
"""))
|
||
conn.commit()
|
||
logger.info("✅ margin 列添加成功")
|
||
|
||
# 为现有记录计算并设置 margin 值
|
||
logger.info("🔄 为现有记录计算 margin...")
|
||
conn.execute(text("""
|
||
UPDATE paper_orders
|
||
SET margin = ROUND(quantity / 20.0, 2)
|
||
WHERE margin IS NULL OR margin = 50
|
||
"""))
|
||
conn.commit()
|
||
logger.info("✅ 现有记录的 margin 已更新")
|
||
else:
|
||
logger.info("✓ margin 列已存在,跳过")
|
||
|
||
# 检查并添加 leverage 列
|
||
if not check_column_exists('paper_orders', 'leverage'):
|
||
logger.info("➕ 添加 leverage 列...")
|
||
conn.execute(text("""
|
||
ALTER TABLE paper_orders
|
||
ADD COLUMN leverage INTEGER DEFAULT 20
|
||
"""))
|
||
conn.commit()
|
||
logger.info("✅ leverage 列添加成功")
|
||
|
||
# 为现有记录设置 leverage 值
|
||
logger.info("🔄 为现有记录设置 leverage...")
|
||
conn.execute(text("""
|
||
UPDATE paper_orders
|
||
SET leverage = 20
|
||
WHERE leverage IS NULL OR leverage = 20
|
||
"""))
|
||
conn.commit()
|
||
logger.info("✅ 现有记录的 leverage 已更新")
|
||
else:
|
||
logger.info("✓ leverage 列已存在,跳过")
|
||
|
||
logger.info("\n" + "=" * 50)
|
||
logger.info("✅ 迁移完成!")
|
||
logger.info("=" * 50)
|
||
|
||
# 显示更新后的表结构
|
||
inspector = inspect(db_service.engine)
|
||
columns = inspector.get_columns('paper_orders')
|
||
logger.info("\n📊 paper_orders 表结构:")
|
||
for col in columns:
|
||
if col['name'] in ['quantity', 'margin', 'leverage']:
|
||
logger.info(f" - {col['name']}: {col['type']}")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.error(f"❌ 迁移失败: {e}")
|
||
import traceback
|
||
logger.error(traceback.format_exc())
|
||
return False
|
||
|
||
|
||
if __name__ == "__main__":
|
||
success = migrate()
|
||
sys.exit(0 if success else 1)
|