api/app/core/middleware.py
2025-04-09 16:18:52 +08:00

109 lines
4.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Any, Callable, Dict, Optional
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse
import json
import logging
import sys
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from app.schemas.response import StandardResponse, ErrorResponse
# 创建日志记录器并确保它正确配置
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
# 添加控制台处理器确保日志显示
if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False # 避免重复日志
class ResponseMiddleware(BaseHTTPMiddleware):
"""
中间件统一处理API响应格式
请求正确:{code:200, data:Any}
业务错误:{code:500, message:""}
"""
def __init__(self, app: ASGIApp):
logger.warning("======== 初始化响应中间件 ========")
super().__init__(app)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
logger.warning(f"===== 中间件处理请求: {request.method} {request.url.path} =====")
# 不需要处理的路径
exclude_paths = ["/docs", "/redoc", "/openapi.json"]
if any(request.url.path.startswith(path) for path in exclude_paths):
logger.warning(f"跳过处理: {request.url.path}")
return await call_next(request)
# 获取原始响应
try:
response = await call_next(request)
logger.warning(f"请求处理完成: {request.url.path}, 状态码: {response.status_code}")
# 检查内容类型
content_type = response.headers.get("content-type", "")
logger.warning(f"响应内容类型: {content_type}")
# 如果不是JSON响应直接返回
if "application/json" not in content_type:
logger.warning(f"非JSON响应跳过处理: {content_type}")
return response
# 读取响应内容
try:
# 使用JSONResponse的方法获取内容
if isinstance(response, JSONResponse):
# 获取原始数据
raw_data = response.body.decode("utf-8")
data = json.loads(raw_data)
logger.warning(f"原始响应数据: {data}")
# 已经是标准格式,不再封装
if isinstance(data, dict) and "code" in data and ("data" in data or "message" in data):
logger.warning("响应已经是标准格式,不再包装")
return response
# 创建新的标准响应
std_response = StandardResponse(code=200, data=data)
logger.warning(f"包装为标准响应: code=200, data类型={type(data).__name__}")
# 创建新的JSONResponse
wrapped_response = JSONResponse(
content=std_response.model_dump(),
status_code=200, # 始终返回200状态码错误码在响应内容中
headers=dict(response.headers)
)
logger.warning("返回包装后的响应")
return wrapped_response
# 处理流式响应或其他类型响应
else:
logger.warning(f"非JSONResponse类型响应: {type(response)}")
return response
except Exception as e:
# 出现异常,记录日志
logger.error(f"处理响应内容时发生异常: {str(e)}", exc_info=True)
# 返回原始响应
return response
except Exception as e:
logger.error(f"中间件处理请求时发生异常: {str(e)}", exc_info=True)
raise
# 使用函数工厂模式创建中间件,确保每次创建新实例
def create_response_middleware():
return ResponseMiddleware
def add_response_middleware(app: FastAPI) -> None:
"""添加响应处理中间件到FastAPI应用"""
logger.warning("添加响应处理中间件到FastAPI应用")
# 使用函数工厂以确保每次获取新的中间件实例
app.add_middleware(create_response_middleware())