from app.models.database import engine import logging from fastapi import FastAPI from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request import time import threading from collections import defaultdict, deque import datetime logger = logging.getLogger(__name__) class DBStats: """数据库统计信息""" def __init__(self): self.slow_queries = deque(maxlen=100) # 最多保存100条慢查询记录 self.endpoint_stats = defaultdict(lambda: {"count": 0, "total_time": 0, "max_time": 0}) self.hourly_requests = defaultdict(int) self.lock = threading.Lock() def record_slow_query(self, method, path, duration): """记录慢查询""" with self.lock: self.slow_queries.append({ "method": method, "path": path, "duration": duration, "timestamp": datetime.datetime.now().isoformat() }) def record_request(self, method, path, duration): """记录请求统计""" with self.lock: key = f"{method} {path}" self.endpoint_stats[key]["count"] += 1 self.endpoint_stats[key]["total_time"] += duration self.endpoint_stats[key]["max_time"] = max(self.endpoint_stats[key]["max_time"], duration) # 记录每小时请求数 hour = datetime.datetime.now().strftime("%Y-%m-%d %H:00") self.hourly_requests[hour] += 1 def get_stats(self): """获取统计信息""" with self.lock: # 计算平均响应时间 avg_times = {} for endpoint, stats in self.endpoint_stats.items(): if stats["count"] > 0: avg_times[endpoint] = stats["total_time"] / stats["count"] # 按平均响应时间排序的前10个端点 sorted_endpoints = sorted( avg_times.items(), key=lambda x: x[1], reverse=True )[:10] return { "slow_queries": list(self.slow_queries), "top_slow_endpoints": sorted_endpoints, "hourly_requests": dict(self.hourly_requests) } class DBConnectionMonitor: """数据库连接监控工具""" # 初始化统计对象 stats = DBStats() @staticmethod def get_connection_stats(): """获取当前连接池状态""" pool = engine.pool return { "pool_size": pool.size(), "checkedin": pool.checkedin(), "checkedout": pool.checkedout(), "overflow": pool.overflow(), } @staticmethod def log_connection_stats(): """记录当前连接池状态""" stats = DBConnectionMonitor.get_connection_stats() logger.info(f"DB Connection Pool Stats: {stats}") return stats @staticmethod def get_all_stats(): """获取所有统计信息""" return { "connection_pool": DBConnectionMonitor.get_connection_stats(), "performance_stats": DBConnectionMonitor.stats.get_stats() } class DBMonitorMiddleware(BaseHTTPMiddleware): """数据库连接监控中间件,定期记录连接池状态""" def __init__(self, app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0): super().__init__(app) self.log_interval = log_interval self.request_count = 0 self.slow_query_threshold = slow_query_threshold async def dispatch(self, request: Request, call_next): self.request_count += 1 # 每处理N个请求记录一次连接池状态 if self.request_count % self.log_interval == 0: DBConnectionMonitor.log_connection_stats() # 记录请求处理时间 start_time = time.time() response = await call_next(request) process_time = time.time() - start_time # 记录请求统计 method = request.method path = request.url.path DBConnectionMonitor.stats.record_request(method, path, process_time) # 如果请求处理时间超过阈值,记录为慢查询 if process_time > self.slow_query_threshold: logger.warning(f"Slow request: {method} {path} took {process_time:.2f}s") DBConnectionMonitor.stats.record_slow_query(method, path, process_time) return response def setup_db_monitor(app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0): """设置数据库监控""" app.add_middleware( DBMonitorMiddleware, log_interval=log_interval, slow_query_threshold=slow_query_threshold ) @app.on_event("startup") async def startup_db_monitor(): logger.info("Starting DB connection monitoring") DBConnectionMonitor.log_connection_stats() @app.on_event("shutdown") async def shutdown_db_monitor(): logger.info("Final DB connection stats before shutdown") DBConnectionMonitor.log_connection_stats()