149 lines
5.1 KiB
Python
149 lines
5.1 KiB
Python
from app.models.database import engine
|
|
import logging
|
|
from fastapi import FastAPI
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import Request
|
|
import time
|
|
import threading
|
|
from collections import defaultdict, deque
|
|
import datetime
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class DBStats:
|
|
"""数据库统计信息"""
|
|
def __init__(self):
|
|
self.slow_queries = deque(maxlen=100) # 最多保存100条慢查询记录
|
|
self.endpoint_stats = defaultdict(lambda: {"count": 0, "total_time": 0, "max_time": 0})
|
|
self.hourly_requests = defaultdict(int)
|
|
self.lock = threading.Lock()
|
|
|
|
def record_slow_query(self, method, path, duration):
|
|
"""记录慢查询"""
|
|
with self.lock:
|
|
self.slow_queries.append({
|
|
"method": method,
|
|
"path": path,
|
|
"duration": duration,
|
|
"timestamp": datetime.datetime.now().isoformat()
|
|
})
|
|
|
|
def record_request(self, method, path, duration):
|
|
"""记录请求统计"""
|
|
with self.lock:
|
|
key = f"{method} {path}"
|
|
self.endpoint_stats[key]["count"] += 1
|
|
self.endpoint_stats[key]["total_time"] += duration
|
|
self.endpoint_stats[key]["max_time"] = max(self.endpoint_stats[key]["max_time"], duration)
|
|
|
|
# 记录每小时请求数
|
|
hour = datetime.datetime.now().strftime("%Y-%m-%d %H:00")
|
|
self.hourly_requests[hour] += 1
|
|
|
|
def get_stats(self):
|
|
"""获取统计信息"""
|
|
with self.lock:
|
|
# 计算平均响应时间
|
|
avg_times = {}
|
|
for endpoint, stats in self.endpoint_stats.items():
|
|
if stats["count"] > 0:
|
|
avg_times[endpoint] = stats["total_time"] / stats["count"]
|
|
|
|
# 按平均响应时间排序的前10个端点
|
|
sorted_endpoints = sorted(
|
|
avg_times.items(),
|
|
key=lambda x: x[1],
|
|
reverse=True
|
|
)[:10]
|
|
|
|
return {
|
|
"slow_queries": list(self.slow_queries),
|
|
"top_slow_endpoints": sorted_endpoints,
|
|
"hourly_requests": dict(self.hourly_requests)
|
|
}
|
|
|
|
|
|
class DBConnectionMonitor:
|
|
"""数据库连接监控工具"""
|
|
|
|
# 初始化统计对象
|
|
stats = DBStats()
|
|
|
|
@staticmethod
|
|
def get_connection_stats():
|
|
"""获取当前连接池状态"""
|
|
pool = engine.pool
|
|
return {
|
|
"pool_size": pool.size(),
|
|
"checkedin": pool.checkedin(),
|
|
"checkedout": pool.checkedout(),
|
|
"overflow": pool.overflow(),
|
|
}
|
|
|
|
@staticmethod
|
|
def log_connection_stats():
|
|
"""记录当前连接池状态"""
|
|
stats = DBConnectionMonitor.get_connection_stats()
|
|
logger.info(f"DB Connection Pool Stats: {stats}")
|
|
return stats
|
|
|
|
@staticmethod
|
|
def get_all_stats():
|
|
"""获取所有统计信息"""
|
|
return {
|
|
"connection_pool": DBConnectionMonitor.get_connection_stats(),
|
|
"performance_stats": DBConnectionMonitor.stats.get_stats()
|
|
}
|
|
|
|
|
|
class DBMonitorMiddleware(BaseHTTPMiddleware):
|
|
"""数据库连接监控中间件,定期记录连接池状态"""
|
|
|
|
def __init__(self, app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0):
|
|
super().__init__(app)
|
|
self.log_interval = log_interval
|
|
self.request_count = 0
|
|
self.slow_query_threshold = slow_query_threshold
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
self.request_count += 1
|
|
|
|
# 每处理N个请求记录一次连接池状态
|
|
if self.request_count % self.log_interval == 0:
|
|
DBConnectionMonitor.log_connection_stats()
|
|
|
|
# 记录请求处理时间
|
|
start_time = time.time()
|
|
response = await call_next(request)
|
|
process_time = time.time() - start_time
|
|
|
|
# 记录请求统计
|
|
method = request.method
|
|
path = request.url.path
|
|
DBConnectionMonitor.stats.record_request(method, path, process_time)
|
|
|
|
# 如果请求处理时间超过阈值,记录为慢查询
|
|
if process_time > self.slow_query_threshold:
|
|
logger.warning(f"Slow request: {method} {path} took {process_time:.2f}s")
|
|
DBConnectionMonitor.stats.record_slow_query(method, path, process_time)
|
|
|
|
return response
|
|
|
|
|
|
def setup_db_monitor(app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0):
|
|
"""设置数据库监控"""
|
|
app.add_middleware(
|
|
DBMonitorMiddleware,
|
|
log_interval=log_interval,
|
|
slow_query_threshold=slow_query_threshold
|
|
)
|
|
|
|
@app.on_event("startup")
|
|
async def startup_db_monitor():
|
|
logger.info("Starting DB connection monitoring")
|
|
DBConnectionMonitor.log_connection_stats()
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_db_monitor():
|
|
logger.info("Final DB connection stats before shutdown")
|
|
DBConnectionMonitor.log_connection_stats() |