deliveryman-api/app/core/db_monitor.py
2025-03-10 20:18:52 +08:00

149 lines
5.1 KiB
Python

from app.models.database import engine
import logging
from fastapi import FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
import time
import threading
from collections import defaultdict, deque
import datetime
logger = logging.getLogger(__name__)
class DBStats:
"""数据库统计信息"""
def __init__(self):
self.slow_queries = deque(maxlen=100) # 最多保存100条慢查询记录
self.endpoint_stats = defaultdict(lambda: {"count": 0, "total_time": 0, "max_time": 0})
self.hourly_requests = defaultdict(int)
self.lock = threading.Lock()
def record_slow_query(self, method, path, duration):
"""记录慢查询"""
with self.lock:
self.slow_queries.append({
"method": method,
"path": path,
"duration": duration,
"timestamp": datetime.datetime.now().isoformat()
})
def record_request(self, method, path, duration):
"""记录请求统计"""
with self.lock:
key = f"{method} {path}"
self.endpoint_stats[key]["count"] += 1
self.endpoint_stats[key]["total_time"] += duration
self.endpoint_stats[key]["max_time"] = max(self.endpoint_stats[key]["max_time"], duration)
# 记录每小时请求数
hour = datetime.datetime.now().strftime("%Y-%m-%d %H:00")
self.hourly_requests[hour] += 1
def get_stats(self):
"""获取统计信息"""
with self.lock:
# 计算平均响应时间
avg_times = {}
for endpoint, stats in self.endpoint_stats.items():
if stats["count"] > 0:
avg_times[endpoint] = stats["total_time"] / stats["count"]
# 按平均响应时间排序的前10个端点
sorted_endpoints = sorted(
avg_times.items(),
key=lambda x: x[1],
reverse=True
)[:10]
return {
"slow_queries": list(self.slow_queries),
"top_slow_endpoints": sorted_endpoints,
"hourly_requests": dict(self.hourly_requests)
}
class DBConnectionMonitor:
"""数据库连接监控工具"""
# 初始化统计对象
stats = DBStats()
@staticmethod
def get_connection_stats():
"""获取当前连接池状态"""
pool = engine.pool
return {
"pool_size": pool.size(),
"checkedin": pool.checkedin(),
"checkedout": pool.checkedout(),
"overflow": pool.overflow(),
}
@staticmethod
def log_connection_stats():
"""记录当前连接池状态"""
stats = DBConnectionMonitor.get_connection_stats()
logger.info(f"DB Connection Pool Stats: {stats}")
return stats
@staticmethod
def get_all_stats():
"""获取所有统计信息"""
return {
"connection_pool": DBConnectionMonitor.get_connection_stats(),
"performance_stats": DBConnectionMonitor.stats.get_stats()
}
class DBMonitorMiddleware(BaseHTTPMiddleware):
"""数据库连接监控中间件,定期记录连接池状态"""
def __init__(self, app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0):
super().__init__(app)
self.log_interval = log_interval
self.request_count = 0
self.slow_query_threshold = slow_query_threshold
async def dispatch(self, request: Request, call_next):
self.request_count += 1
# 每处理N个请求记录一次连接池状态
if self.request_count % self.log_interval == 0:
DBConnectionMonitor.log_connection_stats()
# 记录请求处理时间
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
# 记录请求统计
method = request.method
path = request.url.path
DBConnectionMonitor.stats.record_request(method, path, process_time)
# 如果请求处理时间超过阈值,记录为慢查询
if process_time > self.slow_query_threshold:
logger.warning(f"Slow request: {method} {path} took {process_time:.2f}s")
DBConnectionMonitor.stats.record_slow_query(method, path, process_time)
return response
def setup_db_monitor(app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0):
"""设置数据库监控"""
app.add_middleware(
DBMonitorMiddleware,
log_interval=log_interval,
slow_query_threshold=slow_query_threshold
)
@app.on_event("startup")
async def startup_db_monitor():
logger.info("Starting DB connection monitoring")
DBConnectionMonitor.log_connection_stats()
@app.on_event("shutdown")
async def shutdown_db_monitor():
logger.info("Final DB connection stats before shutdown")
DBConnectionMonitor.log_connection_stats()