diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..9f3fe46 --- /dev/null +++ b/.env.example @@ -0,0 +1,22 @@ +# 数据库配置 +DB_HOST=localhost +DB_PORT=3306 +DB_USER=ai_user +DB_PASSWORD=your_password +DB_NAME=ai_dressing + +# 阿里云大模型API配置 +DASHSCOPE_API_KEY=your_dashscope_api_key + +# 腾讯云配置 +QCLOUD_SECRET_ID=your_qcloud_secret_id +QCLOUD_SECRET_KEY=your_qcloud_secret_key +QCLOUD_COS_REGION=ap-chengdu +QCLOUD_COS_BUCKET=your-bucket-name +QCLOUD_COS_DOMAIN=https://your-bucket-name.cos.ap-chengdu.myqcloud.com + +# 应用程序配置 +HOST=0.0.0.0 +PORT=9001 +DEBUG=False +LOG_LEVEL=INFO \ No newline at end of file diff --git a/app/database/__init__.py b/app/database/__init__.py index d55ad07..21d5df8 100644 --- a/app/database/__init__.py +++ b/app/database/__init__.py @@ -9,8 +9,7 @@ settings = get_settings() engine = create_engine( settings.database_url, pool_pre_ping=True, # 自动检测断开的连接并重新连接 - pool_recycle=3600, # 每小时回收连接 - echo=settings.debug, # 在调试模式下打印SQL语句 + pool_recycle=3600 # 每小时回收连接 ) # 创建会话工厂 diff --git a/app/main.py b/app/main.py index e6880ac..782b0e4 100644 --- a/app/main.py +++ b/app/main.py @@ -4,7 +4,7 @@ import os import logging from dotenv import load_dotenv -from app.routers import dashscope_router, qcloud_router, dress_router, tryon_router +from app.routers import qcloud_router, dress_router, tryon_router from app.utils.config import get_settings from app.database import Base, engine @@ -37,7 +37,6 @@ app.add_middleware( ) # 注册路由 -app.include_router(dashscope_router.router, prefix="/api/dashscope", tags=["DashScope"]) app.include_router(qcloud_router.router, prefix="/api/qcloud", tags=["腾讯云"]) app.include_router(dress_router.router, prefix="/api/dresses", tags=["服装"]) app.include_router(tryon_router.router, prefix="/api/tryons", tags=["试穿"]) diff --git a/app/routers/dashscope_router.py b/app/routers/dashscope_router.py deleted file mode 100644 index dd3c6d1..0000000 --- a/app/routers/dashscope_router.py +++ /dev/null @@ -1,262 +0,0 @@ -from fastapi import APIRouter, HTTPException, Depends -from pydantic import BaseModel -import logging -from typing import List, Optional, Dict, Any -import dashscope -from app.services.dashscope_service import DashScopeService -from app.utils.config import get_settings - -logger = logging.getLogger(__name__) -router = APIRouter() - -class ChatMessage(BaseModel): - role: str # 'user' 或 'assistant' - content: str - -class ChatRequest(BaseModel): - messages: List[ChatMessage] - model: Optional[str] = "qwen-max" # 默认使用通义千问MAX模型 - max_tokens: Optional[int] = 2048 - temperature: Optional[float] = 0.7 - stream: Optional[bool] = False - -class ChatResponse(BaseModel): - response: str - usage: Dict[str, Any] - request_id: str - -@router.post("/chat", response_model=ChatResponse) -async def chat_completion( - request: ChatRequest, - dashscope_service: DashScopeService = Depends(lambda: DashScopeService()) -): - """ - 调用DashScope的大模型进行对话 - """ - try: - messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] - response = await dashscope_service.chat_completion( - messages=messages, - model=request.model, - max_tokens=request.max_tokens, - temperature=request.temperature, - stream=request.stream - ) - - # 确保我们能正确访问响应字段,根据dashscope的最新版本调整 - response_text = "" - if hasattr(response, 'output') and hasattr(response.output, 'text'): - response_text = response.output.text - elif hasattr(response, 'output') and isinstance(response.output, dict) and 'text' in response.output: - response_text = response.output['text'] - elif hasattr(response, 'choices') and len(response.choices) > 0: - response_text = response.choices[0].message.content - else: - logger.warning(f"无法解析DashScope响应: {response}") - # 尝试保守提取 - try: - if hasattr(response, 'output'): - if isinstance(response.output, dict): - response_text = str(response.output.get('text', '')) - else: - response_text = str(response.output) - else: - response_text = str(response) - except Exception as e: - logger.error(f"提取响应文本失败: {str(e)}") - response_text = "无法获取模型响应文本" - - # 同样确保我们能正确访问usage和request_id - usage = {} - if hasattr(response, 'usage'): - usage = response.usage if isinstance(response.usage, dict) else vars(response.usage) - - request_id = "" - if hasattr(response, 'request_id'): - request_id = response.request_id - - return { - "response": response_text, - "usage": usage, - "request_id": request_id - } - except Exception as e: - logger.error(f"DashScope API 调用错误: {str(e)}") - raise HTTPException(status_code=500, detail=f"模型调用失败: {str(e)}") - -class ImageGenerationRequest(BaseModel): - prompt: str - negative_prompt: Optional[str] = None - model: Optional[str] = "stable-diffusion-xl" - n: Optional[int] = 1 - size: Optional[str] = "1024*1024" - -class ImageGenerationResponse(BaseModel): - images: List[str] - request_id: str - -@router.post("/generate-image", response_model=ImageGenerationResponse) -async def generate_image( - request: ImageGenerationRequest, - dashscope_service: DashScopeService = Depends(lambda: DashScopeService()) -): - """ - 调用DashScope的图像生成API - """ - try: - response = await dashscope_service.generate_image( - prompt=request.prompt, - negative_prompt=request.negative_prompt, - model=request.model, - n=request.n, - size=request.size - ) - - # 确保我们能正确访问images和request_id - images = [] - if hasattr(response, 'output') and hasattr(response.output, 'images'): - images = response.output.images - elif hasattr(response, 'output') and isinstance(response.output, dict) and 'images' in response.output: - images = response.output['images'] - else: - logger.warning(f"无法解析DashScope图像生成响应: {response}") - try: - if hasattr(response, 'output'): - if isinstance(response.output, dict): - images = response.output.get('images', []) - else: - images = [] - else: - images = [] - except Exception as e: - logger.error(f"提取图像URL失败: {str(e)}") - images = [] - - request_id = "" - if hasattr(response, 'request_id'): - request_id = response.request_id - - return { - "images": images, - "request_id": request_id - } - except Exception as e: - logger.error(f"DashScope 图像生成API调用错误: {str(e)}") - raise HTTPException(status_code=500, detail=f"图像生成失败: {str(e)}") - -class TryOnRequest(BaseModel): - """试穿请求模型""" - person_image_url: str - top_garment_url: Optional[str] = None - bottom_garment_url: Optional[str] = None - resolution: Optional[int] = -1 - restore_face: Optional[bool] = True - -class TaskInfo(BaseModel): - """任务信息模型""" - task_id: str - task_status: str - -class TryOnResponse(BaseModel): - """试穿响应模型""" - output: TaskInfo - request_id: str - -@router.post("/try-on", response_model=TryOnResponse) -async def generate_tryon( - request: TryOnRequest, - dashscope_service: DashScopeService = Depends(lambda: DashScopeService()) -): - """ - 调用DashScope的人物试衣API - - 该API用于给人物图片试穿上衣和/或下衣,返回合成图片的任务ID - - - **person_image_url**: 人物图片URL(必填) - - **top_garment_url**: 上衣图片URL(可选) - - **bottom_garment_url**: 下衣图片URL(可选) - - **resolution**: 分辨率,-1表示自动(可选) - - **restore_face**: 是否修复面部(可选) - - 注意:top_garment_url和bottom_garment_url至少需要提供一个 - """ - try: - # 验证参数 - if not request.top_garment_url and not request.bottom_garment_url: - raise HTTPException(status_code=400, detail="上衣和下衣图片至少需要提供一个") - - response = await dashscope_service.generate_tryon( - person_image_url=request.person_image_url, - top_garment_url=request.top_garment_url, - bottom_garment_url=request.bottom_garment_url, - resolution=request.resolution, - restore_face=request.restore_face - ) - - # 构建响应 - return { - "output": { - "task_id": response.get("output", {}).get("task_id", ""), - "task_status": response.get("output", {}).get("task_status", "") - }, - "request_id": response.get("request_id", "") - } - except Exception as e: - logger.error(f"DashScope 试穿API调用错误: {str(e)}") - raise HTTPException(status_code=500, detail=f"试穿请求失败: {str(e)}") - -class TaskStatusResponse(BaseModel): - """任务状态响应模型""" - task_id: str - task_status: str - completion_url: Optional[str] = None - request_id: str - -@router.get("/try-on/{task_id}", response_model=TaskStatusResponse) -async def check_tryon_status( - task_id: str, - dashscope_service: DashScopeService = Depends(lambda: DashScopeService()) -): - """ - 查询试穿任务状态 - - - **task_id**: 任务ID - """ - try: - response = await dashscope_service.check_tryon_status(task_id) - - # 构建响应 - return { - "task_id": task_id, - "task_status": response.get("output", {}).get("task_status", ""), - "completion_url": response.get("output", {}).get("url", ""), - "request_id": response.get("request_id", "") - } - except Exception as e: - logger.error(f"查询试穿任务状态错误: {str(e)}") - raise HTTPException(status_code=500, detail=f"查询任务状态失败: {str(e)}") - -@router.get("/models") -async def list_available_models(): - """ - 列出DashScope上可用的模型 - """ - try: - models = { - "chat_models": [ - {"id": "qwen-max", "name": "通义千问MAX"}, - {"id": "qwen-plus", "name": "通义千问Plus"}, - {"id": "qwen-turbo", "name": "通义千问Turbo"} - ], - "image_models": [ - {"id": "stable-diffusion-xl", "name": "Stable Diffusion XL"}, - {"id": "wanx-v1", "name": "万相"} - ], - "tryon_models": [ - {"id": "aitryon", "name": "AI试穿"} - ] - } - return models - except Exception as e: - logger.error(f"获取模型列表错误: {str(e)}") - raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") \ No newline at end of file diff --git a/app/routers/tryon_router.py b/app/routers/tryon_router.py index b40e466..0b43041 100644 --- a/app/routers/tryon_router.py +++ b/app/routers/tryon_router.py @@ -92,8 +92,8 @@ async def send_tryon_request( db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first() if db_tryon: db_tryon.request_id = response.get("request_id") - db_tryon.task_id = response.get("output", {}).get("task_id") - db_tryon.task_status = response.get("output", {}).get("task_status") + db_tryon.task_id = response.get("task_id") + db_tryon.task_status = response.get("status") db.commit() logger.info(f"试穿请求发送成功,任务ID: {db_tryon.task_id}") @@ -204,11 +204,11 @@ async def check_tryon_status( status_response = await dashscope_service.check_tryon_status(db_tryon.task_id) # 更新数据库记录 - db_tryon.task_status = status_response.get("output", {}).get("task_status") + db_tryon.task_status = status_response.get("status") # 如果任务完成,保存结果URL if db_tryon.task_status == "SUCCEEDED": - db_tryon.completion_url = status_response.get("output", {}).get("url") + db_tryon.completion_url = status_response.get("image_url") db.commit() db.refresh(db_tryon) diff --git a/app/services/dashscope_service.py b/app/services/dashscope_service.py index 07f7e71..3d5877c 100644 --- a/app/services/dashscope_service.py +++ b/app/services/dashscope_service.py @@ -165,11 +165,13 @@ class DashScopeService: # 构建请求头 headers = { "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" + "Content-Type": "application/json", + "X-DashScope-Async": "enable" } # 发送请求 async with httpx.AsyncClient() as client: + logger.info(f"发送试穿请求: {request_data}") response = await client.post( self.image_synthesis_url, json=request_data, @@ -177,10 +179,32 @@ class DashScopeService: timeout=30.0 ) - if response.status_code == 200: - response_data = response.json() - logger.info(f"试穿请求发送成功,任务ID: {response_data.get('output', {}).get('task_id')}") - return response_data + response_data = response.json() + logger.info(f"试穿API响应: {response_data}") + + if response.status_code == 200 or response.status_code == 202: # 202表示异步任务已接受 + # 提取任务ID,适应不同的API响应格式 + task_id = None + if 'output' in response_data and 'task_id' in response_data['output']: + task_id = response_data['output']['task_id'] + elif 'task_id' in response_data: + task_id = response_data['task_id'] + + if task_id: + logger.info(f"试穿请求发送成功,任务ID: {task_id}") + return { + "task_id": task_id, + "request_id": response_data.get('request_id'), + "status": "processing" + } + else: + # 如果没有任务ID,这可能是同步响应 + logger.info("收到同步响应,没有任务ID") + return { + "status": "completed", + "result": response_data.get('output', {}), + "request_id": response_data.get('request_id') + } else: error_msg = f"试穿请求失败: {response.status_code} - {response.text}" logger.error(error_msg) @@ -219,8 +243,26 @@ class DashScopeService: if response.status_code == 200: response_data = response.json() - logger.info(f"试穿任务状态查询成功: {response_data.get('output', {}).get('task_status')}") - return response_data + print(response_data) + status = response_data.get('output', {}).get('task_status', '') + logger.info(f"试穿任务状态查询成功: {status}") + + # 检查是否完成并返回结果 + if status.lower() == 'succeeded': + image_url = response_data.get('output', {}).get('image_url') + if image_url: + logger.info(f"试穿任务完成,结果URL: {image_url}") + return { + "status": status, + "task_id": task_id, + "image_url": image_url, + "result": response_data.get('output', {}) + } + + return { + "status": status, + "task_id": task_id + } else: error_msg = f"试穿任务状态查询失败: {response.status_code} - {response.text}" logger.error(error_msg) diff --git a/app/utils/config.py b/app/utils/config.py index 4659ced..415750a 100644 --- a/app/utils/config.py +++ b/app/utils/config.py @@ -1,16 +1,22 @@ import os +import logging from functools import lru_cache from typing import Optional from pydantic import BaseModel -from dotenv import load_dotenv +from dotenv import load_dotenv, find_dotenv -# 加载环境变量(如果直接从main导入,这里可能是冗余的,但为了安全起见) +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("config") + +# 尝试加载环境变量 +logger.info("开始加载环境变量...") load_dotenv() class Settings(BaseModel): """应用程序配置类""" # DashScope配置 - dashscope_api_key: str = os.getenv("DASHSCOPE_API_KEY", "sk-caa199589f1c451aaac471fad2986e28") + dashscope_api_key: str = os.getenv("DASHSCOPE_API_KEY", "") # 服务器配置 host: str = os.getenv("HOST", "0.0.0.0") @@ -18,17 +24,17 @@ class Settings(BaseModel): debug: bool = os.getenv("DEBUG", "False").lower() in ["true", "1", "t", "yes"] # 腾讯云配置 - qcloud_secret_id: str = os.getenv("QCLOUD_SECRET_ID", "AKIDxnbGj281iHtKallqqzvlV5YxBCrPltnS") - qcloud_secret_key: str = os.getenv("QCLOUD_SECRET_KEY", "ta6PXTMBsX7dzA7IN6uYUFn8F9uTovoU") + qcloud_secret_id: str = os.getenv("QCLOUD_SECRET_ID", "") + qcloud_secret_key: str = os.getenv("QCLOUD_SECRET_KEY", "") qcloud_cos_region: str = os.getenv("QCLOUD_COS_REGION", "ap-chengdu") - qcloud_cos_bucket: str = os.getenv("QCLOUD_COS_BUCKET", "aidress-1311994147") - qcloud_cos_domain: str = os.getenv("QCLOUD_COS_DOMAIN", "https://aidress-1311994147.cos.ap-chengdu.myqcloud.com") + qcloud_cos_bucket: str = os.getenv("QCLOUD_COS_BUCKET", "") + qcloud_cos_domain: str = os.getenv("QCLOUD_COS_DOMAIN", "") # 数据库配置 - db_host: str = os.getenv("DB_HOST", "gz-cynosdbmysql-grp-2j1cnopr.sql.tencentcdb.com") - db_port: int = int(os.getenv("DB_PORT", "27469")) + db_host: str = os.getenv("DB_HOST", "localhost") + db_port: int = int(os.getenv("DB_PORT", "3306")) db_user: str = os.getenv("DB_USER", "root") - db_password: str = os.getenv("DB_PASSWORD", "Aa#223388") + db_password: str = os.getenv("DB_PASSWORD", "") db_name: str = os.getenv("DB_NAME", "aidress") # 数据库URL @@ -54,7 +60,9 @@ def get_settings() -> Settings: Returns: Settings: 配置对象 """ - return Settings() + settings = Settings() + logger.info(f"配置已加载:DB_HOST={settings.db_host}, DB_NAME={settings.db_name}") + return settings def validate_api_key(): """