181 lines
6.9 KiB
Python
181 lines
6.9 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, status, Response
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
import logging
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
from app.api import deps
|
|
from app.core import security
|
|
from app.core.config import settings
|
|
from app.schemas.tryon import TryonRequest
|
|
from app.schemas.user import User
|
|
from app.services import person_image as person_image_service
|
|
from app.services import clothing as clothing_service
|
|
from app.api.deps import get_current_user
|
|
from app.services.dashscope_service import DashScopeService
|
|
from app.schemas.response import StandardResponse
|
|
from app.models.tryon import TryonHistory, TryonStatus
|
|
from app.schemas.tryon import TryonHistoryModel
|
|
from app.core.exceptions import BusinessError
|
|
from app.services import cos as cos_service
|
|
from app.schemas.person_image import PersonImage
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import selectinload
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.DEBUG)
|
|
router = APIRouter()
|
|
|
|
@router.post("", tags=["tryon"])
|
|
async def tryon(
|
|
tryon_request: TryonRequest,
|
|
db: AsyncSession = Depends(deps.get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""
|
|
试穿请求
|
|
"""
|
|
|
|
# 获取当前用户的默认形象
|
|
person_image = await person_image_service.get_default_image(db, current_user.id)
|
|
if not person_image:
|
|
raise HTTPException(status_code=404, detail="默认形象不存在")
|
|
|
|
# 获取试穿请求中的衣物ID
|
|
top_clothing_id = tryon_request.top_clothing_id
|
|
bottom_clothing_id = tryon_request.bottom_clothing_id
|
|
|
|
# 获取衣物详情
|
|
top_clothing_url = tryon_request.top_clothing_url
|
|
bottom_clothing_url = tryon_request.bottom_clothing_url
|
|
|
|
if top_clothing_id:
|
|
top_clothing = await clothing_service.get_clothing(db, top_clothing_id)
|
|
top_clothing_url = top_clothing.image_url
|
|
|
|
if bottom_clothing_id:
|
|
bottom_clothing = await clothing_service.get_clothing(db, bottom_clothing_id)
|
|
bottom_clothing_url = bottom_clothing.image_url
|
|
|
|
|
|
# 调用试穿服务
|
|
dashscope_service = DashScopeService()
|
|
tryon_result = await dashscope_service.generate_tryon(person_image.image_url,
|
|
top_clothing_url,
|
|
bottom_clothing_url)
|
|
|
|
task_id = tryon_result.get("task_id")
|
|
if task_id:
|
|
tryon_history = TryonHistory(
|
|
user_id=current_user.id,
|
|
person_image_id=person_image.id,
|
|
top_clothing_id=top_clothing_id,
|
|
bottom_clothing_id=bottom_clothing_id,
|
|
top_clothing_url=top_clothing_url,
|
|
bottom_clothing_url=bottom_clothing_url,
|
|
task_id=task_id,
|
|
status=TryonStatus.GENERATING
|
|
)
|
|
db.add(tryon_history)
|
|
await db.commit()
|
|
return StandardResponse(code=200, message="试穿任务已提交", data=tryon_history.id)
|
|
else:
|
|
return StandardResponse(code=500, message="试穿任务提交失败")
|
|
|
|
@router.get("/histories", tags=["tryon"])
|
|
async def get_tryon_histories(
|
|
db: AsyncSession = Depends(deps.get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""
|
|
获取试穿历史
|
|
"""
|
|
histories = await db.execute(select(TryonHistory).where(TryonHistory.user_id == current_user.id).order_by(TryonHistory.create_time.desc()).options(selectinload(TryonHistory.person_image)))
|
|
tryon_histories = histories.scalars().all()
|
|
|
|
result = []
|
|
for history in tryon_histories:
|
|
item = {
|
|
"id": history.id,
|
|
"person_image": PersonImage.model_validate(history.person_image),
|
|
"top_clothing_url": history.top_clothing_url,
|
|
"bottom_clothing_url": history.bottom_clothing_url,
|
|
"top_clothing_id": history.top_clothing_id,
|
|
"bottom_clothing_id": history.bottom_clothing_id,
|
|
"status": history.status,
|
|
"task_id": history.task_id,
|
|
"completion_url": history.completion_url,
|
|
}
|
|
result.append(item)
|
|
|
|
return StandardResponse(code=200, message="试穿历史获取成功", data=result)
|
|
|
|
@router.delete("/history/{history_id}", tags=["tryon"])
|
|
async def delete_tryon_history(
|
|
history_id: int,
|
|
db: AsyncSession = Depends(deps.get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""
|
|
删除试穿历史
|
|
"""
|
|
history = await db.execute(select(TryonHistory).where(TryonHistory.id == history_id))
|
|
tryon_history = history.scalar_one_or_none()
|
|
if not tryon_history:
|
|
raise BusinessError(code=404, message="试穿历史不存在")
|
|
if tryon_history.user_id != current_user.id:
|
|
raise BusinessError(code=403, message="无权限删除试穿历史")
|
|
|
|
await db.delete(tryon_history)
|
|
await db.commit()
|
|
|
|
return StandardResponse(code=200, message="试穿历史删除成功")
|
|
|
|
|
|
@router.get("/history/{history_id}", tags=["tryon"])
|
|
async def get_tryon_history(
|
|
history_id: int,
|
|
db: AsyncSession = Depends(deps.get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""
|
|
获取试穿历史详情
|
|
"""
|
|
history = await db.execute(select(TryonHistory).where(TryonHistory.id == history_id, TryonHistory.user_id == current_user.id))
|
|
tryon_history = history.scalar_one_or_none()
|
|
if not tryon_history:
|
|
raise BusinessError(code=404, message="试穿历史不存在")
|
|
|
|
return StandardResponse(code=200, message="试穿历史详情获取成功", data=TryonHistoryModel.model_validate(tryon_history))
|
|
|
|
@router.get("/history/{history_id}/check", tags=["tryon"])
|
|
async def check_tryon_status(
|
|
history_id: int,
|
|
db: AsyncSession = Depends(deps.get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""
|
|
检查试穿状态
|
|
"""
|
|
history = await db.execute(select(TryonHistory).where(TryonHistory.id == history_id, TryonHistory.user_id == current_user.id))
|
|
tryon_history = history.scalar_one_or_none()
|
|
if not tryon_history:
|
|
raise BusinessError(code=404, message="试穿历史不存在")
|
|
|
|
if tryon_history.status != TryonStatus.COMPLETED:
|
|
dashscope_service = DashScopeService()
|
|
completion = await dashscope_service.check_tryon_status(tryon_history.task_id)
|
|
if completion.get("status") == "SUCCEEDED":
|
|
completion_url = completion.get("image_url")
|
|
print(completion_url)
|
|
|
|
url = await cos_service.upload_file_from_url(completion_url)
|
|
tryon_history.status = TryonStatus.COMPLETED
|
|
tryon_history.completion_url = url
|
|
await db.commit()
|
|
await db.refresh(tryon_history)
|
|
elif completion.get("status") == "FAILED":
|
|
tryon_history.status = TryonStatus.FAILED
|
|
await db.commit()
|
|
await db.refresh(tryon_history)
|
|
|
|
|
|
return StandardResponse(code=200, message="", data=TryonHistoryModel.model_validate(tryon_history))
|