This commit is contained in:
aaron 2025-03-10 21:15:01 +08:00
parent 595195dedc
commit 8e578953cc
3 changed files with 195 additions and 9 deletions

View File

@ -4,6 +4,7 @@ from sqlalchemy import text
from app.models.database import get_db, get_active_sessions_count, get_long_running_sessions
from app.core.response import success_response, ResponseModel
from app.core.db_monitor import DBConnectionMonitor
from typing import Optional, List
router = APIRouter()
@ -38,14 +39,42 @@ async def health_check(db: Session = Depends(get_db)):
async def performance_stats(
include_slow_queries: bool = Query(False, description="是否包含慢查询记录"),
include_long_sessions: bool = Query(False, description="是否包含长时间运行的会话详情"),
min_duration: Optional[float] = Query(None, description="慢查询最小持续时间(秒)"),
table_filter: Optional[str] = Query(None, description="按表名筛选慢查询"),
query_type: Optional[str] = Query(None, description="按查询类型筛选SELECT, INSERT, UPDATE, DELETE"),
limit: int = Query(20, ge=1, le=100, description="返回的慢查询数量限制"),
db: Session = Depends(get_db)
):
"""获取详细的性能统计信息"""
# 获取所有统计信息
all_stats = DBConnectionMonitor.get_all_stats()
# 如果不需要包含慢查询记录,则移除它们以减少响应大小
if not include_slow_queries and "slow_queries" in all_stats["performance_stats"]:
# 处理慢查询记录
if include_slow_queries and "slow_queries" in all_stats["performance_stats"]:
# 获取原始慢查询列表
slow_queries = all_stats["performance_stats"]["slow_queries"]
# 应用过滤条件
filtered_queries = slow_queries
if min_duration is not None:
filtered_queries = [q for q in filtered_queries if q.get("duration", 0) >= min_duration]
if table_filter:
filtered_queries = [q for q in filtered_queries if table_filter.lower() in q.get("table", "").lower()]
if query_type:
filtered_queries = [q for q in filtered_queries if q.get("query_type", "").upper() == query_type.upper()]
# 按持续时间排序并限制数量
sorted_queries = sorted(filtered_queries, key=lambda x: x.get("duration", 0), reverse=True)[:limit]
# 更新统计信息
all_stats["performance_stats"]["slow_queries"] = sorted_queries
all_stats["performance_stats"]["slow_queries_total_count"] = len(slow_queries)
all_stats["performance_stats"]["slow_queries_filtered_count"] = len(filtered_queries)
elif "slow_queries" in all_stats["performance_stats"]:
# 如果不包含慢查询记录,只返回计数
all_stats["performance_stats"]["slow_queries_count"] = len(all_stats["performance_stats"]["slow_queries"])
del all_stats["performance_stats"]["slow_queries"]
@ -60,4 +89,52 @@ async def performance_stats(
else:
all_stats["sessions"]["long_running_count"] = len(get_long_running_sessions(threshold_seconds=30))
return success_response(data=all_stats)
return success_response(data=all_stats)
@router.get("/slow-queries", response_model=ResponseModel)
async def get_slow_queries(
min_duration: Optional[float] = Query(0.5, description="最小持续时间(秒)"),
table_filter: Optional[str] = Query(None, description="按表名筛选"),
query_type: Optional[str] = Query(None, description="按查询类型筛选SELECT, INSERT, UPDATE, DELETE"),
path_filter: Optional[str] = Query(None, description="按API路径筛选"),
limit: int = Query(50, ge=1, le=100, description="返回的记录数量限制"),
db: Session = Depends(get_db)
):
"""获取慢查询记录,支持多种过滤条件"""
# 获取所有慢查询
all_stats = DBConnectionMonitor.get_all_stats()
slow_queries = all_stats["performance_stats"].get("slow_queries", [])
# 应用过滤条件
filtered_queries = slow_queries
if min_duration is not None:
filtered_queries = [q for q in filtered_queries if q.get("duration", 0) >= min_duration]
if table_filter:
filtered_queries = [q for q in filtered_queries if table_filter.lower() in q.get("table", "").lower()]
if query_type:
filtered_queries = [q for q in filtered_queries if q.get("query_type", "").upper() == query_type.upper()]
if path_filter:
filtered_queries = [q for q in filtered_queries if path_filter.lower() in q.get("path", "").lower()]
# 按持续时间排序并限制数量
sorted_queries = sorted(filtered_queries, key=lambda x: x.get("duration", 0), reverse=True)[:limit]
# 计算统计信息
stats = {
"total_count": len(slow_queries),
"filtered_count": len(filtered_queries),
"displayed_count": len(sorted_queries),
"queries": sorted_queries
}
# 如果有查询,计算平均持续时间
if sorted_queries:
stats["avg_duration"] = sum(q.get("duration", 0) for q in sorted_queries) / len(sorted_queries)
stats["max_duration"] = max(q.get("duration", 0) for q in sorted_queries)
stats["min_duration"] = min(q.get("duration", 0) for q in sorted_queries)
return success_response(data=stats)

View File

@ -5,28 +5,89 @@ from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
import time
import threading
from collections import defaultdict, deque
from collections import defaultdict, deque, Counter
import datetime
import json
import re
logger = logging.getLogger(__name__)
# SQL查询模式识别正则表达式
SELECT_PATTERN = re.compile(r'SELECT\s+.*?\s+FROM\s+(\w+)', re.IGNORECASE)
INSERT_PATTERN = re.compile(r'INSERT\s+INTO\s+(\w+)', re.IGNORECASE)
UPDATE_PATTERN = re.compile(r'UPDATE\s+(\w+)\s+SET', re.IGNORECASE)
DELETE_PATTERN = re.compile(r'DELETE\s+FROM\s+(\w+)', re.IGNORECASE)
def extract_table_name(sql):
"""从SQL语句中提取表名"""
for pattern in [SELECT_PATTERN, INSERT_PATTERN, UPDATE_PATTERN, DELETE_PATTERN]:
match = pattern.search(sql)
if match:
return match.group(1)
return "unknown"
class DBStats:
"""数据库统计信息"""
def __init__(self):
self.slow_queries = deque(maxlen=100) # 最多保存100条慢查询记录
self.endpoint_stats = defaultdict(lambda: {"count": 0, "total_time": 0, "max_time": 0})
self.hourly_requests = defaultdict(int)
self.table_access_count = Counter() # 表访问计数
self.query_patterns = Counter() # 查询模式计数
self.lock = threading.Lock()
def record_slow_query(self, method, path, duration):
def record_slow_query(self, method, path, duration, sql=None, params=None):
"""记录慢查询"""
with self.lock:
self.slow_queries.append({
# 提取表名和查询类型
table_name = "unknown"
query_type = "unknown"
if sql:
# 确定查询类型
if sql.strip().upper().startswith("SELECT"):
query_type = "SELECT"
elif sql.strip().upper().startswith("INSERT"):
query_type = "INSERT"
elif sql.strip().upper().startswith("UPDATE"):
query_type = "UPDATE"
elif sql.strip().upper().startswith("DELETE"):
query_type = "DELETE"
# 提取表名
table_name = extract_table_name(sql)
# 更新表访问统计
self.table_access_count[table_name] += 1
# 更新查询模式统计
query_pattern = f"{query_type} {table_name}"
self.query_patterns[query_pattern] += 1
# 记录慢查询
query_info = {
"method": method,
"path": path,
"duration": duration,
"timestamp": datetime.datetime.now().isoformat()
})
"timestamp": datetime.datetime.now().isoformat(),
"query_type": query_type,
"table": table_name
}
# 如果在调试模式下添加SQL和参数信息
if sql:
query_info["sql"] = ' '.join(sql.split())[:500] # 截断并移除多余空白
if params and isinstance(params, (dict, list, tuple)):
try:
# 尝试安全地序列化参数
param_str = json.dumps(params)
if len(param_str) <= 500: # 限制参数长度
query_info["params"] = param_str
except:
pass
self.slow_queries.append(query_info)
def record_request(self, method, path, duration):
"""记录请求统计"""
@ -56,19 +117,47 @@ class DBStats:
reverse=True
)[:10]
# 获取最常访问的表前10个
top_tables = self.table_access_count.most_common(10)
# 获取最常见的查询模式前10个
top_query_patterns = self.query_patterns.most_common(10)
return {
"slow_queries": list(self.slow_queries),
"top_slow_endpoints": sorted_endpoints,
"hourly_requests": dict(self.hourly_requests)
"hourly_requests": dict(self.hourly_requests),
"top_accessed_tables": top_tables,
"top_query_patterns": top_query_patterns
}
# 全局请求上下文
request_context = threading.local()
class DBConnectionMonitor:
"""数据库连接监控工具"""
# 初始化统计对象
stats = DBStats()
@staticmethod
def set_request_context(request):
"""设置当前请求上下文"""
request_context.method = request.method
request_context.path = request.url.path
@staticmethod
def get_request_context():
"""获取当前请求上下文"""
return getattr(request_context, 'method', 'UNKNOWN'), getattr(request_context, 'path', 'UNKNOWN')
@staticmethod
def record_slow_query(duration, sql=None, params=None):
"""记录慢查询"""
method, path = DBConnectionMonitor.get_request_context()
DBConnectionMonitor.stats.record_slow_query(method, path, duration, sql, params)
@staticmethod
def get_connection_stats():
"""获取当前连接池状态"""
@ -108,6 +197,9 @@ class DBMonitorMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
self.request_count += 1
# 设置请求上下文
DBConnectionMonitor.set_request_context(request)
# 每处理N个请求记录一次连接池状态
if self.request_count % self.log_interval == 0:
DBConnectionMonitor.log_connection_stats()

