deliveryman-api/app/models/database.py
2025-03-12 15:48:33 +08:00

222 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sqlalchemy import create_engine, event, text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
from app.core.config import settings
import pymysql
import logging
import time
import threading
from contextlib import contextmanager
# 设置日志记录器
logger = logging.getLogger(__name__)
# 注册 MySQL Python SQL Driver
pymysql.install_as_MySQLdb()
# 创建数据库引擎
engine = create_engine(
settings.SQLALCHEMY_DATABASE_URL,
pool_pre_ping=True, # 自动处理断开的连接
pool_recycle=1800, # 连接回收时间30分钟
pool_size=10, # 连接池大小
max_overflow=20, # 允许的最大连接数超出pool_size的数量
pool_timeout=30, # 获取连接的超时时间(秒)
echo_pool=settings.DEBUG, # 在调试模式下记录连接池事件
# 添加以下参数以增强连接稳定性
pool_reset_on_return='commit', # 在连接返回池时重置连接状态
isolation_level='READ COMMITTED', # 设置事务隔离级别
connect_args={
'connect_timeout': 10, # 连接超时时间(秒)
'read_timeout': 30, # 读取超时时间(秒)
'write_timeout': 30 # 写入超时时间(秒)
}
)
# 创建线程安全的会话工厂
SessionLocal = scoped_session(
sessionmaker(autocommit=False, autoflush=False, bind=engine)
)
# 声明基类
Base = declarative_base()
# 会话跟踪
active_sessions = {}
session_lock = threading.Lock()
# 导入监控工具(延迟导入以避免循环依赖)
def get_db_monitor():
from app.core.db_monitor import DBConnectionMonitor
return DBConnectionMonitor
# 添加连接池事件监听器
@event.listens_for(engine, "connect")
def connect(dbapi_connection, connection_record):
"""连接创建时的回调"""
logger.debug("数据库连接已创建")
@event.listens_for(engine, "checkout")
def checkout(dbapi_connection, connection_record, connection_proxy):
"""连接从池中取出时的回调"""
# 记录连接被取出的时间
connection_record.info['checkout_time'] = time.time()
# 验证连接是否有效
try:
# 执行简单查询测试连接
cursor = dbapi_connection.cursor()
cursor.execute("SELECT 1")
cursor.fetchone() # 确保实际获取结果
cursor.close()
except Exception as e:
# 如果连接无效,断开它并尝试重新连接
logger.warning(f"检测到无效的数据库连接: {str(e)}")
# 标记连接为无效
connection_record.invalidate()
# 强制处理连接池
connection_proxy._pool.dispose()
raise
@event.listens_for(engine, "checkin")
def checkin(dbapi_connection, connection_record):
"""连接归还到池中时的回调"""
checkout_time = connection_record.info.get('checkout_time')
if checkout_time:
# 计算连接被使用的时间
used_time = time.time() - checkout_time
if used_time > 10: # 记录使用时间超过10秒的连接
logger.warning(f"数据库连接使用时间较长: {used_time:.2f}")
# 添加事件监听器,记录长时间运行的查询
@event.listens_for(engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
"""在执行SQL查询前记录时间"""
conn.info.setdefault('query_start_time', []).append(time.time())
conn.info.setdefault('query_statement', []).append(statement)
conn.info.setdefault('query_parameters', []).append(parameters)
if settings.DEBUG:
# 安全地记录SQL语句避免敏感信息泄露
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
logger.debug(f"开始执行查询: {safe_statement}...")
@event.listens_for(engine, "after_cursor_execute")
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
"""在执行SQL查询后计算耗时并记录慢查询"""
total = time.time() - conn.info['query_start_time'].pop(-1)
statement = conn.info['query_statement'].pop(-1)
parameters = conn.info['query_parameters'].pop(-1)
# 记录慢查询
if total > 0.5: # 记录超过0.5秒的查询
# 安全地记录SQL语句避免敏感信息泄露
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
logger.warning(f"慢查询 ({total:.2f}s): {safe_statement}...")
# 记录到监控系统
try:
monitor = get_db_monitor()
monitor.record_slow_query(total, statement, parameters)
except Exception as e:
logger.error(f"记录慢查询失败: {str(e)}")
# 依赖项
def get_db():
"""获取数据库会话"""
session_id = threading.get_ident()
db = None
retry_count = 0
max_retries = 3
while retry_count < max_retries:
try:
db = SessionLocal()
# 测试连接是否有效
db.execute(text("SELECT 1"))
break
except Exception as e:
retry_count += 1
if db:
db.close()
if retry_count >= max_retries:
logger.error(f"无法建立数据库连接,已重试 {retry_count} 次: {str(e)}")
raise
logger.warning(f"数据库连接失败,正在重试 ({retry_count}/{max_retries}): {str(e)}")
time.sleep(0.5) # 短暂延迟后重试
# 记录活跃会话
with session_lock:
active_sessions[session_id] = {
"created_at": time.time(),
"thread_id": session_id
}
try:
yield db
finally:
# 关闭会话并从活跃会话中移除
if db:
db.close()
with session_lock:
if session_id in active_sessions:
del active_sessions[session_id]
@contextmanager
def get_db_context():
"""上下文管理器版本的get_db用于非依赖项场景"""
session = None
retry_count = 0
max_retries = 3
session_id = threading.get_ident()
while retry_count < max_retries:
try:
session = SessionLocal()
# 测试连接是否有效
session.execute(text("SELECT 1"))
break
except Exception as e:
retry_count += 1
if session:
session.close()
if retry_count >= max_retries:
logger.error(f"无法建立数据库连接,已重试 {retry_count} 次: {str(e)}")
raise
logger.warning(f"数据库连接失败,正在重试 ({retry_count}/{max_retries}): {str(e)}")
time.sleep(0.5) # 短暂延迟后重试
# 记录活跃会话
with session_lock:
active_sessions[session_id] = {
"created_at": time.time(),
"thread_id": session_id
}
try:
yield session
session.commit()
except Exception as e:
if session:
session.rollback()
raise e
finally:
if session:
session.close()
with session_lock:
if session_id in active_sessions:
del active_sessions[session_id]
def get_active_sessions_count():
"""获取当前活跃会话数量"""
with session_lock:
return len(active_sessions)
def get_long_running_sessions(threshold_seconds=60):
"""获取长时间运行的会话"""
current_time = time.time()
with session_lock:
return {
thread_id: info for thread_id, info in active_sessions.items()
if current_time - info["created_at"] > threshold_seconds
}