deliveryman-api/app/models/database.py
2025-03-10 21:15:01 +08:00

137 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sqlalchemy import create_engine, event
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
from app.core.config import settings
import pymysql
import logging
import time
import threading
from contextlib import contextmanager
# 设置日志记录器
logger = logging.getLogger(__name__)
# 注册 MySQL Python SQL Driver
pymysql.install_as_MySQLdb()
# 创建数据库引擎
engine = create_engine(
settings.SQLALCHEMY_DATABASE_URL,
pool_pre_ping=True, # 自动处理断开的连接
pool_recycle=3600, # 连接回收时间1小时
pool_size=10, # 连接池大小
max_overflow=20, # 允许的最大连接数超出pool_size的数量
pool_timeout=30, # 获取连接的超时时间(秒)
echo=settings.DEBUG, # 在调试模式下打印SQL语句
)
# 创建线程安全的会话工厂
SessionLocal = scoped_session(
sessionmaker(autocommit=False, autoflush=False, bind=engine)
)
# 声明基类
Base = declarative_base()
# 会话跟踪
active_sessions = {}
session_lock = threading.Lock()
# 导入监控工具(延迟导入以避免循环依赖)
def get_db_monitor():
from app.core.db_monitor import DBConnectionMonitor
return DBConnectionMonitor
# 添加事件监听器,记录长时间运行的查询
@event.listens_for(engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
"""在执行SQL查询前记录时间"""
conn.info.setdefault('query_start_time', []).append(time.time())
conn.info.setdefault('query_statement', []).append(statement)
conn.info.setdefault('query_parameters', []).append(parameters)
if settings.DEBUG:
# 安全地记录SQL语句避免敏感信息泄露
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
logger.debug(f"开始执行查询: {safe_statement}...")
@event.listens_for(engine, "after_cursor_execute")
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
"""在执行SQL查询后计算耗时并记录慢查询"""
total = time.time() - conn.info['query_start_time'].pop(-1)
statement = conn.info['query_statement'].pop(-1)
parameters = conn.info['query_parameters'].pop(-1)
# 记录慢查询
if total > 0.5: # 记录超过0.5秒的查询
# 安全地记录SQL语句避免敏感信息泄露
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
logger.warning(f"慢查询 ({total:.2f}s): {safe_statement}...")
# 记录到监控系统
try:
monitor = get_db_monitor()
monitor.record_slow_query(total, statement, parameters)
except Exception as e:
logger.error(f"记录慢查询失败: {str(e)}")
# 依赖项
def get_db():
"""获取数据库会话"""
session_id = threading.get_ident()
session = SessionLocal()
# 记录活跃会话
with session_lock:
active_sessions[session_id] = {
"created_at": time.time(),
"thread_id": session_id
}
try:
yield session
finally:
# 关闭会话并从活跃会话中移除
session.close()
with session_lock:
if session_id in active_sessions:
del active_sessions[session_id]
@contextmanager
def get_db_context():
"""上下文管理器版本的get_db用于非依赖项场景"""
session = SessionLocal()
session_id = threading.get_ident()
# 记录活跃会话
with session_lock:
active_sessions[session_id] = {
"created_at": time.time(),
"thread_id": session_id
}
try:
yield session
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
with session_lock:
if session_id in active_sessions:
del active_sessions[session_id]
def get_active_sessions_count():
"""获取当前活跃会话数量"""
with session_lock:
return len(active_sessions)
def get_long_running_sessions(threshold_seconds=60):
"""获取长时间运行的会话"""
current_time = time.time()
with session_lock:
return {
thread_id: info for thread_id, info in active_sessions.items()
if current_time - info["created_at"] > threshold_seconds
}