109 lines
4.6 KiB
Python
109 lines
4.6 KiB
Python
from typing import Any, Callable, Dict, Optional
|
||
from fastapi import FastAPI, Request, Response
|
||
from fastapi.responses import JSONResponse
|
||
import json
|
||
import logging
|
||
import sys
|
||
from starlette.middleware.base import BaseHTTPMiddleware
|
||
from starlette.types import ASGIApp
|
||
from app.schemas.response import StandardResponse, ErrorResponse
|
||
|
||
# 创建日志记录器并确保它正确配置
|
||
logger = logging.getLogger(__name__)
|
||
logger.setLevel(logging.DEBUG)
|
||
|
||
# 添加控制台处理器确保日志显示
|
||
if not logger.handlers:
|
||
handler = logging.StreamHandler(sys.stdout)
|
||
handler.setLevel(logging.DEBUG)
|
||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
handler.setFormatter(formatter)
|
||
logger.addHandler(handler)
|
||
logger.propagate = False # 避免重复日志
|
||
|
||
|
||
class ResponseMiddleware(BaseHTTPMiddleware):
|
||
"""
|
||
中间件:统一处理API响应格式
|
||
|
||
请求正确:{code:200, data:Any}
|
||
业务错误:{code:500, message:""}
|
||
"""
|
||
|
||
def __init__(self, app: ASGIApp):
|
||
logger.warning("======== 初始化响应中间件 ========")
|
||
super().__init__(app)
|
||
|
||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||
logger.warning(f"===== 中间件处理请求: {request.method} {request.url.path} =====")
|
||
|
||
# 不需要处理的路径
|
||
exclude_paths = ["/docs", "/redoc", "/openapi.json"]
|
||
if any(request.url.path.startswith(path) for path in exclude_paths):
|
||
logger.warning(f"跳过处理: {request.url.path}")
|
||
return await call_next(request)
|
||
|
||
# 获取原始响应
|
||
try:
|
||
response = await call_next(request)
|
||
logger.warning(f"请求处理完成: {request.url.path}, 状态码: {response.status_code}")
|
||
|
||
# 检查内容类型
|
||
content_type = response.headers.get("content-type", "")
|
||
logger.warning(f"响应内容类型: {content_type}")
|
||
|
||
# 如果不是JSON响应,直接返回
|
||
if "application/json" not in content_type:
|
||
logger.warning(f"非JSON响应,跳过处理: {content_type}")
|
||
return response
|
||
|
||
# 读取响应内容
|
||
try:
|
||
# 使用JSONResponse的方法获取内容
|
||
if isinstance(response, JSONResponse):
|
||
# 获取原始数据
|
||
raw_data = response.body.decode("utf-8")
|
||
data = json.loads(raw_data)
|
||
|
||
logger.warning(f"原始响应数据: {data}")
|
||
|
||
# 已经是标准格式,不再封装
|
||
if isinstance(data, dict) and "code" in data and ("data" in data or "message" in data):
|
||
logger.warning("响应已经是标准格式,不再包装")
|
||
return response
|
||
|
||
# 创建新的标准响应
|
||
std_response = StandardResponse(code=200, data=data)
|
||
logger.warning(f"包装为标准响应: code=200, data类型={type(data).__name__}")
|
||
|
||
# 创建新的JSONResponse
|
||
wrapped_response = JSONResponse(
|
||
content=std_response.model_dump(),
|
||
status_code=200, # 始终返回200状态码,错误码在响应内容中
|
||
headers=dict(response.headers)
|
||
)
|
||
logger.warning("返回包装后的响应")
|
||
return wrapped_response
|
||
# 处理流式响应或其他类型响应
|
||
else:
|
||
logger.warning(f"非JSONResponse类型响应: {type(response)}")
|
||
return response
|
||
|
||
except Exception as e:
|
||
# 出现异常,记录日志
|
||
logger.error(f"处理响应内容时发生异常: {str(e)}", exc_info=True)
|
||
# 返回原始响应
|
||
return response
|
||
except Exception as e:
|
||
logger.error(f"中间件处理请求时发生异常: {str(e)}", exc_info=True)
|
||
raise
|
||
|
||
# 使用函数工厂模式创建中间件,确保每次创建新实例
|
||
def create_response_middleware():
|
||
return ResponseMiddleware
|
||
|
||
def add_response_middleware(app: FastAPI) -> None:
|
||
"""添加响应处理中间件到FastAPI应用"""
|
||
logger.warning("添加响应处理中间件到FastAPI应用")
|
||
# 使用函数工厂以确保每次获取新的中间件实例
|
||
app.add_middleware(create_response_middleware()) |