aidress/app/routers/dashscope_router.py
2025-03-21 17:06:54 +08:00

262 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
import logging
from typing import List, Optional, Dict, Any
import dashscope
from app.services.dashscope_service import DashScopeService
from app.utils.config import get_settings
logger = logging.getLogger(__name__)
router = APIRouter()
class ChatMessage(BaseModel):
role: str # 'user' 或 'assistant'
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
model: Optional[str] = "qwen-max" # 默认使用通义千问MAX模型
max_tokens: Optional[int] = 2048
temperature: Optional[float] = 0.7
stream: Optional[bool] = False
class ChatResponse(BaseModel):
response: str
usage: Dict[str, Any]
request_id: str
@router.post("/chat", response_model=ChatResponse)
async def chat_completion(
request: ChatRequest,
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
):
"""
调用DashScope的大模型进行对话
"""
try:
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
response = await dashscope_service.chat_completion(
messages=messages,
model=request.model,
max_tokens=request.max_tokens,
temperature=request.temperature,
stream=request.stream
)
# 确保我们能正确访问响应字段根据dashscope的最新版本调整
response_text = ""
if hasattr(response, 'output') and hasattr(response.output, 'text'):
response_text = response.output.text
elif hasattr(response, 'output') and isinstance(response.output, dict) and 'text' in response.output:
response_text = response.output['text']
elif hasattr(response, 'choices') and len(response.choices) > 0:
response_text = response.choices[0].message.content
else:
logger.warning(f"无法解析DashScope响应: {response}")
# 尝试保守提取
try:
if hasattr(response, 'output'):
if isinstance(response.output, dict):
response_text = str(response.output.get('text', ''))
else:
response_text = str(response.output)
else:
response_text = str(response)
except Exception as e:
logger.error(f"提取响应文本失败: {str(e)}")
response_text = "无法获取模型响应文本"
# 同样确保我们能正确访问usage和request_id
usage = {}
if hasattr(response, 'usage'):
usage = response.usage if isinstance(response.usage, dict) else vars(response.usage)
request_id = ""
if hasattr(response, 'request_id'):
request_id = response.request_id
return {
"response": response_text,
"usage": usage,
"request_id": request_id
}
except Exception as e:
logger.error(f"DashScope API 调用错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"模型调用失败: {str(e)}")
class ImageGenerationRequest(BaseModel):
prompt: str
negative_prompt: Optional[str] = None
model: Optional[str] = "stable-diffusion-xl"
n: Optional[int] = 1
size: Optional[str] = "1024*1024"
class ImageGenerationResponse(BaseModel):
images: List[str]
request_id: str
@router.post("/generate-image", response_model=ImageGenerationResponse)
async def generate_image(
request: ImageGenerationRequest,
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
):
"""
调用DashScope的图像生成API
"""
try:
response = await dashscope_service.generate_image(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
model=request.model,
n=request.n,
size=request.size
)
# 确保我们能正确访问images和request_id
images = []
if hasattr(response, 'output') and hasattr(response.output, 'images'):
images = response.output.images
elif hasattr(response, 'output') and isinstance(response.output, dict) and 'images' in response.output:
images = response.output['images']
else:
logger.warning(f"无法解析DashScope图像生成响应: {response}")
try:
if hasattr(response, 'output'):
if isinstance(response.output, dict):
images = response.output.get('images', [])
else:
images = []
else:
images = []
except Exception as e:
logger.error(f"提取图像URL失败: {str(e)}")
images = []
request_id = ""
if hasattr(response, 'request_id'):
request_id = response.request_id
return {
"images": images,
"request_id": request_id
}
except Exception as e:
logger.error(f"DashScope 图像生成API调用错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"图像生成失败: {str(e)}")
class TryOnRequest(BaseModel):
"""试穿请求模型"""
person_image_url: str
top_garment_url: Optional[str] = None
bottom_garment_url: Optional[str] = None
resolution: Optional[int] = -1
restore_face: Optional[bool] = True
class TaskInfo(BaseModel):
"""任务信息模型"""
task_id: str
task_status: str
class TryOnResponse(BaseModel):
"""试穿响应模型"""
output: TaskInfo
request_id: str
@router.post("/try-on", response_model=TryOnResponse)
async def generate_tryon(
request: TryOnRequest,
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
):
"""
调用DashScope的人物试衣API
该API用于给人物图片试穿上衣和/或下衣返回合成图片的任务ID
- **person_image_url**: 人物图片URL必填
- **top_garment_url**: 上衣图片URL可选
- **bottom_garment_url**: 下衣图片URL可选
- **resolution**: 分辨率,-1表示自动可选
- **restore_face**: 是否修复面部(可选)
注意top_garment_url和bottom_garment_url至少需要提供一个
"""
try:
# 验证参数
if not request.top_garment_url and not request.bottom_garment_url:
raise HTTPException(status_code=400, detail="上衣和下衣图片至少需要提供一个")
response = await dashscope_service.generate_tryon(
person_image_url=request.person_image_url,
top_garment_url=request.top_garment_url,
bottom_garment_url=request.bottom_garment_url,
resolution=request.resolution,
restore_face=request.restore_face
)
# 构建响应
return {
"output": {
"task_id": response.get("output", {}).get("task_id", ""),
"task_status": response.get("output", {}).get("task_status", "")
},
"request_id": response.get("request_id", "")
}
except Exception as e:
logger.error(f"DashScope 试穿API调用错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"试穿请求失败: {str(e)}")
class TaskStatusResponse(BaseModel):
"""任务状态响应模型"""
task_id: str
task_status: str
completion_url: Optional[str] = None
request_id: str
@router.get("/try-on/{task_id}", response_model=TaskStatusResponse)
async def check_tryon_status(
task_id: str,
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
):
"""
查询试穿任务状态
- **task_id**: 任务ID
"""
try:
response = await dashscope_service.check_tryon_status(task_id)
# 构建响应
return {
"task_id": task_id,
"task_status": response.get("output", {}).get("task_status", ""),
"completion_url": response.get("output", {}).get("url", ""),
"request_id": response.get("request_id", "")
}
except Exception as e:
logger.error(f"查询试穿任务状态错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"查询任务状态失败: {str(e)}")
@router.get("/models")
async def list_available_models():
"""
列出DashScope上可用的模型
"""
try:
models = {
"chat_models": [
{"id": "qwen-max", "name": "通义千问MAX"},
{"id": "qwen-plus", "name": "通义千问Plus"},
{"id": "qwen-turbo", "name": "通义千问Turbo"}
],
"image_models": [
{"id": "stable-diffusion-xl", "name": "Stable Diffusion XL"},
{"id": "wanx-v1", "name": "万相"}
],
"tryon_models": [
{"id": "aitryon", "name": "AI试穿"}
]
}
return models
except Exception as e:
logger.error(f"获取模型列表错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}")