206 lines
7.7 KiB
Python
206 lines
7.7 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, status, Response
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
import logging
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
from app.api import deps
|
|
from app.core import security
|
|
from app.core.config import settings
|
|
from app.schemas.tryon import TryonRequest
|
|
from app.schemas.user import User
|
|
from app.services import person_image as person_image_service
|
|
from app.services import clothing as clothing_service
|
|
from app.api.deps import get_current_user
|
|
from app.services.dashscope_service import DashScopeService
|
|
from app.schemas.response import StandardResponse
|
|
from app.models.tryon import TryonHistory, TryonStatus
|
|
from app.schemas.tryon import TryonHistoryModel
|
|
from app.core.exceptions import BusinessError
|
|
from app.services import cos as cos_service
|
|
from app.schemas.person_image import PersonImage
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import selectinload
|
|
logger = logging.getLogger(__name__)
|
|
logger.setLevel(logging.DEBUG)
|
|
router = APIRouter()
|
|
|
|
@router.post("", tags=["tryon"])
|
|
async def tryon(
|
|
tryon_request: TryonRequest,
|
|
db: AsyncSession = Depends(deps.get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""
|
|
试穿请求
|
|
"""
|
|
# 获取当前用户的试穿次数
|
|
tryon_remain_count = current_user.tryon_remain_count
|
|
if tryon_remain_count <= 0:
|
|
raise BusinessError(code=400, message="试穿次数不足")
|
|
|
|
# 获取当前用户的默认形象
|
|
person_image = await person_image_service.get_default_image(db, current_user.id)
|
|
if not person_image:
|
|
raise HTTPException(status_code=404, detail="默认形象不存在")
|
|
|
|
|
|
# 获取衣物详情
|
|
top_clothing_url = tryon_request.top_clothing_url
|
|
bottom_clothing_url = tryon_request.bottom_clothing_url
|
|
|
|
# 调用试穿服务
|
|
dashscope_service = DashScopeService()
|
|
tryon_result = await dashscope_service.generate_tryon(person_image.image_url,
|
|
top_clothing_url,
|
|
bottom_clothing_url)
|
|
|
|
task_id = tryon_result.get("task_id")
|
|
if task_id:
|
|
tryon_history = TryonHistory(
|
|
user_id=current_user.id,
|
|
person_image_url=person_image.image_url,
|
|
top_clothing_url=top_clothing_url,
|
|
bottom_clothing_url=bottom_clothing_url,
|
|
task_id=task_id,
|
|
status=TryonStatus.GENERATING
|
|
)
|
|
db.add(tryon_history)
|
|
|
|
# 更新试穿次数
|
|
current_user.tryon_remain_count -= 1
|
|
await db.commit()
|
|
return StandardResponse(code=200, message="试穿任务已提交", data={"tryon_history_id": tryon_history.id, "tryon_remain_count": current_user.tryon_remain_count})
|
|
else:
|
|
return StandardResponse(code=500, message="试穿任务提交失败")
|
|
|
|
@router.get("/histories", tags=["tryon"])
|
|
async def get_tryon_histories(
|
|
db: AsyncSession = Depends(deps.get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
skip: int = 0,
|
|
limit: int = 10
|
|
):
|
|
"""
|
|
获取试穿历史
|
|
"""
|
|
histories = await db.execute(select(TryonHistory).where(TryonHistory.user_id == current_user.id).order_by(TryonHistory.create_time.desc()).offset(skip).limit(limit))
|
|
tryon_histories = histories.scalars().all()
|
|
|
|
result = []
|
|
for history in tryon_histories:
|
|
item = {
|
|
"id": history.id,
|
|
"person_image_url": history.person_image_url,
|
|
"top_clothing_url": history.top_clothing_url,
|
|
"bottom_clothing_url": history.bottom_clothing_url,
|
|
"status": history.status,
|
|
"task_id": history.task_id,
|
|
"completion_url": history.completion_url,
|
|
}
|
|
result.append(item)
|
|
|
|
return StandardResponse(code=200, message="试穿历史获取成功", data=result)
|
|
|
|
@router.delete("/history/{history_id}", tags=["tryon"])
|
|
async def delete_tryon_history(
|
|
history_id: int,
|
|
db: AsyncSession = Depends(deps.get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""
|
|
删除试穿历史
|
|
"""
|
|
history = await db.execute(select(TryonHistory).where(TryonHistory.id == history_id))
|
|
tryon_history = history.scalar_one_or_none()
|
|
if not tryon_history:
|
|
raise BusinessError(code=404, message="试穿历史不存在")
|
|
if tryon_history.user_id != current_user.id:
|
|
raise BusinessError(code=403, message="无权限删除试穿历史")
|
|
|
|
await db.delete(tryon_history)
|
|
await db.commit()
|
|
|
|
return StandardResponse(code=200, message="试穿历史删除成功")
|
|
|
|
|
|
@router.get("/history/{history_id}", tags=["tryon"])
|
|
async def get_tryon_history(
|
|
history_id: int,
|
|
db: AsyncSession = Depends(deps.get_db)
|
|
):
|
|
"""
|
|
获取试穿历史详情
|
|
"""
|
|
history = await db.execute(select(TryonHistory).where(TryonHistory.id == history_id))
|
|
tryon_history = history.scalar_one_or_none()
|
|
if not tryon_history:
|
|
raise BusinessError(code=404, message="试穿历史不存在")
|
|
|
|
return StandardResponse(code=200, message="试穿历史详情获取成功", data=TryonHistoryModel.model_validate(tryon_history))
|
|
|
|
@router.get("/history/{history_id}/check", tags=["tryon"])
|
|
async def check_tryon_status(
|
|
history_id: int,
|
|
db: AsyncSession = Depends(deps.get_db),
|
|
current_user: User = Depends(get_current_user)
|
|
):
|
|
"""
|
|
检查试穿状态
|
|
"""
|
|
history = await db.execute(select(TryonHistory).where(TryonHistory.id == history_id, TryonHistory.user_id == current_user.id))
|
|
tryon_history = history.scalar_one_or_none()
|
|
if not tryon_history:
|
|
raise BusinessError(code=404, message="试穿历史不存在")
|
|
|
|
if tryon_history.status != TryonStatus.COMPLETED:
|
|
dashscope_service = DashScopeService()
|
|
completion = await dashscope_service.check_tryon_status(tryon_history.task_id)
|
|
if completion.get("status") == "SUCCEEDED":
|
|
completion_url = completion.get("image_url")
|
|
print(completion_url)
|
|
|
|
url = await cos_service.upload_file_from_url(completion_url)
|
|
tryon_history.status = TryonStatus.COMPLETED
|
|
tryon_history.completion_url = url
|
|
await db.commit()
|
|
await db.refresh(tryon_history)
|
|
elif completion.get("status") == "FAILED":
|
|
tryon_history.status = TryonStatus.FAILED
|
|
await db.commit()
|
|
await db.refresh(tryon_history)
|
|
|
|
|
|
return StandardResponse(code=200, message="", data=TryonHistoryModel.model_validate(tryon_history))
|
|
|
|
|
|
@router.get("/comment", tags=["tryon"])
|
|
async def comment_tryon(
|
|
history_id: int,
|
|
db: AsyncSession = Depends(deps.get_db)
|
|
):
|
|
"""
|
|
穿搭点评
|
|
"""
|
|
history = await db.execute(select(TryonHistory).where(TryonHistory.id == history_id))
|
|
tryon_history = history.scalar_one_or_none()
|
|
if not tryon_history:
|
|
raise BusinessError(code=404, message="试穿历史不存在")
|
|
|
|
if tryon_history.status != TryonStatus.COMPLETED:
|
|
raise BusinessError(code=400, message="试穿未完成")
|
|
|
|
if tryon_history.comment and tryon_history.score:
|
|
return StandardResponse(code=200, message="穿搭点评获取成功", data=TryonHistoryModel.model_validate(tryon_history))
|
|
|
|
dashscope_service = DashScopeService()
|
|
comment = await dashscope_service.generate_dressing_comment(tryon_history.completion_url)
|
|
|
|
if comment.get("comment") and comment.get("score"):
|
|
logger.info(f"穿搭点评: {comment}")
|
|
tryon_history.comment = comment.get("comment")
|
|
tryon_history.score = comment.get("score")
|
|
await db.commit()
|
|
await db.refresh(tryon_history)
|
|
|
|
return StandardResponse(code=200, message="穿搭点评获取成功", data=TryonHistoryModel.model_validate(tryon_history))
|
|
|