first commit

This commit is contained in:
aaron 2025-03-21 17:06:54 +08:00
commit 7bd63eefc5
28 changed files with 2348 additions and 0 deletions

42
.gitignore vendored Normal file
View File

@ -0,0 +1,42 @@
# 环境变量
.env
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# 虚拟环境
venv/
env/
ENV/
# IDE
.idea/
.vscode/
*.swp
*.swo
# 日志
*.log
# 操作系统
.DS_Store
Thumbs.db

259
README.md Normal file
View File

@ -0,0 +1,259 @@
# AI-Dressing API
基于 FastAPI 框架和阿里云 DashScope 接口的 AI 服务 API。
## 功能特点
- 集成阿里云 DashScope 的大模型 API支持通义千问系列模型
- 提供图像生成 API支持 Stable Diffusion XL、万相等模型
- 集成腾讯云 COS 对象存储服务,支持文件上传、下载等功能
- 使用 MySQL 数据库存储服装数据
- RESTful API 风格,支持异步处理
- 易于扩展的模块化设计
## 环境要求
- Python 3.8+
- 阿里云 DashScope API 密钥
- 腾讯云 API 密钥(用于 COS 对象存储)
- MySQL 5.7+ 数据库
## 数据库设置
在使用应用程序之前,您需要创建 MySQL 数据库。登录到 MySQL 并执行以下命令:
```sql
CREATE DATABASE ai_dressing CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
CREATE USER 'ai_user'@'%' IDENTIFIED BY 'your_password';
GRANT ALL PRIVILEGES ON ai_dressing.* TO 'ai_user'@'%';
FLUSH PRIVILEGES;
```
## 快速开始
### 1. 安装依赖
```bash
pip install -r requirements.txt
```
### 2. 配置环境变量
复制 `.env.example` 文件为 `.env`,并填写您的 API 密钥:
```bash
cp .env.example .env
```
然后编辑 `.env` 文件:
```
# DashScope API密钥
DASHSCOPE_API_KEY=your_api_key_here
# 腾讯云配置
QCLOUD_SECRET_ID=your_secret_id_here
QCLOUD_SECRET_KEY=your_secret_key_here
QCLOUD_COS_REGION=ap-guangzhou
QCLOUD_COS_BUCKET=your-bucket-name
QCLOUD_COS_DOMAIN=https://your-bucket-name.cos.ap-guangzhou.myqcloud.com
# 数据库配置
DB_HOST=localhost
DB_PORT=3306
DB_USER=ai_user
DB_PASSWORD=your_password
DB_NAME=ai_dressing
```
### 3. 数据库迁移
首次运行或模型变更后,需要创建或更新数据库表:
```bash
# 创建迁移(仅在模型变更时需要)
python create_migration.py create -m "初始化数据库"
# 应用迁移
python create_migration.py upgrade
```
### 4. 启动服务
```bash
python run.py
```
服务将在 http://localhost:9001 启动,您可以访问 http://localhost:9001/docs 查看 API 文档。
## API 文档
启动服务后,访问以下地址查看自动生成的 API 文档:
- Swagger UI: http://localhost:9001/docs
- ReDoc: http://localhost:9001/redoc
## 主要 API 端点
### 大模型对话
```
POST /api/dashscope/chat
```
### 图像生成
```
POST /api/dashscope/generate-image
```
### 获取支持的模型列表
```
GET /api/dashscope/models
```
### 文件上传到腾讯云 COS
```
POST /api/qcloud/upload
```
### 从腾讯云 COS 获取文件列表
```
GET /api/qcloud/files
```
### 生成 COS 临时上传凭证
```
POST /api/qcloud/sts-token
```
### 服装管理 API
```
# 创建服装
POST /api/dresses/
# 获取服装列表
GET /api/dresses/
# 获取单个服装
GET /api/dresses/{dress_id}
# 更新服装
PUT /api/dresses/{dress_id}
# 删除服装
DELETE /api/dresses/{dress_id}
```
## 项目结构
```
.
├── alembic/ # 数据库迁移相关文件
│ ├── versions/ # 迁移版本文件
│ ├── env.py # 迁移环境配置
│ └── script.py.mako # 迁移脚本模板
├── app/ # 应用主目录
│ ├── __init__.py
│ ├── main.py # FastAPI 应用程序
│ ├── database/ # 数据库相关
│ │ └── __init__.py
│ ├── models/ # 数据库模型
│ │ ├── __init__.py
│ │ └── dress.py
│ ├── routers/ # API 路由
│ │ ├── __init__.py
│ │ ├── dashscope_router.py
│ │ ├── qcloud_router.py
│ │ └── dress_router.py
│ ├── schemas/ # Pydantic 模型
│ │ ├── __init__.py
│ │ └── dress.py
│ ├── services/ # 业务逻辑服务
│ │ ├── __init__.py
│ │ ├── dashscope_service.py
│ │ └── qcloud_service.py
│ └── utils/ # 工具类
│ ├── __init__.py
│ └── config.py
├── .env.example # 环境变量示例
├── alembic.ini # Alembic 配置
├── create_migration.py # 数据库迁移工具
├── requirements.txt # 项目依赖
├── README.md # 项目文档
└── run.py # 应用入口
```
## 开发指南
### 添加新的路由
1. 在 `app/routers/` 目录下创建新的路由文件
2. 在 `app/main.py` 中注册新的路由
### 添加新的数据库模型
1. 在 `app/models/` 目录下创建新的模型文件
2. 在 `app/schemas/` 目录下创建对应的 Pydantic 模型
3. 创建数据库迁移: `python create_migration.py create -m "添加新模型"`
4. 应用迁移: `python create_migration.py upgrade`
### 添加新的服务
`app/services/` 目录下创建新的服务类。
## 腾讯云 COS 使用指南
### 前端直传文件
1. 首先从服务端获取临时上传凭证:
```javascript
const response = await fetch('/api/qcloud/sts-token', {
method: 'POST',
headers: {
'Content-Type': 'application/x-www-form-urlencoded',
},
body: 'allow_prefix=uploads/&duration_seconds=1800'
});
const stsData = await response.json();
```
2. 使用临时凭证进行文件上传:
```javascript
// 使用腾讯云COS SDK上传
const cos = new COS({
getAuthorization: function(options, callback) {
callback({
TmpSecretId: stsData.credentials.tmpSecretId,
TmpSecretKey: stsData.credentials.tmpSecretKey,
SecurityToken: stsData.credentials.sessionToken,
ExpiredTime: stsData.expiration
});
}
});
cos.putObject({
Bucket: stsData.bucket,
Region: stsData.region,
Key: 'uploads/example.jpg',
Body: file,
onProgress: function(progressData) {
console.log(progressData);
}
}, function(err, data) {
if (err) {
console.error('上传失败', err);
} else {
console.log('上传成功', data);
}
});
```
## 许可证
MIT

99
alembic.ini Normal file
View File

@ -0,0 +1,99 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = alembic
# template used to generate migration files
# file_template = %%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python-dateutil library that can be
# installed by adding `python-dateutil` to the requirements.
# For example: timezone = UTC
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator"
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. Valid values are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # default: use os.pathsep
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

87
alembic/env.py Normal file
View File

