deliveryman-api/app/core/db_monitor.py
2025-03-10 21:15:01 +08:00

241 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from app.models.database import engine
import logging
from fastapi import FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
import time
import threading
from collections import defaultdict, deque, Counter
import datetime
import json
import re
logger = logging.getLogger(__name__)
# SQL查询模式识别正则表达式
SELECT_PATTERN = re.compile(r'SELECT\s+.*?\s+FROM\s+(\w+)', re.IGNORECASE)
INSERT_PATTERN = re.compile(r'INSERT\s+INTO\s+(\w+)', re.IGNORECASE)
UPDATE_PATTERN = re.compile(r'UPDATE\s+(\w+)\s+SET', re.IGNORECASE)
DELETE_PATTERN = re.compile(r'DELETE\s+FROM\s+(\w+)', re.IGNORECASE)
def extract_table_name(sql):
"""从SQL语句中提取表名"""
for pattern in [SELECT_PATTERN, INSERT_PATTERN, UPDATE_PATTERN, DELETE_PATTERN]:
match = pattern.search(sql)
if match:
return match.group(1)
return "unknown"
class DBStats:
"""数据库统计信息"""
def __init__(self):
self.slow_queries = deque(maxlen=100) # 最多保存100条慢查询记录
self.endpoint_stats = defaultdict(lambda: {"count": 0, "total_time": 0, "max_time": 0})
self.hourly_requests = defaultdict(int)
self.table_access_count = Counter() # 表访问计数
self.query_patterns = Counter() # 查询模式计数
self.lock = threading.Lock()
def record_slow_query(self, method, path, duration, sql=None, params=None):
"""记录慢查询"""
with self.lock:
# 提取表名和查询类型
table_name = "unknown"
query_type = "unknown"
if sql:
# 确定查询类型
if sql.strip().upper().startswith("SELECT"):
query_type = "SELECT"
elif sql.strip().upper().startswith("INSERT"):
query_type = "INSERT"
elif sql.strip().upper().startswith("UPDATE"):
query_type = "UPDATE"
elif sql.strip().upper().startswith("DELETE"):
query_type = "DELETE"
# 提取表名
table_name = extract_table_name(sql)
# 更新表访问统计
self.table_access_count[table_name] += 1
# 更新查询模式统计
query_pattern = f"{query_type} {table_name}"
self.query_patterns[query_pattern] += 1
# 记录慢查询
query_info = {
"method": method,
"path": path,
"duration": duration,
"timestamp": datetime.datetime.now().isoformat(),
"query_type": query_type,
"table": table_name
}
# 如果在调试模式下添加SQL和参数信息
if sql:
query_info["sql"] = ' '.join(sql.split())[:500] # 截断并移除多余空白
if params and isinstance(params, (dict, list, tuple)):
try:
# 尝试安全地序列化参数
param_str = json.dumps(params)
if len(param_str) <= 500: # 限制参数长度
query_info["params"] = param_str
except:
pass
self.slow_queries.append(query_info)
def record_request(self, method, path, duration):
"""记录请求统计"""
with self.lock:
key = f"{method} {path}"
self.endpoint_stats[key]["count"] += 1
self.endpoint_stats[key]["total_time"] += duration
self.endpoint_stats[key]["max_time"] = max(self.endpoint_stats[key]["max_time"], duration)
# 记录每小时请求数
hour = datetime.datetime.now().strftime("%Y-%m-%d %H:00")
self.hourly_requests[hour] += 1
def get_stats(self):
"""获取统计信息"""
with self.lock:
# 计算平均响应时间
avg_times = {}
for endpoint, stats in self.endpoint_stats.items():
if stats["count"] > 0:
avg_times[endpoint] = stats["total_time"] / stats["count"]
# 按平均响应时间排序的前10个端点
sorted_endpoints = sorted(
avg_times.items(),
key=lambda x: x[1],
reverse=True
)[:10]
# 获取最常访问的表前10个
top_tables = self.table_access_count.most_common(10)
# 获取最常见的查询模式前10个
top_query_patterns = self.query_patterns.most_common(10)
return {
"slow_queries": list(self.slow_queries),
"top_slow_endpoints": sorted_endpoints,
"hourly_requests": dict(self.hourly_requests),
"top_accessed_tables": top_tables,
"top_query_patterns": top_query_patterns
}
# 全局请求上下文
request_context = threading.local()
class DBConnectionMonitor:
"""数据库连接监控工具"""
# 初始化统计对象
stats = DBStats()
@staticmethod
def set_request_context(request):
"""设置当前请求上下文"""
request_context.method = request.method
request_context.path = request.url.path
@staticmethod
def get_request_context():
"""获取当前请求上下文"""
return getattr(request_context, 'method', 'UNKNOWN'), getattr(request_context, 'path', 'UNKNOWN')
@staticmethod
def record_slow_query(duration, sql=None, params=None):
"""记录慢查询"""
method, path = DBConnectionMonitor.get_request_context()
DBConnectionMonitor.stats.record_slow_query(method, path, duration, sql, params)
@staticmethod
def get_connection_stats():
"""获取当前连接池状态"""
pool = engine.pool
return {
"pool_size": pool.size(),
"checkedin": pool.checkedin(),
"checkedout": pool.checkedout(),
"overflow": pool.overflow(),
}
@staticmethod
def log_connection_stats():
"""记录当前连接池状态"""
stats = DBConnectionMonitor.get_connection_stats()
logger.info(f"DB Connection Pool Stats: {stats}")
return stats
@staticmethod
def get_all_stats():
"""获取所有统计信息"""
return {
"connection_pool": DBConnectionMonitor.get_connection_stats(),
"performance_stats": DBConnectionMonitor.stats.get_stats()
}
class DBMonitorMiddleware(BaseHTTPMiddleware):
"""数据库连接监控中间件,定期记录连接池状态"""
def __init__(self, app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0):
super().__init__(app)
self.log_interval = log_interval
self.request_count = 0
self.slow_query_threshold = slow_query_threshold
async def dispatch(self, request: Request, call_next):
self.request_count += 1
# 设置请求上下文
DBConnectionMonitor.set_request_context(request)
# 每处理N个请求记录一次连接池状态
if self.request_count % self.log_interval == 0:
DBConnectionMonitor.log_connection_stats()
# 记录请求处理时间
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
# 记录请求统计
method = request.method
path = request.url.path
DBConnectionMonitor.stats.record_request(method, path, process_time)
# 如果请求处理时间超过阈值,记录为慢查询
if process_time > self.slow_query_threshold:
logger.warning(f"Slow request: {method} {path} took {process_time:.2f}s")
DBConnectionMonitor.stats.record_slow_query(method, path, process_time)
return response
def setup_db_monitor(app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0):
"""设置数据库监控"""
app.add_middleware(
DBMonitorMiddleware,
log_interval=log_interval,
slow_query_threshold=slow_query_threshold
)
@app.on_event("startup")
async def startup_db_monitor():
logger.info("Starting DB connection monitoring")
DBConnectionMonitor.log_connection_stats()
@app.on_event("shutdown")
async def shutdown_db_monitor():
logger.info("Final DB connection stats before shutdown")
DBConnectionMonitor.log_connection_stats()