update
This commit is contained in:
parent
595195dedc
commit
8e578953cc
@ -4,6 +4,7 @@ from sqlalchemy import text
|
||||
from app.models.database import get_db, get_active_sessions_count, get_long_running_sessions
|
||||
from app.core.response import success_response, ResponseModel
|
||||
from app.core.db_monitor import DBConnectionMonitor
|
||||
from typing import Optional, List
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -38,14 +39,42 @@ async def health_check(db: Session = Depends(get_db)):
|
||||
async def performance_stats(
|
||||
include_slow_queries: bool = Query(False, description="是否包含慢查询记录"),
|
||||
include_long_sessions: bool = Query(False, description="是否包含长时间运行的会话详情"),
|
||||
min_duration: Optional[float] = Query(None, description="慢查询最小持续时间(秒)"),
|
||||
table_filter: Optional[str] = Query(None, description="按表名筛选慢查询"),
|
||||
query_type: Optional[str] = Query(None, description="按查询类型筛选(SELECT, INSERT, UPDATE, DELETE)"),
|
||||
limit: int = Query(20, ge=1, le=100, description="返回的慢查询数量限制"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取详细的性能统计信息"""
|
||||
# 获取所有统计信息
|
||||
all_stats = DBConnectionMonitor.get_all_stats()
|
||||
|
||||
# 如果不需要包含慢查询记录,则移除它们以减少响应大小
|
||||
if not include_slow_queries and "slow_queries" in all_stats["performance_stats"]:
|
||||
# 处理慢查询记录
|
||||
if include_slow_queries and "slow_queries" in all_stats["performance_stats"]:
|
||||
# 获取原始慢查询列表
|
||||
slow_queries = all_stats["performance_stats"]["slow_queries"]
|
||||
|
||||
# 应用过滤条件
|
||||
filtered_queries = slow_queries
|
||||
|
||||
if min_duration is not None:
|
||||
filtered_queries = [q for q in filtered_queries if q.get("duration", 0) >= min_duration]
|
||||
|
||||
if table_filter:
|
||||
filtered_queries = [q for q in filtered_queries if table_filter.lower() in q.get("table", "").lower()]
|
||||
|
||||
if query_type:
|
||||
filtered_queries = [q for q in filtered_queries if q.get("query_type", "").upper() == query_type.upper()]
|
||||
|
||||
# 按持续时间排序并限制数量
|
||||
sorted_queries = sorted(filtered_queries, key=lambda x: x.get("duration", 0), reverse=True)[:limit]
|
||||
|
||||
# 更新统计信息
|
||||
all_stats["performance_stats"]["slow_queries"] = sorted_queries
|
||||
all_stats["performance_stats"]["slow_queries_total_count"] = len(slow_queries)
|
||||
all_stats["performance_stats"]["slow_queries_filtered_count"] = len(filtered_queries)
|
||||
elif "slow_queries" in all_stats["performance_stats"]:
|
||||
# 如果不包含慢查询记录,只返回计数
|
||||
all_stats["performance_stats"]["slow_queries_count"] = len(all_stats["performance_stats"]["slow_queries"])
|
||||
del all_stats["performance_stats"]["slow_queries"]
|
||||
|
||||
@ -61,3 +90,51 @@ async def performance_stats(
|
||||
all_stats["sessions"]["long_running_count"] = len(get_long_running_sessions(threshold_seconds=30))
|
||||
|
||||
return success_response(data=all_stats)
|
||||
|
||||
@router.get("/slow-queries", response_model=ResponseModel)
|
||||
async def get_slow_queries(
|
||||
min_duration: Optional[float] = Query(0.5, description="最小持续时间(秒)"),
|
||||
table_filter: Optional[str] = Query(None, description="按表名筛选"),
|
||||
query_type: Optional[str] = Query(None, description="按查询类型筛选(SELECT, INSERT, UPDATE, DELETE)"),
|
||||
path_filter: Optional[str] = Query(None, description="按API路径筛选"),
|
||||
limit: int = Query(50, ge=1, le=100, description="返回的记录数量限制"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取慢查询记录,支持多种过滤条件"""
|
||||
# 获取所有慢查询
|
||||
all_stats = DBConnectionMonitor.get_all_stats()
|
||||
slow_queries = all_stats["performance_stats"].get("slow_queries", [])
|
||||
|
||||
# 应用过滤条件
|
||||
filtered_queries = slow_queries
|
||||
|
||||
if min_duration is not None:
|
||||
filtered_queries = [q for q in filtered_queries if q.get("duration", 0) >= min_duration]
|
||||
|
||||
if table_filter:
|
||||
filtered_queries = [q for q in filtered_queries if table_filter.lower() in q.get("table", "").lower()]
|
||||
|
||||
if query_type:
|
||||
filtered_queries = [q for q in filtered_queries if q.get("query_type", "").upper() == query_type.upper()]
|
||||
|
||||
if path_filter:
|
||||
filtered_queries = [q for q in filtered_queries if path_filter.lower() in q.get("path", "").lower()]
|
||||
|
||||
# 按持续时间排序并限制数量
|
||||
sorted_queries = sorted(filtered_queries, key=lambda x: x.get("duration", 0), reverse=True)[:limit]
|
||||
|
||||
# 计算统计信息
|
||||
stats = {
|
||||
"total_count": len(slow_queries),
|
||||
"filtered_count": len(filtered_queries),
|
||||
"displayed_count": len(sorted_queries),
|
||||
"queries": sorted_queries
|
||||
}
|
||||
|
||||
# 如果有查询,计算平均持续时间
|
||||
if sorted_queries:
|
||||
stats["avg_duration"] = sum(q.get("duration", 0) for q in sorted_queries) / len(sorted_queries)
|
||||
stats["max_duration"] = max(q.get("duration", 0) for q in sorted_queries)
|
||||
stats["min_duration"] = min(q.get("duration", 0) for q in sorted_queries)
|
||||
|
||||
return success_response(data=stats)
|
||||
@ -5,28 +5,89 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
import time
|
||||
import threading
|
||||
from collections import defaultdict, deque
|
||||
from collections import defaultdict, deque, Counter
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# SQL查询模式识别正则表达式
|
||||
SELECT_PATTERN = re.compile(r'SELECT\s+.*?\s+FROM\s+(\w+)', re.IGNORECASE)
|
||||
INSERT_PATTERN = re.compile(r'INSERT\s+INTO\s+(\w+)', re.IGNORECASE)
|
||||
UPDATE_PATTERN = re.compile(r'UPDATE\s+(\w+)\s+SET', re.IGNORECASE)
|
||||
DELETE_PATTERN = re.compile(r'DELETE\s+FROM\s+(\w+)', re.IGNORECASE)
|
||||
|
||||
def extract_table_name(sql):
|
||||
"""从SQL语句中提取表名"""
|
||||
for pattern in [SELECT_PATTERN, INSERT_PATTERN, UPDATE_PATTERN, DELETE_PATTERN]:
|
||||
match = pattern.search(sql)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return "unknown"
|
||||
|
||||
class DBStats:
|
||||
"""数据库统计信息"""
|
||||
def __init__(self):
|
||||
self.slow_queries = deque(maxlen=100) # 最多保存100条慢查询记录
|
||||
self.endpoint_stats = defaultdict(lambda: {"count": 0, "total_time": 0, "max_time": 0})
|
||||
self.hourly_requests = defaultdict(int)
|
||||
self.table_access_count = Counter() # 表访问计数
|
||||
self.query_patterns = Counter() # 查询模式计数
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def record_slow_query(self, method, path, duration):
|
||||
def record_slow_query(self, method, path, duration, sql=None, params=None):
|
||||
"""记录慢查询"""
|
||||
with self.lock:
|
||||
self.slow_queries.append({
|
||||
# 提取表名和查询类型
|
||||
table_name = "unknown"
|
||||
query_type = "unknown"
|
||||
|
||||
if sql:
|
||||
# 确定查询类型
|
||||
if sql.strip().upper().startswith("SELECT"):
|
||||
query_type = "SELECT"
|
||||
elif sql.strip().upper().startswith("INSERT"):
|
||||
query_type = "INSERT"
|
||||
elif sql.strip().upper().startswith("UPDATE"):
|
||||
query_type = "UPDATE"
|
||||
elif sql.strip().upper().startswith("DELETE"):
|
||||
query_type = "DELETE"
|
||||
|
||||
# 提取表名
|
||||
table_name = extract_table_name(sql)
|
||||
|
||||
# 更新表访问统计
|
||||
self.table_access_count[table_name] += 1
|
||||
|
||||
# 更新查询模式统计
|
||||
query_pattern = f"{query_type} {table_name}"
|
||||
self.query_patterns[query_pattern] += 1
|
||||
|
||||
# 记录慢查询
|
||||
query_info = {
|
||||
"method": method,
|
||||
"path": path,
|
||||
"duration": duration,
|
||||
"timestamp": datetime.datetime.now().isoformat()
|
||||
})
|
||||
"timestamp": datetime.datetime.now().isoformat(),
|
||||
"query_type": query_type,
|
||||
"table": table_name
|
||||
}
|
||||
|
||||
# 如果在调试模式下,添加SQL和参数信息
|
||||
if sql:
|
||||
query_info["sql"] = ' '.join(sql.split())[:500] # 截断并移除多余空白
|
||||
|
||||
if params and isinstance(params, (dict, list, tuple)):
|
||||
try:
|
||||
# 尝试安全地序列化参数
|
||||
param_str = json.dumps(params)
|
||||
if len(param_str) <= 500: # 限制参数长度
|
||||
query_info["params"] = param_str
|
||||
except:
|
||||
pass
|
||||
|
||||
self.slow_queries.append(query_info)
|
||||
|
||||
def record_request(self, method, path, duration):
|
||||
"""记录请求统计"""
|
||||
@ -56,19 +117,47 @@ class DBStats:
|
||||
reverse=True
|
||||
)[:10]
|
||||
|
||||
# 获取最常访问的表(前10个)
|
||||
top_tables = self.table_access_count.most_common(10)
|
||||
|
||||
# 获取最常见的查询模式(前10个)
|
||||
top_query_patterns = self.query_patterns.most_common(10)
|
||||
|
||||
return {
|
||||
"slow_queries": list(self.slow_queries),
|
||||
"top_slow_endpoints": sorted_endpoints,
|
||||
"hourly_requests": dict(self.hourly_requests)
|
||||
"hourly_requests": dict(self.hourly_requests),
|
||||
"top_accessed_tables": top_tables,
|
||||
"top_query_patterns": top_query_patterns
|
||||
}
|
||||
|
||||
|
||||
# 全局请求上下文
|
||||
request_context = threading.local()
|
||||
|
||||
class DBConnectionMonitor:
|
||||
"""数据库连接监控工具"""
|
||||
|
||||
# 初始化统计对象
|
||||
stats = DBStats()
|
||||
|
||||
@staticmethod
|
||||
def set_request_context(request):
|
||||
"""设置当前请求上下文"""
|
||||
request_context.method = request.method
|
||||
request_context.path = request.url.path
|
||||
|
||||
@staticmethod
|
||||
def get_request_context():
|
||||
"""获取当前请求上下文"""
|
||||
return getattr(request_context, 'method', 'UNKNOWN'), getattr(request_context, 'path', 'UNKNOWN')
|
||||
|
||||
@staticmethod
|
||||
def record_slow_query(duration, sql=None, params=None):
|
||||
"""记录慢查询"""
|
||||
method, path = DBConnectionMonitor.get_request_context()
|
||||
DBConnectionMonitor.stats.record_slow_query(method, path, duration, sql, params)
|
||||
|
||||
@staticmethod
|
||||
def get_connection_stats():
|
||||
"""获取当前连接池状态"""
|
||||
@ -108,6 +197,9 @@ class DBMonitorMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
self.request_count += 1
|
||||
|
||||
# 设置请求上下文
|
||||
DBConnectionMonitor.set_request_context(request)
|
||||
|
||||
# 每处理N个请求记录一次连接池状态
|
||||
if self.request_count % self.log_interval == 0:
|
||||
DBConnectionMonitor.log_connection_stats()
|
||||
|
||||
@ -37,11 +37,19 @@ Base = declarative_base()
|
||||
active_sessions = {}
|
||||
session_lock = threading.Lock()
|
||||
|
||||
# 导入监控工具(延迟导入以避免循环依赖)
|
||||
def get_db_monitor():
|
||||
from app.core.db_monitor import DBConnectionMonitor
|
||||
return DBConnectionMonitor
|
||||
|
||||
# 添加事件监听器,记录长时间运行的查询
|
||||
@event.listens_for(engine, "before_cursor_execute")
|
||||
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
"""在执行SQL查询前记录时间"""
|
||||
conn.info.setdefault('query_start_time', []).append(time.time())
|
||||
conn.info.setdefault('query_statement', []).append(statement)
|
||||
conn.info.setdefault('query_parameters', []).append(parameters)
|
||||
|
||||
if settings.DEBUG:
|
||||
# 安全地记录SQL语句,避免敏感信息泄露
|
||||
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
|
||||
@ -51,6 +59,8 @@ def before_cursor_execute(conn, cursor, statement, parameters, context, executem
|
||||
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
"""在执行SQL查询后计算耗时并记录慢查询"""
|
||||
total = time.time() - conn.info['query_start_time'].pop(-1)
|
||||
statement = conn.info['query_statement'].pop(-1)
|
||||
parameters = conn.info['query_parameters'].pop(-1)
|
||||
|
||||
# 记录慢查询
|
||||
if total > 0.5: # 记录超过0.5秒的查询
|
||||
@ -58,6 +68,13 @@ def after_cursor_execute(conn, cursor, statement, parameters, context, executema
|
||||
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
|
||||
logger.warning(f"慢查询 ({total:.2f}s): {safe_statement}...")
|
||||
|
||||
# 记录到监控系统
|
||||
try:
|
||||
monitor = get_db_monitor()
|
||||
monitor.record_slow_query(total, statement, parameters)
|
||||
except Exception as e:
|
||||
logger.error(f"记录慢查询失败: {str(e)}")
|
||||
|
||||
# 依赖项
|
||||
def get_db():
|
||||
"""获取数据库会话"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user