diff --git a/app/api/v1/clothing.py b/app/api/v1/clothing.py index 9a512e8..11d9028 100644 --- a/app/api/v1/clothing.py +++ b/app/api/v1/clothing.py @@ -122,10 +122,11 @@ async def create_clothing( async def read_clothes( skip: int = Query(0, ge=0), limit: int = Query(100, ge=1, le=100), - db: AsyncSession = Depends(get_db) + db: AsyncSession = Depends(get_db), + current_user: UserModel = Depends(get_current_user) ): """获取所有衣服""" - clothes = await clothing_service.get_clothes(db=db, skip=skip, limit=limit) + clothes = await clothing_service.get_clothes(db=db, skip=skip, limit=limit, user_id=current_user.id) # 手动返回标准响应格式 return StandardResponse(code=200, data=[Clothing.model_validate(clothing) for clothing in clothes]) diff --git a/app/api/v1/tryon.py b/app/api/v1/tryon.py index 7dfa561..3abc2df 100644 --- a/app/api/v1/tryon.py +++ b/app/api/v1/tryon.py @@ -12,8 +12,10 @@ from app.services import clothing as clothing_service from app.api.deps import get_current_user from app.services.dashscope_service import DashScopeService from app.schemas.response import StandardResponse -from app.models.tryon import TryonHistory +from app.models.tryon import TryonHistory, TryonStatus from app.schemas.tryon import TryonHistoryModel +from app.core.exceptions import BusinessError +from app.services import cos as cos_service from sqlalchemy import select logger = logging.getLogger(__name__) @@ -67,7 +69,8 @@ async def tryon( bottom_clothing_id=bottom_clothing_id, top_clothing_url=top_clothing_url, bottom_clothing_url=bottom_clothing_url, - task_id=task_id + task_id=task_id, + status=TryonStatus.GENERATING ) db.add(tryon_history) await db.commit() @@ -87,3 +90,55 @@ async def get_tryon_histories( tryon_histories = histories.scalars().all() return StandardResponse(code=200, message="试穿历史获取成功", data=[TryonHistoryModel.model_validate(history) for history in tryon_histories]) + +@router.get("/tryon/history/{history_id}", tags=["tryon"]) +async def get_tryon_history( + history_id: int, + db: AsyncSession = Depends(deps.get_db), + current_user: User = Depends(get_current_user) +): + """ + 获取试穿历史详情 + """ + history = await db.execute(select(TryonHistory).where(TryonHistory.id == history_id, TryonHistory.user_id == current_user.id)) + tryon_history = history.scalar_one_or_none() + if not tryon_history: + raise BusinessError(code=404, message="试穿历史不存在") + + return StandardResponse(code=200, message="试穿历史详情获取成功", data=TryonHistoryModel.model_validate(tryon_history)) + +@router.get("/tryon/history/{history_id}/check", tags=["tryon"]) +async def check_tryon_status( + history_id: int, + db: AsyncSession = Depends(deps.get_db), + current_user: User = Depends(get_current_user) +): + """ + 检查试穿状态 + """ + history = await db.execute(select(TryonHistory).where(TryonHistory.id == history_id, TryonHistory.user_id == current_user.id)) + tryon_history = history.scalar_one_or_none() + if not tryon_history: + raise BusinessError(code=404, message="试穿历史不存在") + + dashscope_service = DashScopeService() + completion = await dashscope_service.check_tryon_status(tryon_history.task_id) + if completion.get("status") == "SUCCEEDED": + completion_url = completion.get("image_url") + + url = await cos_service.upload_file_from_url(completion_url, "tryon", tryon_history.id) + tryon_history.status = TryonStatus.COMPLETED + tryon_history.completion_url = url + await db.commit() + + return StandardResponse(code=200, message="检查试穿状态成功", data={ + "history_id": tryon_history.id, + "status": tryon_history.status, + "completion_url": url + }) + else: + return StandardResponse(code=200, message="检查试穿状态成功", data={ + "history_id": tryon_history.id, + "status": tryon_history.status, + "completion_url": None + }) diff --git a/app/core/exceptions.py b/app/core/exceptions.py index 1a81267..c36362b 100644 --- a/app/core/exceptions.py +++ b/app/core/exceptions.py @@ -63,8 +63,8 @@ async def http_exception_handler(request: Request, exc: HTTPException): ) return JSONResponse( - status_code=200, # 与业务异常一致,返回200状态码 - content=error_response.model_dump() + status_code=exc.status_code, + content=exc.detail ) def add_exception_handlers(app): diff --git a/app/services/clothing.py b/app/services/clothing.py index 5435bdc..eaae9e5 100644 --- a/app/services/clothing.py +++ b/app/services/clothing.py @@ -61,10 +61,11 @@ async def get_clothing(db: AsyncSession, clothing_id: int): result = await db.execute(select(Clothing).filter(Clothing.id == clothing_id)) return result.scalars().first() -async def get_clothes(db: AsyncSession, skip: int = 0, limit: int = 100): +async def get_clothes(db: AsyncSession, skip: int = 0, limit: int = 100, user_id: int = None): """获取所有衣服""" result = await db.execute( select(Clothing) + .filter(Clothing.user_id == user_id) .order_by(Clothing.create_time.desc()) .offset(skip) .limit(limit)