66 lines
2.6 KiB
Python
66 lines
2.6 KiB
Python
from fastapi import Request
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import Response, JSONResponse
|
|
import json
|
|
import os
|
|
|
|
from app.models.api_response import APIResponseModel
|
|
from app.utils.response import APIResponse
|
|
|
|
class ResponseWrapperMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
# 检查是否需要启用标准响应包装
|
|
use_standard_response = os.environ.get("USE_STANDARD_RESPONSE") == "1"
|
|
|
|
# 不包装的情况:健康检查端点或禁用了标准响应
|
|
health_endpoints = ["/health", "/"]
|
|
if request.url.path in health_endpoints or not use_standard_response:
|
|
return await call_next(request)
|
|
|
|
response = await call_next(request)
|
|
|
|
# 如果响应已经被包装或者是文件下载等特殊响应,则不再处理
|
|
if (
|
|
isinstance(response, Response)
|
|
and not isinstance(response, JSONResponse)
|
|
or response.headers.get("content-type") != "application/json"
|
|
):
|
|
return response
|
|
|
|
# 处理JSON响应
|
|
try:
|
|
response_body = [section async for section in response.body_iterator]
|
|
response.body_iterator = None
|
|
|
|
if len(response_body) > 0:
|
|
body = response_body[0].decode()
|
|
json_body = json.loads(body)
|
|
|
|
# 检查是否已经是标准格式
|
|
if isinstance(json_body, dict) and "code" in json_body and "data" in json_body and "message" in json_body:
|
|
# 已经是标准格式,直接返回
|
|
new_response = Response(
|
|
content=body,
|
|
status_code=response.status_code,
|
|
headers=dict(response.headers),
|
|
media_type=response.media_type
|
|
)
|
|
return new_response
|
|
|
|
# 包装响应
|
|
api_response = APIResponseModel(
|
|
code=200,
|
|
message="操作成功",
|
|
data=json_body
|
|
)
|
|
|
|
return JSONResponse(
|
|
content=api_response.dict(),
|
|
status_code=response.status_code,
|
|
headers=dict(response.headers),
|
|
)
|
|
|
|
return response
|
|
except Exception as e:
|
|
# 如果解析失败,返回原始响应
|
|
return response |