aidress/app/routers/tryon_router.py
2025-03-21 17:06:54 +08:00

225 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import APIRouter, HTTPException, Depends, Query, Path, BackgroundTasks
from sqlalchemy.orm import Session
from typing import List, Optional
import httpx
import logging
import os
from dotenv import load_dotenv
from app.database import get_db
from app.models.tryon import TryOn
from app.schemas.tryon import (
TryOnCreate, TryOnUpdate, TryOnResponse,
AiTryonRequest, AiTryonResponse, TaskInfo
)
from app.services.dashscope_service import DashScopeService
# 加载环境变量
load_dotenv()
logger = logging.getLogger(__name__)
router = APIRouter()
# 从环境变量获取API密钥
DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis"
@router.post("/", response_model=TryOnResponse, status_code=201)
async def create_tryon(
tryon_data: TryOnCreate,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db)
):
"""
创建一个试穿记录并发送到阿里百炼平台
- **top_garment_url**: 上衣图片URL可选
- **bottom_garment_url**: 下衣图片URL可选
- **person_image_url**: 人物图片URL必填
注意:上衣和下衣至少需要提供一个
"""
if not tryon_data.top_garment_url and not tryon_data.bottom_garment_url:
raise HTTPException(status_code=400, detail="上衣和下衣图片至少需要提供一个")
try:
# 创建试穿记录
db_tryon = TryOn(
top_garment_url=tryon_data.top_garment_url,
bottom_garment_url=tryon_data.bottom_garment_url,
person_image_url=tryon_data.person_image_url
)
db.add(db_tryon)
db.commit()
db.refresh(db_tryon)
# 在后台发送请求到阿里百炼平台
background_tasks.add_task(
send_tryon_request,
db=db,
tryon_id=db_tryon.id,
top_garment_url=tryon_data.top_garment_url,
bottom_garment_url=tryon_data.bottom_garment_url,
person_image_url=tryon_data.person_image_url
)
return db_tryon
except Exception as e:
logger.error(f"创建试穿记录失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"创建试穿记录失败: {str(e)}")
async def send_tryon_request(
db: Session,
tryon_id: int,
top_garment_url: Optional[str],
bottom_garment_url: Optional[str],
person_image_url: str
):
"""发送试穿请求到阿里百炼平台"""
try:
# 创建DashScopeService实例
dashscope_service = DashScopeService()
# 调用服务发送试穿请求
response = await dashscope_service.generate_tryon(
person_image_url=person_image_url,
top_garment_url=top_garment_url,
bottom_garment_url=bottom_garment_url
)
# 更新数据库记录
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
if db_tryon:
db_tryon.request_id = response.get("request_id")
db_tryon.task_id = response.get("output", {}).get("task_id")
db_tryon.task_status = response.get("output", {}).get("task_status")
db.commit()
logger.info(f"试穿请求发送成功任务ID: {db_tryon.task_id}")
except Exception as e:
logger.error(f"发送试穿请求异常: {str(e)}")
# 更新数据库记录状态
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
if db_tryon:
db_tryon.task_status = "ERROR"
db.commit()
@router.get("/", response_model=List[TryOnResponse])
async def get_tryons(
skip: int = Query(0, description="跳过的记录数量"),
limit: int = Query(100, description="返回的最大记录数量"),
db: Session = Depends(get_db)
):
"""
获取试穿记录列表,支持分页
- **skip**: 跳过的记录数量,用于分页
- **limit**: 返回的最大记录数量,用于分页
"""
try:
tryons = db.query(TryOn).order_by(TryOn.created_at.desc()).offset(skip).limit(limit).all()
return tryons
except Exception as e:
logger.error(f"获取试穿记录列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取试穿记录列表失败: {str(e)}")
@router.get("/{tryon_id}", response_model=TryOnResponse)
async def get_tryon(
tryon_id: int = Path(..., description="试穿记录ID"),
db: Session = Depends(get_db)
):
"""
根据ID获取试穿记录详情
- **tryon_id**: 试穿记录ID
"""
try:
tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
if not tryon:
raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录")
return tryon
except HTTPException:
raise
except Exception as e:
logger.error(f"获取试穿记录详情失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取试穿记录详情失败: {str(e)}")
@router.put("/{tryon_id}", response_model=TryOnResponse)
async def update_tryon(
tryon_id: int = Path(..., description="试穿记录ID"),
tryon_data: TryOnUpdate = None,
db: Session = Depends(get_db)
):
"""
更新试穿记录信息
- **tryon_id**: 试穿记录ID
- **task_status**: 任务状态(可选)
- **completion_url**: 生成图片URL可选
"""
try:
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
if not db_tryon:
raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录")
# 更新提供的字段
if tryon_data.task_status is not None:
db_tryon.task_status = tryon_data.task_status
if tryon_data.completion_url is not None:
db_tryon.completion_url = tryon_data.completion_url
db.commit()
db.refresh(db_tryon)
return db_tryon
except HTTPException:
raise
except Exception as e:
logger.error(f"更新试穿记录失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"更新试穿记录失败: {str(e)}")
@router.post("/{tryon_id}/check", response_model=TryOnResponse)
async def check_tryon_status(
tryon_id: int = Path(..., description="试穿记录ID"),
db: Session = Depends(get_db),
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
):
"""
检查试穿任务状态
- **tryon_id**: 试穿记录ID
"""
try:
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
if not db_tryon:
raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录")
if not db_tryon.task_id:
raise HTTPException(status_code=400, detail=f"试穿记录未包含任务ID")
# 调用DashScopeService检查任务状态
try:
status_response = await dashscope_service.check_tryon_status(db_tryon.task_id)
# 更新数据库记录
db_tryon.task_status = status_response.get("output", {}).get("task_status")
# 如果任务完成保存结果URL
if db_tryon.task_status == "SUCCEEDED":
db_tryon.completion_url = status_response.get("output", {}).get("url")
db.commit()
db.refresh(db_tryon)
logger.info(f"试穿任务状态更新: {db_tryon.task_status}")
except Exception as e:
logger.error(f"调用DashScope API检查任务状态失败: {str(e)}")
return db_tryon
except HTTPException:
raise
except Exception as e:
logger.error(f"检查试穿任务状态失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"检查试穿任务状态失败: {str(e)}")