272 lines
9.9 KiB
Python
272 lines
9.9 KiB
Python
import os
|
||
import logging
|
||
import dashscope
|
||
from dashscope import Generation
|
||
# 修改导入语句,dashscope的API响应可能改变了结构
|
||
from typing import List, Dict, Any, Optional
|
||
import asyncio
|
||
import httpx
|
||
from app.utils.config import get_settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class DashScopeService:
|
||
"""DashScope服务类,提供对DashScope API的调用封装"""
|
||
|
||
def __init__(self):
|
||
settings = get_settings()
|
||
self.api_key = settings.dashscope_api_key
|
||
# 配置DashScope
|
||
dashscope.api_key = self.api_key
|
||
# 配置API URL
|
||
self.image_synthesis_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis"
|
||
|
||
async def chat_completion(
|
||
self,
|
||
messages: List[Dict[str, str]],
|
||
model: str = "qwen-max",
|
||
max_tokens: int = 2048,
|
||
temperature: float = 0.7,
|
||
stream: bool = False
|
||
):
|
||
"""
|
||
调用DashScope的大模型API进行对话
|
||
|
||
Args:
|
||
messages: 对话历史记录
|
||
model: 模型名称
|
||
max_tokens: 最大生成token数
|
||
temperature: 温度参数,控制随机性
|
||
stream: 是否流式输出
|
||
|
||
Returns:
|
||
ApiResponse: DashScope的API响应
|
||
"""
|
||
try:
|
||
# 为了不阻塞FastAPI的异步性能,我们使用run_in_executor运行同步API
|
||
loop = asyncio.get_event_loop()
|
||
response = await loop.run_in_executor(
|
||
None,
|
||
lambda: Generation.call(
|
||
model=model,
|
||
messages=messages,
|
||
max_tokens=max_tokens,
|
||
temperature=temperature,
|
||
result_format='message',
|
||
stream=stream,
|
||
)
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
logger.error(f"DashScope API请求失败,状态码:{response.status_code}, 错误信息:{response.message}")
|
||
raise Exception(f"API调用失败: {response.message}")
|
||
|
||
return response
|
||
except Exception as e:
|
||
logger.error(f"DashScope聊天API调用出错: {str(e)}")
|
||
raise e
|
||
|
||
async def generate_image(
|
||
self,
|
||
prompt: str,
|
||
negative_prompt: Optional[str] = None,
|
||
model: str = "stable-diffusion-xl",
|
||
n: int = 1,
|
||
size: str = "1024*1024"
|
||
):
|
||
"""
|
||
调用DashScope的图像生成API
|
||
|
||
Args:
|
||
prompt: 生成图像的文本描述
|
||
negative_prompt: 负面提示词
|
||
model: 模型名称
|
||
n: 生成图像数量
|
||
size: 图像尺寸
|
||
|
||
Returns:
|
||
ApiResponse: DashScope的API响应
|
||
"""
|
||
try:
|
||
# 构建请求参数
|
||
params = {
|
||
"model": model,
|
||
"prompt": prompt,
|
||
"n": n,
|
||
"size": size,
|
||
}
|
||
|
||
if negative_prompt:
|
||
params["negative_prompt"] = negative_prompt
|
||
|
||
# 异步调用图像生成API
|
||
loop = asyncio.get_event_loop()
|
||
response = await loop.run_in_executor(
|
||
None,
|
||
lambda: dashscope.ImageSynthesis.call(**params)
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
logger.error(f"DashScope 图像生成API请求失败,状态码:{response.status_code}, 错误信息:{response.message}")
|
||
raise Exception(f"图像生成API调用失败: {response.message}")
|
||
|
||
return response
|
||
except Exception as e:
|
||
logger.error(f"DashScope图像生成API调用出错: {str(e)}")
|
||
raise e
|
||
|
||
async def generate_tryon(
|
||
self,
|
||
person_image_url: str,
|
||
top_garment_url: Optional[str] = None,
|
||
bottom_garment_url: Optional[str] = None,
|
||
resolution: int = -1,
|
||
restore_face: bool = True
|
||
):
|
||
"""
|
||
调用阿里百炼平台的试衣服务
|
||
|
||
Args:
|
||
person_image_url: 人物图片URL
|
||
top_garment_url: 上衣图片URL
|
||
bottom_garment_url: 下衣图片URL
|
||
resolution: 分辨率,-1表示自动
|
||
restore_face: 是否修复面部
|
||
|
||
Returns:
|
||
Dict: 包含任务ID和请求ID的响应
|
||
"""
|
||
try:
|
||
# 验证参数
|
||
if not person_image_url:
|
||
raise ValueError("人物图片URL不能为空")
|
||
|
||
if not top_garment_url and not bottom_garment_url:
|
||
raise ValueError("上衣和下衣图片至少需要提供一个")
|
||
|
||
# 构建请求数据
|
||
request_data = {
|
||
"model": "aitryon",
|
||
"input": {
|
||
"person_image_url": person_image_url
|
||
},
|
||
"parameters": {
|
||
"resolution": resolution,
|
||
"restore_face": restore_face
|
||
}
|
||
}
|
||
|
||
# 添加可选字段
|
||
if top_garment_url:
|
||
request_data["input"]["top_garment_url"] = top_garment_url
|
||
if bottom_garment_url:
|
||
request_data["input"]["bottom_garment_url"] = bottom_garment_url
|
||
|
||
# 构建请求头
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json",
|
||
"X-DashScope-Async": "enable"
|
||
}
|
||
|
||
# 发送请求
|
||
async with httpx.AsyncClient() as client:
|
||
logger.info(f"发送试穿请求: {request_data}")
|
||
response = await client.post(
|
||
self.image_synthesis_url,
|
||
json=request_data,
|
||
headers=headers,
|
||
timeout=30.0
|
||
)
|
||
|
||
response_data = response.json()
|
||
logger.info(f"试穿API响应: {response_data}")
|
||
|
||
if response.status_code == 200 or response.status_code == 202: # 202表示异步任务已接受
|
||
# 提取任务ID,适应不同的API响应格式
|
||
task_id = None
|
||
if 'output' in response_data and 'task_id' in response_data['output']:
|
||
task_id = response_data['output']['task_id']
|
||
elif 'task_id' in response_data:
|
||
task_id = response_data['task_id']
|
||
|
||
if task_id:
|
||
logger.info(f"试穿请求发送成功,任务ID: {task_id}")
|
||
return {
|
||
"task_id": task_id,
|
||
"request_id": response_data.get('request_id'),
|
||
"status": "processing"
|
||
}
|
||
else:
|
||
# 如果没有任务ID,这可能是同步响应
|
||
logger.info("收到同步响应,没有任务ID")
|
||
return {
|
||
"status": "completed",
|
||
"result": response_data.get('output', {}),
|
||
"request_id": response_data.get('request_id')
|
||
}
|
||
else:
|
||
error_msg = f"试穿请求失败: {response.status_code} - {response.text}"
|
||
logger.error(error_msg)
|
||
raise Exception(error_msg)
|
||
except Exception as e:
|
||
logger.error(f"DashScope试穿API调用出错: {str(e)}")
|
||
raise e
|
||
|
||
async def check_tryon_status(self, task_id: str):
|
||
"""
|
||
检查试穿任务状态
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
Dict: 任务状态信息
|
||
"""
|
||
try:
|
||
# 构建请求头
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
# 构建请求URL
|
||
status_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
|
||
|
||
# 发送请求
|
||
async with httpx.AsyncClient() as client:
|
||
response = await client.get(
|
||
status_url,
|
||
headers=headers,
|
||
timeout=30.0
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
response_data = response.json()
|
||
print(response_data)
|
||
status = response_data.get('output', {}).get('task_status', '')
|
||
logger.info(f"试穿任务状态查询成功: {status}")
|
||
|
||
# 检查是否完成并返回结果
|
||
if status.lower() == 'succeeded':
|
||
image_url = response_data.get('output', {}).get('image_url')
|
||
if image_url:
|
||
logger.info(f"试穿任务完成,结果URL: {image_url}")
|
||
return {
|
||
"status": status,
|
||
"task_id": task_id,
|
||
"image_url": image_url,
|
||
"result": response_data.get('output', {})
|
||
}
|
||
|
||
return {
|
||
"status": status,
|
||
"task_id": task_id
|
||
}
|
||
else:
|
||
error_msg = f"试穿任务状态查询失败: {response.status_code} - {response.text}"
|
||
logger.error(error_msg)
|
||
raise Exception(error_msg)
|
||
except Exception as e:
|
||
logger.error(f"查询试穿任务状态出错: {str(e)}")
|
||
raise e |