From 8e578953ccda626e8ea2450a7ddd0d371bdd5fbf Mon Sep 17 00:00:00 2001 From: aaron <> Date: Mon, 10 Mar 2025 21:15:01 +0800 Subject: [PATCH] update --- app/api/endpoints/health.py | 83 ++++++++++++++++++++++++++-- app/core/db_monitor.py | 104 +++++++++++++++++++++++++++++++++--- app/models/database.py | 17 ++++++ 3 files changed, 195 insertions(+), 9 deletions(-) diff --git a/app/api/endpoints/health.py b/app/api/endpoints/health.py index d5baeff..86a4cae 100644 --- a/app/api/endpoints/health.py +++ b/app/api/endpoints/health.py @@ -4,6 +4,7 @@ from sqlalchemy import text from app.models.database import get_db, get_active_sessions_count, get_long_running_sessions from app.core.response import success_response, ResponseModel from app.core.db_monitor import DBConnectionMonitor +from typing import Optional, List router = APIRouter() @@ -38,14 +39,42 @@ async def health_check(db: Session = Depends(get_db)): async def performance_stats( include_slow_queries: bool = Query(False, description="是否包含慢查询记录"), include_long_sessions: bool = Query(False, description="是否包含长时间运行的会话详情"), + min_duration: Optional[float] = Query(None, description="慢查询最小持续时间(秒)"), + table_filter: Optional[str] = Query(None, description="按表名筛选慢查询"), + query_type: Optional[str] = Query(None, description="按查询类型筛选(SELECT, INSERT, UPDATE, DELETE)"), + limit: int = Query(20, ge=1, le=100, description="返回的慢查询数量限制"), db: Session = Depends(get_db) ): """获取详细的性能统计信息""" # 获取所有统计信息 all_stats = DBConnectionMonitor.get_all_stats() - # 如果不需要包含慢查询记录,则移除它们以减少响应大小 - if not include_slow_queries and "slow_queries" in all_stats["performance_stats"]: + # 处理慢查询记录 + if include_slow_queries and "slow_queries" in all_stats["performance_stats"]: + # 获取原始慢查询列表 + slow_queries = all_stats["performance_stats"]["slow_queries"] + + # 应用过滤条件 + filtered_queries = slow_queries + + if min_duration is not None: + filtered_queries = [q for q in filtered_queries if q.get("duration", 0) >= min_duration] + + if table_filter: + filtered_queries = [q for q in filtered_queries if table_filter.lower() in q.get("table", "").lower()] + + if query_type: + filtered_queries = [q for q in filtered_queries if q.get("query_type", "").upper() == query_type.upper()] + + # 按持续时间排序并限制数量 + sorted_queries = sorted(filtered_queries, key=lambda x: x.get("duration", 0), reverse=True)[:limit] + + # 更新统计信息 + all_stats["performance_stats"]["slow_queries"] = sorted_queries + all_stats["performance_stats"]["slow_queries_total_count"] = len(slow_queries) + all_stats["performance_stats"]["slow_queries_filtered_count"] = len(filtered_queries) + elif "slow_queries" in all_stats["performance_stats"]: + # 如果不包含慢查询记录,只返回计数 all_stats["performance_stats"]["slow_queries_count"] = len(all_stats["performance_stats"]["slow_queries"]) del all_stats["performance_stats"]["slow_queries"] @@ -60,4 +89,52 @@ async def performance_stats( else: all_stats["sessions"]["long_running_count"] = len(get_long_running_sessions(threshold_seconds=30)) - return success_response(data=all_stats) \ No newline at end of file + return success_response(data=all_stats) + +@router.get("/slow-queries", response_model=ResponseModel) +async def get_slow_queries( + min_duration: Optional[float] = Query(0.5, description="最小持续时间(秒)"), + table_filter: Optional[str] = Query(None, description="按表名筛选"), + query_type: Optional[str] = Query(None, description="按查询类型筛选(SELECT, INSERT, UPDATE, DELETE)"), + path_filter: Optional[str] = Query(None, description="按API路径筛选"), + limit: int = Query(50, ge=1, le=100, description="返回的记录数量限制"), + db: Session = Depends(get_db) +): + """获取慢查询记录,支持多种过滤条件""" + # 获取所有慢查询 + all_stats = DBConnectionMonitor.get_all_stats() + slow_queries = all_stats["performance_stats"].get("slow_queries", []) + + # 应用过滤条件 + filtered_queries = slow_queries + + if min_duration is not None: + filtered_queries = [q for q in filtered_queries if q.get("duration", 0) >= min_duration] + + if table_filter: + filtered_queries = [q for q in filtered_queries if table_filter.lower() in q.get("table", "").lower()] + + if query_type: + filtered_queries = [q for q in filtered_queries if q.get("query_type", "").upper() == query_type.upper()] + + if path_filter: + filtered_queries = [q for q in filtered_queries if path_filter.lower() in q.get("path", "").lower()] + + # 按持续时间排序并限制数量 + sorted_queries = sorted(filtered_queries, key=lambda x: x.get("duration", 0), reverse=True)[:limit] + + # 计算统计信息 + stats = { + "total_count": len(slow_queries), + "filtered_count": len(filtered_queries), + "displayed_count": len(sorted_queries), + "queries": sorted_queries + } + + # 如果有查询,计算平均持续时间 + if sorted_queries: + stats["avg_duration"] = sum(q.get("duration", 0) for q in sorted_queries) / len(sorted_queries) + stats["max_duration"] = max(q.get("duration", 0) for q in sorted_queries) + stats["min_duration"] = min(q.get("duration", 0) for q in sorted_queries) + + return success_response(data=stats) \ No newline at end of file diff --git a/app/core/db_monitor.py b/app/core/db_monitor.py index 978862d..5526398 100644 --- a/app/core/db_monitor.py +++ b/app/core/db_monitor.py @@ -5,28 +5,89 @@ from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request import time import threading -from collections import defaultdict, deque +from collections import defaultdict, deque, Counter import datetime +import json +import re logger = logging.getLogger(__name__) +# SQL查询模式识别正则表达式 +SELECT_PATTERN = re.compile(r'SELECT\s+.*?\s+FROM\s+(\w+)', re.IGNORECASE) +INSERT_PATTERN = re.compile(r'INSERT\s+INTO\s+(\w+)', re.IGNORECASE) +UPDATE_PATTERN = re.compile(r'UPDATE\s+(\w+)\s+SET', re.IGNORECASE) +DELETE_PATTERN = re.compile(r'DELETE\s+FROM\s+(\w+)', re.IGNORECASE) + +def extract_table_name(sql): + """从SQL语句中提取表名""" + for pattern in [SELECT_PATTERN, INSERT_PATTERN, UPDATE_PATTERN, DELETE_PATTERN]: + match = pattern.search(sql) + if match: + return match.group(1) + return "unknown" + class DBStats: """数据库统计信息""" def __init__(self): self.slow_queries = deque(maxlen=100) # 最多保存100条慢查询记录 self.endpoint_stats = defaultdict(lambda: {"count": 0, "total_time": 0, "max_time": 0}) self.hourly_requests = defaultdict(int) + self.table_access_count = Counter() # 表访问计数 + self.query_patterns = Counter() # 查询模式计数 self.lock = threading.Lock() - def record_slow_query(self, method, path, duration): + def record_slow_query(self, method, path, duration, sql=None, params=None): """记录慢查询""" with self.lock: - self.slow_queries.append({ + # 提取表名和查询类型 + table_name = "unknown" + query_type = "unknown" + + if sql: + # 确定查询类型 + if sql.strip().upper().startswith("SELECT"): + query_type = "SELECT" + elif sql.strip().upper().startswith("INSERT"): + query_type = "INSERT" + elif sql.strip().upper().startswith("UPDATE"): + query_type = "UPDATE" + elif sql.strip().upper().startswith("DELETE"): + query_type = "DELETE" + + # 提取表名 + table_name = extract_table_name(sql) + + # 更新表访问统计 + self.table_access_count[table_name] += 1 + + # 更新查询模式统计 + query_pattern = f"{query_type} {table_name}" + self.query_patterns[query_pattern] += 1 + + # 记录慢查询 + query_info = { "method": method, "path": path, "duration": duration, - "timestamp": datetime.datetime.now().isoformat() - }) + "timestamp": datetime.datetime.now().isoformat(), + "query_type": query_type, + "table": table_name + } + + # 如果在调试模式下,添加SQL和参数信息 + if sql: + query_info["sql"] = ' '.join(sql.split())[:500] # 截断并移除多余空白 + + if params and isinstance(params, (dict, list, tuple)): + try: + # 尝试安全地序列化参数 + param_str = json.dumps(params) + if len(param_str) <= 500: # 限制参数长度 + query_info["params"] = param_str + except: + pass + + self.slow_queries.append(query_info) def record_request(self, method, path, duration): """记录请求统计""" @@ -56,19 +117,47 @@ class DBStats: reverse=True )[:10] + # 获取最常访问的表(前10个) + top_tables = self.table_access_count.most_common(10) + + # 获取最常见的查询模式(前10个) + top_query_patterns = self.query_patterns.most_common(10) + return { "slow_queries": list(self.slow_queries), "top_slow_endpoints": sorted_endpoints, - "hourly_requests": dict(self.hourly_requests) + "hourly_requests": dict(self.hourly_requests), + "top_accessed_tables": top_tables, + "top_query_patterns": top_query_patterns } +# 全局请求上下文 +request_context = threading.local() + class DBConnectionMonitor: """数据库连接监控工具""" # 初始化统计对象 stats = DBStats() + @staticmethod + def set_request_context(request): + """设置当前请求上下文""" + request_context.method = request.method + request_context.path = request.url.path + + @staticmethod + def get_request_context(): + """获取当前请求上下文""" + return getattr(request_context, 'method', 'UNKNOWN'), getattr(request_context, 'path', 'UNKNOWN') + + @staticmethod + def record_slow_query(duration, sql=None, params=None): + """记录慢查询""" + method, path = DBConnectionMonitor.get_request_context() + DBConnectionMonitor.stats.record_slow_query(method, path, duration, sql, params) + @staticmethod def get_connection_stats(): """获取当前连接池状态""" @@ -108,6 +197,9 @@ class DBMonitorMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): self.request_count += 1 + # 设置请求上下文 + DBConnectionMonitor.set_request_context(request) + # 每处理N个请求记录一次连接池状态 if self.request_count % self.log_interval == 0: DBConnectionMonitor.log_connection_stats() diff --git a/app/models/database.py b/app/models/database.py index eac6900..6ea6971 100644 --- a/app/models/database.py +++ b/app/models/database.py @@ -37,11 +37,19 @@ Base = declarative_base() active_sessions = {} session_lock = threading.Lock() +# 导入监控工具(延迟导入以避免循环依赖) +def get_db_monitor(): + from app.core.db_monitor import DBConnectionMonitor + return DBConnectionMonitor + # 添加事件监听器,记录长时间运行的查询 @event.listens_for(engine, "before_cursor_execute") def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): """在执行SQL查询前记录时间""" conn.info.setdefault('query_start_time', []).append(time.time()) + conn.info.setdefault('query_statement', []).append(statement) + conn.info.setdefault('query_parameters', []).append(parameters) + if settings.DEBUG: # 安全地记录SQL语句,避免敏感信息泄露 safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白 @@ -51,12 +59,21 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem def after_cursor_execute(conn, cursor, statement, parameters, context, executemany): """在执行SQL查询后计算耗时并记录慢查询""" total = time.time() - conn.info['query_start_time'].pop(-1) + statement = conn.info['query_statement'].pop(-1) + parameters = conn.info['query_parameters'].pop(-1) # 记录慢查询 if total > 0.5: # 记录超过0.5秒的查询 # 安全地记录SQL语句,避免敏感信息泄露 safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白 logger.warning(f"慢查询 ({total:.2f}s): {safe_statement}...") + + # 记录到监控系统 + try: + monitor = get_db_monitor() + monitor.record_slow_query(total, statement, parameters) + except Exception as e: + logger.error(f"记录慢查询失败: {str(e)}") # 依赖项 def get_db():