222 lines
7.6 KiB
Python
222 lines
7.6 KiB
Python
from sqlalchemy import create_engine, event, text
|
||
from sqlalchemy.ext.declarative import declarative_base
|
||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||
from app.core.config import settings
|
||
import pymysql
|
||
import logging
|
||
import time
|
||
import threading
|
||
from contextlib import contextmanager
|
||
|
||
# 设置日志记录器
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 注册 MySQL Python SQL Driver
|
||
pymysql.install_as_MySQLdb()
|
||
|
||
# 创建数据库引擎
|
||
engine = create_engine(
|
||
settings.SQLALCHEMY_DATABASE_URL,
|
||
pool_pre_ping=True, # 自动处理断开的连接
|
||
pool_recycle=1800, # 连接回收时间(30分钟)
|
||
pool_size=10, # 连接池大小
|
||
max_overflow=20, # 允许的最大连接数超出pool_size的数量
|
||
pool_timeout=30, # 获取连接的超时时间(秒)
|
||
echo_pool=settings.DEBUG, # 在调试模式下记录连接池事件
|
||
# 添加以下参数以增强连接稳定性
|
||
pool_reset_on_return='commit', # 在连接返回池时重置连接状态
|
||
isolation_level='READ COMMITTED', # 设置事务隔离级别
|
||
connect_args={
|
||
'connect_timeout': 10, # 连接超时时间(秒)
|
||
'read_timeout': 30, # 读取超时时间(秒)
|
||
'write_timeout': 30 # 写入超时时间(秒)
|
||
}
|
||
)
|
||
|
||
# 创建线程安全的会话工厂
|
||
SessionLocal = scoped_session(
|
||
sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||
)
|
||
|
||
# 声明基类
|
||
Base = declarative_base()
|
||
|
||
# 会话跟踪
|
||
active_sessions = {}
|
||
session_lock = threading.Lock()
|
||
|
||
# 导入监控工具(延迟导入以避免循环依赖)
|
||
def get_db_monitor():
|
||
from app.core.db_monitor import DBConnectionMonitor
|
||
return DBConnectionMonitor
|
||
|
||
# 添加连接池事件监听器
|
||
@event.listens_for(engine, "connect")
|
||
def connect(dbapi_connection, connection_record):
|
||
"""连接创建时的回调"""
|
||
logger.debug("数据库连接已创建")
|
||
|
||
@event.listens_for(engine, "checkout")
|
||
def checkout(dbapi_connection, connection_record, connection_proxy):
|
||
"""连接从池中取出时的回调"""
|
||
# 记录连接被取出的时间
|
||
connection_record.info['checkout_time'] = time.time()
|
||
|
||
# 验证连接是否有效
|
||
try:
|
||
# 执行简单查询测试连接
|
||
cursor = dbapi_connection.cursor()
|
||
cursor.execute("SELECT 1")
|
||
cursor.fetchone() # 确保实际获取结果
|
||
cursor.close()
|
||
except Exception as e:
|
||
# 如果连接无效,断开它并尝试重新连接
|
||
logger.warning(f"检测到无效的数据库连接: {str(e)}")
|
||
# 标记连接为无效
|
||
connection_record.invalidate()
|
||
# 强制处理连接池
|
||
connection_proxy._pool.dispose()
|
||
raise
|
||
|
||
@event.listens_for(engine, "checkin")
|
||
def checkin(dbapi_connection, connection_record):
|
||
"""连接归还到池中时的回调"""
|
||
checkout_time = connection_record.info.get('checkout_time')
|
||
if checkout_time:
|
||
# 计算连接被使用的时间
|
||
used_time = time.time() - checkout_time
|
||
if used_time > 10: # 记录使用时间超过10秒的连接
|
||
logger.warning(f"数据库连接使用时间较长: {used_time:.2f}秒")
|
||
|
||
# 添加事件监听器,记录长时间运行的查询
|
||
@event.listens_for(engine, "before_cursor_execute")
|
||
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||
"""在执行SQL查询前记录时间"""
|
||
conn.info.setdefault('query_start_time', []).append(time.time())
|
||
conn.info.setdefault('query_statement', []).append(statement)
|
||
conn.info.setdefault('query_parameters', []).append(parameters)
|
||
|
||
if settings.DEBUG:
|
||
# 安全地记录SQL语句,避免敏感信息泄露
|
||
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
|
||
logger.debug(f"开始执行查询: {safe_statement}...")
|
||
|
||
@event.listens_for(engine, "after_cursor_execute")
|
||
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||
"""在执行SQL查询后计算耗时并记录慢查询"""
|
||
total = time.time() - conn.info['query_start_time'].pop(-1)
|
||
statement = conn.info['query_statement'].pop(-1)
|
||
parameters = conn.info['query_parameters'].pop(-1)
|
||
|
||
# 记录慢查询
|
||
if total > 0.5: # 记录超过0.5秒的查询
|
||
# 安全地记录SQL语句,避免敏感信息泄露
|
||
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
|
||
logger.warning(f"慢查询 ({total:.2f}s): {safe_statement}...")
|
||
|
||
# 记录到监控系统
|
||
try:
|
||
monitor = get_db_monitor()
|
||
monitor.record_slow_query(total, statement, parameters)
|
||
except Exception as e:
|
||
logger.error(f"记录慢查询失败: {str(e)}")
|
||
|
||
# 依赖项
|
||
def get_db():
|
||
"""获取数据库会话"""
|
||
session_id = threading.get_ident()
|
||
db = None
|
||
retry_count = 0
|
||
max_retries = 3
|
||
|
||
while retry_count < max_retries:
|
||
try:
|
||
db = SessionLocal()
|
||
# 测试连接是否有效
|
||
db.execute(text("SELECT 1"))
|
||
break
|
||
except Exception as e:
|
||
retry_count += 1
|
||
if db:
|
||
db.close()
|
||
if retry_count >= max_retries:
|
||
logger.error(f"无法建立数据库连接,已重试 {retry_count} 次: {str(e)}")
|
||
raise
|
||
logger.warning(f"数据库连接失败,正在重试 ({retry_count}/{max_retries}): {str(e)}")
|
||
time.sleep(0.5) # 短暂延迟后重试
|
||
|
||
# 记录活跃会话
|
||
with session_lock:
|
||
active_sessions[session_id] = {
|
||
"created_at": time.time(),
|
||
"thread_id": session_id
|
||
}
|
||
|
||
try:
|
||
yield db
|
||
finally:
|
||
# 关闭会话并从活跃会话中移除
|
||
if db:
|
||
db.close()
|
||
with session_lock:
|
||
if session_id in active_sessions:
|
||
del active_sessions[session_id]
|
||
|
||
@contextmanager
|
||
def get_db_context():
|
||
"""上下文管理器版本的get_db,用于非依赖项场景"""
|
||
session = None
|
||
retry_count = 0
|
||
max_retries = 3
|
||
session_id = threading.get_ident()
|
||
|
||
while retry_count < max_retries:
|
||
try:
|
||
session = SessionLocal()
|
||
# 测试连接是否有效
|
||
session.execute(text("SELECT 1"))
|
||
break
|
||
except Exception as e:
|
||
retry_count += 1
|
||
if session:
|
||
session.close()
|
||
if retry_count >= max_retries:
|
||
logger.error(f"无法建立数据库连接,已重试 {retry_count} 次: {str(e)}")
|
||
raise
|
||
logger.warning(f"数据库连接失败,正在重试 ({retry_count}/{max_retries}): {str(e)}")
|
||
time.sleep(0.5) # 短暂延迟后重试
|
||
|
||
# 记录活跃会话
|
||
with session_lock:
|
||
active_sessions[session_id] = {
|
||
"created_at": time.time(),
|
||
"thread_id": session_id
|
||
}
|
||
|
||
try:
|
||
yield session
|
||
session.commit()
|
||
except Exception as e:
|
||
if session:
|
||
session.rollback()
|
||
raise e
|
||
finally:
|
||
if session:
|
||
session.close()
|
||
with session_lock:
|
||
if session_id in active_sessions:
|
||
del active_sessions[session_id]
|
||
|
||
def get_active_sessions_count():
|
||
"""获取当前活跃会话数量"""
|
||
with session_lock:
|
||
return len(active_sessions)
|
||
|
||
def get_long_running_sessions(threshold_seconds=60):
|
||
"""获取长时间运行的会话"""
|
||
current_time = time.time()
|
||
with session_lock:
|
||
return {
|
||
thread_id: info for thread_id, info in active_sessions.items()
|
||
if current_time - info["created_at"] > threshold_seconds
|
||
} |