225 lines
7.9 KiB
Python
225 lines
7.9 KiB
Python
from fastapi import APIRouter, HTTPException, Depends, Query, Path, BackgroundTasks
|
||
from sqlalchemy.orm import Session
|
||
from typing import List, Optional
|
||
import httpx
|
||
import logging
|
||
import os
|
||
from dotenv import load_dotenv
|
||
|
||
from app.database import get_db
|
||
from app.models.tryon import TryOn
|
||
from app.schemas.tryon import (
|
||
TryOnCreate, TryOnUpdate, TryOnResponse,
|
||
AiTryonRequest, AiTryonResponse, TaskInfo
|
||
)
|
||
from app.services.dashscope_service import DashScopeService
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
logger = logging.getLogger(__name__)
|
||
router = APIRouter()
|
||
|
||
# 从环境变量获取API密钥
|
||
DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
|
||
DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis"
|
||
|
||
@router.post("/", response_model=TryOnResponse, status_code=201)
|
||
async def create_tryon(
|
||
tryon_data: TryOnCreate,
|
||
background_tasks: BackgroundTasks,
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""
|
||
创建一个试穿记录并发送到阿里百炼平台
|
||
|
||
- **top_garment_url**: 上衣图片URL(可选)
|
||
- **bottom_garment_url**: 下衣图片URL(可选)
|
||
- **person_image_url**: 人物图片URL(必填)
|
||
|
||
注意:上衣和下衣至少需要提供一个
|
||
"""
|
||
if not tryon_data.top_garment_url and not tryon_data.bottom_garment_url:
|
||
raise HTTPException(status_code=400, detail="上衣和下衣图片至少需要提供一个")
|
||
|
||
try:
|
||
# 创建试穿记录
|
||
db_tryon = TryOn(
|
||
top_garment_url=tryon_data.top_garment_url,
|
||
bottom_garment_url=tryon_data.bottom_garment_url,
|
||
person_image_url=tryon_data.person_image_url
|
||
)
|
||
db.add(db_tryon)
|
||
db.commit()
|
||
db.refresh(db_tryon)
|
||
|
||
# 在后台发送请求到阿里百炼平台
|
||
background_tasks.add_task(
|
||
send_tryon_request,
|
||
db=db,
|
||
tryon_id=db_tryon.id,
|
||
top_garment_url=tryon_data.top_garment_url,
|
||
bottom_garment_url=tryon_data.bottom_garment_url,
|
||
person_image_url=tryon_data.person_image_url
|
||
)
|
||
|
||
return db_tryon
|
||
except Exception as e:
|
||
logger.error(f"创建试穿记录失败: {str(e)}")
|
||
db.rollback()
|
||
raise HTTPException(status_code=500, detail=f"创建试穿记录失败: {str(e)}")
|
||
|
||
async def send_tryon_request(
|
||
db: Session,
|
||
tryon_id: int,
|
||
top_garment_url: Optional[str],
|
||
bottom_garment_url: Optional[str],
|
||
person_image_url: str
|
||
):
|
||
"""发送试穿请求到阿里百炼平台"""
|
||
try:
|
||
# 创建DashScopeService实例
|
||
dashscope_service = DashScopeService()
|
||
|
||
# 调用服务发送试穿请求
|
||
response = await dashscope_service.generate_tryon(
|
||
person_image_url=person_image_url,
|
||
top_garment_url=top_garment_url,
|
||
bottom_garment_url=bottom_garment_url
|
||
)
|
||
|
||
# 更新数据库记录
|
||
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
|
||
if db_tryon:
|
||
db_tryon.request_id = response.get("request_id")
|
||
db_tryon.task_id = response.get("output", {}).get("task_id")
|
||
db_tryon.task_status = response.get("output", {}).get("task_status")
|
||
db.commit()
|
||
|
||
logger.info(f"试穿请求发送成功,任务ID: {db_tryon.task_id}")
|
||
except Exception as e:
|
||
logger.error(f"发送试穿请求异常: {str(e)}")
|
||
|
||
# 更新数据库记录状态
|
||
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
|
||
if db_tryon:
|
||
db_tryon.task_status = "ERROR"
|
||
db.commit()
|
||
|
||
@router.get("/", response_model=List[TryOnResponse])
|
||
async def get_tryons(
|
||
skip: int = Query(0, description="跳过的记录数量"),
|
||
limit: int = Query(100, description="返回的最大记录数量"),
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""
|
||
获取试穿记录列表,支持分页
|
||
|
||
- **skip**: 跳过的记录数量,用于分页
|
||
- **limit**: 返回的最大记录数量,用于分页
|
||
"""
|
||
try:
|
||
tryons = db.query(TryOn).order_by(TryOn.created_at.desc()).offset(skip).limit(limit).all()
|
||
return tryons
|
||
except Exception as e:
|
||
logger.error(f"获取试穿记录列表失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取试穿记录列表失败: {str(e)}")
|
||
|
||
@router.get("/{tryon_id}", response_model=TryOnResponse)
|
||
async def get_tryon(
|
||
tryon_id: int = Path(..., description="试穿记录ID"),
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""
|
||
根据ID获取试穿记录详情
|
||
|
||
- **tryon_id**: 试穿记录ID
|
||
"""
|
||
try:
|
||
tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
|
||
if not tryon:
|
||
raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录")
|
||
return tryon
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取试穿记录详情失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取试穿记录详情失败: {str(e)}")
|
||
|
||
@router.put("/{tryon_id}", response_model=TryOnResponse)
|
||
async def update_tryon(
|
||
tryon_id: int = Path(..., description="试穿记录ID"),
|
||
tryon_data: TryOnUpdate = None,
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""
|
||
更新试穿记录信息
|
||
|
||
- **tryon_id**: 试穿记录ID
|
||
- **task_status**: 任务状态(可选)
|
||
- **completion_url**: 生成图片URL(可选)
|
||
"""
|
||
try:
|
||
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
|
||
if not db_tryon:
|
||
raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录")
|
||
|
||
# 更新提供的字段
|
||
if tryon_data.task_status is not None:
|
||
db_tryon.task_status = tryon_data.task_status
|
||
if tryon_data.completion_url is not None:
|
||
db_tryon.completion_url = tryon_data.completion_url
|
||
|
||
db.commit()
|
||
db.refresh(db_tryon)
|
||
return db_tryon
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"更新试穿记录失败: {str(e)}")
|
||
db.rollback()
|
||
raise HTTPException(status_code=500, detail=f"更新试穿记录失败: {str(e)}")
|
||
|
||
@router.post("/{tryon_id}/check", response_model=TryOnResponse)
|
||
async def check_tryon_status(
|
||
tryon_id: int = Path(..., description="试穿记录ID"),
|
||
db: Session = Depends(get_db),
|
||
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||
):
|
||
"""
|
||
检查试穿任务状态
|
||
|
||
- **tryon_id**: 试穿记录ID
|
||
"""
|
||
try:
|
||
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
|
||
if not db_tryon:
|
||
raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录")
|
||
|
||
if not db_tryon.task_id:
|
||
raise HTTPException(status_code=400, detail=f"试穿记录未包含任务ID")
|
||
|
||
# 调用DashScopeService检查任务状态
|
||
try:
|
||
status_response = await dashscope_service.check_tryon_status(db_tryon.task_id)
|
||
|
||
# 更新数据库记录
|
||
db_tryon.task_status = status_response.get("output", {}).get("task_status")
|
||
|
||
# 如果任务完成,保存结果URL
|
||
if db_tryon.task_status == "SUCCEEDED":
|
||
db_tryon.completion_url = status_response.get("output", {}).get("url")
|
||
|
||
db.commit()
|
||
db.refresh(db_tryon)
|
||
|
||
logger.info(f"试穿任务状态更新: {db_tryon.task_status}")
|
||
except Exception as e:
|
||
logger.error(f"调用DashScope API检查任务状态失败: {str(e)}")
|
||
|
||
return db_tryon
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"检查试穿任务状态失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"检查试穿任务状态失败: {str(e)}") |