This commit is contained in:
aaron 2025-04-09 16:18:52 +08:00
parent 16c8c11172
commit 10866b1cbb
11 changed files with 285 additions and 146 deletions

View File

@ -9,13 +9,14 @@ from app.services import user as user_service
from app.core.security import create_access_token
from app.db.database import get_db
from app.core.exceptions import BusinessError
from app.schemas.response import StandardResponse
# 创建日志记录器
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/login/wechat", response_model=LoginResponse, tags=["auth"])
@router.post("/login/wechat", tags=["auth"])
async def wechat_login(
login_data: WechatLogin,
db: AsyncSession = Depends(get_db)
@ -62,7 +63,7 @@ async def wechat_login(
)
# 返回标准响应
return response_data
return StandardResponse(code=200, data=response_data)
except Exception as e:
# 记录异常

View File

@ -1,4 +1,4 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi import APIRouter, Depends, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
@ -10,11 +10,17 @@ from app.services import clothing as clothing_service
from app.db.database import get_db
from app.api.deps import get_current_user
from app.models.users import User as UserModel
from app.core.exceptions import BusinessError
from app.schemas.response import StandardResponse
import logging
# 创建日志记录器
logger = logging.getLogger(__name__)
router = APIRouter()
# 衣服分类API
@router.post("/categories", response_model=ClothingCategory, tags=["clothing-categories"])
@router.post("/categories", tags=["clothing-categories"])
async def create_category(
category: ClothingCategoryCreate,
db: AsyncSession = Depends(get_db),
@ -25,18 +31,24 @@ async def create_category(
需要JWT令牌认证
"""
return await clothing_service.create_category(db=db, category=category)
result = await clothing_service.create_category(db=db, category=category)
logger.info(f"创建分类成功: id={result.id}, name={result.name}")
# 手动返回标准响应格式
return StandardResponse(code=200, data=result)
@router.get("/categories", response_model=List[ClothingCategory], tags=["clothing-categories"])
@router.get("/categories", tags=["clothing-categories"])
async def read_categories(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=100),
db: AsyncSession = Depends(get_db)
):
"""获取所有衣服分类"""
return await clothing_service.get_categories(db=db, skip=skip, limit=limit)
result = await clothing_service.get_categories(db=db, skip=skip, limit=limit)
logger.info(f"获取分类列表: 共{len(result)}")
# 手动返回标准响应格式
return StandardResponse(code=200, data=result)
@router.get("/categories/{category_id}", response_model=ClothingCategory, tags=["clothing-categories"])
@router.get("/categories/{category_id}", tags=["clothing-categories"])
async def read_category(
category_id: int,
db: AsyncSession = Depends(get_db)
@ -44,10 +56,11 @@ async def read_category(
"""获取单个衣服分类"""
db_category = await clothing_service.get_category(db, category_id=category_id)
if db_category is None:
raise HTTPException(status_code=404, detail="分类不存在")
return db_category
raise BusinessError("分类不存在", code=404)
# 手动返回标准响应格式
return StandardResponse(code=200, data=db_category)
@router.put("/categories/{category_id}", response_model=ClothingCategory, tags=["clothing-categories"])
@router.put("/categories/{category_id}", tags=["clothing-categories"])
async def update_category(
category_id: int,
category: ClothingCategoryUpdate,
@ -65,10 +78,11 @@ async def update_category(
category_update=category
)
if db_category is None:
raise HTTPException(status_code=404, detail="分类不存在")
return db_category
raise BusinessError("分类不存在", code=404)
# 手动返回标准响应格式
return StandardResponse(code=200, data=db_category)
@router.delete("/categories/{category_id}", response_model=ClothingCategory, tags=["clothing-categories"])
@router.delete("/categories/{category_id}", tags=["clothing-categories"])
async def delete_category(
category_id: int,
db: AsyncSession = Depends(get_db),
@ -81,11 +95,12 @@ async def delete_category(
"""
db_category = await clothing_service.delete_category(db=db, category_id=category_id)
if db_category is None:
raise HTTPException(status_code=404, detail="分类不存在")
return db_category
raise BusinessError("分类不存在", code=404)
# 手动返回标准响应格式
return StandardResponse(code=200, data=db_category)
# 衣服API
@router.post("/", response_model=Clothing, tags=["clothing"])
@router.post("/", tags=["clothing"])
async def create_clothing(
clothing: ClothingCreate,
db: AsyncSession = Depends(get_db),
@ -99,20 +114,25 @@ async def create_clothing(
# 检查分类是否存在
category = await clothing_service.get_category(db, category_id=clothing.clothing_category_id)
if category is None:
raise HTTPException(status_code=404, detail="指定的分类不存在")
return await clothing_service.create_clothing(db=db, clothing=clothing)
raise BusinessError("指定的分类不存在", code=404)
result = await clothing_service.create_clothing(db=db, clothing=clothing)
logger.info(f"创建衣服成功: id={result.id}")
# 手动返回标准响应格式
return StandardResponse(code=200, data=result)
@router.get("/", response_model=List[Clothing], tags=["clothing"])
@router.get("/", tags=["clothing"])
async def read_clothes(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=100),
db: AsyncSession = Depends(get_db)
):
"""获取所有衣服"""
return await clothing_service.get_clothes(db=db, skip=skip, limit=limit)
result = await clothing_service.get_clothes(db=db, skip=skip, limit=limit)
# 手动返回标准响应格式
return StandardResponse(code=200, data=result)
@router.get("/by-category/{category_id}", response_model=List[Clothing], tags=["clothing"])
@router.get("/by-category/{category_id}", tags=["clothing"])
async def read_clothes_by_category(
category_id: int,
skip: int = Query(0, ge=0),
@ -123,16 +143,18 @@ async def read_clothes_by_category(
# 检查分类是否存在
category = await clothing_service.get_category(db, category_id=category_id)
if category is None:
raise HTTPException(status_code=404, detail="分类不存在")
raise BusinessError("分类不存在", code=404)
return await clothing_service.get_clothes_by_category(
result = await clothing_service.get_clothes_by_category(
db=db,
category_id=category_id,
skip=skip,
limit=limit
)
# 手动返回标准响应格式
return StandardResponse(code=200, data=result)
@router.get("/{clothing_id}", response_model=Clothing, tags=["clothing"])
@router.get("/{clothing_id}", tags=["clothing"])
async def read_clothing(
clothing_id: int,
db: AsyncSession = Depends(get_db)
@ -140,10 +162,11 @@ async def read_clothing(
"""获取单个衣服"""
db_clothing = await clothing_service.get_clothing(db, clothing_id=clothing_id)
if db_clothing is None:
raise HTTPException(status_code=404, detail="衣服不存在")
return db_clothing
raise BusinessError("衣服不存在", code=404)
# 手动返回标准响应格式
return StandardResponse(code=200, data=db_clothing)
@router.put("/{clothing_id}", response_model=Clothing, tags=["clothing"])
@router.put("/{clothing_id}", tags=["clothing"])
async def update_clothing(
clothing_id: int,
clothing: ClothingUpdate,
@ -159,7 +182,7 @@ async def update_clothing(
if clothing.clothing_category_id is not None:
category = await clothing_service.get_category(db, category_id=clothing.clothing_category_id)
if category is None:
raise HTTPException(status_code=404, detail="指定的分类不存在")
raise BusinessError("指定的分类不存在", code=404)
db_clothing = await clothing_service.update_clothing(
db=db,
@ -167,10 +190,11 @@ async def update_clothing(
clothing_update=clothing
)
if db_clothing is None:
raise HTTPException(status_code=404, detail="衣服不存在")
return db_clothing
raise BusinessError("衣服不存在", code=404)
# 手动返回标准响应格式
return StandardResponse(code=200, data=db_clothing)
@router.delete("/{clothing_id}", response_model=Clothing, tags=["clothing"])
@router.delete("/{clothing_id}", tags=["clothing"])
async def delete_clothing(
clothing_id: int,
db: AsyncSession = Depends(get_db),
@ -183,5 +207,6 @@ async def delete_clothing(
"""
db_clothing = await clothing_service.delete_clothing(db=db, clothing_id=clothing_id)
if db_clothing is None:
raise HTTPException(status_code=404, detail="衣服不存在")
return db_clothing
raise BusinessError("衣服不存在", code=404)
# 手动返回标准响应格式
return StandardResponse(code=200, data=db_clothing)

View File

@ -1,11 +1,17 @@
from fastapi import APIRouter, HTTPException
from app.schemas.response import StandardResponse
import logging
# 创建日志记录器
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/")
async def get_api_info():
return {
api_info = {
"name": "美搭Meida API",
"version": "v1",
"status": "active"
}
}
return StandardResponse(code=200, data=api_info)

View File

@ -2,17 +2,22 @@ from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
from app.schemas.user_image import UserImage, UserImageCreate, UserImageUpdate
from app.schemas.user_image import UserImage, UserImageCreate, UserImageUpdate, UserImageWithUser
from app.services import user_image as user_image_service
from app.db.database import get_db
from app.api.deps import get_current_user
from app.models.users import User as UserModel
from app.core.exceptions import BusinessError
from app.schemas.response import StandardResponse
import logging
# 创建日志记录器
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/", response_model=UserImage, tags=["user-images"])
async def create_image(
@router.post("/", tags=["user-images"])
async def create_user_image(
image: UserImageCreate,
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
@ -22,14 +27,14 @@ async def create_image(
需要JWT令牌认证
"""
return await user_image_service.create_user_image(
db=db,
user_id=current_user.id,
image=image
)
# 设置用户ID
image.user_id = current_user.id
result = await user_image_service.create_user_image(db=db, image=image)
return StandardResponse(code=200, data=result)
@router.get("/", response_model=List[UserImage], tags=["user-images"])
async def read_images(
@router.get("/", tags=["user-images"])
async def read_user_images(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=100),
current_user: UserModel = Depends(get_current_user),
@ -40,14 +45,15 @@ async def read_images(
需要JWT令牌认证
"""
return await user_image_service.get_user_images(
db=db,
images = await user_image_service.get_user_images_by_user(
db,
user_id=current_user.id,
skip=skip,
limit=limit
)
return StandardResponse(code=200, data=images)
@router.get("/{image_id}", response_model=UserImage, tags=["user-images"])
@router.get("/{image_id}", tags=["user-images"])
async def read_image(
image_id: int,
current_user: UserModel = Depends(get_current_user),
@ -60,15 +66,15 @@ async def read_image(
"""
image = await user_image_service.get_user_image(db, image_id=image_id)
if image is None:
raise HTTPException(status_code=404, detail="图片不存在")
raise BusinessError("图片不存在", code=404)
if image.user_id != current_user.id:
raise HTTPException(status_code=403, detail="没有权限访问此图片")
return image
raise BusinessError("没有权限访问此图片", code=403)
return StandardResponse(code=200, data=image)
@router.put("/{image_id}", response_model=UserImage, tags=["user-images"])
@router.put("/{image_id}", tags=["user-images"])
async def update_image(
image_id: int,
image_update: UserImageUpdate,
image: UserImageUpdate,
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
@ -77,22 +83,21 @@ async def update_image(
需要JWT令牌认证
"""
# 检查图片是否存在
# 检查图片是否存在且属于当前用户
db_image = await user_image_service.get_user_image(db, image_id=image_id)
if db_image is None:
raise HTTPException(status_code=404, detail="图片不存在")
raise BusinessError("图片不存在", code=404)
if db_image.user_id != current_user.id:
raise HTTPException(status_code=403, detail="没有权限更新此图片")
# 执行更新
raise BusinessError("没有权限更新此图片", code=403)
updated_image = await user_image_service.update_user_image(
db=db,
image_id=image_id,
image_update=image_update
image=image
)
return updated_image
return StandardResponse(code=200, data=updated_image)
@router.delete("/{image_id}", response_model=UserImage, tags=["user-images"])
@router.delete("/{image_id}", tags=["user-images"])
async def delete_image(
image_id: int,
current_user: UserModel = Depends(get_current_user),
@ -103,12 +108,12 @@ async def delete_image(
需要JWT令牌认证
"""
# 检查图片是否存在
# 检查图片是否存在且属于当前用户
db_image = await user_image_service.get_user_image(db, image_id=image_id)
if db_image is None:
raise HTTPException(status_code=404, detail="图片不存在")
raise BusinessError("图片不存在", code=404)
if db_image.user_id != current_user.id:
raise HTTPException(status_code=403, detail="没有权限删除此图片")
# 执行删除
return await user_image_service.delete_user_image(db=db, image_id=image_id)
raise BusinessError("没有权限删除此图片", code=403)
deleted_image = await user_image_service.delete_user_image(db=db, image_id=image_id)
return StandardResponse(code=200, data=deleted_image)

View File

@ -8,10 +8,15 @@ from app.db.database import get_db
from app.api.deps import get_current_user
from app.models.users import User as UserModel
from app.core.exceptions import BusinessError
from app.schemas.response import StandardResponse
import logging
# 创建日志记录器
logger = logging.getLogger(__name__)
router = APIRouter()
@router.get("/me", response_model=User, tags=["users"])
@router.get("/me", tags=["users"])
async def read_user_me(
current_user: UserModel = Depends(get_current_user)
):
@ -20,9 +25,9 @@ async def read_user_me(
需要JWT令牌认证
"""
return current_user
return StandardResponse(code=200, data=current_user)
@router.put("/me", response_model=User, tags=["users"])
@router.put("/me", tags=["users"])
async def update_user_me(
user_update: UserUpdate,
db: AsyncSession = Depends(get_db),
@ -33,12 +38,12 @@ async def update_user_me(
需要JWT令牌认证
"""
user = await user_service.update_user(db, user_id=current_user.id, user_update=user_update)
if user is None:
result = await user_service.update_user(db, user_id=current_user.id, user_update=user_update)
if result is None:
raise BusinessError("用户更新失败", code=500)
return user
return StandardResponse(code=200, data=result)
@router.get("/{user_id}", response_model=User, tags=["users"])
@router.get("/{user_id}", tags=["users"])
async def read_user(
user_id: int,
db: AsyncSession = Depends(get_db)
@ -47,4 +52,4 @@ async def read_user(
db_user = await user_service.get_user(db, user_id=user_id)
if db_user is None:
raise BusinessError("用户不存在", code=404)
return db_user
return StandardResponse(code=200, data=db_user)

View File

@ -1,8 +1,13 @@
from fastapi import HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
import logging
from app.schemas.response import ErrorResponse
# 创建日志记录器
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class BusinessError(Exception):
"""业务错误异常,使用标准响应格式"""
def __init__(self, message: str, code: int = 500):
@ -13,12 +18,16 @@ class BusinessError(Exception):
# 业务异常处理器
async def business_exception_handler(request: Request, exc: BusinessError):
"""将业务异常转换为标准响应格式"""
logger.debug(f"处理业务异常: code={exc.code}, message={exc.message}")
error_response = ErrorResponse(
code=exc.code,
message=exc.message
)
return JSONResponse(
status_code=200, # 返回200状态码但在响应内容中设置业务错误码
content=ErrorResponse(
code=exc.code,
message=exc.message
).model_dump()
status_code=200, # 始终返回200状态码业务错误码在响应内容中
content=error_response.model_dump()
)
# 请求验证错误处理器
@ -31,16 +40,35 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
error_messages.append(f"{loc_str}: {error.get('msg')}")
error_message = ", ".join(error_messages)
logger.debug(f"处理请求验证错误: {error_message}")
error_response = ErrorResponse(
code=422,
message=f"请求参数验证错误: {error_message}"
)
return JSONResponse(
status_code=422, # 保持422状态码表示验证错误
content=ErrorResponse(
code=422,
message=f"请求参数验证错误: {error_message}"
).model_dump()
status_code=200, # 与业务异常一致返回200状态码
content=error_response.model_dump()
)
# HTTP异常处理器
async def http_exception_handler(request: Request, exc: HTTPException):
"""将HTTP异常转换为标准错误响应"""
logger.debug(f"处理HTTP异常: status_code={exc.status_code}, detail={exc.detail}")
error_response = ErrorResponse(
code=exc.status_code,
message=str(exc.detail)
)
return JSONResponse(
status_code=200, # 与业务异常一致返回200状态码
content=error_response.model_dump()
)
def add_exception_handlers(app):
"""添加异常处理器到FastAPI应用"""
app.add_exception_handler(BusinessError, business_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)
app.add_exception_handler(HTTPException, http_exception_handler)

View File

@ -2,10 +2,26 @@ from typing import Any, Callable, Dict, Optional
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse
import json
import logging
import sys
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from app.schemas.response import StandardResponse, ErrorResponse
# 创建日志记录器并确保它正确配置
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
# 添加控制台处理器确保日志显示
if not logger.handlers:
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.propagate = False # 避免重复日志
class ResponseMiddleware(BaseHTTPMiddleware):
"""
中间件统一处理API响应格式
@ -14,48 +30,80 @@ class ResponseMiddleware(BaseHTTPMiddleware):
业务错误{code:500, message:""}
"""
def __init__(self, app: ASGIApp):
logger.warning("======== 初始化响应中间件 ========")
super().__init__(app)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
logger.warning(f"===== 中间件处理请求: {request.method} {request.url.path} =====")
# 不需要处理的路径
exclude_paths = ["/docs", "/redoc", "/openapi.json"]
if any(request.url.path.startswith(path) for path in exclude_paths):
logger.warning(f"跳过处理: {request.url.path}")
return await call_next(request)
# 获取原始响应
response = await call_next(request)
# 如果不是200系列响应或不是JSON响应直接返回
if response.status_code >= 300 or response.headers.get("content-type") != "application/json":
return response
# 读取响应内容
try:
# 使用JSONResponse的_render方法获取内容
if isinstance(response, JSONResponse):
# 获取原始数据
raw_data = response.body.decode("utf-8")
data = json.loads(raw_data)
# 已经是标准格式,不再封装
if isinstance(data, dict) and "code" in data and ("data" in data or "message" in data):
return response
# 创建新的标准响应
std_response = StandardResponse(code=200, data=data)
# 创建新的JSONResponse
return JSONResponse(
content=std_response.model_dump(),
status_code=response.status_code,
headers=dict(response.headers)
)
# 处理流式响应或其他类型响应
else:
response = await call_next(request)
logger.warning(f"请求处理完成: {request.url.path}, 状态码: {response.status_code}")
# 检查内容类型
content_type = response.headers.get("content-type", "")
logger.warning(f"响应内容类型: {content_type}")
# 如果不是JSON响应直接返回
if "application/json" not in content_type:
logger.warning(f"非JSON响应跳过处理: {content_type}")
return response
# 读取响应内容
try:
# 使用JSONResponse的方法获取内容
if isinstance(response, JSONResponse):
# 获取原始数据
raw_data = response.body.decode("utf-8")
data = json.loads(raw_data)
logger.warning(f"原始响应数据: {data}")
# 已经是标准格式,不再封装
if isinstance(data, dict) and "code" in data and ("data" in data or "message" in data):
logger.warning("响应已经是标准格式,不再包装")
return response
# 创建新的标准响应
std_response = StandardResponse(code=200, data=data)
logger.warning(f"包装为标准响应: code=200, data类型={type(data).__name__}")
# 创建新的JSONResponse
wrapped_response = JSONResponse(
content=std_response.model_dump(),
status_code=200, # 始终返回200状态码错误码在响应内容中
headers=dict(response.headers)
)
logger.warning("返回包装后的响应")
return wrapped_response
# 处理流式响应或其他类型响应
else:
logger.warning(f"非JSONResponse类型响应: {type(response)}")
return response
except Exception as e:
# 出现异常,记录日志
logger.error(f"处理响应内容时发生异常: {str(e)}", exc_info=True)
# 返回原始响应
return response
except Exception as e:
# 出现异常,返回原始响应
return response
logger.error(f"中间件处理请求时发生异常: {str(e)}", exc_info=True)
raise
# 使用函数工厂模式创建中间件,确保每次创建新实例
def create_response_middleware():
return ResponseMiddleware
def add_response_middleware(app: FastAPI) -> None:
"""添加响应处理中间件到FastAPI应用"""
app.add_middleware(ResponseMiddleware)
logger.warning("添加响应处理中间件到FastAPI应用")
# 使用函数工厂以确保每次获取新的中间件实例
app.add_middleware(create_response_middleware())

View File

@ -4,12 +4,13 @@ from contextlib import asynccontextmanager
import logging
from app.core.config import settings
from app.api.v1.api import api_router
from app.core.middleware import add_response_middleware
from app.core.exceptions import add_exception_handlers
from app.core.middleware import add_response_middleware
from app.schemas.response import StandardResponse
# 配置日志
logging.basicConfig(
level=logging.INFO,
level=logging.DEBUG, # 设置为DEBUG级别以显示更多日志
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
)
logger = logging.getLogger(__name__)
@ -32,6 +33,8 @@ app = FastAPI(
lifespan=lifespan
)
logger.info("开始配置应用中间件和路由...")
# 配置CORS
app.add_middleware(
CORSMiddleware,
@ -40,20 +43,25 @@ app.add_middleware(
allow_methods=["*"],
allow_headers=["*"],
)
# 添加响应中间件
add_response_middleware(app)
logger.info("CORS中间件已添加")
# 添加异常处理器
add_exception_handlers(app)
logger.info("异常处理器已添加")
# add_response_middleware(app)
# logger.info("响应中间件已添加")
# 包含API路由
app.include_router(api_router, prefix=settings.API_V1_STR)
logger.info(f"API路由已添加前缀: {settings.API_V1_STR}")
@app.get("/")
async def root():
return {"message": "欢迎使用美搭Meida API服务"}
logger.info("访问根路径")
return StandardResponse(code=200, data={"message": "欢迎使用美搭Meida API服务"})
@app.get("/health")
async def health_check():
return {"status": "healthy"}
logger.debug("健康检查")
return StandardResponse(code=200, data={"status": "healthy"})

View File

@ -6,11 +6,12 @@ T = TypeVar('T')
class StandardResponse(BaseModel, Generic[T]):
"""标准API响应格式"""
code: int = 200
message: Optional[str] = None
data: Optional[T] = None
code: int = Field(200, description="状态码200表示成功其他值表示错误")
message: Optional[str] = Field(None, description="消息,通常在错误时提供")
data: Optional[T] = Field(None, description="响应数据")
class Config:
from_attributes = True
json_schema_extra = {
"example": {
"code": 200,
@ -21,11 +22,12 @@ class StandardResponse(BaseModel, Generic[T]):
class ErrorResponse(BaseModel):
"""错误响应格式"""
code: int = 500
message: str
data: None = None
code: int = Field(500, description="错误码非200值")
message: str = Field(..., description="错误消息")
data: None = Field(None, description="数据字段错误时为null")
class Config:
from_attributes = True
json_schema_extra = {
"example": {
"code": 500,

View File

@ -1,6 +1,7 @@
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
from app.schemas.user import User
class UserImageBase(BaseModel):
"""用户形象基础模式"""
@ -9,7 +10,7 @@ class UserImageBase(BaseModel):
class UserImageCreate(UserImageBase):
"""创建用户形象请求"""
pass
user_id: Optional[int] = Field(None, description="用户ID可选通常由系统设置")
class UserImageInDB(UserImageBase):
"""数据库中的用户形象数据"""
@ -27,4 +28,11 @@ class UserImage(UserImageInDB):
class UserImageUpdate(BaseModel):
"""更新用户形象请求"""
image_url: Optional[str] = Field(None, description="图片URL")
is_default: Optional[bool] = Field(None, description="是否为默认形象")
is_default: Optional[bool] = Field(None, description="是否为默认形象")
class UserImageWithUser(UserImage):
"""包含用户信息的用户形象响应"""
user: Optional[User] = Field(None, description="用户信息")
class Config:
from_attributes = True

View File

@ -10,7 +10,7 @@ async def get_user_image(db: AsyncSession, image_id: int):
result = await db.execute(select(UserImage).filter(UserImage.id == image_id))
return result.scalars().first()
async def get_user_images(db: AsyncSession, user_id: int, skip: int = 0, limit: int = 100):
async def get_user_images_by_user(db: AsyncSession, user_id: int, skip: int = 0, limit: int = 100):
"""获取用户的所有形象图片"""
result = await db.execute(
select(UserImage)
@ -21,14 +21,17 @@ async def get_user_images(db: AsyncSession, user_id: int, skip: int = 0, limit:
)
return result.scalars().all()
async def create_user_image(db: AsyncSession, user_id: int, image: UserImageCreate):
"""创建用户形象"""
async def create_user_image(db: AsyncSession, image: UserImageCreate):
"""创建用户形象
image: 包含user_id的UserImageCreate对象
"""
# 如果设置为默认形象,先重置用户的所有形象为非默认
if image.is_default:
await reset_default_images(db, user_id)
if image.is_default and image.user_id:
await reset_default_images(db, image.user_id)
db_image = UserImage(
user_id=user_id,
user_id=image.user_id,
image_url=image.image_url,
is_default=image.is_default
)
@ -37,23 +40,23 @@ async def create_user_image(db: AsyncSession, user_id: int, image: UserImageCrea
await db.refresh(db_image)
return db_image
async def update_user_image(db: AsyncSession, image_id: int, image_update: UserImageUpdate):
async def update_user_image(db: AsyncSession, image_id: int, image: UserImageUpdate):
"""更新用户形象"""
db_image = await get_user_image(db, image_id)
if not db_image:
return None
# 处理默认形象逻辑
if image_update.is_default is not None and image_update.is_default and not db_image.is_default:
if image.is_default is not None and image.is_default and not db_image.is_default:
# 如果设置为默认且之前不是默认,重置其他形象
await reset_default_images(db, db_image.user_id)
db_image.is_default = True
elif image_update.is_default is not None:
db_image.is_default = image_update.is_default
elif image.is_default is not None:
db_image.is_default = image.is_default
# 更新图片URL
if image_update.image_url:
db_image.image_url = image_update.image_url
if image.image_url:
db_image.image_url = image.image_url
await db.commit()
await db.refresh(db_image)