from app.models.database import engine import logging from fastapi import FastAPI from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request import time import threading from collections import defaultdict, deque, Counter import datetime import json import re logger = logging.getLogger(__name__) # SQL查询模式识别正则表达式 SELECT_PATTERN = re.compile(r'SELECT\s+.*?\s+FROM\s+(\w+)', re.IGNORECASE) INSERT_PATTERN = re.compile(r'INSERT\s+INTO\s+(\w+)', re.IGNORECASE) UPDATE_PATTERN = re.compile(r'UPDATE\s+(\w+)\s+SET', re.IGNORECASE) DELETE_PATTERN = re.compile(r'DELETE\s+FROM\s+(\w+)', re.IGNORECASE) def extract_table_name(sql): """从SQL语句中提取表名""" for pattern in [SELECT_PATTERN, INSERT_PATTERN, UPDATE_PATTERN, DELETE_PATTERN]: match = pattern.search(sql) if match: return match.group(1) return "unknown" class DBStats: """数据库统计信息""" def __init__(self): self.slow_queries = deque(maxlen=100) # 最多保存100条慢查询记录 self.endpoint_stats = defaultdict(lambda: {"count": 0, "total_time": 0, "max_time": 0}) self.hourly_requests = defaultdict(int) self.table_access_count = Counter() # 表访问计数 self.query_patterns = Counter() # 查询模式计数 self.lock = threading.Lock() def record_slow_query(self, method, path, duration, sql=None, params=None): """记录慢查询""" with self.lock: # 提取表名和查询类型 table_name = "unknown" query_type = "unknown" if sql: # 确定查询类型 if sql.strip().upper().startswith("SELECT"): query_type = "SELECT" elif sql.strip().upper().startswith("INSERT"): query_type = "INSERT" elif sql.strip().upper().startswith("UPDATE"): query_type = "UPDATE" elif sql.strip().upper().startswith("DELETE"): query_type = "DELETE" # 提取表名 table_name = extract_table_name(sql) # 更新表访问统计 self.table_access_count[table_name] += 1 # 更新查询模式统计 query_pattern = f"{query_type} {table_name}" self.query_patterns[query_pattern] += 1 # 记录慢查询 query_info = { "method": method, "path": path, "duration": duration, "timestamp": datetime.datetime.now().isoformat(), "query_type": query_type, "table": table_name } # 如果在调试模式下,添加SQL和参数信息 if sql: query_info["sql"] = ' '.join(sql.split())[:500] # 截断并移除多余空白 if params and isinstance(params, (dict, list, tuple)): try: # 尝试安全地序列化参数 param_str = json.dumps(params) if len(param_str) <= 500: # 限制参数长度 query_info["params"] = param_str except: pass self.slow_queries.append(query_info) def record_request(self, method, path, duration): """记录请求统计""" with self.lock: key = f"{method} {path}" self.endpoint_stats[key]["count"] += 1 self.endpoint_stats[key]["total_time"] += duration self.endpoint_stats[key]["max_time"] = max(self.endpoint_stats[key]["max_time"], duration) # 记录每小时请求数 hour = datetime.datetime.now().strftime("%Y-%m-%d %H:00") self.hourly_requests[hour] += 1 def get_stats(self): """获取统计信息""" with self.lock: # 计算平均响应时间 avg_times = {} for endpoint, stats in self.endpoint_stats.items(): if stats["count"] > 0: avg_times[endpoint] = stats["total_time"] / stats["count"] # 按平均响应时间排序的前10个端点 sorted_endpoints = sorted( avg_times.items(), key=lambda x: x[1], reverse=True )[:10] # 获取最常访问的表(前10个) top_tables = self.table_access_count.most_common(10) # 获取最常见的查询模式(前10个) top_query_patterns = self.query_patterns.most_common(10) return { "slow_queries": list(self.slow_queries), "top_slow_endpoints": sorted_endpoints, "hourly_requests": dict(self.hourly_requests), "top_accessed_tables": top_tables, "top_query_patterns": top_query_patterns } # 全局请求上下文 request_context = threading.local() class DBConnectionMonitor: """数据库连接监控工具""" # 初始化统计对象 stats = DBStats() @staticmethod def set_request_context(request): """设置当前请求上下文""" request_context.method = request.method request_context.path = request.url.path @staticmethod def get_request_context(): """获取当前请求上下文""" return getattr(request_context, 'method', 'UNKNOWN'), getattr(request_context, 'path', 'UNKNOWN') @staticmethod def record_slow_query(duration, sql=None, params=None): """记录慢查询""" method, path = DBConnectionMonitor.get_request_context() DBConnectionMonitor.stats.record_slow_query(method, path, duration, sql, params) @staticmethod def get_connection_stats(): """获取当前连接池状态""" pool = engine.pool return { "pool_size": pool.size(), "checkedin": pool.checkedin(), "checkedout": pool.checkedout(), "overflow": pool.overflow(), } @staticmethod def log_connection_stats(): """记录当前连接池状态""" stats = DBConnectionMonitor.get_connection_stats() logger.info(f"DB Connection Pool Stats: {stats}") return stats @staticmethod def get_all_stats(): """获取所有统计信息""" return { "connection_pool": DBConnectionMonitor.get_connection_stats(), "performance_stats": DBConnectionMonitor.stats.get_stats() } class DBMonitorMiddleware(BaseHTTPMiddleware): """数据库连接监控中间件,定期记录连接池状态""" def __init__(self, app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0): super().__init__(app) self.log_interval = log_interval self.request_count = 0 self.slow_query_threshold = slow_query_threshold async def dispatch(self, request: Request, call_next): self.request_count += 1 # 设置请求上下文 DBConnectionMonitor.set_request_context(request) # 每处理N个请求记录一次连接池状态 if self.request_count % self.log_interval == 0: DBConnectionMonitor.log_connection_stats() # 记录请求处理时间 start_time = time.time() response = await call_next(request) process_time = time.time() - start_time # 记录请求统计 method = request.method path = request.url.path DBConnectionMonitor.stats.record_request(method, path, process_time) # 如果请求处理时间超过阈值,记录为慢查询 if process_time > self.slow_query_threshold: logger.warning(f"Slow request: {method} {path} took {process_time:.2f}s") DBConnectionMonitor.stats.record_slow_query(method, path, process_time) return response def setup_db_monitor(app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0): """设置数据库监控""" app.add_middleware( DBMonitorMiddleware, log_interval=log_interval, slow_query_threshold=slow_query_threshold ) @app.on_event("startup") async def startup_db_monitor(): logger.info("Starting DB connection monitoring") DBConnectionMonitor.log_connection_stats() @app.on_event("shutdown") async def shutdown_db_monitor(): logger.info("Final DB connection stats before shutdown") DBConnectionMonitor.log_connection_stats()