241 lines
8.7 KiB
Python
241 lines
8.7 KiB
Python
from app.models.database import engine
|
||
import logging
|
||
from fastapi import FastAPI
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
from starlette.requests import Request
|
||
import time
|
||
import threading
|
||
from collections import defaultdict, deque, Counter
|
||
import datetime
|
||
import json
|
||
import re
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# SQL查询模式识别正则表达式
|
||
SELECT_PATTERN = re.compile(r'SELECT\s+.*?\s+FROM\s+(\w+)', re.IGNORECASE)
|
||
INSERT_PATTERN = re.compile(r'INSERT\s+INTO\s+(\w+)', re.IGNORECASE)
|
||
UPDATE_PATTERN = re.compile(r'UPDATE\s+(\w+)\s+SET', re.IGNORECASE)
|
||
DELETE_PATTERN = re.compile(r'DELETE\s+FROM\s+(\w+)', re.IGNORECASE)
|
||
|
||
def extract_table_name(sql):
|
||
"""从SQL语句中提取表名"""
|
||
for pattern in [SELECT_PATTERN, INSERT_PATTERN, UPDATE_PATTERN, DELETE_PATTERN]:
|
||
match = pattern.search(sql)
|
||
if match:
|
||
return match.group(1)
|
||
return "unknown"
|
||
|
||
class DBStats:
|
||
"""数据库统计信息"""
|
||
def __init__(self):
|
||
self.slow_queries = deque(maxlen=100) # 最多保存100条慢查询记录
|
||
self.endpoint_stats = defaultdict(lambda: {"count": 0, "total_time": 0, "max_time": 0})
|
||
self.hourly_requests = defaultdict(int)
|
||
self.table_access_count = Counter() # 表访问计数
|
||
self.query_patterns = Counter() # 查询模式计数
|
||
self.lock = threading.Lock()
|
||
|
||
def record_slow_query(self, method, path, duration, sql=None, params=None):
|
||
"""记录慢查询"""
|
||
with self.lock:
|
||
# 提取表名和查询类型
|
||
table_name = "unknown"
|
||
query_type = "unknown"
|
||
|
||
if sql:
|
||
# 确定查询类型
|
||
if sql.strip().upper().startswith("SELECT"):
|
||
query_type = "SELECT"
|
||
elif sql.strip().upper().startswith("INSERT"):
|
||
query_type = "INSERT"
|
||
elif sql.strip().upper().startswith("UPDATE"):
|
||
query_type = "UPDATE"
|
||
elif sql.strip().upper().startswith("DELETE"):
|
||
query_type = "DELETE"
|
||
|
||
# 提取表名
|
||
table_name = extract_table_name(sql)
|
||
|
||
# 更新表访问统计
|
||
self.table_access_count[table_name] += 1
|
||
|
||
# 更新查询模式统计
|
||
query_pattern = f"{query_type} {table_name}"
|
||
self.query_patterns[query_pattern] += 1
|
||
|
||
# 记录慢查询
|
||
query_info = {
|
||
"method": method,
|
||
"path": path,
|
||
"duration": duration,
|
||
"timestamp": datetime.datetime.now().isoformat(),
|
||
"query_type": query_type,
|
||
"table": table_name
|
||
}
|
||
|
||
# 如果在调试模式下,添加SQL和参数信息
|
||
if sql:
|
||
query_info["sql"] = ' '.join(sql.split())[:500] # 截断并移除多余空白
|
||
|
||
if params and isinstance(params, (dict, list, tuple)):
|
||
try:
|
||
# 尝试安全地序列化参数
|
||
param_str = json.dumps(params)
|
||
if len(param_str) <= 500: # 限制参数长度
|
||
query_info["params"] = param_str
|
||
except:
|
||
pass
|
||
|
||
self.slow_queries.append(query_info)
|
||
|
||
def record_request(self, method, path, duration):
|
||
"""记录请求统计"""
|
||
with self.lock:
|
||
key = f"{method} {path}"
|
||
self.endpoint_stats[key]["count"] += 1
|
||
self.endpoint_stats[key]["total_time"] += duration
|
||
self.endpoint_stats[key]["max_time"] = max(self.endpoint_stats[key]["max_time"], duration)
|
||
|
||
# 记录每小时请求数
|
||
hour = datetime.datetime.now().strftime("%Y-%m-%d %H:00")
|
||
self.hourly_requests[hour] += 1
|
||
|
||
def get_stats(self):
|
||
"""获取统计信息"""
|
||
with self.lock:
|
||
# 计算平均响应时间
|
||
avg_times = {}
|
||
for endpoint, stats in self.endpoint_stats.items():
|
||
if stats["count"] > 0:
|
||
avg_times[endpoint] = stats["total_time"] / stats["count"]
|
||
|
||
# 按平均响应时间排序的前10个端点
|
||
sorted_endpoints = sorted(
|
||
avg_times.items(),
|
||
key=lambda x: x[1],
|
||
reverse=True
|
||
)[:10]
|
||
|
||
# 获取最常访问的表(前10个)
|
||
top_tables = self.table_access_count.most_common(10)
|
||
|
||
# 获取最常见的查询模式(前10个)
|
||
top_query_patterns = self.query_patterns.most_common(10)
|
||
|
||
return {
|
||
"slow_queries": list(self.slow_queries),
|
||
"top_slow_endpoints": sorted_endpoints,
|
||
"hourly_requests": dict(self.hourly_requests),
|
||
"top_accessed_tables": top_tables,
|
||
"top_query_patterns": top_query_patterns
|
||
}
|
||
|
||
|
||
# 全局请求上下文
|
||
request_context = threading.local()
|
||
|
||
class DBConnectionMonitor:
|
||
"""数据库连接监控工具"""
|
||
|
||
# 初始化统计对象
|
||
stats = DBStats()
|
||
|
||
@staticmethod
|
||
def set_request_context(request):
|
||
"""设置当前请求上下文"""
|
||
request_context.method = request.method
|
||
request_context.path = request.url.path
|
||
|
||
@staticmethod
|
||
def get_request_context():
|
||
"""获取当前请求上下文"""
|
||
return getattr(request_context, 'method', 'UNKNOWN'), getattr(request_context, 'path', 'UNKNOWN')
|
||
|
||
@staticmethod
|
||
def record_slow_query(duration, sql=None, params=None):
|
||
"""记录慢查询"""
|
||
method, path = DBConnectionMonitor.get_request_context()
|
||
DBConnectionMonitor.stats.record_slow_query(method, path, duration, sql, params)
|
||
|
||
@staticmethod
|
||
def get_connection_stats():
|
||
"""获取当前连接池状态"""
|
||
pool = engine.pool
|
||
return {
|
||
"pool_size": pool.size(),
|
||
"checkedin": pool.checkedin(),
|
||
"checkedout": pool.checkedout(),
|
||
"overflow": pool.overflow(),
|
||
}
|
||
|
||
@staticmethod
|
||
def log_connection_stats():
|
||
"""记录当前连接池状态"""
|
||
stats = DBConnectionMonitor.get_connection_stats()
|
||
logger.info(f"DB Connection Pool Stats: {stats}")
|
||
return stats
|
||
|
||
@staticmethod
|
||
def get_all_stats():
|
||
"""获取所有统计信息"""
|
||
return {
|
||
"connection_pool": DBConnectionMonitor.get_connection_stats(),
|
||
"performance_stats": DBConnectionMonitor.stats.get_stats()
|
||
}
|
||
|
||
|
||
class DBMonitorMiddleware(BaseHTTPMiddleware):
|
||
"""数据库连接监控中间件,定期记录连接池状态"""
|
||
|
||
def __init__(self, app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0):
|
||
super().__init__(app)
|
||
self.log_interval = log_interval
|
||
self.request_count = 0
|
||
self.slow_query_threshold = slow_query_threshold
|
||
|
||
async def dispatch(self, request: Request, call_next):
|
||
self.request_count += 1
|
||
|
||
# 设置请求上下文
|
||
DBConnectionMonitor.set_request_context(request)
|
||
|
||
# 每处理N个请求记录一次连接池状态
|
||
if self.request_count % self.log_interval == 0:
|
||
DBConnectionMonitor.log_connection_stats()
|
||
|
||
# 记录请求处理时间
|
||
start_time = time.time()
|
||
response = await call_next(request)
|
||
process_time = time.time() - start_time
|
||
|
||
# 记录请求统计
|
||
method = request.method
|
||
path = request.url.path
|
||
DBConnectionMonitor.stats.record_request(method, path, process_time)
|
||
|
||
# 如果请求处理时间超过阈值,记录为慢查询
|
||
if process_time > self.slow_query_threshold:
|
||
logger.warning(f"Slow request: {method} {path} took {process_time:.2f}s")
|
||
DBConnectionMonitor.stats.record_slow_query(method, path, process_time)
|
||
|
||
return response
|
||
|
||
|
||
def setup_db_monitor(app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0):
|
||
"""设置数据库监控"""
|
||
app.add_middleware(
|
||
DBMonitorMiddleware,
|
||
log_interval=log_interval,
|
||
slow_query_threshold=slow_query_threshold
|
||
)
|
||
|
||
@app.on_event("startup")
|
||
async def startup_db_monitor():
|
||
logger.info("Starting DB connection monitoring")
|
||
DBConnectionMonitor.log_connection_stats()
|
||
|
||
@app.on_event("shutdown")
|
||
async def shutdown_db_monitor():
|
||
logger.info("Final DB connection stats before shutdown")
|
||
DBConnectionMonitor.log_connection_stats() |