update
This commit is contained in:
parent
5ffa471403
commit
3d837597ba
22
.env.example
Normal file
22
.env.example
Normal file
@ -0,0 +1,22 @@
|
||||
# 数据库配置
|
||||
DB_HOST=localhost
|
||||
DB_PORT=3306
|
||||
DB_USER=ai_user
|
||||
DB_PASSWORD=your_password
|
||||
DB_NAME=ai_dressing
|
||||
|
||||
# 阿里云大模型API配置
|
||||
DASHSCOPE_API_KEY=your_dashscope_api_key
|
||||
|
||||
# 腾讯云配置
|
||||
QCLOUD_SECRET_ID=your_qcloud_secret_id
|
||||
QCLOUD_SECRET_KEY=your_qcloud_secret_key
|
||||
QCLOUD_COS_REGION=ap-chengdu
|
||||
QCLOUD_COS_BUCKET=your-bucket-name
|
||||
QCLOUD_COS_DOMAIN=https://your-bucket-name.cos.ap-chengdu.myqcloud.com
|
||||
|
||||
# 应用程序配置
|
||||
HOST=0.0.0.0
|
||||
PORT=9001
|
||||
DEBUG=False
|
||||
LOG_LEVEL=INFO
|
||||
@ -9,8 +9,7 @@ settings = get_settings()
|
||||
engine = create_engine(
|
||||
settings.database_url,
|
||||
pool_pre_ping=True, # 自动检测断开的连接并重新连接
|
||||
pool_recycle=3600, # 每小时回收连接
|
||||
echo=settings.debug, # 在调试模式下打印SQL语句
|
||||
pool_recycle=3600 # 每小时回收连接
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
|
||||
@ -4,7 +4,7 @@ import os
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.routers import dashscope_router, qcloud_router, dress_router, tryon_router
|
||||
from app.routers import qcloud_router, dress_router, tryon_router
|
||||
from app.utils.config import get_settings
|
||||
from app.database import Base, engine
|
||||
|
||||
@ -37,7 +37,6 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(dashscope_router.router, prefix="/api/dashscope", tags=["DashScope"])
|
||||
app.include_router(qcloud_router.router, prefix="/api/qcloud", tags=["腾讯云"])
|
||||
app.include_router(dress_router.router, prefix="/api/dresses", tags=["服装"])
|
||||
app.include_router(tryon_router.router, prefix="/api/tryons", tags=["试穿"])
|
||||
|
||||
@ -1,262 +0,0 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
import dashscope
|
||||
from app.services.dashscope_service import DashScopeService
|
||||
from app.utils.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str # 'user' 或 'assistant'
|
||||
content: str
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: List[ChatMessage]
|
||||
model: Optional[str] = "qwen-max" # 默认使用通义千问MAX模型
|
||||
max_tokens: Optional[int] = 2048
|
||||
temperature: Optional[float] = 0.7
|
||||
stream: Optional[bool] = False
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
response: str
|
||||
usage: Dict[str, Any]
|
||||
request_id: str
|
||||
|
||||
@router.post("/chat", response_model=ChatResponse)
|
||||
async def chat_completion(
|
||||
request: ChatRequest,
|
||||
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||||
):
|
||||
"""
|
||||
调用DashScope的大模型进行对话
|
||||
"""
|
||||
try:
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
response = await dashscope_service.chat_completion(
|
||||
messages=messages,
|
||||
model=request.model,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
stream=request.stream
|
||||
)
|
||||
|
||||
# 确保我们能正确访问响应字段,根据dashscope的最新版本调整
|
||||
response_text = ""
|
||||
if hasattr(response, 'output') and hasattr(response.output, 'text'):
|
||||
response_text = response.output.text
|
||||
elif hasattr(response, 'output') and isinstance(response.output, dict) and 'text' in response.output:
|
||||
response_text = response.output['text']
|
||||
elif hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
response_text = response.choices[0].message.content
|
||||
else:
|
||||
logger.warning(f"无法解析DashScope响应: {response}")
|
||||
# 尝试保守提取
|
||||
try:
|
||||
if hasattr(response, 'output'):
|
||||
if isinstance(response.output, dict):
|
||||
response_text = str(response.output.get('text', ''))
|
||||
else:
|
||||
response_text = str(response.output)
|
||||
else:
|
||||
response_text = str(response)
|
||||
except Exception as e:
|
||||
logger.error(f"提取响应文本失败: {str(e)}")
|
||||
response_text = "无法获取模型响应文本"
|
||||
|
||||
# 同样确保我们能正确访问usage和request_id
|
||||
usage = {}
|
||||
if hasattr(response, 'usage'):
|
||||
usage = response.usage if isinstance(response.usage, dict) else vars(response.usage)
|
||||
|
||||
request_id = ""
|
||||
if hasattr(response, 'request_id'):
|
||||
request_id = response.request_id
|
||||
|
||||
return {
|
||||
"response": response_text,
|
||||
"usage": usage,
|
||||
"request_id": request_id
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope API 调用错误: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"模型调用失败: {str(e)}")
|
||||
|
||||
class ImageGenerationRequest(BaseModel):
|
||||
prompt: str
|
||||
negative_prompt: Optional[str] = None
|
||||
model: Optional[str] = "stable-diffusion-xl"
|
||||
n: Optional[int] = 1
|
||||
size: Optional[str] = "1024*1024"
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
images: List[str]
|
||||
request_id: str
|
||||
|
||||
@router.post("/generate-image", response_model=ImageGenerationResponse)
|
||||
async def generate_image(
|
||||
request: ImageGenerationRequest,
|
||||
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||||
):
|
||||
"""
|
||||
调用DashScope的图像生成API
|
||||
"""
|
||||
try:
|
||||
response = await dashscope_service.generate_image(
|
||||
prompt=request.prompt,
|
||||
negative_prompt=request.negative_prompt,
|
||||
model=request.model,
|
||||
n=request.n,
|
||||
size=request.size
|
||||
)
|
||||
|
||||
# 确保我们能正确访问images和request_id
|
||||
images = []
|
||||
if hasattr(response, 'output') and hasattr(response.output, 'images'):
|
||||
images = response.output.images
|
||||
elif hasattr(response, 'output') and isinstance(response.output, dict) and 'images' in response.output:
|
||||
images = response.output['images']
|
||||
else:
|
||||
logger.warning(f"无法解析DashScope图像生成响应: {response}")
|
||||
try:
|
||||
if hasattr(response, 'output'):
|
||||
if isinstance(response.output, dict):
|
||||
images = response.output.get('images', [])
|
||||
else:
|
||||
images = []
|
||||
else:
|
||||
images = []
|
||||
except Exception as e:
|
||||
logger.error(f"提取图像URL失败: {str(e)}")
|
||||
images = []
|
||||
|
||||
request_id = ""
|
||||
if hasattr(response, 'request_id'):
|
||||
request_id = response.request_id
|
||||
|
||||
return {
|
||||
"images": images,
|
||||
"request_id": request_id
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope 图像生成API调用错误: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"图像生成失败: {str(e)}")
|
||||
|
||||
class TryOnRequest(BaseModel):
|
||||
"""试穿请求模型"""
|
||||
person_image_url: str
|
||||
top_garment_url: Optional[str] = None
|
||||
bottom_garment_url: Optional[str] = None
|
||||
resolution: Optional[int] = -1
|
||||
restore_face: Optional[bool] = True
|
||||
|
||||
class TaskInfo(BaseModel):
|
||||
"""任务信息模型"""
|
||||
task_id: str
|
||||
task_status: str
|
||||
|
||||
class TryOnResponse(BaseModel):
|
||||
"""试穿响应模型"""
|
||||
output: TaskInfo
|
||||
request_id: str
|
||||
|
||||
@router.post("/try-on", response_model=TryOnResponse)
|
||||
async def generate_tryon(
|
||||
request: TryOnRequest,
|
||||
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||||
):
|
||||
"""
|
||||
调用DashScope的人物试衣API
|
||||
|
||||
该API用于给人物图片试穿上衣和/或下衣,返回合成图片的任务ID
|
||||
|
||||
- **person_image_url**: 人物图片URL(必填)
|
||||
- **top_garment_url**: 上衣图片URL(可选)
|
||||
- **bottom_garment_url**: 下衣图片URL(可选)
|
||||
- **resolution**: 分辨率,-1表示自动(可选)
|
||||
- **restore_face**: 是否修复面部(可选)
|
||||
|
||||
注意:top_garment_url和bottom_garment_url至少需要提供一个
|
||||
"""
|
||||
try:
|
||||
# 验证参数
|
||||
if not request.top_garment_url and not request.bottom_garment_url:
|
||||
raise HTTPException(status_code=400, detail="上衣和下衣图片至少需要提供一个")
|
||||
|
||||
response = await dashscope_service.generate_tryon(
|
||||
person_image_url=request.person_image_url,
|
||||
top_garment_url=request.top_garment_url,
|
||||
bottom_garment_url=request.bottom_garment_url,
|
||||
resolution=request.resolution,
|
||||
restore_face=request.restore_face
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
return {
|
||||
"output": {
|
||||
"task_id": response.get("output", {}).get("task_id", ""),
|
||||
"task_status": response.get("output", {}).get("task_status", "")
|
||||
},
|
||||
"request_id": response.get("request_id", "")
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope 试穿API调用错误: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"试穿请求失败: {str(e)}")
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
"""任务状态响应模型"""
|
||||
task_id: str
|
||||
task_status: str
|
||||
completion_url: Optional[str] = None
|
||||
request_id: str
|
||||
|
||||
@router.get("/try-on/{task_id}", response_model=TaskStatusResponse)
|
||||
async def check_tryon_status(
|
||||
task_id: str,
|
||||
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||||
):
|
||||
"""
|
||||
查询试穿任务状态
|
||||
|
||||
- **task_id**: 任务ID
|
||||
"""
|
||||
try:
|
||||
response = await dashscope_service.check_tryon_status(task_id)
|
||||
|
||||
# 构建响应
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"task_status": response.get("output", {}).get("task_status", ""),
|
||||
"completion_url": response.get("output", {}).get("url", ""),
|
||||
"request_id": response.get("request_id", "")
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"查询试穿任务状态错误: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"查询任务状态失败: {str(e)}")
|
||||
|
||||
@router.get("/models")
|
||||
async def list_available_models():
|
||||
"""
|
||||
列出DashScope上可用的模型
|
||||
"""
|
||||
try:
|
||||
models = {
|
||||
"chat_models": [
|
||||
{"id": "qwen-max", "name": "通义千问MAX"},
|
||||
{"id": "qwen-plus", "name": "通义千问Plus"},
|
||||
{"id": "qwen-turbo", "name": "通义千问Turbo"}
|
||||
],
|
||||
"image_models": [
|
||||
{"id": "stable-diffusion-xl", "name": "Stable Diffusion XL"},
|
||||
{"id": "wanx-v1", "name": "万相"}
|
||||
],
|
||||
"tryon_models": [
|
||||
{"id": "aitryon", "name": "AI试穿"}
|
||||
]
|
||||
}
|
||||
return models
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型列表错误: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}")
|
||||
@ -92,8 +92,8 @@ async def send_tryon_request(
|
||||
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
|
||||
if db_tryon:
|
||||
db_tryon.request_id = response.get("request_id")
|
||||
db_tryon.task_id = response.get("output", {}).get("task_id")
|
||||
db_tryon.task_status = response.get("output", {}).get("task_status")
|
||||
db_tryon.task_id = response.get("task_id")
|
||||
db_tryon.task_status = response.get("status")
|
||||
db.commit()
|
||||
|
||||
logger.info(f"试穿请求发送成功,任务ID: {db_tryon.task_id}")
|
||||
@ -204,11 +204,11 @@ async def check_tryon_status(
|
||||
status_response = await dashscope_service.check_tryon_status(db_tryon.task_id)
|
||||
|
||||
# 更新数据库记录
|
||||
db_tryon.task_status = status_response.get("output", {}).get("task_status")
|
||||
db_tryon.task_status = status_response.get("status")
|
||||
|
||||
# 如果任务完成,保存结果URL
|
||||
if db_tryon.task_status == "SUCCEEDED":
|
||||
db_tryon.completion_url = status_response.get("output", {}).get("url")
|
||||
db_tryon.completion_url = status_response.get("image_url")
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_tryon)
|
||||
|
||||
@ -165,11 +165,13 @@ class DashScopeService:
|
||||
# 构建请求头
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
"Content-Type": "application/json",
|
||||
"X-DashScope-Async": "enable"
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
async with httpx.AsyncClient() as client:
|
||||
logger.info(f"发送试穿请求: {request_data}")
|
||||
response = await client.post(
|
||||
self.image_synthesis_url,
|
||||
json=request_data,
|
||||
@ -177,10 +179,32 @@ class DashScopeService:
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
logger.info(f"试穿请求发送成功,任务ID: {response_data.get('output', {}).get('task_id')}")
|
||||
return response_data
|
||||
response_data = response.json()
|
||||
logger.info(f"试穿API响应: {response_data}")
|
||||
|
||||
if response.status_code == 200 or response.status_code == 202: # 202表示异步任务已接受
|
||||
# 提取任务ID,适应不同的API响应格式
|
||||
task_id = None
|
||||
if 'output' in response_data and 'task_id' in response_data['output']:
|
||||
task_id = response_data['output']['task_id']
|
||||
elif 'task_id' in response_data:
|
||||
task_id = response_data['task_id']
|
||||
|
||||
if task_id:
|
||||
logger.info(f"试穿请求发送成功,任务ID: {task_id}")
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"request_id": response_data.get('request_id'),
|
||||
"status": "processing"
|
||||
}
|
||||
else:
|
||||
# 如果没有任务ID,这可能是同步响应
|
||||
logger.info("收到同步响应,没有任务ID")
|
||||
return {
|
||||
"status": "completed",
|
||||
"result": response_data.get('output', {}),
|
||||
"request_id": response_data.get('request_id')
|
||||
}
|
||||
else:
|
||||
error_msg = f"试穿请求失败: {response.status_code} - {response.text}"
|
||||
logger.error(error_msg)
|
||||
@ -219,8 +243,26 @@ class DashScopeService:
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
logger.info(f"试穿任务状态查询成功: {response_data.get('output', {}).get('task_status')}")
|
||||
return response_data
|
||||
print(response_data)
|
||||
status = response_data.get('output', {}).get('task_status', '')
|
||||
logger.info(f"试穿任务状态查询成功: {status}")
|
||||
|
||||
# 检查是否完成并返回结果
|
||||
if status.lower() == 'succeeded':
|
||||
image_url = response_data.get('output', {}).get('image_url')
|
||||
if image_url:
|
||||
logger.info(f"试穿任务完成,结果URL: {image_url}")
|
||||
return {
|
||||
"status": status,
|
||||
"task_id": task_id,
|
||||
"image_url": image_url,
|
||||
"result": response_data.get('output', {})
|
||||
}
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"task_id": task_id
|
||||
}
|
||||
else:
|
||||
error_msg = f"试穿任务状态查询失败: {response.status_code} - {response.text}"
|
||||
logger.error(error_msg)
|
||||
|
||||
@ -1,16 +1,22 @@
|
||||
import os
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from dotenv import load_dotenv
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
|
||||
# 加载环境变量(如果直接从main导入,这里可能是冗余的,但为了安全起见)
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("config")
|
||||
|
||||
# 尝试加载环境变量
|
||||
logger.info("开始加载环境变量...")
|
||||
load_dotenv()
|
||||
|
||||
class Settings(BaseModel):
|
||||
"""应用程序配置类"""
|
||||
# DashScope配置
|
||||
dashscope_api_key: str = os.getenv("DASHSCOPE_API_KEY", "sk-caa199589f1c451aaac471fad2986e28")
|
||||
dashscope_api_key: str = os.getenv("DASHSCOPE_API_KEY", "")
|
||||
|
||||
# 服务器配置
|
||||
host: str = os.getenv("HOST", "0.0.0.0")
|
||||
@ -18,17 +24,17 @@ class Settings(BaseModel):
|
||||
debug: bool = os.getenv("DEBUG", "False").lower() in ["true", "1", "t", "yes"]
|
||||
|
||||
# 腾讯云配置
|
||||
qcloud_secret_id: str = os.getenv("QCLOUD_SECRET_ID", "AKIDxnbGj281iHtKallqqzvlV5YxBCrPltnS")
|
||||
qcloud_secret_key: str = os.getenv("QCLOUD_SECRET_KEY", "ta6PXTMBsX7dzA7IN6uYUFn8F9uTovoU")
|
||||
qcloud_secret_id: str = os.getenv("QCLOUD_SECRET_ID", "")
|
||||
qcloud_secret_key: str = os.getenv("QCLOUD_SECRET_KEY", "")
|
||||
qcloud_cos_region: str = os.getenv("QCLOUD_COS_REGION", "ap-chengdu")
|
||||
qcloud_cos_bucket: str = os.getenv("QCLOUD_COS_BUCKET", "aidress-1311994147")
|
||||
qcloud_cos_domain: str = os.getenv("QCLOUD_COS_DOMAIN", "https://aidress-1311994147.cos.ap-chengdu.myqcloud.com")
|
||||
qcloud_cos_bucket: str = os.getenv("QCLOUD_COS_BUCKET", "")
|
||||
qcloud_cos_domain: str = os.getenv("QCLOUD_COS_DOMAIN", "")
|
||||
|
||||
# 数据库配置
|
||||
db_host: str = os.getenv("DB_HOST", "gz-cynosdbmysql-grp-2j1cnopr.sql.tencentcdb.com")
|
||||
db_port: int = int(os.getenv("DB_PORT", "27469"))
|
||||
db_host: str = os.getenv("DB_HOST", "localhost")
|
||||
db_port: int = int(os.getenv("DB_PORT", "3306"))
|
||||
db_user: str = os.getenv("DB_USER", "root")
|
||||
db_password: str = os.getenv("DB_PASSWORD", "Aa#223388")
|
||||
db_password: str = os.getenv("DB_PASSWORD", "")
|
||||
db_name: str = os.getenv("DB_NAME", "aidress")
|
||||
|
||||
# 数据库URL
|
||||
@ -54,7 +60,9 @@ def get_settings() -> Settings:
|
||||
Returns:
|
||||
Settings: 配置对象
|
||||
"""
|
||||
return Settings()
|
||||
settings = Settings()
|
||||
logger.info(f"配置已加载:DB_HOST={settings.db_host}, DB_NAME={settings.db_name}")
|
||||
return settings
|
||||
|
||||
def validate_api_key():
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user