61 lines
2.3 KiB
Python
61 lines
2.3 KiB
Python
from typing import Any, Callable, Dict, Optional
|
||
from fastapi import FastAPI, Request, Response
|
||
from fastapi.responses import JSONResponse
|
||
import json
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
from starlette.types import ASGIApp
|
||
from app.schemas.response import StandardResponse, ErrorResponse
|
||
|
||
class ResponseMiddleware(BaseHTTPMiddleware):
|
||
"""
|
||
中间件:统一处理API响应格式
|
||
|
||
请求正确:{code:200, data:Any}
|
||
业务错误:{code:500, message:""}
|
||
"""
|
||
|
||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||
# 不需要处理的路径
|
||
exclude_paths = ["/docs", "/redoc", "/openapi.json"]
|
||
if any(request.url.path.startswith(path) for path in exclude_paths):
|
||
return await call_next(request)
|
||
|
||
# 获取原始响应
|
||
response = await call_next(request)
|
||
|
||
# 如果不是200系列响应或不是JSON响应,直接返回
|
||
if response.status_code >= 300 or response.headers.get("content-type") != "application/json":
|
||
return response
|
||
|
||
# 读取响应内容
|
||
try:
|
||
# 使用JSONResponse的_render方法获取内容
|
||
if isinstance(response, JSONResponse):
|
||
# 获取原始数据
|
||
raw_data = response.body.decode("utf-8")
|
||
data = json.loads(raw_data)
|
||
|
||
# 已经是标准格式,不再封装
|
||
if isinstance(data, dict) and "code" in data and ("data" in data or "message" in data):
|
||
return response
|
||
|
||
# 创建新的标准响应
|
||
std_response = StandardResponse(code=200, data=data)
|
||
|
||
# 创建新的JSONResponse
|
||
return JSONResponse(
|
||
content=std_response.model_dump(),
|
||
status_code=response.status_code,
|
||
headers=dict(response.headers)
|
||
)
|
||
# 处理流式响应或其他类型响应
|
||
else:
|
||
return response
|
||
|
||
except Exception as e:
|
||
# 出现异常,返回原始响应
|
||
return response
|
||
|
||
def add_response_middleware(app: FastAPI) -> None:
|
||
"""添加响应处理中间件到FastAPI应用"""
|
||
app.add_middleware(ResponseMiddleware) |