262 lines
9.2 KiB
Python
262 lines
9.2 KiB
Python
from fastapi import APIRouter, HTTPException, Depends
|
||
from pydantic import BaseModel
|
||
import logging
|
||
from typing import List, Optional, Dict, Any
|
||
import dashscope
|
||
from app.services.dashscope_service import DashScopeService
|
||
from app.utils.config import get_settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
router = APIRouter()
|
||
|
||
class ChatMessage(BaseModel):
|
||
role: str # 'user' 或 'assistant'
|
||
content: str
|
||
|
||
class ChatRequest(BaseModel):
|
||
messages: List[ChatMessage]
|
||
model: Optional[str] = "qwen-max" # 默认使用通义千问MAX模型
|
||
max_tokens: Optional[int] = 2048
|
||
temperature: Optional[float] = 0.7
|
||
stream: Optional[bool] = False
|
||
|
||
class ChatResponse(BaseModel):
|
||
response: str
|
||
usage: Dict[str, Any]
|
||
request_id: str
|
||
|
||
@router.post("/chat", response_model=ChatResponse)
|
||
async def chat_completion(
|
||
request: ChatRequest,
|
||
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||
):
|
||
"""
|
||
调用DashScope的大模型进行对话
|
||
"""
|
||
try:
|
||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||
response = await dashscope_service.chat_completion(
|
||
messages=messages,
|
||
model=request.model,
|
||
max_tokens=request.max_tokens,
|
||
temperature=request.temperature,
|
||
stream=request.stream
|
||
)
|
||
|
||
# 确保我们能正确访问响应字段,根据dashscope的最新版本调整
|
||
response_text = ""
|
||
if hasattr(response, 'output') and hasattr(response.output, 'text'):
|
||
response_text = response.output.text
|
||
elif hasattr(response, 'output') and isinstance(response.output, dict) and 'text' in response.output:
|
||
response_text = response.output['text']
|
||
elif hasattr(response, 'choices') and len(response.choices) > 0:
|
||
response_text = response.choices[0].message.content
|
||
else:
|
||
logger.warning(f"无法解析DashScope响应: {response}")
|
||
# 尝试保守提取
|
||
try:
|
||
if hasattr(response, 'output'):
|
||
if isinstance(response.output, dict):
|
||
response_text = str(response.output.get('text', ''))
|
||
else:
|
||
response_text = str(response.output)
|
||
else:
|
||
response_text = str(response)
|
||
except Exception as e:
|
||
logger.error(f"提取响应文本失败: {str(e)}")
|
||
response_text = "无法获取模型响应文本"
|
||
|
||
# 同样确保我们能正确访问usage和request_id
|
||
usage = {}
|
||
if hasattr(response, 'usage'):
|
||
usage = response.usage if isinstance(response.usage, dict) else vars(response.usage)
|
||
|
||
request_id = ""
|
||
if hasattr(response, 'request_id'):
|
||
request_id = response.request_id
|
||
|
||
return {
|
||
"response": response_text,
|
||
"usage": usage,
|
||
"request_id": request_id
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"DashScope API 调用错误: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"模型调用失败: {str(e)}")
|
||
|
||
class ImageGenerationRequest(BaseModel):
|
||
prompt: str
|
||
negative_prompt: Optional[str] = None
|
||
model: Optional[str] = "stable-diffusion-xl"
|
||
n: Optional[int] = 1
|
||
size: Optional[str] = "1024*1024"
|
||
|
||
class ImageGenerationResponse(BaseModel):
|
||
images: List[str]
|
||
request_id: str
|
||
|
||
@router.post("/generate-image", response_model=ImageGenerationResponse)
|
||
async def generate_image(
|
||
request: ImageGenerationRequest,
|
||
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||
):
|
||
"""
|
||
调用DashScope的图像生成API
|
||
"""
|
||
try:
|
||
response = await dashscope_service.generate_image(
|
||
prompt=request.prompt,
|
||
negative_prompt=request.negative_prompt,
|
||
model=request.model,
|
||
n=request.n,
|
||
size=request.size
|
||
)
|
||
|
||
# 确保我们能正确访问images和request_id
|
||
images = []
|
||
if hasattr(response, 'output') and hasattr(response.output, 'images'):
|
||
images = response.output.images
|
||
elif hasattr(response, 'output') and isinstance(response.output, dict) and 'images' in response.output:
|
||
images = response.output['images']
|
||
else:
|
||
logger.warning(f"无法解析DashScope图像生成响应: {response}")
|
||
try:
|
||
if hasattr(response, 'output'):
|
||
if isinstance(response.output, dict):
|
||
images = response.output.get('images', [])
|
||
else:
|
||
images = []
|
||
else:
|
||
images = []
|
||
except Exception as e:
|
||
logger.error(f"提取图像URL失败: {str(e)}")
|
||
images = []
|
||
|
||
request_id = ""
|
||
if hasattr(response, 'request_id'):
|
||
request_id = response.request_id
|
||
|
||
return {
|
||
"images": images,
|
||
"request_id": request_id
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"DashScope 图像生成API调用错误: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"图像生成失败: {str(e)}")
|
||
|
||
class TryOnRequest(BaseModel):
|
||
"""试穿请求模型"""
|
||
person_image_url: str
|
||
top_garment_url: Optional[str] = None
|
||
bottom_garment_url: Optional[str] = None
|
||
resolution: Optional[int] = -1
|
||
restore_face: Optional[bool] = True
|
||
|
||
class TaskInfo(BaseModel):
|
||
"""任务信息模型"""
|
||
task_id: str
|
||
task_status: str
|
||
|
||
class TryOnResponse(BaseModel):
|
||
"""试穿响应模型"""
|
||
output: TaskInfo
|
||
request_id: str
|
||
|
||
@router.post("/try-on", response_model=TryOnResponse)
|
||
async def generate_tryon(
|
||
request: TryOnRequest,
|
||
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||
):
|
||
"""
|
||
调用DashScope的人物试衣API
|
||
|
||
该API用于给人物图片试穿上衣和/或下衣,返回合成图片的任务ID
|
||
|
||
- **person_image_url**: 人物图片URL(必填)
|
||
- **top_garment_url**: 上衣图片URL(可选)
|
||
- **bottom_garment_url**: 下衣图片URL(可选)
|
||
- **resolution**: 分辨率,-1表示自动(可选)
|
||
- **restore_face**: 是否修复面部(可选)
|
||
|
||
注意:top_garment_url和bottom_garment_url至少需要提供一个
|
||
"""
|
||
try:
|
||
# 验证参数
|
||
if not request.top_garment_url and not request.bottom_garment_url:
|
||
raise HTTPException(status_code=400, detail="上衣和下衣图片至少需要提供一个")
|
||
|
||
response = await dashscope_service.generate_tryon(
|
||
person_image_url=request.person_image_url,
|
||
top_garment_url=request.top_garment_url,
|
||
bottom_garment_url=request.bottom_garment_url,
|
||
resolution=request.resolution,
|
||
restore_face=request.restore_face
|
||
)
|
||
|
||
# 构建响应
|
||
return {
|
||
"output": {
|
||
"task_id": response.get("output", {}).get("task_id", ""),
|
||
"task_status": response.get("output", {}).get("task_status", "")
|
||
},
|
||
"request_id": response.get("request_id", "")
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"DashScope 试穿API调用错误: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"试穿请求失败: {str(e)}")
|
||
|
||
class TaskStatusResponse(BaseModel):
|
||
"""任务状态响应模型"""
|
||
task_id: str
|
||
task_status: str
|
||
completion_url: Optional[str] = None
|
||
request_id: str
|
||
|
||
@router.get("/try-on/{task_id}", response_model=TaskStatusResponse)
|
||
async def check_tryon_status(
|
||
task_id: str,
|
||
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||
):
|
||
"""
|
||
查询试穿任务状态
|
||
|
||
- **task_id**: 任务ID
|
||
"""
|
||
try:
|
||
response = await dashscope_service.check_tryon_status(task_id)
|
||
|
||
# 构建响应
|
||
return {
|
||
"task_id": task_id,
|
||
"task_status": response.get("output", {}).get("task_status", ""),
|
||
"completion_url": response.get("output", {}).get("url", ""),
|
||
"request_id": response.get("request_id", "")
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"查询试穿任务状态错误: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"查询任务状态失败: {str(e)}")
|
||
|
||
@router.get("/models")
|
||
async def list_available_models():
|
||
"""
|
||
列出DashScope上可用的模型
|
||
"""
|
||
try:
|
||
models = {
|
||
"chat_models": [
|
||
{"id": "qwen-max", "name": "通义千问MAX"},
|
||
{"id": "qwen-plus", "name": "通义千问Plus"},
|
||
{"id": "qwen-turbo", "name": "通义千问Turbo"}
|
||
],
|
||
"image_models": [
|
||
{"id": "stable-diffusion-xl", "name": "Stable Diffusion XL"},
|
||
{"id": "wanx-v1", "name": "万相"}
|
||
],
|
||
"tryon_models": [
|
||
{"id": "aitryon", "name": "AI试穿"}
|
||
]
|
||
}
|
||
return models
|
||
except Exception as e:
|
||
logger.error(f"获取模型列表错误: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") |