This commit is contained in:
aaron 2025-03-21 21:53:22 +08:00
parent 5ffa471403
commit 3d837597ba
7 changed files with 96 additions and 288 deletions

22
.env.example Normal file
View File

@ -0,0 +1,22 @@
# 数据库配置
DB_HOST=localhost
DB_PORT=3306
DB_USER=ai_user
DB_PASSWORD=your_password
DB_NAME=ai_dressing
# 阿里云大模型API配置
DASHSCOPE_API_KEY=your_dashscope_api_key
# 腾讯云配置
QCLOUD_SECRET_ID=your_qcloud_secret_id
QCLOUD_SECRET_KEY=your_qcloud_secret_key
QCLOUD_COS_REGION=ap-chengdu
QCLOUD_COS_BUCKET=your-bucket-name
QCLOUD_COS_DOMAIN=https://your-bucket-name.cos.ap-chengdu.myqcloud.com
# 应用程序配置
HOST=0.0.0.0
PORT=9001
DEBUG=False
LOG_LEVEL=INFO

View File

@ -9,8 +9,7 @@ settings = get_settings()
engine = create_engine(
settings.database_url,
pool_pre_ping=True, # 自动检测断开的连接并重新连接
pool_recycle=3600, # 每小时回收连接
echo=settings.debug, # 在调试模式下打印SQL语句
pool_recycle=3600 # 每小时回收连接
)
# 创建会话工厂

View File

@ -4,7 +4,7 @@ import os
import logging
from dotenv import load_dotenv
from app.routers import dashscope_router, qcloud_router, dress_router, tryon_router
from app.routers import qcloud_router, dress_router, tryon_router
from app.utils.config import get_settings
from app.database import Base, engine
@ -37,7 +37,6 @@ app.add_middleware(
)
# 注册路由
app.include_router(dashscope_router.router, prefix="/api/dashscope", tags=["DashScope"])
app.include_router(qcloud_router.router, prefix="/api/qcloud", tags=["腾讯云"])
app.include_router(dress_router.router, prefix="/api/dresses", tags=["服装"])
app.include_router(tryon_router.router, prefix="/api/tryons", tags=["试穿"])

View File

