86 lines
2.8 KiB
Python
86 lines
2.8 KiB
Python
from fastapi import Request, Response
|
||
from fastapi.responses import JSONResponse
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
from typing import Any, Dict, Optional, Union
|
||
import json
|
||
|
||
class ResponseWrapperMiddleware(BaseHTTPMiddleware):
|
||
"""
|
||
响应包装中间件,自动将所有API响应包装为标准格式
|
||
格式为:
|
||
{
|
||
"success": true/false,
|
||
"code": 200,
|
||
"message": "操作成功",
|
||
"data": 原始响应数据
|
||
}
|
||
"""
|
||
|
||
async def dispatch(
|
||
self, request: Request, call_next
|
||
) -> Response:
|
||
# 排除不需要包装的路径
|
||
if self._should_skip_path(request.url.path):
|
||
return await call_next(request)
|
||
|
||
# 调用下一个中间件或路由处理函数
|
||
response = await call_next(request)
|
||
|
||
# 如果响应是JSON,且未使用标准包装格式,则进行包装
|
||
if (
|
||
isinstance(response, JSONResponse) and
|
||
self._should_wrap_response(response)
|
||
):
|
||
return self._wrap_response(response)
|
||
|
||
return response
|
||
|
||
def _should_skip_path(self, path: str) -> bool:
|
||
"""判断是否跳过包装处理"""
|
||
# 跳过文档相关路径
|
||
skip_paths = ["/docs", "/redoc", "/openapi.json"]
|
||
|
||
for skip_path in skip_paths:
|
||
if path.startswith(skip_path):
|
||
return True
|
||
|
||
return False
|
||
|
||
def _should_wrap_response(self, response: JSONResponse) -> bool:
|
||
"""判断是否需要包装响应"""
|
||
try:
|
||
content = response.body.decode()
|
||
data = json.loads(content)
|
||
|
||
# 已经是标准格式则不需要再包装
|
||
if isinstance(data, dict) and "success" in data and "code" in data and "message" in data:
|
||
return False
|
||
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
def _wrap_response(self, response: JSONResponse) -> JSONResponse:
|
||
"""包装响应为标准格式"""
|
||
try:
|
||
# 解析原始响应内容
|
||
content = response.body.decode()
|
||
data = json.loads(content)
|
||
|
||
# 构造标准格式响应
|
||
wrapped_data = {
|
||
"success": response.status_code < 400,
|
||
"code": response.status_code,
|
||
"message": "操作成功" if response.status_code < 400 else "操作失败",
|
||
"data": data
|
||
}
|
||
|
||
# 创建新的响应
|
||
return JSONResponse(
|
||
content=wrapped_data,
|
||
status_code=response.status_code,
|
||
headers=dict(response.headers),
|
||
)
|
||
except Exception:
|
||
# 出错时返回原始响应
|
||
return response |