from sqlalchemy import create_engine, event, text from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, scoped_session from app.core.config import settings import pymysql import logging import time import threading from contextlib import contextmanager # 设置日志记录器 logger = logging.getLogger(__name__) # 注册 MySQL Python SQL Driver pymysql.install_as_MySQLdb() # 创建数据库引擎 engine = create_engine( settings.SQLALCHEMY_DATABASE_URL, pool_pre_ping=True, # 自动处理断开的连接 pool_recycle=1800, # 连接回收时间(30分钟) pool_size=10, # 连接池大小 max_overflow=20, # 允许的最大连接数超出pool_size的数量 pool_timeout=30, # 获取连接的超时时间(秒) echo_pool=settings.DEBUG, # 在调试模式下记录连接池事件 # 添加以下参数以增强连接稳定性 pool_reset_on_return='commit', # 在连接返回池时重置连接状态 isolation_level='READ COMMITTED', # 设置事务隔离级别 connect_args={ 'connect_timeout': 10, # 连接超时时间(秒) 'read_timeout': 30, # 读取超时时间(秒) 'write_timeout': 30 # 写入超时时间(秒) } ) # 创建线程安全的会话工厂 SessionLocal = scoped_session( sessionmaker(autocommit=False, autoflush=False, bind=engine) ) # 声明基类 Base = declarative_base() # 会话跟踪 active_sessions = {} session_lock = threading.Lock() # 导入监控工具(延迟导入以避免循环依赖) def get_db_monitor(): from app.core.db_monitor import DBConnectionMonitor return DBConnectionMonitor # 添加连接池事件监听器 @event.listens_for(engine, "connect") def connect(dbapi_connection, connection_record): """连接创建时的回调""" logger.debug("数据库连接已创建") @event.listens_for(engine, "checkout") def checkout(dbapi_connection, connection_record, connection_proxy): """连接从池中取出时的回调""" # 记录连接被取出的时间 connection_record.info['checkout_time'] = time.time() # 验证连接是否有效 try: # 执行简单查询测试连接 cursor = dbapi_connection.cursor() cursor.execute("SELECT 1") cursor.fetchone() # 确保实际获取结果 cursor.close() except Exception as e: # 如果连接无效,断开它并尝试重新连接 logger.warning(f"检测到无效的数据库连接: {str(e)}") # 标记连接为无效 connection_record.invalidate() # 强制处理连接池 connection_proxy._pool.dispose() raise @event.listens_for(engine, "checkin") def checkin(dbapi_connection, connection_record): """连接归还到池中时的回调""" checkout_time = connection_record.info.get('checkout_time') if checkout_time: # 计算连接被使用的时间 used_time = time.time() - checkout_time if used_time > 10: # 记录使用时间超过10秒的连接 logger.warning(f"数据库连接使用时间较长: {used_time:.2f}秒") # 添加事件监听器,记录长时间运行的查询 @event.listens_for(engine, "before_cursor_execute") def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): """在执行SQL查询前记录时间""" conn.info.setdefault('query_start_time', []).append(time.time()) conn.info.setdefault('query_statement', []).append(statement) conn.info.setdefault('query_parameters', []).append(parameters) if settings.DEBUG: # 安全地记录SQL语句,避免敏感信息泄露 safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白 logger.debug(f"开始执行查询: {safe_statement}...") @event.listens_for(engine, "after_cursor_execute") def after_cursor_execute(conn, cursor, statement, parameters, context, executemany): """在执行SQL查询后计算耗时并记录慢查询""" total = time.time() - conn.info['query_start_time'].pop(-1) statement = conn.info['query_statement'].pop(-1) parameters = conn.info['query_parameters'].pop(-1) # 记录慢查询 if total > 0.5: # 记录超过0.5秒的查询 # 安全地记录SQL语句,避免敏感信息泄露 safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白 logger.warning(f"慢查询 ({total:.2f}s): {safe_statement}...") # 记录到监控系统 try: monitor = get_db_monitor() monitor.record_slow_query(total, statement, parameters) except Exception as e: logger.error(f"记录慢查询失败: {str(e)}") # 依赖项 def get_db(): """获取数据库会话""" session_id = threading.get_ident() db = None retry_count = 0 max_retries = 3 while retry_count < max_retries: try: db = SessionLocal() # 测试连接是否有效 db.execute(text("SELECT 1")) break except Exception as e: retry_count += 1 if db: db.close() if retry_count >= max_retries: logger.error(f"无法建立数据库连接,已重试 {retry_count} 次: {str(e)}") raise logger.warning(f"数据库连接失败,正在重试 ({retry_count}/{max_retries}): {str(e)}") time.sleep(0.5) # 短暂延迟后重试 # 记录活跃会话 with session_lock: active_sessions[session_id] = { "created_at": time.time(), "thread_id": session_id } try: yield db finally: # 关闭会话并从活跃会话中移除 if db: db.close() with session_lock: if session_id in active_sessions: del active_sessions[session_id] @contextmanager def get_db_context(): """上下文管理器版本的get_db,用于非依赖项场景""" session = None retry_count = 0 max_retries = 3 session_id = threading.get_ident() while retry_count < max_retries: try: session = SessionLocal() # 测试连接是否有效 session.execute(text("SELECT 1")) break except Exception as e: retry_count += 1 if session: session.close() if retry_count >= max_retries: logger.error(f"无法建立数据库连接,已重试 {retry_count} 次: {str(e)}") raise logger.warning(f"数据库连接失败,正在重试 ({retry_count}/{max_retries}): {str(e)}") time.sleep(0.5) # 短暂延迟后重试 # 记录活跃会话 with session_lock: active_sessions[session_id] = { "created_at": time.time(), "thread_id": session_id } try: yield session session.commit() except Exception as e: if session: session.rollback() raise e finally: if session: session.close() with session_lock: if session_id in active_sessions: del active_sessions[session_id] def get_active_sessions_count(): """获取当前活跃会话数量""" with session_lock: return len(active_sessions) def get_long_running_sessions(threshold_seconds=60): """获取长时间运行的会话""" current_time = time.time() with session_lock: return { thread_id: info for thread_id, info in active_sessions.items() if current_time - info["created_at"] > threshold_seconds }