This commit is contained in:
aaron 2025-03-10 21:15:01 +08:00
parent 595195dedc
commit 8e578953cc
3 changed files with 195 additions and 9 deletions

View File

@ -4,6 +4,7 @@ from sqlalchemy import text
from app.models.database import get_db, get_active_sessions_count, get_long_running_sessions from app.models.database import get_db, get_active_sessions_count, get_long_running_sessions
from app.core.response import success_response, ResponseModel from app.core.response import success_response, ResponseModel
from app.core.db_monitor import DBConnectionMonitor from app.core.db_monitor import DBConnectionMonitor
from typing import Optional, List
router = APIRouter() router = APIRouter()
@ -38,14 +39,42 @@ async def health_check(db: Session = Depends(get_db)):
async def performance_stats( async def performance_stats(
include_slow_queries: bool = Query(False, description="是否包含慢查询记录"), include_slow_queries: bool = Query(False, description="是否包含慢查询记录"),
include_long_sessions: bool = Query(False, description="是否包含长时间运行的会话详情"), include_long_sessions: bool = Query(False, description="是否包含长时间运行的会话详情"),
min_duration: Optional[float] = Query(None, description="慢查询最小持续时间(秒)"),
table_filter: Optional[str] = Query(None, description="按表名筛选慢查询"),
query_type: Optional[str] = Query(None, description="按查询类型筛选SELECT, INSERT, UPDATE, DELETE"),
limit: int = Query(20, ge=1, le=100, description="返回的慢查询数量限制"),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""获取详细的性能统计信息""" """获取详细的性能统计信息"""
# 获取所有统计信息 # 获取所有统计信息
all_stats = DBConnectionMonitor.get_all_stats() all_stats = DBConnectionMonitor.get_all_stats()
# 如果不需要包含慢查询记录,则移除它们以减少响应大小 # 处理慢查询记录
if not include_slow_queries and "slow_queries" in all_stats["performance_stats"]: if include_slow_queries and "slow_queries" in all_stats["performance_stats"]:
# 获取原始慢查询列表
slow_queries = all_stats["performance_stats"]["slow_queries"]
# 应用过滤条件
filtered_queries = slow_queries
if min_duration is not None:
filtered_queries = [q for q in filtered_queries if q.get("duration", 0) >= min_duration]
if table_filter:
filtered_queries = [q for q in filtered_queries if table_filter.lower() in q.get("table", "").lower()]
if query_type:
filtered_queries = [q for q in filtered_queries if q.get("query_type", "").upper() == query_type.upper()]
# 按持续时间排序并限制数量
sorted_queries = sorted(filtered_queries, key=lambda x: x.get("duration", 0), reverse=True)[:limit]
# 更新统计信息
all_stats["performance_stats"]["slow_queries"] = sorted_queries
all_stats["performance_stats"]["slow_queries_total_count"] = len(slow_queries)
all_stats["performance_stats"]["slow_queries_filtered_count"] = len(filtered_queries)
elif "slow_queries" in all_stats["performance_stats"]:
# 如果不包含慢查询记录,只返回计数
all_stats["performance_stats"]["slow_queries_count"] = len(all_stats["performance_stats"]["slow_queries"]) all_stats["performance_stats"]["slow_queries_count"] = len(all_stats["performance_stats"]["slow_queries"])
del all_stats["performance_stats"]["slow_queries"] del all_stats["performance_stats"]["slow_queries"]
@ -60,4 +89,52 @@ async def performance_stats(
else: else:
all_stats["sessions"]["long_running_count"] = len(get_long_running_sessions(threshold_seconds=30)) all_stats["sessions"]["long_running_count"] = len(get_long_running_sessions(threshold_seconds=30))
return success_response(data=all_stats) return success_response(data=all_stats)
@router.get("/slow-queries", response_model=ResponseModel)
async def get_slow_queries(
min_duration: Optional[float] = Query(0.5, description="最小持续时间(秒)"),
table_filter: Optional[str] = Query(None, description="按表名筛选"),
query_type: Optional[str] = Query(None, description="按查询类型筛选SELECT, INSERT, UPDATE, DELETE"),
path_filter: Optional[str] = Query(None, description="按API路径筛选"),
limit: int = Query(50, ge=1, le=100, description="返回的记录数量限制"),
db: Session = Depends(get_db)
):
"""获取慢查询记录,支持多种过滤条件"""
# 获取所有慢查询
all_stats = DBConnectionMonitor.get_all_stats()
slow_queries = all_stats["performance_stats"].get("slow_queries", [])
# 应用过滤条件
filtered_queries = slow_queries
if min_duration is not None:
filtered_queries = [q for q in filtered_queries if q.get("duration", 0) >= min_duration]
if table_filter:
filtered_queries = [q for q in filtered_queries if table_filter.lower() in q.get("table", "").lower()]
if query_type:
filtered_queries = [q for q in filtered_queries if q.get("query_type", "").upper() == query_type.upper()]
if path_filter:
filtered_queries = [q for q in filtered_queries if path_filter.lower() in q.get("path", "").lower()]
# 按持续时间排序并限制数量
sorted_queries = sorted(filtered_queries, key=lambda x: x.get("duration", 0), reverse=True)[:limit]
# 计算统计信息
stats = {
"total_count": len(slow_queries),
"filtered_count": len(filtered_queries),
"displayed_count": len(sorted_queries),
"queries": sorted_queries
}
# 如果有查询,计算平均持续时间
if sorted_queries:
stats["avg_duration"] = sum(q.get("duration", 0) for q in sorted_queries) / len(sorted_queries)
stats["max_duration"] = max(q.get("duration", 0) for q in sorted_queries)
stats["min_duration"] = min(q.get("duration", 0) for q in sorted_queries)
return success_response(data=stats)

View File

@ -5,28 +5,89 @@ from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request from starlette.requests import Request
import time import time
import threading import threading
from collections import defaultdict, deque from collections import defaultdict, deque, Counter
import datetime import datetime
import json
import re
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# SQL查询模式识别正则表达式
SELECT_PATTERN = re.compile(r'SELECT\s+.*?\s+FROM\s+(\w+)', re.IGNORECASE)
INSERT_PATTERN = re.compile(r'INSERT\s+INTO\s+(\w+)', re.IGNORECASE)
UPDATE_PATTERN = re.compile(r'UPDATE\s+(\w+)\s+SET', re.IGNORECASE)
DELETE_PATTERN = re.compile(r'DELETE\s+FROM\s+(\w+)', re.IGNORECASE)
def extract_table_name(sql):
"""从SQL语句中提取表名"""
for pattern in [SELECT_PATTERN, INSERT_PATTERN, UPDATE_PATTERN, DELETE_PATTERN]:
match = pattern.search(sql)
if match:
return match.group(1)
return "unknown"
class DBStats: class DBStats:
"""数据库统计信息""" """数据库统计信息"""
def __init__(self): def __init__(self):
self.slow_queries = deque(maxlen=100) # 最多保存100条慢查询记录 self.slow_queries = deque(maxlen=100) # 最多保存100条慢查询记录
self.endpoint_stats = defaultdict(lambda: {"count": 0, "total_time": 0, "max_time": 0}) self.endpoint_stats = defaultdict(lambda: {"count": 0, "total_time": 0, "max_time": 0})
self.hourly_requests = defaultdict(int) self.hourly_requests = defaultdict(int)
self.table_access_count = Counter() # 表访问计数
self.query_patterns = Counter() # 查询模式计数
self.lock = threading.Lock() self.lock = threading.Lock()
def record_slow_query(self, method, path, duration): def record_slow_query(self, method, path, duration, sql=None, params=None):
"""记录慢查询""" """记录慢查询"""
with self.lock: with self.lock:
self.slow_queries.append({ # 提取表名和查询类型
table_name = "unknown"
query_type = "unknown"
if sql:
# 确定查询类型
if sql.strip().upper().startswith("SELECT"):
query_type = "SELECT"
elif sql.strip().upper().startswith("INSERT"):
query_type = "INSERT"
elif sql.strip().upper().startswith("UPDATE"):
query_type = "UPDATE"
elif sql.strip().upper().startswith("DELETE"):
query_type = "DELETE"
# 提取表名
table_name = extract_table_name(sql)
# 更新表访问统计
self.table_access_count[table_name] += 1
# 更新查询模式统计
query_pattern = f"{query_type} {table_name}"
self.query_patterns[query_pattern] += 1
# 记录慢查询
query_info = {
"method": method, "method": method,
"path": path, "path": path,
"duration": duration, "duration": duration,
"timestamp": datetime.datetime.now().isoformat() "timestamp": datetime.datetime.now().isoformat(),
}) "query_type": query_type,
"table": table_name
}
# 如果在调试模式下添加SQL和参数信息
if sql:
query_info["sql"] = ' '.join(sql.split())[:500] # 截断并移除多余空白
if params and isinstance(params, (dict, list, tuple)):
try:
# 尝试安全地序列化参数
param_str = json.dumps(params)
if len(param_str) <= 500: # 限制参数长度
query_info["params"] = param_str
except:
pass
self.slow_queries.append(query_info)
def record_request(self, method, path, duration): def record_request(self, method, path, duration):
"""记录请求统计""" """记录请求统计"""
@ -56,19 +117,47 @@ class DBStats:
reverse=True reverse=True
)[:10] )[:10]
# 获取最常访问的表前10个
top_tables = self.table_access_count.most_common(10)
# 获取最常见的查询模式前10个
top_query_patterns = self.query_patterns.most_common(10)
return { return {
"slow_queries": list(self.slow_queries), "slow_queries": list(self.slow_queries),
"top_slow_endpoints": sorted_endpoints, "top_slow_endpoints": sorted_endpoints,
"hourly_requests": dict(self.hourly_requests) "hourly_requests": dict(self.hourly_requests),
"top_accessed_tables": top_tables,
"top_query_patterns": top_query_patterns
} }
# 全局请求上下文
request_context = threading.local()
class DBConnectionMonitor: class DBConnectionMonitor:
"""数据库连接监控工具""" """数据库连接监控工具"""
# 初始化统计对象 # 初始化统计对象
stats = DBStats() stats = DBStats()
@staticmethod
def set_request_context(request):
"""设置当前请求上下文"""
request_context.method = request.method
request_context.path = request.url.path
@staticmethod
def get_request_context():
"""获取当前请求上下文"""
return getattr(request_context, 'method', 'UNKNOWN'), getattr(request_context, 'path', 'UNKNOWN')
@staticmethod
def record_slow_query(duration, sql=None, params=None):
"""记录慢查询"""
method, path = DBConnectionMonitor.get_request_context()
DBConnectionMonitor.stats.record_slow_query(method, path, duration, sql, params)
@staticmethod @staticmethod
def get_connection_stats(): def get_connection_stats():
"""获取当前连接池状态""" """获取当前连接池状态"""
@ -108,6 +197,9 @@ class DBMonitorMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
self.request_count += 1 self.request_count += 1
# 设置请求上下文
DBConnectionMonitor.set_request_context(request)
# 每处理N个请求记录一次连接池状态 # 每处理N个请求记录一次连接池状态
if self.request_count % self.log_interval == 0: if self.request_count % self.log_interval == 0:
DBConnectionMonitor.log_connection_stats() DBConnectionMonitor.log_connection_stats()

View File

@ -37,11 +37,19 @@ Base = declarative_base()
active_sessions = {} active_sessions = {}
session_lock = threading.Lock() session_lock = threading.Lock()
# 导入监控工具(延迟导入以避免循环依赖)
def get_db_monitor():
from app.core.db_monitor import DBConnectionMonitor
return DBConnectionMonitor
# 添加事件监听器,记录长时间运行的查询 # 添加事件监听器,记录长时间运行的查询
@event.listens_for(engine, "before_cursor_execute") @event.listens_for(engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
"""在执行SQL查询前记录时间""" """在执行SQL查询前记录时间"""
conn.info.setdefault('query_start_time', []).append(time.time()) conn.info.setdefault('query_start_time', []).append(time.time())
conn.info.setdefault('query_statement', []).append(statement)
conn.info.setdefault('query_parameters', []).append(parameters)
if settings.DEBUG: if settings.DEBUG:
# 安全地记录SQL语句避免敏感信息泄露 # 安全地记录SQL语句避免敏感信息泄露
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白 safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
@ -51,12 +59,21 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany): def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
"""在执行SQL查询后计算耗时并记录慢查询""" """在执行SQL查询后计算耗时并记录慢查询"""
total = time.time() - conn.info['query_start_time'].pop(-1) total = time.time() - conn.info['query_start_time'].pop(-1)
statement = conn.info['query_statement'].pop(-1)
parameters = conn.info['query_parameters'].pop(-1)
# 记录慢查询 # 记录慢查询
if total > 0.5: # 记录超过0.5秒的查询 if total > 0.5: # 记录超过0.5秒的查询
# 安全地记录SQL语句避免敏感信息泄露 # 安全地记录SQL语句避免敏感信息泄露
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白 safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
logger.warning(f"慢查询 ({total:.2f}s): {safe_statement}...") logger.warning(f"慢查询 ({total:.2f}s): {safe_statement}...")
# 记录到监控系统
try:
monitor = get_db_monitor()
monitor.record_slow_query(total, statement, parameters)
except Exception as e:
logger.error(f"记录慢查询失败: {str(e)}")
# 依赖项 # 依赖项
def get_db(): def get_db():