update
This commit is contained in:
parent
41e61b364d
commit
f3b05805bb
@ -1 +0,0 @@
|
|||||||
# API相关初始化文件
|
|
||||||
@ -1 +0,0 @@
|
|||||||
# API v1 版本初始化文件
|
|
||||||
@ -7,6 +7,7 @@ from app.services import wechat as wechat_service
|
|||||||
from app.services import user as user_service
|
from app.services import user as user_service
|
||||||
from app.core.security import create_access_token
|
from app.core.security import create_access_token
|
||||||
from app.db.database import get_db
|
from app.db.database import get_db
|
||||||
|
from app.core.exceptions import BusinessError
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -22,36 +23,29 @@ async def wechat_login(
|
|||||||
- 如果用户不存在,则创建新用户
|
- 如果用户不存在,则创建新用户
|
||||||
- 生成JWT令牌
|
- 生成JWT令牌
|
||||||
"""
|
"""
|
||||||
try:
|
# 调用微信API获取openid和unionid
|
||||||
# 调用微信API获取openid和unionid
|
openid, unionid = await wechat_service.code2session(login_data.code)
|
||||||
openid, unionid = await wechat_service.code2session(login_data.code)
|
|
||||||
|
# 检查用户是否存在
|
||||||
# 检查用户是否存在
|
existing_user = await user_service.get_user_by_openid(db, openid=openid)
|
||||||
existing_user = await user_service.get_user_by_openid(db, openid=openid)
|
is_new_user = existing_user is None
|
||||||
is_new_user = existing_user is None
|
|
||||||
|
if is_new_user:
|
||||||
if is_new_user:
|
# 创建新用户
|
||||||
# 创建新用户
|
user_create = UserCreate(
|
||||||
user_create = UserCreate(
|
openid=openid,
|
||||||
openid=openid,
|
unionid=unionid
|
||||||
unionid=unionid
|
|
||||||
)
|
|
||||||
user = await user_service.create_user(db, user=user_create)
|
|
||||||
else:
|
|
||||||
user = existing_user
|
|
||||||
|
|
||||||
# 创建访问令牌 - 使用openid作为标识
|
|
||||||
access_token = create_access_token(subject=openid)
|
|
||||||
|
|
||||||
# 返回登录响应
|
|
||||||
return LoginResponse(
|
|
||||||
access_token=access_token,
|
|
||||||
is_new_user=is_new_user,
|
|
||||||
openid=openid
|
|
||||||
)
|
)
|
||||||
|
user = await user_service.create_user(db, user=user_create)
|
||||||
|
else:
|
||||||
|
user = existing_user
|
||||||
|
|
||||||
except wechat_service.WechatLoginError as e:
|
# 创建访问令牌 - 使用openid作为标识
|
||||||
raise HTTPException(
|
access_token = create_access_token(subject=openid)
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=str(e)
|
# 返回登录响应
|
||||||
)
|
return LoginResponse(
|
||||||
|
access_token=access_token,
|
||||||
|
is_new_user=is_new_user,
|
||||||
|
openid=openid
|
||||||
|
)
|
||||||
@ -7,6 +7,7 @@ from app.services import user as user_service
|
|||||||
from app.db.database import get_db
|
from app.db.database import get_db
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.models.users import User as UserModel
|
from app.models.users import User as UserModel
|
||||||
|
from app.core.exceptions import BusinessError
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@ -33,4 +34,17 @@ async def update_user_me(
|
|||||||
需要JWT令牌认证
|
需要JWT令牌认证
|
||||||
"""
|
"""
|
||||||
user = await user_service.update_user(db, user_id=current_user.id, user_update=user_update)
|
user = await user_service.update_user(db, user_id=current_user.id, user_update=user_update)
|
||||||
return user
|
if user is None:
|
||||||
|
raise BusinessError("用户更新失败", code=500)
|
||||||
|
return user
|
||||||
|
|
||||||
|
@router.get("/{user_id}", response_model=User, tags=["users"])
|
||||||
|
async def read_user(
|
||||||
|
user_id: int,
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""获取指定用户信息"""
|
||||||
|
db_user = await user_service.get_user(db, user_id=user_id)
|
||||||
|
if db_user is None:
|
||||||
|
raise BusinessError("用户不存在", code=404)
|
||||||
|
return db_user
|
||||||
46
app/core/exceptions.py
Normal file
46
app/core/exceptions.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from fastapi import HTTPException, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from app.schemas.response import ErrorResponse
|
||||||
|
|
||||||
|
class BusinessError(Exception):
|
||||||
|
"""业务错误异常,使用标准响应格式"""
|
||||||
|
def __init__(self, message: str, code: int = 500):
|
||||||
|
self.message = message
|
||||||
|
self.code = code
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
# 业务异常处理器
|
||||||
|
async def business_exception_handler(request: Request, exc: BusinessError):
|
||||||
|
"""将业务异常转换为标准响应格式"""
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200, # 返回200状态码,但在响应内容中设置业务错误码
|
||||||
|
content=ErrorResponse(
|
||||||
|
code=exc.code,
|
||||||
|
message=exc.message
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 请求验证错误处理器
|
||||||
|
async def validation_exception_handler(request: Request, exc: RequestValidationError):
|
||||||
|
"""将请求验证错误转换为标准错误响应"""
|
||||||
|
error_messages = []
|
||||||
|
for error in exc.errors():
|
||||||
|
loc = error.get("loc", [])
|
||||||
|
loc_str = " -> ".join(str(l) for l in loc)
|
||||||
|
error_messages.append(f"{loc_str}: {error.get('msg')}")
|
||||||
|
|
||||||
|
error_message = ", ".join(error_messages)
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=422, # 保持422状态码,表示验证错误
|
||||||
|
content=ErrorResponse(
|
||||||
|
code=422,
|
||||||
|
message=f"请求参数验证错误: {error_message}"
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_exception_handlers(app):
|
||||||
|
"""添加异常处理器到FastAPI应用"""
|
||||||
|
app.add_exception_handler(BusinessError, business_exception_handler)
|
||||||
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||||
68
app/core/middleware.py
Normal file
68
app/core/middleware.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
from fastapi import FastAPI, Request, Response
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
import json
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
from app.schemas.response import StandardResponse, ErrorResponse
|
||||||
|
|
||||||
|
class ResponseMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""
|
||||||
|
中间件:统一处理API响应格式
|
||||||
|
|
||||||
|
请求正确:{code:200, data:Any}
|
||||||
|
业务错误:{code:500, message:""}
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||||
|
# 不需要处理的路径
|
||||||
|
exclude_paths = ["/docs", "/redoc", "/openapi.json"]
|
||||||
|
if any(request.url.path.startswith(path) for path in exclude_paths):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
# 如果是HTTPException,直接返回,不进行封装
|
||||||
|
if response.status_code >= 400:
|
||||||
|
return response
|
||||||
|
|
||||||
|
# 正常响应进行封装
|
||||||
|
if response.status_code < 300 and response.headers.get("content-type") == "application/json":
|
||||||
|
response_body = [chunk async for chunk in response.body_iterator]
|
||||||
|
response_body = b"".join(response_body)
|
||||||
|
|
||||||
|
if response_body:
|
||||||
|
try:
|
||||||
|
data = json.loads(response_body)
|
||||||
|
|
||||||
|
# 已经是统一格式的响应,不再封装
|
||||||
|
if isinstance(data, dict) and "code" in data and ("data" in data or "message" in data):
|
||||||
|
return Response(
|
||||||
|
content=response_body,
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=dict(response.headers),
|
||||||
|
media_type=response.media_type
|
||||||
|
)
|
||||||
|
|
||||||
|
# 封装为标准响应格式
|
||||||
|
result = StandardResponse(code=200, data=data).model_dump()
|
||||||
|
return JSONResponse(
|
||||||
|
content=result,
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=dict(response.headers)
|
||||||
|
)
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# 非JSON响应,直接返回
|
||||||
|
return Response(
|
||||||
|
content=response_body,
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=dict(response.headers),
|
||||||
|
media_type=response.media_type
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def add_response_middleware(app: FastAPI) -> None:
|
||||||
|
"""添加响应处理中间件到FastAPI应用"""
|
||||||
|
app.add_middleware(ResponseMiddleware)
|
||||||
@ -13,9 +13,9 @@ def create_access_token(
|
|||||||
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
|
subject: Union[str, Any], expires_delta: Optional[timedelta] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
if expires_delta:
|
if expires_delta:
|
||||||
expire = datetime.now() + expires_delta
|
expire = datetime.utcnow() + expires_delta
|
||||||
else:
|
else:
|
||||||
expire = datetime.now() + timedelta(
|
expire = datetime.utcnow() + timedelta(
|
||||||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||||||
)
|
)
|
||||||
to_encode = {"exp": expire, "sub": str(subject)}
|
to_encode = {"exp": expire, "sub": str(subject)}
|
||||||
|
|||||||
@ -1 +0,0 @@
|
|||||||
# 数据库相关初始化文件
|
|
||||||
@ -1,2 +0,0 @@
|
|||||||
# 导入所有模型,确保它们被SQLAlchemy注册
|
|
||||||
from app.models.users import User
|
|
||||||
@ -1 +0,0 @@
|
|||||||
# 数据模式初始化文件
|
|
||||||
35
app/schemas/response.py
Normal file
35
app/schemas/response.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
from typing import Any, Generic, Optional, TypeVar
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
# from pydantic.generics import GenericModel
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
class StandardResponse(BaseModel, Generic[T]):
|
||||||
|
"""标准API响应格式"""
|
||||||
|
code: int = 200
|
||||||
|
message: Optional[str] = None
|
||||||
|
data: Optional[T] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"code": 200,
|
||||||
|
"message": None,
|
||||||
|
"data": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class ErrorResponse(BaseModel):
|
||||||
|
"""错误响应格式"""
|
||||||
|
code: int = 500
|
||||||
|
message: str
|
||||||
|
data: None = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"example": {
|
||||||
|
"code": 500,
|
||||||
|
"message": "业务处理错误",
|
||||||
|
"data": None
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1 +0,0 @@
|
|||||||
# 服务层初始化文件
|
|
||||||
@ -1,6 +1,7 @@
|
|||||||
import httpx
|
import httpx
|
||||||
from typing import Optional, Dict, Any, Tuple
|
from typing import Optional, Dict, Any, Tuple
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
from app.core.exceptions import BusinessError
|
||||||
|
|
||||||
class WechatLoginError(Exception):
|
class WechatLoginError(Exception):
|
||||||
"""微信登录错误"""
|
"""微信登录错误"""
|
||||||
@ -17,7 +18,7 @@ async def code2session(code: str) -> Tuple[str, Optional[str]]:
|
|||||||
Tuple[str, Optional[str]]: (openid, unionid)
|
Tuple[str, Optional[str]]: (openid, unionid)
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
WechatLoginError: 当微信API调用失败时
|
BusinessError: 当微信API调用失败时
|
||||||
"""
|
"""
|
||||||
url = "https://api.weixin.qq.com/sns/jscode2session"
|
url = "https://api.weixin.qq.com/sns/jscode2session"
|
||||||
params = {
|
params = {
|
||||||
@ -33,17 +34,19 @@ async def code2session(code: str) -> Tuple[str, Optional[str]]:
|
|||||||
result = response.json()
|
result = response.json()
|
||||||
|
|
||||||
if "errcode" in result and result["errcode"] != 0:
|
if "errcode" in result and result["errcode"] != 0:
|
||||||
raise WechatLoginError(f"微信登录失败: {result.get('errmsg', '未知错误')}")
|
raise BusinessError(f"微信登录失败: {result.get('errmsg', '未知错误')}", code=500)
|
||||||
|
|
||||||
openid = result.get("openid")
|
openid = result.get("openid")
|
||||||
unionid = result.get("unionid") # 可能为None
|
unionid = result.get("unionid") # 可能为None
|
||||||
|
|
||||||
if not openid:
|
if not openid:
|
||||||
raise WechatLoginError("无法获取openid")
|
raise BusinessError("无法获取openid", code=500)
|
||||||
|
|
||||||
return openid, unionid
|
return openid, unionid
|
||||||
|
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
raise WechatLoginError(f"网络请求失败: {str(e)}")
|
raise BusinessError(f"网络请求失败: {str(e)}", code=500)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise WechatLoginError(f"未知错误: {str(e)}")
|
if isinstance(e, BusinessError):
|
||||||
|
raise
|
||||||
|
raise BusinessError(f"未知错误: {str(e)}", code=500)
|
||||||
12
main.py
12
main.py
@ -2,7 +2,8 @@ from fastapi import FastAPI
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.api.v1.api import api_router
|
from app.api.v1.api import api_router
|
||||||
from app.db.init_db import init_db
|
from app.core.middleware import add_response_middleware
|
||||||
|
from app.core.exceptions import add_exception_handlers
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title=settings.PROJECT_NAME,
|
title=settings.PROJECT_NAME,
|
||||||
@ -19,6 +20,12 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 添加响应中间件
|
||||||
|
add_response_middleware(app)
|
||||||
|
|
||||||
|
# 添加异常处理器
|
||||||
|
add_exception_handlers(app)
|
||||||
|
|
||||||
# 包含API路由
|
# 包含API路由
|
||||||
app.include_router(api_router, prefix=settings.API_V1_STR)
|
app.include_router(api_router, prefix=settings.API_V1_STR)
|
||||||
|
|
||||||
@ -33,6 +40,9 @@ async def health_check():
|
|||||||
# 应用启动事件
|
# 应用启动事件
|
||||||
@app.on_event("startup")
|
@app.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
|
# 延迟导入,避免循环导入问题
|
||||||
|
from app.db.init_db import init_db
|
||||||
|
# 调用异步初始化函数
|
||||||
await init_db()
|
await init_db()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user