115 lines
4.1 KiB
Python
115 lines
4.1 KiB
Python
from fastapi import Request
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
import time
|
||
from app.core.logger import log_request_async
|
||
from app.core.security import verify_token, decode_jwt
|
||
import json
|
||
import copy
|
||
import asyncio
|
||
from starlette.responses import JSONResponse
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
logger.setLevel(logging.DEBUG)
|
||
|
||
class RequestLoggerMiddleware(BaseHTTPMiddleware):
|
||
LOGGED_METHODS = {"POST", "PUT", "DELETE"}
|
||
SENSITIVE_PATHS = {
|
||
"/api/user/login",
|
||
"/api/user/password-login",
|
||
"/api/user/reset-password"
|
||
}
|
||
SENSITIVE_FIELDS = {"password", "verify_code", "old_password", "new_password"}
|
||
|
||
def filter_sensitive_data(self, data: dict, path: str) -> dict:
|
||
"""过滤敏感数据"""
|
||
if not data or path not in self.SENSITIVE_PATHS:
|
||
return data
|
||
|
||
filtered_data = copy.deepcopy(data)
|
||
for field in self.SENSITIVE_FIELDS:
|
||
if field in filtered_data:
|
||
filtered_data[field] = "***"
|
||
return filtered_data
|
||
|
||
def filter_headers(self, headers: dict) -> dict:
|
||
"""过滤敏感请求头"""
|
||
filtered_headers = copy.deepcopy(headers)
|
||
sensitive_headers = ["authorization", "cookie"]
|
||
|
||
for header in sensitive_headers:
|
||
if header in filtered_headers:
|
||
filtered_headers[header] = "***"
|
||
|
||
return filtered_headers
|
||
|
||
async def dispatch(self, request: Request, call_next):
|
||
method = request.method
|
||
|
||
if method not in self.LOGGED_METHODS:
|
||
return await call_next(request)
|
||
|
||
start_time = time.time()
|
||
path = request.url.path
|
||
|
||
# 不记录健康检查请求
|
||
if path.endswith("/health") or path.endswith("/ping"):
|
||
return await call_next(request)
|
||
|
||
# 过滤敏感请求头
|
||
headers = self.filter_headers(dict(request.headers))
|
||
query_params = dict(request.query_params)
|
||
|
||
# 创建请求体的副本,而不消费原始请求体
|
||
body = None
|
||
try:
|
||
# 保存原始请求体内容
|
||
body_bytes = await request.body()
|
||
|
||
# 重要:设置_body属性,使FastAPI可以再次读取请求体
|
||
request._body = body_bytes
|
||
|
||
if body_bytes:
|
||
try:
|
||
body_str = body_bytes.decode()
|
||
body = json.loads(body_str)
|
||
body = self.filter_sensitive_data(body, path)
|
||
except json.JSONDecodeError:
|
||
body = {"raw": "无法解析的JSON数据"}
|
||
except UnicodeDecodeError:
|
||
body = {"raw": "无法解码的二进制数据"}
|
||
except Exception as e:
|
||
logger.error(f"读取请求体异常: {str(e)}")
|
||
body = {"error": "读取请求体异常"}
|
||
|
||
# 处理请求,添加超时保护
|
||
try:
|
||
response = await call_next(request)
|
||
|
||
# 计算响应时间
|
||
response_time = int((time.time() - start_time) * 1000)
|
||
|
||
# 异步记录日志,捕获可能的异常
|
||
try:
|
||
log_data = {
|
||
"path": path,
|
||
"method": method,
|
||
"headers": headers,
|
||
"query_params": query_params,
|
||
"body": body,
|
||
"ip_address": request.client.host if hasattr(request, "client") and request.client else "unknown",
|
||
"status_code": response.status_code,
|
||
"response_time": response_time
|
||
}
|
||
log_request_async(log_data)
|
||
except Exception as e:
|
||
logger.error(f"记录请求日志异常: {str(e)}")
|
||
|
||
return response
|
||
|
||
except Exception as e:
|
||
logger.error(f"请求处理异常: {str(e)}")
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content={"code": 500, "message": "服务器内部错误"}
|
||
) |