stock-ai-agent/backend/scripts/migrate_add_margin_leverage.py
2026-03-03 16:51:17 +08:00

125 lines
4.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
添加 margin 和 leverage 字段到 paper_orders 表
运行方式:
cd backend && python scripts/migrate_add_margin_leverage.py
"""
import sys
import os
# 添加父目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.services.db_service import db_service
from sqlalchemy import text, inspect
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def check_column_exists(table_name, column_name):
"""检查列是否存在"""
inspector = inspect(db_service.engine)
columns = [col['name'] for col in inspector.get_columns(table_name)]
return column_name in columns
def create_tables_if_not_exist():
"""创建表(如果不存在)"""
from app.models.paper_trading import PaperOrder as POTable
from app.models.database import Base
# 创建所有表
Base.metadata.create_all(bind=db_service.engine)
logger.info("✅ 数据库表已创建")
def migrate():
"""执行迁移"""
try:
# 首先确保表存在
inspector = inspect(db_service.engine)
if 'paper_orders' not in inspector.get_table_names():
logger.info("⚠️ paper_orders 表不存在,先创建表...")
create_tables_if_not_exist()
# 重新检查
inspector = inspect(db_service.engine)
if 'paper_orders' not in inspector.get_table_names():
logger.error("❌ 创建表失败")
return False
with db_service.engine.connect() as conn:
logger.info("开始迁移 paper_orders 表...")
# 检查并添加 margin 列
if not check_column_exists('paper_orders', 'margin'):
logger.info(" 添加 margin 列...")
conn.execute(text("""
ALTER TABLE paper_orders
ADD COLUMN margin FLOAT DEFAULT 50
"""))
conn.commit()
logger.info("✅ margin 列添加成功")
# 为现有记录计算并设置 margin 值
logger.info("🔄 为现有记录计算 margin...")
conn.execute(text("""
UPDATE paper_orders
SET margin = ROUND(quantity / 20.0, 2)
WHERE margin IS NULL OR margin = 50
"""))
conn.commit()
logger.info("✅ 现有记录的 margin 已更新")
else:
logger.info("✓ margin 列已存在,跳过")
# 检查并添加 leverage 列
if not check_column_exists('paper_orders', 'leverage'):
logger.info(" 添加 leverage 列...")
conn.execute(text("""
ALTER TABLE paper_orders
ADD COLUMN leverage INTEGER DEFAULT 20
"""))
conn.commit()
logger.info("✅ leverage 列添加成功")
# 为现有记录设置 leverage 值
logger.info("🔄 为现有记录设置 leverage...")
conn.execute(text("""
UPDATE paper_orders
SET leverage = 20
WHERE leverage IS NULL OR leverage = 20
"""))
conn.commit()
logger.info("✅ 现有记录的 leverage 已更新")
else:
logger.info("✓ leverage 列已存在,跳过")
logger.info("\n" + "=" * 50)
logger.info("✅ 迁移完成!")
logger.info("=" * 50)
# 显示更新后的表结构
inspector = inspect(db_service.engine)
columns = inspector.get_columns('paper_orders')
logger.info("\n📊 paper_orders 表结构:")
for col in columns:
if col['name'] in ['quantity', 'margin', 'leverage']:
logger.info(f" - {col['name']}: {col['type']}")
return True
except Exception as e:
logger.error(f"❌ 迁移失败: {e}")
import traceback
logger.error(traceback.format_exc())
return False
if __name__ == "__main__":
success = migrate()
sys.exit(0 if success else 1)