from sqlalchemy import create_engine, event from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, scoped_session from app.core.config import settings import pymysql import logging import time import threading from contextlib import contextmanager # 设置日志记录器 logger = logging.getLogger(__name__) # 注册 MySQL Python SQL Driver pymysql.install_as_MySQLdb() # 创建数据库引擎 engine = create_engine( settings.SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, # 自动处理断开的连接 pool_recycle=3600, # 连接回收时间(1小时) pool_size=10, # 连接池大小 max_overflow=20, # 允许的最大连接数超出pool_size的数量 pool_timeout=30, # 获取连接的超时时间(秒) echo=settings.DEBUG, # 在调试模式下打印SQL语句 ) # 创建线程安全的会话工厂 SessionLocal = scoped_session( sessionmaker(autocommit=False, autoflush=False, bind=engine) ) # 声明基类 Base = declarative_base() # 会话跟踪 active_sessions = {} session_lock = threading.Lock() # 添加事件监听器,记录长时间运行的查询 @event.listens_for(engine, "before_cursor_execute") def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): """在执行SQL查询前记录时间""" conn.info.setdefault('query_start_time', []).append(time.time()) if settings.DEBUG: # 安全地记录SQL语句,避免敏感信息泄露 safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白 logger.debug(f"开始执行查询: {safe_statement}...") @event.listens_for(engine, "after_cursor_execute") def after_cursor_execute(conn, cursor, statement, parameters, context, executemany): """在执行SQL查询后计算耗时并记录慢查询""" total = time.time() - conn.info['query_start_time'].pop(-1) # 记录慢查询 if total > 0.5: # 记录超过0.5秒的查询 # 安全地记录SQL语句,避免敏感信息泄露 safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白 logger.warning(f"慢查询 ({total:.2f}s): {safe_statement}...") # 依赖项 def get_db(): """获取数据库会话""" session_id = threading.get_ident() session = SessionLocal() # 记录活跃会话 with session_lock: active_sessions[session_id] = { "created_at": time.time(), "thread_id": session_id } try: yield session finally: # 关闭会话并从活跃会话中移除 session.close() with session_lock: if session_id in active_sessions: del active_sessions[session_id] @contextmanager def get_db_context(): """上下文管理器版本的get_db,用于非依赖项场景""" session = SessionLocal() session_id = threading.get_ident() # 记录活跃会话 with session_lock: active_sessions[session_id] = { "created_at": time.time(), "thread_id": session_id } try: yield session session.commit() except Exception as e: session.rollback() raise e finally: session.close() with session_lock: if session_id in active_sessions: del active_sessions[session_id] def get_active_sessions_count(): """获取当前活跃会话数量""" with session_lock: return len(active_sessions) def get_long_running_sessions(threshold_seconds=60): """获取长时间运行的会话""" current_time = time.time() with session_lock: return { thread_id: info for thread_id, info in active_sessions.items() if current_time - info["created_at"] > threshold_seconds }