This commit is contained in:
aaron 2025-04-09 11:06:26 +08:00
parent 41e61b364d
commit f3b05805bb
14 changed files with 210 additions and 47 deletions

View File

@ -1 +0,0 @@
# API相关初始化文件

View File

@ -1 +0,0 @@
# API v1 版本初始化文件

View File

@ -7,6 +7,7 @@ from app.services import wechat as wechat_service
from app.services import user as user_service
from app.core.security import create_access_token
from app.db.database import get_db
from app.core.exceptions import BusinessError
router = APIRouter()
@ -22,7 +23,6 @@ async def wechat_login(
- 如果用户不存在则创建新用户
- 生成JWT令牌
"""
try:
# 调用微信API获取openid和unionid
openid, unionid = await wechat_service.code2session(login_data.code)
@ -49,9 +49,3 @@ async def wechat_login(
is_new_user=is_new_user,
openid=openid
)
except wechat_service.WechatLoginError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e)
)

View File

@ -7,6 +7,7 @@ from app.services import user as user_service
from app.db.database import get_db
from app.api.deps import get_current_user
from app.models.users import User as UserModel
from app.core.exceptions import BusinessError
router = APIRouter()
@ -33,4 +34,17 @@ async def update_user_me(
需要JWT令牌认证
"""
user = await user_service.update_user(db, user_id=current_user.id, user_update=user_update)
if user is None:
raise BusinessError("用户更新失败", code=500)
return user
@router.get("/{user_id}", response_model=User, tags=["users"])
async def read_user(
user_id: int,
db: AsyncSession = Depends(get_db)
):
"""获取指定用户信息"""
db_user = await user_service.get_user(db, user_id=user_id)
if db_user is None:
raise BusinessError("用户不存在", code=404)
return db_user

46
app/core/exceptions.py Normal file
View File

@ -0,0 +1,46 @@
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from app.schemas.response import ErrorResponse
class BusinessError(Exception):
"""业务错误异常,使用标准响应格式"""
def __init__(self, message: str, code: int = 500):
self.message = message
self.code = code
super().__init__(self.message)
# 业务异常处理器
async def business_exception_handler(request: Request, exc: BusinessError):
"""将业务异常转换为标准响应格式"""
return JSONResponse(
status_code=200, # 返回200状态码但在响应内容中设置业务错误码
content=ErrorResponse(
code=exc.code,
message=exc.message
).model_dump()
)
# 请求验证错误处理器
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""将请求验证错误转换为标准错误响应"""
error_messages = []
for error in exc.errors():
loc = error.get("loc", [])
loc_str = " -> ".join(str(l) for l in loc)
error_messages.append(f"{loc_str}: {error.get('msg')}")
error_message = ", ".join(error_messages)
return JSONResponse(
status_code=422, # 保持422状态码表示验证错误
content=ErrorResponse(
code=422,
message=f"请求参数验证错误: {error_message}"
).model_dump()
)
def add_exception_handlers(app):
"""添加异常处理器到FastAPI应用"""
app.add_exception_handler(BusinessError, business_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)

68
app/core/middleware.py Normal file
View File

@ -0,0 +1,68 @@
from typing import Any, Callable, Dict, Optional
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse
import json
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from app.schemas.response import StandardResponse, ErrorResponse
class ResponseMiddleware(BaseHTTPMiddleware):
"""
中间件统一处理API响应格式
请求正确{code:200, data:Any}
业务错误{code:500, message:""}
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# 不需要处理的路径
exclude_paths = ["/docs", "/redoc", "/openapi.json"]
if any(request.url.path.startswith(path) for path in exclude_paths):
return await call_next(request)
response = await call_next(request)
# 如果是HTTPException直接返回不进行封装
if response.status_code >= 400:
return response
# 正常响应进行封装
if response.status_code < 300 and response.headers.get("content-type") == "application/json":
response_body = [chunk async for chunk in response.body_iterator]
response_body = b"".join(response_body)
if response_body:
try:
data = json.loads(response_body)
# 已经是统一格式的响应,不再封装
if isinstance(data, dict) and "code" in data and ("data" in data or "message" in data):
return Response(
content=response_body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type
)
# 封装为标准响应格式
result = StandardResponse(code=200, data=data).model_dump()
return JSONResponse(
content=result,
status_code=response.status_code,
headers=dict(response.headers)
)
except json.JSONDecodeError:
# 非JSON响应直接返回
return Response(
content=response_body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type
)
return response
def add_response_middleware(app: FastAPI) -> None:
"""添加响应处理中间件到FastAPI应用"""
app.add_middleware(ResponseMiddleware)

View File

@ -13,9 +13,9 @@ def create_access_token(
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
) -> str:
if expires_delta:
expire = datetime.now() + expires_delta
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.now() + timedelta(
expire = datetime.utcnow() + timedelta(
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
)
to_encode = {"exp": expire, "sub": str(subject)}

View File

@ -1 +0,0 @@
# 数据库相关初始化文件

View File

@ -1,2 +0,0 @@
# 导入所有模型确保它们被SQLAlchemy注册
from app.models.users import User

View File

@ -1 +0,0 @@
# 数据模式初始化文件

35
app/schemas/response.py Normal file
View File

@ -0,0 +1,35 @@
from typing import Any, Generic, Optional, TypeVar
from pydantic import BaseModel, Field
# from pydantic.generics import GenericModel
T = TypeVar('T')
class StandardResponse(BaseModel, Generic[T]):
"""标准API响应格式"""
code: int = 200
message: Optional[str] = None
data: Optional[T] = None
class Config:
json_schema_extra = {
"example": {
"code": 200,
"message": None,
"data": {}
}
}
class ErrorResponse(BaseModel):
"""错误响应格式"""
code: int = 500
message: str
data: None = None
class Config:
json_schema_extra = {
"example": {
"code": 500,
"message": "业务处理错误",
"data": None
}
}

View File

@ -1 +0,0 @@
# 服务层初始化文件

View File

@ -1,6 +1,7 @@
import httpx
from typing import Optional, Dict, Any, Tuple
from app.core.config import settings
from app.core.exceptions import BusinessError
class WechatLoginError(Exception):
"""微信登录错误"""
@ -17,7 +18,7 @@ async def code2session(code: str) -> Tuple[str, Optional[str]]:
Tuple[str, Optional[str]]: (openid, unionid)
Raises:
WechatLoginError: 当微信API调用失败时
BusinessError: 当微信API调用失败时
"""
url = "https://api.weixin.qq.com/sns/jscode2session"
params = {
@ -33,17 +34,19 @@ async def code2session(code: str) -> Tuple[str, Optional[str]]:
result = response.json()
if "errcode" in result and result["errcode"] != 0:
raise WechatLoginError(f"微信登录失败: {result.get('errmsg', '未知错误')}")
raise BusinessError(f"微信登录失败: {result.get('errmsg', '未知错误')}", code=500)
openid = result.get("openid")
unionid = result.get("unionid") # 可能为None
if not openid:
raise WechatLoginError("无法获取openid")
raise BusinessError("无法获取openid", code=500)
return openid, unionid
except httpx.RequestError as e:
raise WechatLoginError(f"网络请求失败: {str(e)}")
raise BusinessError(f"网络请求失败: {str(e)}", code=500)
except Exception as e:
raise WechatLoginError(f"未知错误: {str(e)}")
if isinstance(e, BusinessError):
raise
raise BusinessError(f"未知错误: {str(e)}", code=500)

12
main.py
View File

@ -2,7 +2,8 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.core.config import settings
from app.api.v1.api import api_router
from app.db.init_db import init_db
from app.core.middleware import add_response_middleware
from app.core.exceptions import add_exception_handlers
app = FastAPI(
title=settings.PROJECT_NAME,
@ -19,6 +20,12 @@ app.add_middleware(
allow_headers=["*"],
)
# 添加响应中间件
add_response_middleware(app)
# 添加异常处理器
add_exception_handlers(app)
# 包含API路由
app.include_router(api_router, prefix=settings.API_V1_STR)
@ -33,6 +40,9 @@ async def health_check():
# 应用启动事件
@app.on_event("startup")
async def startup_event():
# 延迟导入,避免循环导入问题
from app.db.init_db import init_db
# 调用异步初始化函数
await init_db()
if __name__ == "__main__":