@ -0,0 +1,87 @@
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
# 导入应用的数据库配置和模型
from app.database import Base
from app.utils.config import get_settings
from app.models.dress import Dress # 导入所有模型
from app.models.tryon import TryOn
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# 从应用程序配置中获取数据库URL
settings = get_settings()
config.set_main_option("sqlalchemy.url", settings.database_url)
# Interpret the config file for Python logging.
# This line sets up loggers basically.
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def run_migrations_offline():
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online():
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

24
alembic/script.py.mako Normal file
View File

@ -0,0 +1,24 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision = ${repr(up_revision)}
down_revision = ${repr(down_revision)}
branch_labels = ${repr(branch_labels)}
depends_on = ${repr(depends_on)}
def upgrade():
${upgrades if upgrades else "pass"}
def downgrade():
${downgrades if downgrades else "pass"}

View File

@ -0,0 +1,42 @@
"""添加试穿记录表
Revision ID: f166e37d7b53
Revises:
Create Date: 2025-03-21 15:58:04.889147
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'f166e37d7b53'
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tryons',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('top_garment_url', sa.String(length=1024), nullable=True, comment='上衣图片URL'),
sa.Column('bottom_garment_url', sa.String(length=1024), nullable=True, comment='下衣图片URL'),
sa.Column('person_image_url', sa.String(length=1024), nullable=True, comment='人物图片URL'),
sa.Column('request_id', sa.String(length=255), nullable=True, comment='请求ID'),
sa.Column('task_id', sa.String(length=255), nullable=True, comment='任务ID'),
sa.Column('task_status', sa.String(length=50), nullable=True, comment='任务状态'),
sa.Column('completion_url', sa.String(length=1024), nullable=True, comment='生成图片URL'),
sa.Column('created_at', sa.DateTime(), nullable=True, comment='创建时间'),
sa.Column('updated_at', sa.DateTime(), nullable=True, comment='更新时间'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_tryons_id'), 'tryons', ['id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_tryons_id'), table_name='tryons')
op.drop_table('tryons')
# ### end Alembic commands ###

1
app/__init__.py Normal file
View File

@ -0,0 +1 @@

28
app/database/__init__.py Normal file
View File

@ -0,0 +1,28 @@
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from app.utils.config import get_settings
settings = get_settings()
# 创建数据库引擎
engine = create_engine(
settings.database_url,
pool_pre_ping=True, # 自动检测断开的连接并重新连接
pool_recycle=3600, # 每小时回收连接
echo=settings.debug, # 在调试模式下打印SQL语句
)
# 创建会话工厂
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建Base类所有模型都将继承这个类
Base = declarative_base()
# 获取数据库会话的依赖函数
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

68
app/main.py Normal file
View File

@ -0,0 +1,68 @@
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
import os
import logging
from dotenv import load_dotenv
from app.routers import dashscope_router, qcloud_router, dress_router, tryon_router
from app.utils.config import get_settings
from app.database import Base, engine
# 创建数据库表
Base.metadata.create_all(bind=engine)
# 加载环境变量
load_dotenv()
# 配置日志
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
app = FastAPI(
title="AI-Dressing API",
description="基于 DashScope 的 AI 服务 API",
version="0.1.0",
)
# 添加 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 在生产环境中,应该指定确切的域名
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 注册路由
app.include_router(dashscope_router.router, prefix="/api/dashscope", tags=["DashScope"])
app.include_router(qcloud_router.router, prefix="/api/qcloud", tags=["腾讯云"])
app.include_router(dress_router.router, prefix="/api/dresses", tags=["服装"])
app.include_router(tryon_router.router, prefix="/api/tryons", tags=["试穿"])
@app.get("/", tags=["健康检查"])
async def health_check():
"""API 健康检查端点"""
return {"status": "正常", "message": "服务运行中"}
@app.get("/info", tags=["服务信息"])
async def get_info():
"""获取服务基本信息"""
settings = get_settings()
return {
"app_name": "AI-Dressing API",
"version": "0.1.0",
"debug_mode": settings.debug,
}
if __name__ == "__main__":
import uvicorn
settings = get_settings()
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
reload=settings.debug,
)

1
app/models/__init__.py Normal file
View File

@ -0,0 +1 @@

27
app/models/dress.py Normal file
View File

@ -0,0 +1,27 @@
from sqlalchemy import Column, Integer, String, DateTime, Text, func, Enum
from sqlalchemy.ext.declarative import declarative_base
from app.database import Base
import datetime
import enum
class GarmentType(enum.Enum):
"""服装类型枚举"""
TOP_GARMENT = "TOP_GARMENT"
BOTTOM_GARMENT = "BOTTOM_GARMENT"
class Dress(Base):
"""服装模型类"""
__tablename__ = "dresses"
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
name = Column(String(255), nullable=False, comment="服装名称")
image_url = Column(String(1024), nullable=True, comment="服装图片URL")
garment_type = Column(Enum(GarmentType), nullable=True, comment="服装类型(上衣/下衣)")
# 以下是一些可选的附加字段,根据需要可以去掉或添加
description = Column(Text, nullable=True, comment="服装描述")
created_at = Column(DateTime, default=datetime.datetime.now(), comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now(), onupdate=datetime.datetime.now, comment="更新时间")
def __repr__(self):
return f"<Dress(id={self.id}, name='{self.name}')>"

23
app/models/tryon.py Normal file
View File

@ -0,0 +1,23 @@
from sqlalchemy import Column, Integer, String, DateTime, Text, func
from app.database import Base
import datetime
class TryOn(Base):
"""试穿记录模型类"""
__tablename__ = "tryons"
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
top_garment_url = Column(String(1024), nullable=True, comment="上衣图片URL")
bottom_garment_url = Column(String(1024), nullable=True, comment="下衣图片URL")
person_image_url = Column(String(1024), nullable=True, comment="人物图片URL")
request_id = Column(String(255), nullable=True, comment="请求ID")
task_id = Column(String(255), nullable=True, comment="任务ID")
task_status = Column(String(50), nullable=True, comment="任务状态")
completion_url = Column(String(1024), nullable=True, comment="生成图片URL")
created_at = Column(DateTime, default=datetime.datetime.now(), comment="创建时间")
updated_at = Column(DateTime, default=datetime.datetime.now(), onupdate=datetime.datetime.now, comment="更新时间")
def __repr__(self):
return f"<TryOn(id={self.id}, request_id='{self.request_id}')>"

1
app/routers/__init__.py Normal file
View File

@ -0,0 +1 @@

View File

@ -0,0 +1,262 @@
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
import logging
from typing import List, Optional, Dict, Any
import dashscope
from app.services.dashscope_service import DashScopeService
from app.utils.config import get_settings
logger = logging.getLogger(__name__)
router = APIRouter()
class ChatMessage(BaseModel):
role: str # 'user' 或 'assistant'
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
model: Optional[str] = "qwen-max" # 默认使用通义千问MAX模型
max_tokens: Optional[int] = 2048
temperature: Optional[float] = 0.7
stream: Optional[bool] = False
class ChatResponse(BaseModel):
response: str
usage: Dict[str, Any]
request_id: str
@router.post("/chat", response_model=ChatResponse)
async def chat_completion(
request: ChatRequest,
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
):
"""
调用DashScope的大模型进行对话
"""
try:
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
response = await dashscope_service.chat_completion(
messages=messages,
model=request.model,
max_tokens=request.max_tokens,
temperature=request.temperature,
stream=request.stream
)
# 确保我们能正确访问响应字段根据dashscope的最新版本调整
response_text = ""
if hasattr(response, 'output') and hasattr(response.output, 'text'):
response_text = response.output.text
elif hasattr(response, 'output') and isinstance(response.output, dict) and 'text' in response.output:
response_text = response.output['text']
elif hasattr(response, 'choices') and len(response.choices) > 0:
response_text = response.choices[0].message.content
else:
logger.warning(f"无法解析DashScope响应: {response}")
# 尝试保守提取
try:
if hasattr(response, 'output'):
if isinstance(response.output, dict):
response_text = str(response.output.get('text', ''))
else:
response_text = str(response.output)
else:
response_text = str(response)
except Exception as e:
logger.error(f"提取响应文本失败: {str(e)}")
response_text = "无法获取模型响应文本"
# 同样确保我们能正确访问usage和request_id
usage = {}
if hasattr(response, 'usage'):
usage = response.usage if isinstance(response.usage, dict) else vars(response.usage)
request_id = ""
if hasattr(response, 'request_id'):
request_id = response.request_id
return {
"response": response_text,
"usage": usage,
"request_id": request_id
}
except Exception as e:
logger.error(f"DashScope API 调用错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"模型调用失败: {str(e)}")
class ImageGenerationRequest(BaseModel):
prompt: str
negative_prompt: Optional[str] = None
model: Optional[str] = "stable-diffusion-xl"
n: Optional[int] = 1
size: Optional[str] = "1024*1024"
class ImageGenerationResponse(BaseModel):
images: List[str]
request_id: str
@router.post("/generate-image", response_model=ImageGenerationResponse)
async def generate_image(
request: ImageGenerationRequest,
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
):
"""
调用DashScope的图像生成API
"""
try:
response = await dashscope_service.generate_image(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
model=request.model,
n=request.n,
size=request.size
)
# 确保我们能正确访问images和request_id
images = []
if hasattr(response, 'output') and hasattr(response.output, 'images'):
images = response.output.images
elif hasattr(response, 'output') and isinstance(response.output, dict) and 'images' in response.output:
images = response.output['images']
else:
logger.warning(f"无法解析DashScope图像生成响应: {response}")
try:
if hasattr(response, 'output'):
if isinstance(response.output, dict):
images = response.output.get('images', [])
else:
images = []
else:
images = []
except Exception as e:
logger.error(f"提取图像URL失败: {str(e)}")
images = []
request_id = ""
if hasattr(response, 'request_id'):
request_id = response.request_id
return {
"images": images,
"request_id": request_id
}
except Exception as e:
logger.error(f"DashScope 图像生成API调用错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"图像生成失败: {str(e)}")
class TryOnRequest(BaseModel):
"""试穿请求模型"""
person_image_url: str
top_garment_url: Optional[str] = None
bottom_garment_url: Optional[str] = None
resolution: Optional[int] = -1
restore_face: Optional[bool] = True
class TaskInfo(BaseModel):
"""任务信息模型"""
task_id: str
task_status: str
class TryOnResponse(BaseModel):
"""试穿响应模型"""
output: TaskInfo
request_id: str
@router.post("/try-on", response_model=TryOnResponse)
async def generate_tryon(
request: TryOnRequest,
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
):
"""
调用DashScope的人物试衣API
该API用于给人物图片试穿上衣和/或下衣返回合成图片的任务ID
- **person_image_url**: 人物图片URL必填
- **top_garment_url**: 上衣图片URL可选
- **bottom_garment_url**: 下衣图片URL可选
- **resolution**: 分辨率-1表示自动可选
- **restore_face**: 是否修复面部可选
注意top_garment_url和bottom_garment_url至少需要提供一个
"""
try:
# 验证参数
if not request.top_garment_url and not request.bottom_garment_url:
raise HTTPException(status_code=400, detail="上衣和下衣图片至少需要提供一个")
response = await dashscope_service.generate_tryon(
person_image_url=request.person_image_url,
top_garment_url=request.top_garment_url,
bottom_garment_url=request.bottom_garment_url,
resolution=request.resolution,
restore_face=request.restore_face
)
# 构建响应
return {
"output": {
"task_id": response.get("output", {}).get("task_id", ""),
"task_status": response.get("output", {}).get("task_status", "")
},
"request_id": response.get("request_id", "")
}
except Exception as e:
logger.error(f"DashScope 试穿API调用错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"试穿请求失败: {str(e)}")
class TaskStatusResponse(BaseModel):
"""任务状态响应模型"""
task_id: str
task_status: str
completion_url: Optional[str] = None
request_id: str
@router.get("/try-on/{task_id}", response_model=TaskStatusResponse)
async def check_tryon_status(
task_id: str,
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
):
"""
查询试穿任务状态
- **task_id**: 任务ID
"""
try:
response = await dashscope_service.check_tryon_status(task_id)
# 构建响应
return {
"task_id": task_id,
"task_status": response.get("output", {}).get("task_status", ""),
"completion_url": response.get("output", {}).get("url", ""),
"request_id": response.get("request_id", "")
}
except Exception as e:
logger.error(f"查询试穿任务状态错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"查询任务状态失败: {str(e)}")
@router.get("/models")
async def list_available_models():
"""
列出DashScope上可用的模型
"""
try:
models = {
"chat_models": [
{"id": "qwen-max", "name": "通义千问MAX"},
{"id": "qwen-plus", "name": "通义千问Plus"},
{"id": "qwen-turbo", "name": "通义千问Turbo"}
],
"image_models": [
{"id": "stable-diffusion-xl", "name": "Stable Diffusion XL"},
{"id": "wanx-v1", "name": "万相"}
],
"tryon_models": [
{"id": "aitryon", "name": "AI试穿"}
]
}
return models
except Exception as e:
logger.error(f"获取模型列表错误: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}")

162
app/routers/dress_router.py Normal file
View File

@ -0,0 +1,162 @@
from fastapi import APIRouter, HTTPException, Depends, Query, Path
from sqlalchemy.orm import Session
from typing import List, Optional
import logging
from app.database import get_db
from app.models.dress import Dress, GarmentType
from app.schemas.dress import DressCreate, DressUpdate, DressResponse
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/", response_model=DressResponse, status_code=201)
async def create_dress(
dress: DressCreate,
db: Session = Depends(get_db)
):
"""
创建一个新的服装记录
- **name**: 服装名称必填
- **image_url**: 服装图片URL可选
- **garment_type**: 服装类型(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)可选
- **description**: 服装描述可选
"""
try:
db_dress = Dress(
name=dress.name,
image_url=dress.image_url,
garment_type=dress.garment_type,
description=dress.description
)
db.add(db_dress)
db.commit()
db.refresh(db_dress)
return db_dress
except Exception as e:
logger.error(f"创建服装记录失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"创建服装记录失败: {str(e)}")
@router.get("/", response_model=List[DressResponse])
async def get_dresses(
skip: int = Query(0, description="跳过的记录数量"),
limit: int = Query(100, description="返回的最大记录数量"),
name: Optional[str] = Query(None, description="按名称过滤"),
garment_type: Optional[str] = Query(None, description="按服装类型过滤(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)"),
db: Session = Depends(get_db)
):
"""
获取服装列表支持分页和过滤
- **skip**: 跳过的记录数量用于分页
- **limit**: 返回的最大记录数量用于分页
- **name**: 按名称过滤可选
- **garment_type**: 按服装类型过滤可选
"""
try:
query = db.query(Dress)
if name:
query = query.filter(Dress.name.ilike(f"%{name}%"))
if garment_type:
try:
garment_type_enum = GarmentType[garment_type]
query = query.filter(Dress.garment_type == garment_type_enum)
except KeyError:
logger.warning(f"无效的服装类型: {garment_type}")
# 继续查询,但不应用无效的过滤条件
dresses = query.offset(skip).limit(limit).all()
return dresses
except Exception as e:
logger.error(f"获取服装列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取服装列表失败: {str(e)}")
@router.get("/{dress_id}", response_model=DressResponse)
async def get_dress(
dress_id: int = Path(..., description="服装ID"),
db: Session = Depends(get_db)
):
"""
根据ID获取服装详情
- **dress_id**: 服装ID
"""
try:
dress = db.query(Dress).filter(Dress.id == dress_id).first()
if not dress:
raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装")
return dress
except HTTPException:
raise
except Exception as e:
logger.error(f"获取服装详情失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取服装详情失败: {str(e)}")
@router.put("/{dress_id}", response_model=DressResponse)
async def update_dress(
dress_id: int = Path(..., description="服装ID"),
dress: DressUpdate = None,
db: Session = Depends(get_db)
):
"""
更新服装信息
- **dress_id**: 服装ID
- **name**: 服装名称可选
- **image_url**: 服装图片URL可选
- **garment_type**: 服装类型(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)可选
- **description**: 服装描述可选
"""
try:
db_dress = db.query(Dress).filter(Dress.id == dress_id).first()
if not db_dress:
raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装")
# 更新提供的字段
if dress.name is not None:
db_dress.name = dress.name
if dress.image_url is not None:
db_dress.image_url = dress.image_url
if dress.garment_type is not None:
db_dress.garment_type = dress.garment_type
if dress.description is not None:
db_dress.description = dress.description
db.commit()
db.refresh(db_dress)
return db_dress
except HTTPException:
raise
except Exception as e:
logger.error(f"更新服装信息失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"更新服装信息失败: {str(e)}")
@router.delete("/{dress_id}", status_code=204)
async def delete_dress(
dress_id: int = Path(..., description="服装ID"),
db: Session = Depends(get_db)
):
"""
删除服装
- **dress_id**: 服装ID
"""
try:
db_dress = db.query(Dress).filter(Dress.id == dress_id).first()
if not db_dress:
raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装")
db.delete(db_dress)
db.commit()
return None
except HTTPException:
raise
except Exception as e:
logger.error(f"删除服装失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"删除服装失败: {str(e)}")

View File

@ -0,0 +1,155 @@
from fastapi import APIRouter, HTTPException, Depends, UploadFile, File, Form, Query
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import logging
from typing import List, Optional, Dict, Any
from app.services.qcloud_service import QCloudCOSService
logger = logging.getLogger(__name__)
router = APIRouter()
class STSTokenResponse(BaseModel):
"""STS临时凭证响应"""
credentials: Dict[str, str]
expiration: str
request_id: str
bucket: str
region: str
class FileInfo(BaseModel):
"""文件信息"""
key: str
url: str
size: int
last_modified: str
etag: str
class ListFilesResponse(BaseModel):
"""列出文件响应"""
files: List[FileInfo]
is_truncated: bool
next_marker: Optional[str] = None
common_prefixes: Optional[List[Dict[str, Any]]] = None
@router.post("/upload", tags=["腾讯云COS"])
async def upload_file(
file: UploadFile = File(...),
directory: str = Form("uploads"),
qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService())
):
"""
上传文件到腾讯云COS
- **file**: 要上传的文件
- **directory**: 存储目录默认为"uploads"
"""
try:
# 读取文件内容
file_content = await file.read()
# 上传文件
result = await qcloud_service.upload_file(
file_content=file_content,
file_name=file.filename,
directory=directory,
content_type=file.content_type
)
return result
except Exception as e:
logger.error(f"文件上传失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"文件上传失败: {str(e)}")
finally:
# 关闭文件
await file.close()
@router.delete("/files/{key:path}", tags=["腾讯云COS"])
async def delete_file(
key: str,
qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService())
):
"""
从腾讯云COS删除文件
- **key**: 文件的对象键COS路径
"""
try:
success = await qcloud_service.delete_file(key=key)
if success:
return {"message": "文件删除成功", "key": key}
else:
raise HTTPException(status_code=404, detail=f"文件删除失败,文件可能不存在")
except Exception as e:
logger.error(f"文件删除失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"文件删除失败: {str(e)}")
@router.get("/files/url/{key:path}", tags=["腾讯云COS"])
async def get_file_url(
key: str,
expires: int = Query(3600, description="URL有效期"),
qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService())
):
"""
获取COS文件的临时访问URL
- **key**: 文件的对象键COS路径
- **expires**: URL的有效期默认3600秒
"""
try:
url = await qcloud_service.get_file_url(key=key, expires=expires)
return {"url": url, "key": key, "expires": expires}
except Exception as e:
logger.error(f"获取文件URL失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取文件URL失败: {str(e)}")
@router.post("/sts-token", response_model=STSTokenResponse, tags=["腾讯云COS"])
async def generate_sts_token(
allow_prefix: str = Form("*", description="允许操作的对象前缀"),
duration_seconds: int = Form(1800, description="凭证有效期(秒)"),
qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService())
):
"""
生成腾讯云COS临时访问凭证STS用于前端直传
- **allow_prefix**: 允许操作的对象前缀默认为'*'表示所有对象
- **duration_seconds**: 凭证有效期默认1800秒
"""
try:
sts_result = await qcloud_service.generate_cos_sts_token(
allow_prefix=allow_prefix,
duration_seconds=duration_seconds
)
# 添加桶和区域信息,方便前端使用
sts_result['bucket'] = qcloud_service.bucket
sts_result['region'] = qcloud_service.region
return sts_result
except Exception as e:
logger.error(f"生成STS临时凭证失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"生成STS临时凭证失败: {str(e)}")
@router.get("/files", response_model=ListFilesResponse, tags=["腾讯云COS"])
async def list_files(
directory: str = Query("", description="目录前缀"),
limit: int = Query(100, description="返回的最大文件数"),
marker: str = Query("", description="分页标记"),
qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService())
):
"""
列出COS中的文件
- **directory**: 目录前缀
- **limit**: 返回的最大文件数
- **marker**: 分页标记
"""
try:
result = await qcloud_service.list_files(
directory=directory,
limit=limit,
marker=marker
)
return result
except Exception as e:
logger.error(f"列出文件失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"列出文件失败: {str(e)}")

225
app/routers/tryon_router.py Normal file
View File

@ -0,0 +1,225 @@
from fastapi import APIRouter, HTTPException, Depends, Query, Path, BackgroundTasks
from sqlalchemy.orm import Session
from typing import List, Optional
import httpx
import logging
import os
from dotenv import load_dotenv
from app.database import get_db
from app.models.tryon import TryOn
from app.schemas.tryon import (
TryOnCreate, TryOnUpdate, TryOnResponse,
AiTryonRequest, AiTryonResponse, TaskInfo
)
from app.services.dashscope_service import DashScopeService
# 加载环境变量
load_dotenv()
logger = logging.getLogger(__name__)
router = APIRouter()
# 从环境变量获取API密钥
DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis"
@router.post("/", response_model=TryOnResponse, status_code=201)
async def create_tryon(
tryon_data: TryOnCreate,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db)
):
"""
创建一个试穿记录并发送到阿里百炼平台
- **top_garment_url**: 上衣图片URL可选
- **bottom_garment_url**: 下衣图片URL可选
- **person_image_url**: 人物图片URL必填
注意上衣和下衣至少需要提供一个
"""
if not tryon_data.top_garment_url and not tryon_data.bottom_garment_url:
raise HTTPException(status_code=400, detail="上衣和下衣图片至少需要提供一个")
try:
# 创建试穿记录
db_tryon = TryOn(
top_garment_url=tryon_data.top_garment_url,
bottom_garment_url=tryon_data.bottom_garment_url,
person_image_url=tryon_data.person_image_url
)
db.add(db_tryon)
db.commit()
db.refresh(db_tryon)
# 在后台发送请求到阿里百炼平台
background_tasks.add_task(
send_tryon_request,
db=db,
tryon_id=db_tryon.id,
top_garment_url=tryon_data.top_garment_url,
bottom_garment_url=tryon_data.bottom_garment_url,
person_image_url=tryon_data.person_image_url
)
return db_tryon
except Exception as e:
logger.error(f"创建试穿记录失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"创建试穿记录失败: {str(e)}")
async def send_tryon_request(
db: Session,
tryon_id: int,
top_garment_url: Optional[str],
bottom_garment_url: Optional[str],
person_image_url: str
):
"""发送试穿请求到阿里百炼平台"""
try:
# 创建DashScopeService实例
dashscope_service = DashScopeService()
# 调用服务发送试穿请求
response = await dashscope_service.generate_tryon(
person_image_url=person_image_url,
top_garment_url=top_garment_url,
bottom_garment_url=bottom_garment_url
)
# 更新数据库记录
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
if db_tryon:
db_tryon.request_id = response.get("request_id")
db_tryon.task_id = response.get("output", {}).get("task_id")
db_tryon.task_status = response.get("output", {}).get("task_status")
db.commit()
logger.info(f"试穿请求发送成功任务ID: {db_tryon.task_id}")
except Exception as e:
logger.error(f"发送试穿请求异常: {str(e)}")
# 更新数据库记录状态
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
if db_tryon:
db_tryon.task_status = "ERROR"
db.commit()
@router.get("/", response_model=List[TryOnResponse])
async def get_tryons(
skip: int = Query(0, description="跳过的记录数量"),
limit: int = Query(100, description="返回的最大记录数量"),
db: Session = Depends(get_db)
):
"""
获取试穿记录列表支持分页
- **skip**: 跳过的记录数量用于分页
- **limit**: 返回的最大记录数量用于分页
"""
try:
tryons = db.query(TryOn).order_by(TryOn.created_at.desc()).offset(skip).limit(limit).all()
return tryons
except Exception as e:
logger.error(f"获取试穿记录列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取试穿记录列表失败: {str(e)}")
@router.get("/{tryon_id}", response_model=TryOnResponse)
async def get_tryon(
tryon_id: int = Path(..., description="试穿记录ID"),
db: Session = Depends(get_db)
):
"""
根据ID获取试穿记录详情
- **tryon_id**: 试穿记录ID
"""
try:
tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
if not tryon:
raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录")
return tryon
except HTTPException:
raise
except Exception as e:
logger.error(f"获取试穿记录详情失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取试穿记录详情失败: {str(e)}")
@router.put("/{tryon_id}", response_model=TryOnResponse)
async def update_tryon(
tryon_id: int = Path(..., description="试穿记录ID"),
tryon_data: TryOnUpdate = None,
db: Session = Depends(get_db)
):
"""
更新试穿记录信息
- **tryon_id**: 试穿记录ID
- **task_status**: 任务状态可选
- **completion_url**: 生成图片URL可选
"""
try:
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
if not db_tryon:
raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录")
# 更新提供的字段
if tryon_data.task_status is not None:
db_tryon.task_status = tryon_data.task_status
if tryon_data.completion_url is not None:
db_tryon.completion_url = tryon_data.completion_url
db.commit()
db.refresh(db_tryon)
return db_tryon
except HTTPException:
raise
except Exception as e:
logger.error(f"更新试穿记录失败: {str(e)}")
db.rollback()
raise HTTPException(status_code=500, detail=f"更新试穿记录失败: {str(e)}")
@router.post("/{tryon_id}/check", response_model=TryOnResponse)
async def check_tryon_status(
tryon_id: int = Path(..., description="试穿记录ID"),
db: Session = Depends(get_db),
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
):
"""
检查试穿任务状态
- **tryon_id**: 试穿记录ID
"""
try:
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
if not db_tryon:
raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录")
if not db_tryon.task_id:
raise HTTPException(status_code=400, detail=f"试穿记录未包含任务ID")
# 调用DashScopeService检查任务状态
try:
status_response = await dashscope_service.check_tryon_status(db_tryon.task_id)
# 更新数据库记录
db_tryon.task_status = status_response.get("output", {}).get("task_status")
# 如果任务完成保存结果URL
if db_tryon.task_status == "SUCCEEDED":
db_tryon.completion_url = status_response.get("output", {}).get("url")
db.commit()
db.refresh(db_tryon)
logger.info(f"试穿任务状态更新: {db_tryon.task_status}")
except Exception as e:
logger.error(f"调用DashScope API检查任务状态失败: {str(e)}")
return db_tryon
except HTTPException:
raise
except Exception as e:
logger.error(f"检查试穿任务状态失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"检查试穿任务状态失败: {str(e)}")

1
app/schemas/__init__.py Normal file
View File

@ -0,0 +1 @@

36
app/schemas/dress.py Normal file
View File

@ -0,0 +1,36 @@
from pydantic import BaseModel, Field, HttpUrl
from typing import Optional
from datetime import datetime
from enum import Enum
from app.models.dress import GarmentType
class DressBase(BaseModel):
"""服装基础模型"""
name: str = Field(..., description="服装名称", example="夏季连衣裙")
image_url: Optional[str] = Field(None, description="服装图片URL", example="https://example.com/dress1.jpg")
garment_type: Optional[GarmentType] = Field(None, description="服装类型(上衣/下衣)", example="TOP_GARMENT")
description: Optional[str] = Field(None, description="服装描述", example="一款适合夏季穿着的轻薄连衣裙")
class DressCreate(DressBase):
"""创建服装的请求模型"""
pass
class DressUpdate(BaseModel):
"""更新服装的请求模型"""
name: Optional[str] = Field(None, description="服装名称", example="夏季连衣裙")
image_url: Optional[str] = Field(None, description="服装图片URL", example="https://example.com/dress1.jpg")
garment_type: Optional[GarmentType] = Field(None, description="服装类型(上衣/下衣)", example="TOP_GARMENT")
description: Optional[str] = Field(None, description="服装描述", example="一款适合夏季穿着的轻薄连衣裙")
class DressInDB(DressBase):
"""数据库中的服装模型"""
id: int
created_at: datetime
updated_at: datetime
class Config:
orm_mode = True
class DressResponse(DressInDB):
"""服装API响应模型"""
pass

54
app/schemas/tryon.py Normal file
View File

@ -0,0 +1,54 @@
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
class TryOnBase(BaseModel):
"""试穿基础模型"""
top_garment_url: Optional[str] = Field(None, description="上衣图片URL", example="https://example.com/top.jpg")
bottom_garment_url: Optional[str] = Field(None, description="下衣图片URL", example="https://example.com/bottom.jpg")
person_image_url: str = Field(..., description="人物图片URL", example="https://example.com/person.jpg")
class TryOnCreate(TryOnBase):
"""创建试穿记录的请求模型"""
pass
class TryOnUpdate(BaseModel):
"""更新试穿记录的请求模型"""
task_status: Optional[str] = Field(None, description="任务状态", example="SUCCEEDED")
completion_url: Optional[str] = Field(None, description="生成图片URL", example="https://example.com/result.jpg")
class TryOnInDB(TryOnBase):
"""数据库中的试穿记录模型"""
id: int
request_id: Optional[str] = None
task_id: Optional[str] = None
task_status: Optional[str] = None
completion_url: Optional[str] = None
created_at: datetime
updated_at: datetime
class Config:
orm_mode = True
class TryOnResponse(TryOnInDB):
"""试穿记录API响应模型"""
pass
class AiTryonRequest(BaseModel):
"""阿里百炼平台试穿请求模型"""
model: str = "aitryon"
input: TryOnBase
parameters: dict = {
"resolution": -1,
"restore_face": True
}
class TaskInfo(BaseModel):
"""任务信息模型"""
task_id: str
task_status: str
class AiTryonResponse(BaseModel):
"""阿里百炼平台试穿响应模型"""
output: TaskInfo
request_id: str

1
app/services/__init__.py Normal file
View File

@ -0,0 +1 @@

View File

@ -0,0 +1,230 @@
import os
import logging
import dashscope
from dashscope import Generation
# 修改导入语句dashscope的API响应可能改变了结构
from typing import List, Dict, Any, Optional
import asyncio
import httpx
from app.utils.config import get_settings
logger = logging.getLogger(__name__)
class DashScopeService:
"""DashScope服务类提供对DashScope API的调用封装"""
def __init__(self):
settings = get_settings()
self.api_key = settings.dashscope_api_key
# 配置DashScope
dashscope.api_key = self.api_key
# 配置API URL
self.image_synthesis_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis"
async def chat_completion(
self,
messages: List[Dict[str, str]],
model: str = "qwen-max",
max_tokens: int = 2048,
temperature: float = 0.7,
stream: bool = False
):
"""
调用DashScope的大模型API进行对话
Args:
messages: 对话历史记录
model: 模型名称
max_tokens: 最大生成token数
temperature: 温度参数控制随机性
stream: 是否流式输出
Returns:
ApiResponse: DashScope的API响应
"""
try:
# 为了不阻塞FastAPI的异步性能我们使用run_in_executor运行同步API
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: Generation.call(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
result_format='message',
stream=stream,
)
)
if response.status_code != 200:
logger.error(f"DashScope API请求失败状态码{response.status_code}, 错误信息:{response.message}")
raise Exception(f"API调用失败: {response.message}")
return response
except Exception as e:
logger.error(f"DashScope聊天API调用出错: {str(e)}")
raise e
async def generate_image(
self,
prompt: str,
negative_prompt: Optional[str] = None,
model: str = "stable-diffusion-xl",
n: int = 1,
size: str = "1024*1024"
):
"""
调用DashScope的图像生成API
Args:
prompt: 生成图像的文本描述
negative_prompt: 负面提示词
model: 模型名称
n: 生成图像数量
size: 图像尺寸
Returns:
ApiResponse: DashScope的API响应
"""
try:
# 构建请求参数
params = {
"model": model,
"prompt": prompt,
"n": n,
"size": size,
}
if negative_prompt:
params["negative_prompt"] = negative_prompt
# 异步调用图像生成API
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: dashscope.ImageSynthesis.call(**params)
)
if response.status_code != 200:
logger.error(f"DashScope 图像生成API请求失败状态码{response.status_code}, 错误信息:{response.message}")
raise Exception(f"图像生成API调用失败: {response.message}")
return response
except Exception as e:
logger.error(f"DashScope图像生成API调用出错: {str(e)}")
raise e
async def generate_tryon(
self,
person_image_url: str,
top_garment_url: Optional[str] = None,
bottom_garment_url: Optional[str] = None,
resolution: int = -1,
restore_face: bool = True
):
"""
调用阿里百炼平台的试衣服务
Args:
person_image_url: 人物图片URL
top_garment_url: 上衣图片URL
bottom_garment_url: 下衣图片URL
resolution: 分辨率-1表示自动
restore_face: 是否修复面部
Returns:
Dict: 包含任务ID和请求ID的响应
"""
try:
# 验证参数
if not person_image_url:
raise ValueError("人物图片URL不能为空")
if not top_garment_url and not bottom_garment_url:
raise ValueError("上衣和下衣图片至少需要提供一个")
# 构建请求数据
request_data = {
"model": "aitryon",
"input": {
"person_image_url": person_image_url
},
"parameters": {
"resolution": resolution,
"restore_face": restore_face
}
}
# 添加可选字段
if top_garment_url:
request_data["input"]["top_garment_url"] = top_garment_url
if bottom_garment_url:
request_data["input"]["bottom_garment_url"] = bottom_garment_url
# 构建请求头
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 发送请求
async with httpx.AsyncClient() as client:
response = await client.post(
self.image_synthesis_url,
json=request_data,
headers=headers,
timeout=30.0
)
if response.status_code == 200:
response_data = response.json()
logger.info(f"试穿请求发送成功任务ID: {response_data.get('output', {}).get('task_id')}")
return response_data
else:
error_msg = f"试穿请求失败: {response.status_code} - {response.text}"
logger.error(error_msg)
raise Exception(error_msg)
except Exception as e:
logger.error(f"DashScope试穿API调用出错: {str(e)}")
raise e
async def check_tryon_status(self, task_id: str):
"""
检查试穿任务状态
Args:
task_id: 任务ID
Returns:
Dict: 任务状态信息
"""
try:
# 构建请求头
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求URL
status_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
# 发送请求
async with httpx.AsyncClient() as client:
response = await client.get(
status_url,
headers=headers,
timeout=30.0
)
if response.status_code == 200:
response_data = response.json()
logger.info(f"试穿任务状态查询成功: {response_data.get('output', {}).get('task_status')}")
return response_data
else:
error_msg = f"试穿任务状态查询失败: {response.status_code} - {response.text}"
logger.error(error_msg)
raise Exception(error_msg)
except Exception as e:
logger.error(f"查询试穿任务状态出错: {str(e)}")
raise e

View File

@ -0,0 +1,259 @@
import os
import logging
import time
import uuid
import base64
import hashlib
from typing import Optional, List, Dict, Any, BinaryIO, Union
from urllib.parse import urljoin
from qcloud_cos import CosConfig, CosS3Client
from qcloud_cos.cos_exception import CosServiceError, CosClientError
import sts.sts
from app.utils.config import get_settings
logger = logging.getLogger(__name__)
class QCloudCOSService:
"""腾讯云对象存储COS服务类"""
def __init__(self):
"""初始化腾讯云COS服务"""
settings = get_settings()
self.secret_id = settings.qcloud_secret_id
self.secret_key = settings.qcloud_secret_key
self.region = settings.qcloud_cos_region
self.bucket = settings.qcloud_cos_bucket
self.domain = settings.qcloud_cos_domain
# 创建COS客户端
config = CosConfig(
Region=self.region,
SecretId=self.secret_id,
SecretKey=self.secret_key
)
self.client = CosS3Client(config)
async def upload_file(
self,
file_content: Union[bytes, BinaryIO],
file_name: Optional[str] = None,
directory: str = "uploads",
content_type: Optional[str] = None
) -> Dict[str, str]:
"""
上传文件到腾讯云COS
Args:
file_content: 文件内容bytes或文件对象
file_name: 文件名如果不提供则生成随机文件名
directory: 存储目录
content_type: 文件MIME类型
Returns:
Dict[str, str]: 包含文件URL等信息的字典
"""
try:
# 生成文件名(如果未提供)
if not file_name:
# 使用UUID生成随机文件名
random_filename = str(uuid.uuid4())
if isinstance(file_content, bytes):
# 对于字节流,我们无法确定文件扩展名,默认用.bin
file_name = f"{random_filename}.bin"
else:
# 假设file_content是文件对象尝试从其name属性获取扩展名
if hasattr(file_content, 'name'):
original_name = os.path.basename(file_content.name)
ext = os.path.splitext(original_name)[1]
file_name = f"{random_filename}{ext}"
else:
file_name = f"{random_filename}.bin"
# 构建对象键文件在COS中的完整路径
key = f"{directory}/{file_name}"
# 上传文件
response = self.client.put_object(
Bucket=self.bucket,
Body=file_content,
Key=key,
ContentType=content_type
)
# 构建文件URL
file_url = urljoin(self.domain, key)
return {
"url": file_url,
"key": key,
"file_name": file_name,
"content_type": content_type,
"etag": response["ETag"] if "ETag" in response else None
}
except (CosServiceError, CosClientError) as e:
logger.error(f"腾讯云COS上传文件失败: {str(e)}")
raise Exception(f"文件上传失败: {str(e)}")
async def delete_file(self, key: str) -> bool:
"""
从腾讯云COS删除文件
Args:
key: 文件的对象键COS路径
Returns:
bool: 删除是否成功
"""
try:
self.client.delete_object(
Bucket=self.bucket,
Key=key
)
return True
except (CosServiceError, CosClientError) as e:
logger.error(f"腾讯云COS删除文件失败: {str(e)}")
return False
async def get_file_url(self, key: str, expires: int = 3600) -> str:
"""
获取COS文件的临时访问URL
Args:
key: 文件的对象键COS路径
expires: URL的有效期
Returns:
str: 临时访问URL
"""
try:
url = self.client.get_presigned_url(
Method='GET',
Bucket=self.bucket,
Key=key,
Expired=expires
)
return url
except (CosServiceError, CosClientError) as e:
logger.error(f"获取腾讯云COS文件URL失败: {str(e)}")
raise Exception(f"获取文件URL失败: {str(e)}")
async def generate_cos_sts_token(
self,
allow_actions: Optional[List[str]] = None,
allow_prefix: str = "*",
duration_seconds: int = 1800
) -> Dict[str, Any]:
"""
生成COS的临时安全凭证STS用于前端直传
Args:
allow_actions: 允许的COS操作列表
allow_prefix: 允许操作的对象前缀
duration_seconds: 凭证有效期
Returns:
Dict[str, Any]: STS凭证信息
"""
try:
if allow_actions is None:
# 默认只允许上传操作
allow_actions = [
'name/cos:PutObject',
'name/cos:PostObject',
'name/cos:InitiateMultipartUpload',
'name/cos:ListMultipartUploads',
'name/cos:ListParts',
'name/cos:UploadPart',
'name/cos:CompleteMultipartUpload'
]
# 配置STS
config = {
'url': 'https://sts.tencentcloudapi.com/',
'domain': 'sts.tencentcloudapi.com',
'duration_seconds': duration_seconds,
'secret_id': self.secret_id,
'secret_key': self.secret_key,
'region': self.region,
'policy': {
'version': '2.0',
'statement': [
{
'action': allow_actions,
'effect': 'allow',
'resource': [
f'qcs::cos:{self.region}:uid/:{self.bucket}/{allow_prefix}'
]
}
]
}
}
sts_client = sts.sts.Sts(config)
response = sts_client.get_credential()
return response
except Exception as e:
logger.error(f"生成腾讯云COS STS凭证失败: {str(e)}")
raise Exception(f"生成COS临时凭证失败: {str(e)}")
async def list_files(
self,
directory: str = "",
limit: int = 100,
marker: str = ""
) -> Dict[str, Any]:
"""
列出COS中的文件
Args:
directory: 目录前缀
limit: 返回的最大文件数
marker: 分页标记
Returns:
Dict[str, Any]: 文件列表信息
"""
try:
# 处理目录前缀
prefix = directory
if prefix and not prefix.endswith('/'):
prefix += '/'
# 调用COS API列出对象
response = self.client.list_objects(
Bucket=self.bucket,
Prefix=prefix,
Marker=marker,
MaxKeys=limit
)
# 处理响应
files = []
if 'Contents' in response:
for item in response['Contents']:
key = item.get('Key', '')
# 过滤出文件(忽略目录)
if not key.endswith('/'):
file_url = urljoin(self.domain, key)
files.append({
'key': key,
'url': file_url,
'size': item.get('Size', 0),
'last_modified': item.get('LastModified', ''),
'etag': item.get('ETag', '').strip('"')
})
return {
'files': files,
'is_truncated': response.get('IsTruncated', False),
'next_marker': response.get('NextMarker', ''),
'common_prefixes': response.get('CommonPrefixes', [])
}
except (CosServiceError, CosClientError) as e:
logger.error(f"列出腾讯云COS文件失败: {str(e)}")
raise Exception(f"列出文件失败: {str(e)}")

1
app/utils/__init__.py Normal file
View File

@ -0,0 +1 @@

83
app/utils/config.py Normal file
View File

@ -0,0 +1,83 @@
import os
from functools import lru_cache
from typing import Optional
from pydantic import BaseModel
from dotenv import load_dotenv
# 加载环境变量如果直接从main导入这里可能是冗余的但为了安全起见
load_dotenv()
class Settings(BaseModel):
"""应用程序配置类"""
# DashScope配置
dashscope_api_key: str = os.getenv("DASHSCOPE_API_KEY", "sk-caa199589f1c451aaac471fad2986e28")
# 服务器配置
host: str = os.getenv("HOST", "0.0.0.0")
port: int = int(os.getenv("PORT", "9001"))
debug: bool = os.getenv("DEBUG", "False").lower() in ["true", "1", "t", "yes"]
# 腾讯云配置
qcloud_secret_id: str = os.getenv("QCLOUD_SECRET_ID", "")
qcloud_secret_key: str = os.getenv("QCLOUD_SECRET_KEY", "")
qcloud_cos_region: str = os.getenv("QCLOUD_COS_REGION", "ap-guangzhou")
qcloud_cos_bucket: str = os.getenv("QCLOUD_COS_BUCKET", "")
qcloud_cos_domain: str = os.getenv("QCLOUD_COS_DOMAIN", "")
# 数据库配置
db_host: str = os.getenv("DB_HOST", "localhost")
db_port: int = int(os.getenv("DB_PORT", "3306"))
db_user: str = os.getenv("DB_USER", "root")
db_password: str = os.getenv("DB_PASSWORD", "password")
db_name: str = os.getenv("DB_NAME", "ai_dressing")
# 数据库URL
@property
def database_url(self) -> str:
"""
构建SQLAlchemy数据库连接URL
Returns:
str: 数据库连接URL
"""
return f"mysql+pymysql://{self.db_user}:{self.db_password}@{self.db_host}:{self.db_port}/{self.db_name}"
# 其他配置可以根据需要添加
log_level: str = os.getenv("LOG_LEVEL", "INFO")
@lru_cache()
def get_settings() -> Settings:
"""
获取应用程序配置
使用lru_cache装饰器避免重复读取环境变量提高性能
Returns:
Settings: 配置对象
"""
return Settings()
def validate_api_key():
"""
验证API密钥是否已配置
Raises:
ValueError: 如果API密钥未配置
"""
settings = get_settings()
if not settings.dashscope_api_key:
raise ValueError("DASHSCOPE_API_KEY未设置请在.env文件中配置")
return True
def validate_qcloud_config():
"""
验证腾讯云配置是否已设置
Raises:
ValueError: 如果腾讯云配置未正确设置
"""
settings = get_settings()
if not settings.qcloud_secret_id or not settings.qcloud_secret_key:
raise ValueError("腾讯云SecretId和SecretKey未设置请在.env文件中配置")
if not settings.qcloud_cos_bucket:
raise ValueError("腾讯云COS存储桶名称未设置请在.env文件中配置")
return True

60
create_migration.py Normal file
View File

@ -0,0 +1,60 @@
import os
import sys
import argparse
from alembic.config import Config
from alembic import command
def create_migration(message=None):
"""
创建数据库迁移
Args:
message: 迁移消息描述此次迁移的内容
"""
# 获取Alembic配置
alembic_cfg = Config("alembic.ini")
# 创建迁移
if message:
command.revision(alembic_cfg, message=message, autogenerate=True)
else:
command.revision(alembic_cfg, message="自动生成的迁移", autogenerate=True)
def upgrade_database():
"""更新数据库到最新版本"""
alembic_cfg = Config("alembic.ini")
command.upgrade(alembic_cfg, "head")
def downgrade_database(revision):
"""回滚数据库到指定版本"""
alembic_cfg = Config("alembic.ini")
command.downgrade(alembic_cfg, revision)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="数据库迁移工具")
subparsers = parser.add_subparsers(dest="command", help="命令")
# 创建迁移命令
create_parser = subparsers.add_parser("create", help="创建数据库迁移")
create_parser.add_argument("-m", "--message", help="迁移描述消息")
# 更新数据库命令
upgrade_parser = subparsers.add_parser("upgrade", help="更新数据库到最新版本")
# 回滚数据库命令
downgrade_parser = subparsers.add_parser("downgrade", help="回滚数据库到指定版本")
downgrade_parser.add_argument("revision", help="目标版本,例如:-1表示回滚一个版本")
args = parser.parse_args()
if args.command == "create":
create_migration(args.message)
print("迁移文件已创建,请查看 alembic/versions 目录")
elif args.command == "upgrade":
upgrade_database()
print("数据库已更新到最新版本")
elif args.command == "downgrade":
downgrade_database(args.revision)
print(f"数据库已回滚到版本: {args.revision}")
else:
parser.print_help()

12
requirements.txt Normal file
View File

@ -0,0 +1,12 @@
fastapi==0.104.0
uvicorn==0.23.2
dashscope>=1.13.0
python-dotenv==1.0.0
pydantic==2.4.2
httpx==0.25.0
cos-python-sdk-v5==1.9.26
qcloud-python-sts==3.1.4
sqlalchemy==2.0.23
pymysql==1.1.0
cryptography==41.0.5
alembic==1.12.1

105
run.py Normal file
View File

@ -0,0 +1,105 @@
import uvicorn
import logging
import sys
import os
from app.utils.config import get_settings, validate_api_key, validate_qcloud_config
from app.database import engine
def check_dependencies():
"""检查依赖项是否正确安装"""
try:
import dashscope
print(f"已加载 DashScope 版本: {getattr(dashscope, '__version__', '未知')}")
except ImportError as e:
print(f"错误: 无法导入 dashscope 模块: {str(e)}")
print("请确保已正确安装 dashscope: pip install -U dashscope")
return False
try:
from qcloud_cos import CosConfig
print("已加载腾讯云 COS SDK")
except ImportError as e:
print(f"警告: 无法导入腾讯云 COS SDK: {str(e)}")
print("腾讯云 COS 相关功能可能无法使用")
print("请确保已正确安装: pip install -U cos-python-sdk-v5 qcloud-python-sts")
try:
import sqlalchemy
print(f"已加载 SQLAlchemy 版本: {sqlalchemy.__version__}")
import pymysql
print(f"已加载 PyMySQL 版本: {pymysql.__version__}")
except ImportError as e:
print(f"错误: 无法导入数据库模块: {str(e)}")
print("请确保已正确安装: pip install -U sqlalchemy pymysql")
return False
return True
def check_database_connection():
"""检查数据库连接"""
try:
settings = get_settings()
# 尝试连接数据库
with engine.connect() as connection:
print(f"成功连接到数据库: {settings.db_name}")
return True
except Exception as e:
print(f"警告: 无法连接到数据库: {str(e)}")
print("请确保MySQL服务已启动并且数据库配置正确")
print("您可以使用以下命令创建数据库:")
settings = get_settings()
print(f"CREATE DATABASE {settings.db_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;")
print(f"GRANT ALL PRIVILEGES ON {settings.db_name}.* TO '{settings.db_user}'@'%' IDENTIFIED BY '{settings.db_password}';")
print("FLUSH PRIVILEGES;")
# 询问是否继续
if input("是否继续启动应用程序?(y/n): ").lower() not in ["y", "yes"]:
return False
return True
if __name__ == "__main__":
try:
# 检查依赖项
if not check_dependencies():
sys.exit(1)
# 验证API密钥是否配置
validate_api_key()
# 验证腾讯云配置
try:
validate_qcloud_config()
print("腾讯云COS配置已验证")
except ValueError as e:
print(f"警告: {str(e)}")
print("腾讯云COS相关功能可能无法正常使用但应用程序将继续启动")
# 检查数据库连接
if not check_database_connection():
sys.exit(1)
# 获取配置
settings = get_settings()
# 配置日志级别
log_level = getattr(logging, settings.log_level.upper(), logging.INFO)
logging.basicConfig(level=log_level)
# 启动服务器
print(f"启动服务,访问地址: http://{settings.host}:{settings.port}")
uvicorn.run(
"app.main:app",
host=settings.host,
port=settings.port,
reload=settings.debug,
log_level=settings.log_level.lower()
)
except Exception as e:
logging.error(f"启动失败: {str(e)}")
if hasattr(e, "__traceback__"):
import traceback
traceback.print_exception(type(e), e, e.__traceback__)
exit(1)