137 lines
4.5 KiB
Python
137 lines
4.5 KiB
Python
from sqlalchemy import create_engine, event
|
||
from sqlalchemy.ext.declarative import declarative_base
|
||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||
from app.core.config import settings
|
||
import pymysql
|
||
import logging
|
||
import time
|
||
import threading
|
||
from contextlib import contextmanager
|
||
|
||
# 设置日志记录器
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 注册 MySQL Python SQL Driver
|
||
pymysql.install_as_MySQLdb()
|
||
|
||
# 创建数据库引擎
|
||
engine = create_engine(
|
||
settings.SQLALCHEMY_DATABASE_URL,
|
||
pool_pre_ping=True, # 自动处理断开的连接
|
||
pool_recycle=3600, # 连接回收时间(1小时)
|
||
pool_size=10, # 连接池大小
|
||
max_overflow=20, # 允许的最大连接数超出pool_size的数量
|
||
pool_timeout=30, # 获取连接的超时时间(秒)
|
||
echo=settings.DEBUG, # 在调试模式下打印SQL语句
|
||
)
|
||
|
||
# 创建线程安全的会话工厂
|
||
SessionLocal = scoped_session(
|
||
sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||
)
|
||
|
||
# 声明基类
|
||
Base = declarative_base()
|
||
|
||
# 会话跟踪
|
||
active_sessions = {}
|
||
session_lock = threading.Lock()
|
||
|
||
# 导入监控工具(延迟导入以避免循环依赖)
|
||
def get_db_monitor():
|
||
from app.core.db_monitor import DBConnectionMonitor
|
||
return DBConnectionMonitor
|
||
|
||
# 添加事件监听器,记录长时间运行的查询
|
||
@event.listens_for(engine, "before_cursor_execute")
|
||
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||
"""在执行SQL查询前记录时间"""
|
||
conn.info.setdefault('query_start_time', []).append(time.time())
|
||
conn.info.setdefault('query_statement', []).append(statement)
|
||
conn.info.setdefault('query_parameters', []).append(parameters)
|
||
|
||
if settings.DEBUG:
|
||
# 安全地记录SQL语句,避免敏感信息泄露
|
||
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
|
||
logger.debug(f"开始执行查询: {safe_statement}...")
|
||
|
||
@event.listens_for(engine, "after_cursor_execute")
|
||
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||
"""在执行SQL查询后计算耗时并记录慢查询"""
|
||
total = time.time() - conn.info['query_start_time'].pop(-1)
|
||
statement = conn.info['query_statement'].pop(-1)
|
||
parameters = conn.info['query_parameters'].pop(-1)
|
||
|
||
# 记录慢查询
|
||
if total > 0.5: # 记录超过0.5秒的查询
|
||
# 安全地记录SQL语句,避免敏感信息泄露
|
||
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
|
||
logger.warning(f"慢查询 ({total:.2f}s): {safe_statement}...")
|
||
|
||
# 记录到监控系统
|
||
try:
|
||
monitor = get_db_monitor()
|
||
monitor.record_slow_query(total, statement, parameters)
|
||
except Exception as e:
|
||
logger.error(f"记录慢查询失败: {str(e)}")
|
||
|
||
# 依赖项
|
||
def get_db():
|
||
"""获取数据库会话"""
|
||
session_id = threading.get_ident()
|
||
session = SessionLocal()
|
||
|
||
# 记录活跃会话
|
||
with session_lock:
|
||
active_sessions[session_id] = {
|
||
"created_at": time.time(),
|
||
"thread_id": session_id
|
||
}
|
||
|
||
try:
|
||
yield session
|
||
finally:
|
||
# 关闭会话并从活跃会话中移除
|
||
session.close()
|
||
with session_lock:
|
||
if session_id in active_sessions:
|
||
del active_sessions[session_id]
|
||
|
||
@contextmanager
|
||
def get_db_context():
|
||
"""上下文管理器版本的get_db,用于非依赖项场景"""
|
||
session = SessionLocal()
|
||
session_id = threading.get_ident()
|
||
|
||
# 记录活跃会话
|
||
with session_lock:
|
||
active_sessions[session_id] = {
|
||
"created_at": time.time(),
|
||
"thread_id": session_id
|
||
}
|
||
|
||
try:
|
||
yield session
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
raise e
|
||
finally:
|
||
session.close()
|
||
with session_lock:
|
||
if session_id in active_sessions:
|
||
del active_sessions[session_id]
|
||
|
||
def get_active_sessions_count():
|
||
"""获取当前活跃会话数量"""
|
||
with session_lock:
|
||
return len(active_sessions)
|
||
|
||
def get_long_running_sessions(threshold_seconds=60):
|
||
"""获取长时间运行的会话"""
|
||
current_time = time.time()
|
||
with session_lock:
|
||
return {
|
||
thread_id: info for thread_id, info in active_sessions.items()
|
||
if current_time - info["created_at"] > threshold_seconds
|
||
} |