81 lines
2.3 KiB
Python
81 lines
2.3 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import os
|
||
import logging
|
||
from typing import Dict, Any, List, Optional, Union
|
||
from datetime import datetime
|
||
|
||
from sqlalchemy import create_engine
|
||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||
from sqlalchemy.pool import QueuePool
|
||
|
||
from cryptoai.utils.config_loader import ConfigLoader
|
||
|
||
# 配置日志
|
||
logger = logging.getLogger('db_manager')
|
||
logger.setLevel(logging.DEBUG)
|
||
|
||
|
||
config_loader = ConfigLoader()
|
||
db_config = config_loader.get_database_config()
|
||
engine = create_engine(f"mysql+pymysql://{db_config['user']}:{db_config['password']}@{db_config['host']}:{db_config['port']}/{db_config['db_name']}?charset=utf8mb4",
|
||
echo=False, # 设置为True可以输出SQL语句(调试用)
|
||
pool_size=5, # 连接池大小
|
||
max_overflow=10, # 最大溢出连接数
|
||
pool_timeout=30, # 连接超时时间
|
||
pool_recycle=1800, # 连接回收时间(秒)
|
||
pool_pre_ping=True, # 在使用连接前先ping一下,确保连接有效
|
||
connect_args={'charset': 'utf8mb4'})
|
||
|
||
# 创建线程安全的会话工厂
|
||
SessionLocal = scoped_session(
|
||
sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||
)
|
||
|
||
def init_db():
|
||
try:
|
||
# 导入 Base 和所有模型(避免循环导入)
|
||
from cryptoai.models.base import Base
|
||
from cryptoai.models import (
|
||
User, Token, AnalysisResult, UserQuestion,
|
||
AStock, AnalysisHistory, SubscriptionOrder, UserSubscription
|
||
)
|
||
|
||
Base.metadata.create_all(bind=engine, checkfirst=True)
|
||
logger.info("数据库初始化成功")
|
||
|
||
# 输出已创建的表列表
|
||
tables = list(Base.metadata.tables.keys())
|
||
logger.info(f"已创建的数据表: {tables}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"初始化数据库失败: {e}")
|
||
raise e
|
||
|
||
def get_db():
|
||
db = SessionLocal()
|
||
try:
|
||
yield db
|
||
finally:
|
||
if db:
|
||
db.close()
|
||
|
||
def get_db_context():
|
||
try:
|
||
db = SessionLocal()
|
||
yield db
|
||
db.commit()
|
||
except Exception as e:
|
||
if db:
|
||
db.rollback()
|
||
raise e
|
||
finally:
|
||
if db:
|
||
db.close()
|
||
|
||
|
||
|
||
|
||
|