update
This commit is contained in:
parent
41e61b364d
commit
f3b05805bb
@ -1 +0,0 @@
|
||||
# API相关初始化文件
|
||||
@ -1 +0,0 @@
|
||||
# API v1 版本初始化文件
|
||||
@ -7,6 +7,7 @@ from app.services import wechat as wechat_service
|
||||
from app.services import user as user_service
|
||||
from app.core.security import create_access_token
|
||||
from app.db.database import get_db
|
||||
from app.core.exceptions import BusinessError
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -22,36 +23,29 @@ async def wechat_login(
|
||||
- 如果用户不存在,则创建新用户
|
||||
- 生成JWT令牌
|
||||
"""
|
||||
try:
|
||||
# 调用微信API获取openid和unionid
|
||||
openid, unionid = await wechat_service.code2session(login_data.code)
|
||||
|
||||
# 检查用户是否存在
|
||||
existing_user = await user_service.get_user_by_openid(db, openid=openid)
|
||||
is_new_user = existing_user is None
|
||||
|
||||
if is_new_user:
|
||||
# 创建新用户
|
||||
user_create = UserCreate(
|
||||
openid=openid,
|
||||
unionid=unionid
|
||||
)
|
||||
user = await user_service.create_user(db, user=user_create)
|
||||
else:
|
||||
user = existing_user
|
||||
|
||||
# 创建访问令牌 - 使用openid作为标识
|
||||
access_token = create_access_token(subject=openid)
|
||||
|
||||
# 返回登录响应
|
||||
return LoginResponse(
|
||||
access_token=access_token,
|
||||
is_new_user=is_new_user,
|
||||
openid=openid
|
||||
# 调用微信API获取openid和unionid
|
||||
openid, unionid = await wechat_service.code2session(login_data.code)
|
||||
|
||||
# 检查用户是否存在
|
||||
existing_user = await user_service.get_user_by_openid(db, openid=openid)
|
||||
is_new_user = existing_user is None
|
||||
|
||||
if is_new_user:
|
||||
# 创建新用户
|
||||
user_create = UserCreate(
|
||||
openid=openid,
|
||||
unionid=unionid
|
||||
)
|
||||
user = await user_service.create_user(db, user=user_create)
|
||||
else:
|
||||
user = existing_user
|
||||
|
||||
except wechat_service.WechatLoginError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e)
|
||||
)
|
||||
# 创建访问令牌 - 使用openid作为标识
|
||||
access_token = create_access_token(subject=openid)
|
||||
|
||||
# 返回登录响应
|
||||
return LoginResponse(
|
||||
access_token=access_token,
|
||||
is_new_user=is_new_user,
|
||||
openid=openid
|
||||
)
|
||||
@ -7,6 +7,7 @@ from app.services import user as user_service
|
||||
from app.db.database import get_db
|
||||
from app.api.deps import get_current_user
|
||||
from app.models.users import User as UserModel
|
||||
from app.core.exceptions import BusinessError
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@ -33,4 +34,17 @@ async def update_user_me(
|
||||
需要JWT令牌认证
|
||||
"""
|
||||
user = await user_service.update_user(db, user_id=current_user.id, user_update=user_update)
|
||||
return user
|
||||
if user is None:
|
||||
raise BusinessError("用户更新失败", code=500)
|
||||
return user
|
||||
|
||||
@router.get("/{user_id}", response_model=User, tags=["users"])
|
||||
async def read_user(
|
||||
user_id: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取指定用户信息"""
|
||||
db_user = await user_service.get_user(db, user_id=user_id)
|
||||
if db_user is None:
|
||||
raise BusinessError("用户不存在", code=404)
|
||||
return db_user
|
||||
46
app/core/exceptions.py
Normal file
46
app/core/exceptions.py
Normal file
@ -0,0 +1,46 @@
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from app.schemas.response import ErrorResponse
|
||||
|
||||
class BusinessError(Exception):
|
||||
"""业务错误异常,使用标准响应格式"""
|
||||
def __init__(self, message: str, code: int = 500):
|
||||
self.message = message
|
||||
self.code = code
|
||||
super().__init__(self.message)
|
||||
|
||||
# 业务异常处理器
|
||||
async def business_exception_handler(request: Request, exc: BusinessError):
|
||||
"""将业务异常转换为标准响应格式"""
|
||||
return JSONResponse(
|
||||
status_code=200, # 返回200状态码,但在响应内容中设置业务错误码
|
||||
content=ErrorResponse(
|
||||
code=exc.code,
|
||||
message=exc.message
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
# 请求验证错误处理器
|
||||
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||
"""将请求验证错误转换为标准错误响应"""
|
||||
error_messages = []
|
||||
for error in exc.errors():
|
||||
loc = error.get("loc", [])
|
||||
loc_str = " -> ".join(str(l) for l in loc)
|
||||
error_messages.append(f"{loc_str}: {error.get('msg')}")
|
||||
|
||||
error_message = ", ".join(error_messages)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=422, # 保持422状态码,表示验证错误
|
||||
content=ErrorResponse(
|
||||
code=422,
|
||||
message=f"请求参数验证错误: {error_message}"
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
def add_exception_handlers(app):
|
||||
"""添加异常处理器到FastAPI应用"""
|
||||
app.add_exception_handler(BusinessError, business_exception_handler)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
68
app/core/middleware.py
Normal file
68
app/core/middleware.py
Normal file
@ -0,0 +1,68 @@
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from fastapi import FastAPI, Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
import json
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
from app.schemas.response import StandardResponse, ErrorResponse
|
||||
|
||||
class ResponseMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
中间件:统一处理API响应格式
|
||||
|
||||
请求正确:{code:200, data:Any}
|
||||
业务错误:{code:500, message:""}
|
||||
"""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
# 不需要处理的路径
|
||||
exclude_paths = ["/docs", "/redoc", "/openapi.json"]
|
||||
if any(request.url.path.startswith(path) for path in exclude_paths):
|
||||
return await call_next(request)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# 如果是HTTPException,直接返回,不进行封装
|
||||
if response.status_code >= 400:
|
||||
return response
|
||||
|
||||
# 正常响应进行封装
|
||||
if response.status_code < 300 and response.headers.get("content-type") == "application/json":
|
||||
response_body = [chunk async for chunk in response.body_iterator]
|
||||
response_body = b"".join(response_body)
|
||||
|
||||
if response_body:
|
||||
try:
|
||||
data = json.loads(response_body)
|
||||
|
||||
# 已经是统一格式的响应,不再封装
|
||||
if isinstance(data, dict) and "code" in data and ("data" in data or "message" in data):
|
||||
return Response(
|
||||
content=response_body,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type
|
||||
)
|
||||
|
||||
# 封装为标准响应格式
|
||||
result = StandardResponse(code=200, data=data).model_dump()
|
||||
return JSONResponse(
|
||||
content=result,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# 非JSON响应,直接返回
|
||||
return Response(
|
||||
content=response_body,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def add_response_middleware(app: FastAPI) -> None:
|
||||
"""添加响应处理中间件到FastAPI应用"""
|
||||
app.add_middleware(ResponseMiddleware)
|
||||
@ -13,9 +13,9 @@ def create_access_token(
|
||||
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
|
||||
) -> str:
|
||||
if expires_delta:
|
||||
expire = datetime.now() + expires_delta
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.now() + timedelta(
|
||||
expire = datetime.utcnow() + timedelta(
|
||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||
)
|
||||
to_encode = {"exp": expire, "sub": str(subject)}
|
||||
|
||||
@ -1 +0,0 @@
|
||||
# 数据库相关初始化文件
|
||||
@ -1,2 +0,0 @@
|
||||
# 导入所有模型,确保它们被SQLAlchemy注册
|
||||
from app.models.users import User
|
||||
@ -1 +0,0 @@
|
||||
# 数据模式初始化文件
|
||||
35
app/schemas/response.py
Normal file
35
app/schemas/response.py
Normal file
@ -0,0 +1,35 @@
|
||||
from typing import Any, Generic, Optional, TypeVar
|
||||
from pydantic import BaseModel, Field
|
||||
# from pydantic.generics import GenericModel
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
class StandardResponse(BaseModel, Generic[T]):
|
||||
"""标准API响应格式"""
|
||||
code: int = 200
|
||||
message: Optional[str] = None
|
||||
data: Optional[T] = None
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"code": 200,
|
||||
"message": None,
|
||||
"data": {}
|
||||
}
|
||||
}
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""错误响应格式"""
|
||||
code: int = 500
|
||||
message: str
|
||||
data: None = None
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"code": 500,
|
||||
"message": "业务处理错误",
|
||||
"data": None
|
||||
}
|
||||
}
|
||||
@ -1 +0,0 @@
|
||||
# 服务层初始化文件
|
||||
@ -1,6 +1,7 @@
|
||||
import httpx
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from app.core.config import settings
|
||||
from app.core.exceptions import BusinessError
|
||||
|
||||
class WechatLoginError(Exception):
|
||||
"""微信登录错误"""
|
||||
@ -17,7 +18,7 @@ async def code2session(code: str) -> Tuple[str, Optional[str]]:
|
||||
Tuple[str, Optional[str]]: (openid, unionid)
|
||||
|
||||
Raises:
|
||||
WechatLoginError: 当微信API调用失败时
|
||||
BusinessError: 当微信API调用失败时
|
||||
"""
|
||||
url = "https://api.weixin.qq.com/sns/jscode2session"
|
||||
params = {
|
||||
@ -33,17 +34,19 @@ async def code2session(code: str) -> Tuple[str, Optional[str]]:
|
||||
result = response.json()
|
||||
|
||||
if "errcode" in result and result["errcode"] != 0:
|
||||
raise WechatLoginError(f"微信登录失败: {result.get('errmsg', '未知错误')}")
|
||||
raise BusinessError(f"微信登录失败: {result.get('errmsg', '未知错误')}", code=500)
|
||||
|
||||
openid = result.get("openid")
|
||||
unionid = result.get("unionid") # 可能为None
|
||||
|
||||
if not openid:
|
||||
raise WechatLoginError("无法获取openid")
|
||||
raise BusinessError("无法获取openid", code=500)
|
||||
|
||||
return openid, unionid
|
||||
|
||||
except httpx.RequestError as e:
|
||||
raise WechatLoginError(f"网络请求失败: {str(e)}")
|
||||
raise BusinessError(f"网络请求失败: {str(e)}", code=500)
|
||||
except Exception as e:
|
||||
raise WechatLoginError(f"未知错误: {str(e)}")
|
||||
if isinstance(e, BusinessError):
|
||||
raise
|
||||
raise BusinessError(f"未知错误: {str(e)}", code=500)
|
||||
12
main.py
12
main.py
@ -2,7 +2,8 @@ from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.core.config import settings
|
||||
from app.api.v1.api import api_router
|
||||
from app.db.init_db import init_db
|
||||
from app.core.middleware import add_response_middleware
|
||||
from app.core.exceptions import add_exception_handlers
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
@ -19,6 +20,12 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 添加响应中间件
|
||||
add_response_middleware(app)
|
||||
|
||||
# 添加异常处理器
|
||||
add_exception_handlers(app)
|
||||
|
||||
# 包含API路由
|
||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||
|
||||
@ -33,6 +40,9 @@ async def health_check():
|
||||
# 应用启动事件
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
# 延迟导入,避免循环导入问题
|
||||
from app.db.init_db import init_db
|
||||
# 调用异步初始化函数
|
||||
await init_db()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Loading…
Reference in New Issue
Block a user