View File

@ -37,11 +37,19 @@ Base = declarative_base()
active_sessions = {}
session_lock = threading.Lock()
# 导入监控工具(延迟导入以避免循环依赖)
def get_db_monitor():
from app.core.db_monitor import DBConnectionMonitor
return DBConnectionMonitor
# 添加事件监听器,记录长时间运行的查询
@event.listens_for(engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
"""在执行SQL查询前记录时间"""
conn.info.setdefault('query_start_time', []).append(time.time())
conn.info.setdefault('query_statement', []).append(statement)
conn.info.setdefault('query_parameters', []).append(parameters)
if settings.DEBUG:
# 安全地记录SQL语句避免敏感信息泄露
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
@ -51,12 +59,21 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
"""在执行SQL查询后计算耗时并记录慢查询"""
total = time.time() - conn.info['query_start_time'].pop(-1)
statement = conn.info['query_statement'].pop(-1)
parameters = conn.info['query_parameters'].pop(-1)
# 记录慢查询
if total > 0.5: # 记录超过0.5秒的查询
# 安全地记录SQL语句避免敏感信息泄露
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
logger.warning(f"慢查询 ({total:.2f}s): {safe_statement}...")
# 记录到监控系统
try:
monitor = get_db_monitor()
monitor.record_slow_query(total, statement, parameters)
except Exception as e:
logger.error(f"记录慢查询失败: {str(e)}")
# 依赖项
def get_db():