@ -1,262 +0,0 @@
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
import logging
from typing import List, Optional, Dict, Any
import dashscope
from app.services.dashscope_service import DashScopeService
from app.utils.config import get_settings
logger = logging.getLogger(__name__)
router = APIRouter()
class ChatMessage(BaseModel):
role: str # 'user' 或 'assistant'
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
model: Optional[str] = "qwen-max" # 默认使用通义千问MAX模型
max_tokens: Optional[int] = 2048
temperature: Optional[float] = 0.7
stream: Optional[bool] = False
class ChatResponse(BaseModel):
response: str
usage: Dict[str, Any]
request_id: str
@router.post("/chat", response_model=ChatResponse)
async def chat_completion(
request: ChatRequest,
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
):
"""
调用DashScope的大模型进行对话
"""
try:
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
response = await dashscope_service.chat_completion(
messages=messages,
model=request.model,
max_tokens=request.max_tokens,
temperature=request.temperature,
stream=request.stream
)
# 确保我们能正确访问响应字段根据dashscope的最新版本调整
response_text = ""
if hasattr(response, 'output') and hasattr(response.output, 'text'):
response_text = response.output.text
elif hasattr(response, 'output') and isinstance(response.output, dict) and 'text' in response.output:
response_text = response.output['text']
elif hasattr(response, 'choices') and len(response.choices) > 0:
response_text = response.choices[0].message.content
else:
logger.warning(f"无法解析DashScope响应: {response}")
# 尝试保守提取
try:
if hasattr(response, 'output'):
if isinstance(response.output, dict):
response_text = str(response.output.get('text', ''))
else:
response_text = str(response.output)
else:
response_text = str(response)
except Exception as e:
logger.error(f"提取响应文本失败: {str(e)}")
response_text = "无法获取模型响应文本"
# 同样确保我们能正确访问usage和request_id
usage = {}
if hasattr(response, 'usage'):
usage = response.usage if isinstance(response.usage, dict) else vars(response.usage)
request_id = ""
if hasattr(response, 'request_id'):
request_id = response.request_id
return {
"response": response_text,
"usage": usage,
"request_id": request_id
}
except Exception as e:
logger.error(f"DashScope API 调用错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"模型调用失败: {str(e)}")
class ImageGenerationRequest(BaseModel):
prompt: str
negative_prompt: Optional[str] = None
model: Optional[str] = "stable-diffusion-xl"
n: Optional[int] = 1
size: Optional[str] = "1024*1024"
class ImageGenerationResponse(BaseModel):
images: List[str]
request_id: str
@router.post("/generate-image", response_model=ImageGenerationResponse)
async def generate_image(
request: ImageGenerationRequest,
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
):
"""
调用DashScope的图像生成API
"""
try:
response = await dashscope_service.generate_image(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
model=request.model,
n=request.n,
size=request.size
)
# 确保我们能正确访问images和request_id
images = []
if hasattr(response, 'output') and hasattr(response.output, 'images'):
images = response.output.images
elif hasattr(response, 'output') and isinstance(response.output, dict) and 'images' in response.output:
images = response.output['images']
else:
logger.warning(f"无法解析DashScope图像生成响应: {response}")
try:
if hasattr(response, 'output'):
if isinstance(response.output, dict):
images = response.output.get('images', [])
else:
images = []
else:
images = []
except Exception as e:
logger.error(f"提取图像URL失败: {str(e)}")
images = []
request_id = ""
if hasattr(response, 'request_id'):
request_id = response.request_id
return {
"images": images,
"request_id": request_id
}
except Exception as e:
logger.error(f"DashScope 图像生成API调用错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"图像生成失败: {str(e)}")
class TryOnRequest(BaseModel):
"""试穿请求模型"""
person_image_url: str
top_garment_url: Optional[str] = None
bottom_garment_url: Optional[str] = None
resolution: Optional[int] = -1
restore_face: Optional[bool] = True
class TaskInfo(BaseModel):
"""任务信息模型"""
task_id: str
task_status: str
class TryOnResponse(BaseModel):
"""试穿响应模型"""
output: TaskInfo
request_id: str
@router.post("/try-on", response_model=TryOnResponse)
async def generate_tryon(
request: TryOnRequest,
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
):
"""
调用DashScope的人物试衣API
该API用于给人物图片试穿上衣和/或下衣返回合成图片的任务ID
- **person_image_url**: 人物图片URL必填
- **top_garment_url**: 上衣图片URL可选
- **bottom_garment_url**: 下衣图片URL可选
- **resolution**: 分辨率-1表示自动可选
- **restore_face**: 是否修复面部可选
注意top_garment_url和bottom_garment_url至少需要提供一个
"""
try:
# 验证参数
if not request.top_garment_url and not request.bottom_garment_url:
raise HTTPException(status_code=400, detail="上衣和下衣图片至少需要提供一个")
response = await dashscope_service.generate_tryon(
person_image_url=request.person_image_url,
top_garment_url=request.top_garment_url,
bottom_garment_url=request.bottom_garment_url,
resolution=request.resolution,
restore_face=request.restore_face
)
# 构建响应
return {
"output": {
"task_id": response.get("output", {}).get("task_id", ""),
"task_status": response.get("output", {}).get("task_status", "")
},
"request_id": response.get("request_id", "")
}
except Exception as e:
logger.error(f"DashScope 试穿API调用错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"试穿请求失败: {str(e)}")
class TaskStatusResponse(BaseModel):
"""任务状态响应模型"""
task_id: str
task_status: str
completion_url: Optional[str] = None
request_id: str
@router.get("/try-on/{task_id}", response_model=TaskStatusResponse)
async def check_tryon_status(
task_id: str,
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
):
"""
查询试穿任务状态
- **task_id**: 任务ID
"""
try:
response = await dashscope_service.check_tryon_status(task_id)
# 构建响应
return {
"task_id": task_id,
"task_status": response.get("output", {}).get("task_status", ""),
"completion_url": response.get("output", {}).get("url", ""),
"request_id": response.get("request_id", "")
}
except Exception as e:
logger.error(f"查询试穿任务状态错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"查询任务状态失败: {str(e)}")
@router.get("/models")
async def list_available_models():
"""
列出DashScope上可用的模型
"""
try:
models = {
"chat_models": [
{"id": "qwen-max", "name": "通义千问MAX"},
{"id": "qwen-plus", "name": "通义千问Plus"},
{"id": "qwen-turbo", "name": "通义千问Turbo"}
],
"image_models": [
{"id": "stable-diffusion-xl", "name": "Stable Diffusion XL"},
{"id": "wanx-v1", "name": "万相"}
],
"tryon_models": [
{"id": "aitryon", "name": "AI试穿"}
]
}
return models
except Exception as e:
logger.error(f"获取模型列表错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}")

View File

@ -92,8 +92,8 @@ async def send_tryon_request(
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
if db_tryon:
db_tryon.request_id = response.get("request_id")
db_tryon.task_id = response.get("output", {}).get("task_id")
db_tryon.task_status = response.get("output", {}).get("task_status")
db_tryon.task_id = response.get("task_id")
db_tryon.task_status = response.get("status")
db.commit()
logger.info(f"试穿请求发送成功任务ID: {db_tryon.task_id}")
@ -204,11 +204,11 @@ async def check_tryon_status(
status_response = await dashscope_service.check_tryon_status(db_tryon.task_id)
# 更新数据库记录
db_tryon.task_status = status_response.get("output", {}).get("task_status")
db_tryon.task_status = status_response.get("status")
# 如果任务完成保存结果URL
if db_tryon.task_status == "SUCCEEDED":
db_tryon.completion_url = status_response.get("output", {}).get("url")
db_tryon.completion_url = status_response.get("image_url")
db.commit()
db.refresh(db_tryon)

View File

@ -165,11 +165,13 @@ class DashScopeService:
# 构建请求头
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
"Content-Type": "application/json",
"X-DashScope-Async": "enable"
}
# 发送请求
async with httpx.AsyncClient() as client:
logger.info(f"发送试穿请求: {request_data}")
response = await client.post(
self.image_synthesis_url,
json=request_data,
@ -177,10 +179,32 @@ class DashScopeService:
timeout=30.0
)
if response.status_code == 200:
response_data = response.json()
logger.info(f"试穿请求发送成功任务ID: {response_data.get('output', {}).get('task_id')}")
return response_data
logger.info(f"试穿API响应: {response_data}")
if response.status_code == 200 or response.status_code == 202: # 202表示异步任务已接受
# 提取任务ID适应不同的API响应格式
task_id = None
if 'output' in response_data and 'task_id' in response_data['output']:
task_id = response_data['output']['task_id']
elif 'task_id' in response_data:
task_id = response_data['task_id']
if task_id:
logger.info(f"试穿请求发送成功任务ID: {task_id}")
return {
"task_id": task_id,
"request_id": response_data.get('request_id'),
"status": "processing"
}
else:
# 如果没有任务ID这可能是同步响应
logger.info("收到同步响应没有任务ID")
return {
"status": "completed",
"result": response_data.get('output', {}),
"request_id": response_data.get('request_id')
}
else:
error_msg = f"试穿请求失败: {response.status_code} - {response.text}"
logger.error(error_msg)
@ -219,8 +243,26 @@ class DashScopeService:
if response.status_code == 200:
response_data = response.json()
logger.info(f"试穿任务状态查询成功: {response_data.get('output', {}).get('task_status')}")
return response_data
print(response_data)
status = response_data.get('output', {}).get('task_status', '')
logger.info(f"试穿任务状态查询成功: {status}")
# 检查是否完成并返回结果
if status.lower() == 'succeeded':
image_url = response_data.get('output', {}).get('image_url')
if image_url:
logger.info(f"试穿任务完成结果URL: {image_url}")
return {
"status": status,
"task_id": task_id,
"image_url": image_url,
"result": response_data.get('output', {})
}
return {
"status": status,
"task_id": task_id
}
else:
error_msg = f"试穿任务状态查询失败: {response.status_code} - {response.text}"
logger.error(error_msg)

View File

@ -1,16 +1,22 @@
import os
import logging
from functools import lru_cache
from typing import Optional
from pydantic import BaseModel
from dotenv import load_dotenv
from dotenv import load_dotenv, find_dotenv
# 加载环境变量如果直接从main导入这里可能是冗余的但为了安全起见
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("config")
# 尝试加载环境变量
logger.info("开始加载环境变量...")
load_dotenv()
class Settings(BaseModel):
"""应用程序配置类"""
# DashScope配置
dashscope_api_key: str = os.getenv("DASHSCOPE_API_KEY", "sk-caa199589f1c451aaac471fad2986e28")
dashscope_api_key: str = os.getenv("DASHSCOPE_API_KEY", "")
# 服务器配置
host: str = os.getenv("HOST", "0.0.0.0")
@ -18,17 +24,17 @@ class Settings(BaseModel):
debug: bool = os.getenv("DEBUG", "False").lower() in ["true", "1", "t", "yes"]
# 腾讯云配置
qcloud_secret_id: str = os.getenv("QCLOUD_SECRET_ID", "AKIDxnbGj281iHtKallqqzvlV5YxBCrPltnS")
qcloud_secret_key: str = os.getenv("QCLOUD_SECRET_KEY", "ta6PXTMBsX7dzA7IN6uYUFn8F9uTovoU")
qcloud_secret_id: str = os.getenv("QCLOUD_SECRET_ID", "")
qcloud_secret_key: str = os.getenv("QCLOUD_SECRET_KEY", "")
qcloud_cos_region: str = os.getenv("QCLOUD_COS_REGION", "ap-chengdu")
qcloud_cos_bucket: str = os.getenv("QCLOUD_COS_BUCKET", "aidress-1311994147")
qcloud_cos_domain: str = os.getenv("QCLOUD_COS_DOMAIN", "https://aidress-1311994147.cos.ap-chengdu.myqcloud.com")
qcloud_cos_bucket: str = os.getenv("QCLOUD_COS_BUCKET", "")
qcloud_cos_domain: str = os.getenv("QCLOUD_COS_DOMAIN", "")
# 数据库配置
db_host: str = os.getenv("DB_HOST", "gz-cynosdbmysql-grp-2j1cnopr.sql.tencentcdb.com")
db_port: int = int(os.getenv("DB_PORT", "27469"))
db_host: str = os.getenv("DB_HOST", "localhost")
db_port: int = int(os.getenv("DB_PORT", "3306"))
db_user: str = os.getenv("DB_USER", "root")
db_password: str = os.getenv("DB_PASSWORD", "Aa#223388")
db_password: str = os.getenv("DB_PASSWORD", "")
db_name: str = os.getenv("DB_NAME", "aidress")
# 数据库URL
@ -54,7 +60,9 @@ def get_settings() -> Settings:
Returns:
Settings: 配置对象
"""
return Settings()
settings = Settings()
logger.info(f"配置已加载DB_HOST={settings.db_host}, DB_NAME={settings.db_name}")
return settings
def validate_api_key():
"""