first commit
This commit is contained in:
commit
7bd63eefc5
42
.gitignore
vendored
Normal file
42
.gitignore
vendored
Normal file
@ -0,0 +1,42 @@
|
||||
# 环境变量
|
||||
.env
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
|
||||
# 虚拟环境
|
||||
venv/
|
||||
env/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# 日志
|
||||
*.log
|
||||
|
||||
# 操作系统
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
259
README.md
Normal file
259
README.md
Normal file
@ -0,0 +1,259 @@
|
||||
# AI-Dressing API
|
||||
|
||||
基于 FastAPI 框架和阿里云 DashScope 接口的 AI 服务 API。
|
||||
|
||||
## 功能特点
|
||||
|
||||
- 集成阿里云 DashScope 的大模型 API,支持通义千问系列模型
|
||||
- 提供图像生成 API,支持 Stable Diffusion XL、万相等模型
|
||||
- 集成腾讯云 COS 对象存储服务,支持文件上传、下载等功能
|
||||
- 使用 MySQL 数据库存储服装数据
|
||||
- RESTful API 风格,支持异步处理
|
||||
- 易于扩展的模块化设计
|
||||
|
||||
## 环境要求
|
||||
|
||||
- Python 3.8+
|
||||
- 阿里云 DashScope API 密钥
|
||||
- 腾讯云 API 密钥(用于 COS 对象存储)
|
||||
- MySQL 5.7+ 数据库
|
||||
|
||||
## 数据库设置
|
||||
|
||||
在使用应用程序之前,您需要创建 MySQL 数据库。登录到 MySQL 并执行以下命令:
|
||||
|
||||
```sql
|
||||
CREATE DATABASE ai_dressing CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
CREATE USER 'ai_user'@'%' IDENTIFIED BY 'your_password';
|
||||
GRANT ALL PRIVILEGES ON ai_dressing.* TO 'ai_user'@'%';
|
||||
FLUSH PRIVILEGES;
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 安装依赖
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. 配置环境变量
|
||||
|
||||
复制 `.env.example` 文件为 `.env`,并填写您的 API 密钥:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
然后编辑 `.env` 文件:
|
||||
|
||||
```
|
||||
# DashScope API密钥
|
||||
DASHSCOPE_API_KEY=your_api_key_here
|
||||
|
||||
# 腾讯云配置
|
||||
QCLOUD_SECRET_ID=your_secret_id_here
|
||||
QCLOUD_SECRET_KEY=your_secret_key_here
|
||||
QCLOUD_COS_REGION=ap-guangzhou
|
||||
QCLOUD_COS_BUCKET=your-bucket-name
|
||||
QCLOUD_COS_DOMAIN=https://your-bucket-name.cos.ap-guangzhou.myqcloud.com
|
||||
|
||||
# 数据库配置
|
||||
DB_HOST=localhost
|
||||
DB_PORT=3306
|
||||
DB_USER=ai_user
|
||||
DB_PASSWORD=your_password
|
||||
DB_NAME=ai_dressing
|
||||
```
|
||||
|
||||
### 3. 数据库迁移
|
||||
|
||||
首次运行或模型变更后,需要创建或更新数据库表:
|
||||
|
||||
```bash
|
||||
# 创建迁移(仅在模型变更时需要)
|
||||
python create_migration.py create -m "初始化数据库"
|
||||
|
||||
# 应用迁移
|
||||
python create_migration.py upgrade
|
||||
```
|
||||
|
||||
### 4. 启动服务
|
||||
|
||||
```bash
|
||||
python run.py
|
||||
```
|
||||
|
||||
服务将在 http://localhost:9001 启动,您可以访问 http://localhost:9001/docs 查看 API 文档。
|
||||
|
||||
## API 文档
|
||||
|
||||
启动服务后,访问以下地址查看自动生成的 API 文档:
|
||||
|
||||
- Swagger UI: http://localhost:9001/docs
|
||||
- ReDoc: http://localhost:9001/redoc
|
||||
|
||||
## 主要 API 端点
|
||||
|
||||
### 大模型对话
|
||||
|
||||
```
|
||||
POST /api/dashscope/chat
|
||||
```
|
||||
|
||||
### 图像生成
|
||||
|
||||
```
|
||||
POST /api/dashscope/generate-image
|
||||
```
|
||||
|
||||
### 获取支持的模型列表
|
||||
|
||||
```
|
||||
GET /api/dashscope/models
|
||||
```
|
||||
|
||||
### 文件上传到腾讯云 COS
|
||||
|
||||
```
|
||||
POST /api/qcloud/upload
|
||||
```
|
||||
|
||||
### 从腾讯云 COS 获取文件列表
|
||||
|
||||
```
|
||||
GET /api/qcloud/files
|
||||
```
|
||||
|
||||
### 生成 COS 临时上传凭证
|
||||
|
||||
```
|
||||
POST /api/qcloud/sts-token
|
||||
```
|
||||
|
||||
### 服装管理 API
|
||||
|
||||
```
|
||||
# 创建服装
|
||||
POST /api/dresses/
|
||||
|
||||
# 获取服装列表
|
||||
GET /api/dresses/
|
||||
|
||||
# 获取单个服装
|
||||
GET /api/dresses/{dress_id}
|
||||
|
||||
# 更新服装
|
||||
PUT /api/dresses/{dress_id}
|
||||
|
||||
# 删除服装
|
||||
DELETE /api/dresses/{dress_id}
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
.
|
||||
├── alembic/ # 数据库迁移相关文件
|
||||
│ ├── versions/ # 迁移版本文件
|
||||
│ ├── env.py # 迁移环境配置
|
||||
│ └── script.py.mako # 迁移脚本模板
|
||||
├── app/ # 应用主目录
|
||||
│ ├── __init__.py
|
||||
│ ├── main.py # FastAPI 应用程序
|
||||
│ ├── database/ # 数据库相关
|
||||
│ │ └── __init__.py
|
||||
│ ├── models/ # 数据库模型
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── dress.py
|
||||
│ ├── routers/ # API 路由
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── dashscope_router.py
|
||||
│ │ ├── qcloud_router.py
|
||||
│ │ └── dress_router.py
|
||||
│ ├── schemas/ # Pydantic 模型
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── dress.py
|
||||
│ ├── services/ # 业务逻辑服务
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── dashscope_service.py
|
||||
│ │ └── qcloud_service.py
|
||||
│ └── utils/ # 工具类
|
||||
│ ├── __init__.py
|
||||
│ └── config.py
|
||||
├── .env.example # 环境变量示例
|
||||
├── alembic.ini # Alembic 配置
|
||||
├── create_migration.py # 数据库迁移工具
|
||||
├── requirements.txt # 项目依赖
|
||||
├── README.md # 项目文档
|
||||
└── run.py # 应用入口
|
||||
```
|
||||
|
||||
## 开发指南
|
||||
|
||||
### 添加新的路由
|
||||
|
||||
1. 在 `app/routers/` 目录下创建新的路由文件
|
||||
2. 在 `app/main.py` 中注册新的路由
|
||||
|
||||
### 添加新的数据库模型
|
||||
|
||||
1. 在 `app/models/` 目录下创建新的模型文件
|
||||
2. 在 `app/schemas/` 目录下创建对应的 Pydantic 模型
|
||||
3. 创建数据库迁移: `python create_migration.py create -m "添加新模型"`
|
||||
4. 应用迁移: `python create_migration.py upgrade`
|
||||
|
||||
### 添加新的服务
|
||||
|
||||
在 `app/services/` 目录下创建新的服务类。
|
||||
|
||||
## 腾讯云 COS 使用指南
|
||||
|
||||
### 前端直传文件
|
||||
|
||||
1. 首先从服务端获取临时上传凭证:
|
||||
```javascript
|
||||
const response = await fetch('/api/qcloud/sts-token', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
},
|
||||
body: 'allow_prefix=uploads/&duration_seconds=1800'
|
||||
});
|
||||
const stsData = await response.json();
|
||||
```
|
||||
|
||||
2. 使用临时凭证进行文件上传:
|
||||
```javascript
|
||||
// 使用腾讯云COS SDK上传
|
||||
const cos = new COS({
|
||||
getAuthorization: function(options, callback) {
|
||||
callback({
|
||||
TmpSecretId: stsData.credentials.tmpSecretId,
|
||||
TmpSecretKey: stsData.credentials.tmpSecretKey,
|
||||
SecurityToken: stsData.credentials.sessionToken,
|
||||
ExpiredTime: stsData.expiration
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
cos.putObject({
|
||||
Bucket: stsData.bucket,
|
||||
Region: stsData.region,
|
||||
Key: 'uploads/example.jpg',
|
||||
Body: file,
|
||||
onProgress: function(progressData) {
|
||||
console.log(progressData);
|
||||
}
|
||||
}, function(err, data) {
|
||||
if (err) {
|
||||
console.error('上传失败', err);
|
||||
} else {
|
||||
console.log('上传成功', data);
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT
|
||||
99
alembic.ini
Normal file
99
alembic.ini
Normal file
@ -0,0 +1,99 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = alembic
|
||||
|
||||
# template used to generate migration files
|
||||
# file_template = %%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python-dateutil library that can be
|
||||
# installed by adding `python-dateutil` to the requirements.
|
||||
# For example: timezone = UTC
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the
|
||||
# "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to alembic/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator"
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. Valid values are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
version_path_separator = os # default: use os.pathsep
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
87
alembic/env.py
Normal file
87
alembic/env.py
Normal file
@ -0,0 +1,87 @@
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
|
||||
from alembic import context
|
||||
|
||||
# 导入应用的数据库配置和模型
|
||||
from app.database import Base
|
||||
from app.utils.config import get_settings
|
||||
from app.models.dress import Dress # 导入所有模型
|
||||
from app.models.tryon import TryOn
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
config = context.config
|
||||
|
||||
# 从应用程序配置中获取数据库URL
|
||||
settings = get_settings()
|
||||
config.set_main_option("sqlalchemy.url", settings.database_url)
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
|
||||
def run_migrations_offline():
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online():
|
||||
"""Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
24
alembic/script.py.mako
Normal file
24
alembic/script.py.mako
Normal file
@ -0,0 +1,24 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = ${repr(up_revision)}
|
||||
down_revision = ${repr(down_revision)}
|
||||
branch_labels = ${repr(branch_labels)}
|
||||
depends_on = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade():
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade():
|
||||
${downgrades if downgrades else "pass"}
|
||||
42
alembic/versions/f166e37d7b53_添加试穿记录表.py
Normal file
42
alembic/versions/f166e37d7b53_添加试穿记录表.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""添加试穿记录表
|
||||
|
||||
Revision ID: f166e37d7b53
|
||||
Revises:
|
||||
Create Date: 2025-03-21 15:58:04.889147
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'f166e37d7b53'
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tryons',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('top_garment_url', sa.String(length=1024), nullable=True, comment='上衣图片URL'),
|
||||
sa.Column('bottom_garment_url', sa.String(length=1024), nullable=True, comment='下衣图片URL'),
|
||||
sa.Column('person_image_url', sa.String(length=1024), nullable=True, comment='人物图片URL'),
|
||||
sa.Column('request_id', sa.String(length=255), nullable=True, comment='请求ID'),
|
||||
sa.Column('task_id', sa.String(length=255), nullable=True, comment='任务ID'),
|
||||
sa.Column('task_status', sa.String(length=50), nullable=True, comment='任务状态'),
|
||||
sa.Column('completion_url', sa.String(length=1024), nullable=True, comment='生成图片URL'),
|
||||
sa.Column('created_at', sa.DateTime(), nullable=True, comment='创建时间'),
|
||||
sa.Column('updated_at', sa.DateTime(), nullable=True, comment='更新时间'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_tryons_id'), 'tryons', ['id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_tryons_id'), table_name='tryons')
|
||||
op.drop_table('tryons')
|
||||
# ### end Alembic commands ###
|
||||
1
app/__init__.py
Normal file
1
app/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
28
app/database/__init__.py
Normal file
28
app/database/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from app.utils.config import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# 创建数据库引擎
|
||||
engine = create_engine(
|
||||
settings.database_url,
|
||||
pool_pre_ping=True, # 自动检测断开的连接并重新连接
|
||||
pool_recycle=3600, # 每小时回收连接
|
||||
echo=settings.debug, # 在调试模式下打印SQL语句
|
||||
)
|
||||
|
||||
# 创建会话工厂
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# 创建Base类,所有模型都将继承这个类
|
||||
Base = declarative_base()
|
||||
|
||||
# 获取数据库会话的依赖函数
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
68
app/main.py
Normal file
68
app/main.py
Normal file
@ -0,0 +1,68 @@
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import os
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.routers import dashscope_router, qcloud_router, dress_router, tryon_router
|
||||
from app.utils.config import get_settings
|
||||
from app.database import Base, engine
|
||||
|
||||
# 创建数据库表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title="AI-Dressing API",
|
||||
description="基于 DashScope 的 AI 服务 API",
|
||||
version="0.1.0",
|
||||
)
|
||||
|
||||
# 添加 CORS 中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 在生产环境中,应该指定确切的域名
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 注册路由
|
||||
app.include_router(dashscope_router.router, prefix="/api/dashscope", tags=["DashScope"])
|
||||
app.include_router(qcloud_router.router, prefix="/api/qcloud", tags=["腾讯云"])
|
||||
app.include_router(dress_router.router, prefix="/api/dresses", tags=["服装"])
|
||||
app.include_router(tryon_router.router, prefix="/api/tryons", tags=["试穿"])
|
||||
|
||||
@app.get("/", tags=["健康检查"])
|
||||
async def health_check():
|
||||
"""API 健康检查端点"""
|
||||
return {"status": "正常", "message": "服务运行中"}
|
||||
|
||||
@app.get("/info", tags=["服务信息"])
|
||||
async def get_info():
|
||||
"""获取服务基本信息"""
|
||||
settings = get_settings()
|
||||
return {
|
||||
"app_name": "AI-Dressing API",
|
||||
"version": "0.1.0",
|
||||
"debug_mode": settings.debug,
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
settings = get_settings()
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=settings.debug,
|
||||
)
|
||||
1
app/models/__init__.py
Normal file
1
app/models/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
27
app/models/dress.py
Normal file
27
app/models/dress.py
Normal file
@ -0,0 +1,27 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Text, func, Enum
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from app.database import Base
|
||||
import datetime
|
||||
import enum
|
||||
|
||||
class GarmentType(enum.Enum):
|
||||
"""服装类型枚举"""
|
||||
TOP_GARMENT = "TOP_GARMENT"
|
||||
BOTTOM_GARMENT = "BOTTOM_GARMENT"
|
||||
|
||||
class Dress(Base):
|
||||
"""服装模型类"""
|
||||
__tablename__ = "dresses"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||
name = Column(String(255), nullable=False, comment="服装名称")
|
||||
image_url = Column(String(1024), nullable=True, comment="服装图片URL")
|
||||
garment_type = Column(Enum(GarmentType), nullable=True, comment="服装类型(上衣/下衣)")
|
||||
|
||||
# 以下是一些可选的附加字段,根据需要可以去掉或添加
|
||||
description = Column(Text, nullable=True, comment="服装描述")
|
||||
created_at = Column(DateTime, default=datetime.datetime.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now(), onupdate=datetime.datetime.now, comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Dress(id={self.id}, name='{self.name}')>"
|
||||
23
app/models/tryon.py
Normal file
23
app/models/tryon.py
Normal file
@ -0,0 +1,23 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Text, func
|
||||
from app.database import Base
|
||||
import datetime
|
||||
|
||||
class TryOn(Base):
|
||||
"""试穿记录模型类"""
|
||||
__tablename__ = "tryons"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||
top_garment_url = Column(String(1024), nullable=True, comment="上衣图片URL")
|
||||
bottom_garment_url = Column(String(1024), nullable=True, comment="下衣图片URL")
|
||||
person_image_url = Column(String(1024), nullable=True, comment="人物图片URL")
|
||||
|
||||
request_id = Column(String(255), nullable=True, comment="请求ID")
|
||||
task_id = Column(String(255), nullable=True, comment="任务ID")
|
||||
task_status = Column(String(50), nullable=True, comment="任务状态")
|
||||
completion_url = Column(String(1024), nullable=True, comment="生成图片URL")
|
||||
|
||||
created_at = Column(DateTime, default=datetime.datetime.now(), comment="创建时间")
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now(), onupdate=datetime.datetime.now, comment="更新时间")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TryOn(id={self.id}, request_id='{self.request_id}')>"
|
||||
1
app/routers/__init__.py
Normal file
1
app/routers/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
262
app/routers/dashscope_router.py
Normal file
262
app/routers/dashscope_router.py
Normal file
@ -0,0 +1,262 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
import dashscope
|
||||
from app.services.dashscope_service import DashScopeService
|
||||
from app.utils.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str # 'user' 或 'assistant'
|
||||
content: str
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: List[ChatMessage]
|
||||
model: Optional[str] = "qwen-max" # 默认使用通义千问MAX模型
|
||||
max_tokens: Optional[int] = 2048
|
||||
temperature: Optional[float] = 0.7
|
||||
stream: Optional[bool] = False
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
response: str
|
||||
usage: Dict[str, Any]
|
||||
request_id: str
|
||||
|
||||
@router.post("/chat", response_model=ChatResponse)
|
||||
async def chat_completion(
|
||||
request: ChatRequest,
|
||||
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||||
):
|
||||
"""
|
||||
调用DashScope的大模型进行对话
|
||||
"""
|
||||
try:
|
||||
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||
response = await dashscope_service.chat_completion(
|
||||
messages=messages,
|
||||
model=request.model,
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
stream=request.stream
|
||||
)
|
||||
|
||||
# 确保我们能正确访问响应字段,根据dashscope的最新版本调整
|
||||
response_text = ""
|
||||
if hasattr(response, 'output') and hasattr(response.output, 'text'):
|
||||
response_text = response.output.text
|
||||
elif hasattr(response, 'output') and isinstance(response.output, dict) and 'text' in response.output:
|
||||
response_text = response.output['text']
|
||||
elif hasattr(response, 'choices') and len(response.choices) > 0:
|
||||
response_text = response.choices[0].message.content
|
||||
else:
|
||||
logger.warning(f"无法解析DashScope响应: {response}")
|
||||
# 尝试保守提取
|
||||
try:
|
||||
if hasattr(response, 'output'):
|
||||
if isinstance(response.output, dict):
|
||||
response_text = str(response.output.get('text', ''))
|
||||
else:
|
||||
response_text = str(response.output)
|
||||
else:
|
||||
response_text = str(response)
|
||||
except Exception as e:
|
||||
logger.error(f"提取响应文本失败: {str(e)}")
|
||||
response_text = "无法获取模型响应文本"
|
||||
|
||||
# 同样确保我们能正确访问usage和request_id
|
||||
usage = {}
|
||||
if hasattr(response, 'usage'):
|
||||
usage = response.usage if isinstance(response.usage, dict) else vars(response.usage)
|
||||
|
||||
request_id = ""
|
||||
if hasattr(response, 'request_id'):
|
||||
request_id = response.request_id
|
||||
|
||||
return {
|
||||
"response": response_text,
|
||||
"usage": usage,
|
||||
"request_id": request_id
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope API 调用错误: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"模型调用失败: {str(e)}")
|
||||
|
||||
class ImageGenerationRequest(BaseModel):
|
||||
prompt: str
|
||||
negative_prompt: Optional[str] = None
|
||||
model: Optional[str] = "stable-diffusion-xl"
|
||||
n: Optional[int] = 1
|
||||
size: Optional[str] = "1024*1024"
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
images: List[str]
|
||||
request_id: str
|
||||
|
||||
@router.post("/generate-image", response_model=ImageGenerationResponse)
|
||||
async def generate_image(
|
||||
request: ImageGenerationRequest,
|
||||
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||||
):
|
||||
"""
|
||||
调用DashScope的图像生成API
|
||||
"""
|
||||
try:
|
||||
response = await dashscope_service.generate_image(
|
||||
prompt=request.prompt,
|
||||
negative_prompt=request.negative_prompt,
|
||||
model=request.model,
|
||||
n=request.n,
|
||||
size=request.size
|
||||
)
|
||||
|
||||
# 确保我们能正确访问images和request_id
|
||||
images = []
|
||||
if hasattr(response, 'output') and hasattr(response.output, 'images'):
|
||||
images = response.output.images
|
||||
elif hasattr(response, 'output') and isinstance(response.output, dict) and 'images' in response.output:
|
||||
images = response.output['images']
|
||||
else:
|
||||
logger.warning(f"无法解析DashScope图像生成响应: {response}")
|
||||
try:
|
||||
if hasattr(response, 'output'):
|
||||
if isinstance(response.output, dict):
|
||||
images = response.output.get('images', [])
|
||||
else:
|
||||
images = []
|
||||
else:
|
||||
images = []
|
||||
except Exception as e:
|
||||
logger.error(f"提取图像URL失败: {str(e)}")
|
||||
images = []
|
||||
|
||||
request_id = ""
|
||||
if hasattr(response, 'request_id'):
|
||||
request_id = response.request_id
|
||||
|
||||
return {
|
||||
"images": images,
|
||||
"request_id": request_id
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope 图像生成API调用错误: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"图像生成失败: {str(e)}")
|
||||
|
||||
class TryOnRequest(BaseModel):
|
||||
"""试穿请求模型"""
|
||||
person_image_url: str
|
||||
top_garment_url: Optional[str] = None
|
||||
bottom_garment_url: Optional[str] = None
|
||||
resolution: Optional[int] = -1
|
||||
restore_face: Optional[bool] = True
|
||||
|
||||
class TaskInfo(BaseModel):
|
||||
"""任务信息模型"""
|
||||
task_id: str
|
||||
task_status: str
|
||||
|
||||
class TryOnResponse(BaseModel):
|
||||
"""试穿响应模型"""
|
||||
output: TaskInfo
|
||||
request_id: str
|
||||
|
||||
@router.post("/try-on", response_model=TryOnResponse)
|
||||
async def generate_tryon(
|
||||
request: TryOnRequest,
|
||||
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||||
):
|
||||
"""
|
||||
调用DashScope的人物试衣API
|
||||
|
||||
该API用于给人物图片试穿上衣和/或下衣,返回合成图片的任务ID
|
||||
|
||||
- **person_image_url**: 人物图片URL(必填)
|
||||
- **top_garment_url**: 上衣图片URL(可选)
|
||||
- **bottom_garment_url**: 下衣图片URL(可选)
|
||||
- **resolution**: 分辨率,-1表示自动(可选)
|
||||
- **restore_face**: 是否修复面部(可选)
|
||||
|
||||
注意:top_garment_url和bottom_garment_url至少需要提供一个
|
||||
"""
|
||||
try:
|
||||
# 验证参数
|
||||
if not request.top_garment_url and not request.bottom_garment_url:
|
||||
raise HTTPException(status_code=400, detail="上衣和下衣图片至少需要提供一个")
|
||||
|
||||
response = await dashscope_service.generate_tryon(
|
||||
person_image_url=request.person_image_url,
|
||||
top_garment_url=request.top_garment_url,
|
||||
bottom_garment_url=request.bottom_garment_url,
|
||||
resolution=request.resolution,
|
||||
restore_face=request.restore_face
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
return {
|
||||
"output": {
|
||||
"task_id": response.get("output", {}).get("task_id", ""),
|
||||
"task_status": response.get("output", {}).get("task_status", "")
|
||||
},
|
||||
"request_id": response.get("request_id", "")
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope 试穿API调用错误: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"试穿请求失败: {str(e)}")
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
"""任务状态响应模型"""
|
||||
task_id: str
|
||||
task_status: str
|
||||
completion_url: Optional[str] = None
|
||||
request_id: str
|
||||
|
||||
@router.get("/try-on/{task_id}", response_model=TaskStatusResponse)
|
||||
async def check_tryon_status(
|
||||
task_id: str,
|
||||
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||||
):
|
||||
"""
|
||||
查询试穿任务状态
|
||||
|
||||
- **task_id**: 任务ID
|
||||
"""
|
||||
try:
|
||||
response = await dashscope_service.check_tryon_status(task_id)
|
||||
|
||||
# 构建响应
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"task_status": response.get("output", {}).get("task_status", ""),
|
||||
"completion_url": response.get("output", {}).get("url", ""),
|
||||
"request_id": response.get("request_id", "")
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"查询试穿任务状态错误: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"查询任务状态失败: {str(e)}")
|
||||
|
||||
@router.get("/models")
|
||||
async def list_available_models():
|
||||
"""
|
||||
列出DashScope上可用的模型
|
||||
"""
|
||||
try:
|
||||
models = {
|
||||
"chat_models": [
|
||||
{"id": "qwen-max", "name": "通义千问MAX"},
|
||||
{"id": "qwen-plus", "name": "通义千问Plus"},
|
||||
{"id": "qwen-turbo", "name": "通义千问Turbo"}
|
||||
],
|
||||
"image_models": [
|
||||
{"id": "stable-diffusion-xl", "name": "Stable Diffusion XL"},
|
||||
{"id": "wanx-v1", "name": "万相"}
|
||||
],
|
||||
"tryon_models": [
|
||||
{"id": "aitryon", "name": "AI试穿"}
|
||||
]
|
||||
}
|
||||
return models
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型列表错误: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}")
|
||||
162
app/routers/dress_router.py
Normal file
162
app/routers/dress_router.py
Normal file
@ -0,0 +1,162 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, Path
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.dress import Dress, GarmentType
|
||||
from app.schemas.dress import DressCreate, DressUpdate, DressResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/", response_model=DressResponse, status_code=201)
|
||||
async def create_dress(
|
||||
dress: DressCreate,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建一个新的服装记录
|
||||
|
||||
- **name**: 服装名称(必填)
|
||||
- **image_url**: 服装图片URL(可选)
|
||||
- **garment_type**: 服装类型(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)(可选)
|
||||
- **description**: 服装描述(可选)
|
||||
"""
|
||||
try:
|
||||
db_dress = Dress(
|
||||
name=dress.name,
|
||||
image_url=dress.image_url,
|
||||
garment_type=dress.garment_type,
|
||||
description=dress.description
|
||||
)
|
||||
db.add(db_dress)
|
||||
db.commit()
|
||||
db.refresh(db_dress)
|
||||
return db_dress
|
||||
except Exception as e:
|
||||
logger.error(f"创建服装记录失败: {str(e)}")
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"创建服装记录失败: {str(e)}")
|
||||
|
||||
@router.get("/", response_model=List[DressResponse])
|
||||
async def get_dresses(
|
||||
skip: int = Query(0, description="跳过的记录数量"),
|
||||
limit: int = Query(100, description="返回的最大记录数量"),
|
||||
name: Optional[str] = Query(None, description="按名称过滤"),
|
||||
garment_type: Optional[str] = Query(None, description="按服装类型过滤(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取服装列表,支持分页和过滤
|
||||
|
||||
- **skip**: 跳过的记录数量,用于分页
|
||||
- **limit**: 返回的最大记录数量,用于分页
|
||||
- **name**: 按名称过滤(可选)
|
||||
- **garment_type**: 按服装类型过滤(可选)
|
||||
"""
|
||||
try:
|
||||
query = db.query(Dress)
|
||||
|
||||
if name:
|
||||
query = query.filter(Dress.name.ilike(f"%{name}%"))
|
||||
|
||||
if garment_type:
|
||||
try:
|
||||
garment_type_enum = GarmentType[garment_type]
|
||||
query = query.filter(Dress.garment_type == garment_type_enum)
|
||||
except KeyError:
|
||||
logger.warning(f"无效的服装类型: {garment_type}")
|
||||
# 继续查询,但不应用无效的过滤条件
|
||||
|
||||
dresses = query.offset(skip).limit(limit).all()
|
||||
return dresses
|
||||
except Exception as e:
|
||||
logger.error(f"获取服装列表失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取服装列表失败: {str(e)}")
|
||||
|
||||
@router.get("/{dress_id}", response_model=DressResponse)
|
||||
async def get_dress(
|
||||
dress_id: int = Path(..., description="服装ID"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
根据ID获取服装详情
|
||||
|
||||
- **dress_id**: 服装ID
|
||||
"""
|
||||
try:
|
||||
dress = db.query(Dress).filter(Dress.id == dress_id).first()
|
||||
if not dress:
|
||||
raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装")
|
||||
return dress
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取服装详情失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取服装详情失败: {str(e)}")
|
||||
|
||||
@router.put("/{dress_id}", response_model=DressResponse)
|
||||
async def update_dress(
|
||||
dress_id: int = Path(..., description="服装ID"),
|
||||
dress: DressUpdate = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
更新服装信息
|
||||
|
||||
- **dress_id**: 服装ID
|
||||
- **name**: 服装名称(可选)
|
||||
- **image_url**: 服装图片URL(可选)
|
||||
- **garment_type**: 服装类型(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)(可选)
|
||||
- **description**: 服装描述(可选)
|
||||
"""
|
||||
try:
|
||||
db_dress = db.query(Dress).filter(Dress.id == dress_id).first()
|
||||
if not db_dress:
|
||||
raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装")
|
||||
|
||||
# 更新提供的字段
|
||||
if dress.name is not None:
|
||||
db_dress.name = dress.name
|
||||
if dress.image_url is not None:
|
||||
db_dress.image_url = dress.image_url
|
||||
if dress.garment_type is not None:
|
||||
db_dress.garment_type = dress.garment_type
|
||||
if dress.description is not None:
|
||||
db_dress.description = dress.description
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_dress)
|
||||
return db_dress
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新服装信息失败: {str(e)}")
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"更新服装信息失败: {str(e)}")
|
||||
|
||||
@router.delete("/{dress_id}", status_code=204)
|
||||
async def delete_dress(
|
||||
dress_id: int = Path(..., description="服装ID"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
删除服装
|
||||
|
||||
- **dress_id**: 服装ID
|
||||
"""
|
||||
try:
|
||||
db_dress = db.query(Dress).filter(Dress.id == dress_id).first()
|
||||
if not db_dress:
|
||||
raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装")
|
||||
|
||||
db.delete(db_dress)
|
||||
db.commit()
|
||||
return None
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除服装失败: {str(e)}")
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"删除服装失败: {str(e)}")
|
||||
155
app/routers/qcloud_router.py
Normal file
155
app/routers/qcloud_router.py
Normal file
@ -0,0 +1,155 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends, UploadFile, File, Form, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
from app.services.qcloud_service import QCloudCOSService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
class STSTokenResponse(BaseModel):
|
||||
"""STS临时凭证响应"""
|
||||
credentials: Dict[str, str]
|
||||
expiration: str
|
||||
request_id: str
|
||||
bucket: str
|
||||
region: str
|
||||
|
||||
class FileInfo(BaseModel):
|
||||
"""文件信息"""
|
||||
key: str
|
||||
url: str
|
||||
size: int
|
||||
last_modified: str
|
||||
etag: str
|
||||
|
||||
class ListFilesResponse(BaseModel):
|
||||
"""列出文件响应"""
|
||||
files: List[FileInfo]
|
||||
is_truncated: bool
|
||||
next_marker: Optional[str] = None
|
||||
common_prefixes: Optional[List[Dict[str, Any]]] = None
|
||||
|
||||
@router.post("/upload", tags=["腾讯云COS"])
|
||||
async def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
directory: str = Form("uploads"),
|
||||
qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService())
|
||||
):
|
||||
"""
|
||||
上传文件到腾讯云COS
|
||||
|
||||
- **file**: 要上传的文件
|
||||
- **directory**: 存储目录(默认为"uploads")
|
||||
"""
|
||||
try:
|
||||
# 读取文件内容
|
||||
file_content = await file.read()
|
||||
|
||||
# 上传文件
|
||||
result = await qcloud_service.upload_file(
|
||||
file_content=file_content,
|
||||
file_name=file.filename,
|
||||
directory=directory,
|
||||
content_type=file.content_type
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"文件上传失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"文件上传失败: {str(e)}")
|
||||
finally:
|
||||
# 关闭文件
|
||||
await file.close()
|
||||
|
||||
@router.delete("/files/{key:path}", tags=["腾讯云COS"])
|
||||
async def delete_file(
|
||||
key: str,
|
||||
qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService())
|
||||
):
|
||||
"""
|
||||
从腾讯云COS删除文件
|
||||
|
||||
- **key**: 文件的对象键(COS路径)
|
||||
"""
|
||||
try:
|
||||
success = await qcloud_service.delete_file(key=key)
|
||||
if success:
|
||||
return {"message": "文件删除成功", "key": key}
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=f"文件删除失败,文件可能不存在")
|
||||
except Exception as e:
|
||||
logger.error(f"文件删除失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"文件删除失败: {str(e)}")
|
||||
|
||||
@router.get("/files/url/{key:path}", tags=["腾讯云COS"])
|
||||
async def get_file_url(
|
||||
key: str,
|
||||
expires: int = Query(3600, description="URL有效期(秒)"),
|
||||
qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService())
|
||||
):
|
||||
"""
|
||||
获取COS文件的临时访问URL
|
||||
|
||||
- **key**: 文件的对象键(COS路径)
|
||||
- **expires**: URL的有效期(秒),默认3600秒
|
||||
"""
|
||||
try:
|
||||
url = await qcloud_service.get_file_url(key=key, expires=expires)
|
||||
return {"url": url, "key": key, "expires": expires}
|
||||
except Exception as e:
|
||||
logger.error(f"获取文件URL失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取文件URL失败: {str(e)}")
|
||||
|
||||
@router.post("/sts-token", response_model=STSTokenResponse, tags=["腾讯云COS"])
|
||||
async def generate_sts_token(
|
||||
allow_prefix: str = Form("*", description="允许操作的对象前缀"),
|
||||
duration_seconds: int = Form(1800, description="凭证有效期(秒)"),
|
||||
qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService())
|
||||
):
|
||||
"""
|
||||
生成腾讯云COS临时访问凭证(STS),用于前端直传
|
||||
|
||||
- **allow_prefix**: 允许操作的对象前缀,默认为'*'表示所有对象
|
||||
- **duration_seconds**: 凭证有效期(秒),默认1800秒
|
||||
"""
|
||||
try:
|
||||
sts_result = await qcloud_service.generate_cos_sts_token(
|
||||
allow_prefix=allow_prefix,
|
||||
duration_seconds=duration_seconds
|
||||
)
|
||||
# 添加桶和区域信息,方便前端使用
|
||||
sts_result['bucket'] = qcloud_service.bucket
|
||||
sts_result['region'] = qcloud_service.region
|
||||
|
||||
return sts_result
|
||||
except Exception as e:
|
||||
logger.error(f"生成STS临时凭证失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"生成STS临时凭证失败: {str(e)}")
|
||||
|
||||
@router.get("/files", response_model=ListFilesResponse, tags=["腾讯云COS"])
|
||||
async def list_files(
|
||||
directory: str = Query("", description="目录前缀"),
|
||||
limit: int = Query(100, description="返回的最大文件数"),
|
||||
marker: str = Query("", description="分页标记"),
|
||||
qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService())
|
||||
):
|
||||
"""
|
||||
列出COS中的文件
|
||||
|
||||
- **directory**: 目录前缀
|
||||
- **limit**: 返回的最大文件数
|
||||
- **marker**: 分页标记
|
||||
"""
|
||||
try:
|
||||
result = await qcloud_service.list_files(
|
||||
directory=directory,
|
||||
limit=limit,
|
||||
marker=marker
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"列出文件失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"列出文件失败: {str(e)}")
|
||||
225
app/routers/tryon_router.py
Normal file
225
app/routers/tryon_router.py
Normal file
@ -0,0 +1,225 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, Path, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import httpx
|
||||
import logging
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.tryon import TryOn
|
||||
from app.schemas.tryon import (
|
||||
TryOnCreate, TryOnUpdate, TryOnResponse,
|
||||
AiTryonRequest, AiTryonResponse, TaskInfo
|
||||
)
|
||||
from app.services.dashscope_service import DashScopeService
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
# 从环境变量获取API密钥
|
||||
DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
|
||||
DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis"
|
||||
|
||||
@router.post("/", response_model=TryOnResponse, status_code=201)
|
||||
async def create_tryon(
|
||||
tryon_data: TryOnCreate,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建一个试穿记录并发送到阿里百炼平台
|
||||
|
||||
- **top_garment_url**: 上衣图片URL(可选)
|
||||
- **bottom_garment_url**: 下衣图片URL(可选)
|
||||
- **person_image_url**: 人物图片URL(必填)
|
||||
|
||||
注意:上衣和下衣至少需要提供一个
|
||||
"""
|
||||
if not tryon_data.top_garment_url and not tryon_data.bottom_garment_url:
|
||||
raise HTTPException(status_code=400, detail="上衣和下衣图片至少需要提供一个")
|
||||
|
||||
try:
|
||||
# 创建试穿记录
|
||||
db_tryon = TryOn(
|
||||
top_garment_url=tryon_data.top_garment_url,
|
||||
bottom_garment_url=tryon_data.bottom_garment_url,
|
||||
person_image_url=tryon_data.person_image_url
|
||||
)
|
||||
db.add(db_tryon)
|
||||
db.commit()
|
||||
db.refresh(db_tryon)
|
||||
|
||||
# 在后台发送请求到阿里百炼平台
|
||||
background_tasks.add_task(
|
||||
send_tryon_request,
|
||||
db=db,
|
||||
tryon_id=db_tryon.id,
|
||||
top_garment_url=tryon_data.top_garment_url,
|
||||
bottom_garment_url=tryon_data.bottom_garment_url,
|
||||
person_image_url=tryon_data.person_image_url
|
||||
)
|
||||
|
||||
return db_tryon
|
||||
except Exception as e:
|
||||
logger.error(f"创建试穿记录失败: {str(e)}")
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"创建试穿记录失败: {str(e)}")
|
||||
|
||||
async def send_tryon_request(
|
||||
db: Session,
|
||||
tryon_id: int,
|
||||
top_garment_url: Optional[str],
|
||||
bottom_garment_url: Optional[str],
|
||||
person_image_url: str
|
||||
):
|
||||
"""发送试穿请求到阿里百炼平台"""
|
||||
try:
|
||||
# 创建DashScopeService实例
|
||||
dashscope_service = DashScopeService()
|
||||
|
||||
# 调用服务发送试穿请求
|
||||
response = await dashscope_service.generate_tryon(
|
||||
person_image_url=person_image_url,
|
||||
top_garment_url=top_garment_url,
|
||||
bottom_garment_url=bottom_garment_url
|
||||
)
|
||||
|
||||
# 更新数据库记录
|
||||
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
|
||||
if db_tryon:
|
||||
db_tryon.request_id = response.get("request_id")
|
||||
db_tryon.task_id = response.get("output", {}).get("task_id")
|
||||
db_tryon.task_status = response.get("output", {}).get("task_status")
|
||||
db.commit()
|
||||
|
||||
logger.info(f"试穿请求发送成功,任务ID: {db_tryon.task_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"发送试穿请求异常: {str(e)}")
|
||||
|
||||
# 更新数据库记录状态
|
||||
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
|
||||
if db_tryon:
|
||||
db_tryon.task_status = "ERROR"
|
||||
db.commit()
|
||||
|
||||
@router.get("/", response_model=List[TryOnResponse])
|
||||
async def get_tryons(
|
||||
skip: int = Query(0, description="跳过的记录数量"),
|
||||
limit: int = Query(100, description="返回的最大记录数量"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取试穿记录列表,支持分页
|
||||
|
||||
- **skip**: 跳过的记录数量,用于分页
|
||||
- **limit**: 返回的最大记录数量,用于分页
|
||||
"""
|
||||
try:
|
||||
tryons = db.query(TryOn).order_by(TryOn.created_at.desc()).offset(skip).limit(limit).all()
|
||||
return tryons
|
||||
except Exception as e:
|
||||
logger.error(f"获取试穿记录列表失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取试穿记录列表失败: {str(e)}")
|
||||
|
||||
@router.get("/{tryon_id}", response_model=TryOnResponse)
|
||||
async def get_tryon(
|
||||
tryon_id: int = Path(..., description="试穿记录ID"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
根据ID获取试穿记录详情
|
||||
|
||||
- **tryon_id**: 试穿记录ID
|
||||
"""
|
||||
try:
|
||||
tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
|
||||
if not tryon:
|
||||
raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录")
|
||||
return tryon
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取试穿记录详情失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"获取试穿记录详情失败: {str(e)}")
|
||||
|
||||
@router.put("/{tryon_id}", response_model=TryOnResponse)
|
||||
async def update_tryon(
|
||||
tryon_id: int = Path(..., description="试穿记录ID"),
|
||||
tryon_data: TryOnUpdate = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
更新试穿记录信息
|
||||
|
||||
- **tryon_id**: 试穿记录ID
|
||||
- **task_status**: 任务状态(可选)
|
||||
- **completion_url**: 生成图片URL(可选)
|
||||
"""
|
||||
try:
|
||||
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
|
||||
if not db_tryon:
|
||||
raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录")
|
||||
|
||||
# 更新提供的字段
|
||||
if tryon_data.task_status is not None:
|
||||
db_tryon.task_status = tryon_data.task_status
|
||||
if tryon_data.completion_url is not None:
|
||||
db_tryon.completion_url = tryon_data.completion_url
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_tryon)
|
||||
return db_tryon
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新试穿记录失败: {str(e)}")
|
||||
db.rollback()
|
||||
raise HTTPException(status_code=500, detail=f"更新试穿记录失败: {str(e)}")
|
||||
|
||||
@router.post("/{tryon_id}/check", response_model=TryOnResponse)
|
||||
async def check_tryon_status(
|
||||
tryon_id: int = Path(..., description="试穿记录ID"),
|
||||
db: Session = Depends(get_db),
|
||||
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||||
):
|
||||
"""
|
||||
检查试穿任务状态
|
||||
|
||||
- **tryon_id**: 试穿记录ID
|
||||
"""
|
||||
try:
|
||||
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
|
||||
if not db_tryon:
|
||||
raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录")
|
||||
|
||||
if not db_tryon.task_id:
|
||||
raise HTTPException(status_code=400, detail=f"试穿记录未包含任务ID")
|
||||
|
||||
# 调用DashScopeService检查任务状态
|
||||
try:
|
||||
status_response = await dashscope_service.check_tryon_status(db_tryon.task_id)
|
||||
|
||||
# 更新数据库记录
|
||||
db_tryon.task_status = status_response.get("output", {}).get("task_status")
|
||||
|
||||
# 如果任务完成,保存结果URL
|
||||
if db_tryon.task_status == "SUCCEEDED":
|
||||
db_tryon.completion_url = status_response.get("output", {}).get("url")
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_tryon)
|
||||
|
||||
logger.info(f"试穿任务状态更新: {db_tryon.task_status}")
|
||||
except Exception as e:
|
||||
logger.error(f"调用DashScope API检查任务状态失败: {str(e)}")
|
||||
|
||||
return db_tryon
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"检查试穿任务状态失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"检查试穿任务状态失败: {str(e)}")
|
||||
1
app/schemas/__init__.py
Normal file
1
app/schemas/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
36
app/schemas/dress.py
Normal file
36
app/schemas/dress.py
Normal file
@ -0,0 +1,36 @@
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from app.models.dress import GarmentType
|
||||
|
||||
class DressBase(BaseModel):
|
||||
"""服装基础模型"""
|
||||
name: str = Field(..., description="服装名称", example="夏季连衣裙")
|
||||
image_url: Optional[str] = Field(None, description="服装图片URL", example="https://example.com/dress1.jpg")
|
||||
garment_type: Optional[GarmentType] = Field(None, description="服装类型(上衣/下衣)", example="TOP_GARMENT")
|
||||
description: Optional[str] = Field(None, description="服装描述", example="一款适合夏季穿着的轻薄连衣裙")
|
||||
|
||||
class DressCreate(DressBase):
|
||||
"""创建服装的请求模型"""
|
||||
pass
|
||||
|
||||
class DressUpdate(BaseModel):
|
||||
"""更新服装的请求模型"""
|
||||
name: Optional[str] = Field(None, description="服装名称", example="夏季连衣裙")
|
||||
image_url: Optional[str] = Field(None, description="服装图片URL", example="https://example.com/dress1.jpg")
|
||||
garment_type: Optional[GarmentType] = Field(None, description="服装类型(上衣/下衣)", example="TOP_GARMENT")
|
||||
description: Optional[str] = Field(None, description="服装描述", example="一款适合夏季穿着的轻薄连衣裙")
|
||||
|
||||
class DressInDB(DressBase):
|
||||
"""数据库中的服装模型"""
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
class DressResponse(DressInDB):
|
||||
"""服装API响应模型"""
|
||||
pass
|
||||
54
app/schemas/tryon.py
Normal file
54
app/schemas/tryon.py
Normal file
@ -0,0 +1,54 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
class TryOnBase(BaseModel):
|
||||
"""试穿基础模型"""
|
||||
top_garment_url: Optional[str] = Field(None, description="上衣图片URL", example="https://example.com/top.jpg")
|
||||
bottom_garment_url: Optional[str] = Field(None, description="下衣图片URL", example="https://example.com/bottom.jpg")
|
||||
person_image_url: str = Field(..., description="人物图片URL", example="https://example.com/person.jpg")
|
||||
|
||||
class TryOnCreate(TryOnBase):
|
||||
"""创建试穿记录的请求模型"""
|
||||
pass
|
||||
|
||||
class TryOnUpdate(BaseModel):
|
||||
"""更新试穿记录的请求模型"""
|
||||
task_status: Optional[str] = Field(None, description="任务状态", example="SUCCEEDED")
|
||||
completion_url: Optional[str] = Field(None, description="生成图片URL", example="https://example.com/result.jpg")
|
||||
|
||||
class TryOnInDB(TryOnBase):
|
||||
"""数据库中的试穿记录模型"""
|
||||
id: int
|
||||
request_id: Optional[str] = None
|
||||
task_id: Optional[str] = None
|
||||
task_status: Optional[str] = None
|
||||
completion_url: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
orm_mode = True
|
||||
|
||||
class TryOnResponse(TryOnInDB):
|
||||
"""试穿记录API响应模型"""
|
||||
pass
|
||||
|
||||
class AiTryonRequest(BaseModel):
|
||||
"""阿里百炼平台试穿请求模型"""
|
||||
model: str = "aitryon"
|
||||
input: TryOnBase
|
||||
parameters: dict = {
|
||||
"resolution": -1,
|
||||
"restore_face": True
|
||||
}
|
||||
|
||||
class TaskInfo(BaseModel):
|
||||
"""任务信息模型"""
|
||||
task_id: str
|
||||
task_status: str
|
||||
|
||||
class AiTryonResponse(BaseModel):
|
||||
"""阿里百炼平台试穿响应模型"""
|
||||
output: TaskInfo
|
||||
request_id: str
|
||||
1
app/services/__init__.py
Normal file
1
app/services/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
230
app/services/dashscope_service.py
Normal file
230
app/services/dashscope_service.py
Normal file
@ -0,0 +1,230 @@
|
||||
import os
|
||||
import logging
|
||||
import dashscope
|
||||
from dashscope import Generation
|
||||
# 修改导入语句,dashscope的API响应可能改变了结构
|
||||
from typing import List, Dict, Any, Optional
|
||||
import asyncio
|
||||
import httpx
|
||||
from app.utils.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DashScopeService:
|
||||
"""DashScope服务类,提供对DashScope API的调用封装"""
|
||||
|
||||
def __init__(self):
|
||||
settings = get_settings()
|
||||
self.api_key = settings.dashscope_api_key
|
||||
# 配置DashScope
|
||||
dashscope.api_key = self.api_key
|
||||
# 配置API URL
|
||||
self.image_synthesis_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis"
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str = "qwen-max",
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False
|
||||
):
|
||||
"""
|
||||
调用DashScope的大模型API进行对话
|
||||
|
||||
Args:
|
||||
messages: 对话历史记录
|
||||
model: 模型名称
|
||||
max_tokens: 最大生成token数
|
||||
temperature: 温度参数,控制随机性
|
||||
stream: 是否流式输出
|
||||
|
||||
Returns:
|
||||
ApiResponse: DashScope的API响应
|
||||
"""
|
||||
try:
|
||||
# 为了不阻塞FastAPI的异步性能,我们使用run_in_executor运行同步API
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: Generation.call(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
result_format='message',
|
||||
stream=stream,
|
||||
)
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"DashScope API请求失败,状态码:{response.status_code}, 错误信息:{response.message}")
|
||||
raise Exception(f"API调用失败: {response.message}")
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope聊天API调用出错: {str(e)}")
|
||||
raise e
|
||||
|
||||
async def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
model: str = "stable-diffusion-xl",
|
||||
n: int = 1,
|
||||
size: str = "1024*1024"
|
||||
):
|
||||
"""
|
||||
调用DashScope的图像生成API
|
||||
|
||||
Args:
|
||||
prompt: 生成图像的文本描述
|
||||
negative_prompt: 负面提示词
|
||||
model: 模型名称
|
||||
n: 生成图像数量
|
||||
size: 图像尺寸
|
||||
|
||||
Returns:
|
||||
ApiResponse: DashScope的API响应
|
||||
"""
|
||||
try:
|
||||
# 构建请求参数
|
||||
params = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"n": n,
|
||||
"size": size,
|
||||
}
|
||||
|
||||
if negative_prompt:
|
||||
params["negative_prompt"] = negative_prompt
|
||||
|
||||
# 异步调用图像生成API
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: dashscope.ImageSynthesis.call(**params)
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"DashScope 图像生成API请求失败,状态码:{response.status_code}, 错误信息:{response.message}")
|
||||
raise Exception(f"图像生成API调用失败: {response.message}")
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope图像生成API调用出错: {str(e)}")
|
||||
raise e
|
||||
|
||||
async def generate_tryon(
|
||||
self,
|
||||
person_image_url: str,
|
||||
top_garment_url: Optional[str] = None,
|
||||
bottom_garment_url: Optional[str] = None,
|
||||
resolution: int = -1,
|
||||
restore_face: bool = True
|
||||
):
|
||||
"""
|
||||
调用阿里百炼平台的试衣服务
|
||||
|
||||
Args:
|
||||
person_image_url: 人物图片URL
|
||||
top_garment_url: 上衣图片URL
|
||||
bottom_garment_url: 下衣图片URL
|
||||
resolution: 分辨率,-1表示自动
|
||||
restore_face: 是否修复面部
|
||||
|
||||
Returns:
|
||||
Dict: 包含任务ID和请求ID的响应
|
||||
"""
|
||||
try:
|
||||
# 验证参数
|
||||
if not person_image_url:
|
||||
raise ValueError("人物图片URL不能为空")
|
||||
|
||||
if not top_garment_url and not bottom_garment_url:
|
||||
raise ValueError("上衣和下衣图片至少需要提供一个")
|
||||
|
||||
# 构建请求数据
|
||||
request_data = {
|
||||
"model": "aitryon",
|
||||
"input": {
|
||||
"person_image_url": person_image_url
|
||||
},
|
||||
"parameters": {
|
||||
"resolution": resolution,
|
||||
"restore_face": restore_face
|
||||
}
|
||||
}
|
||||
|
||||
# 添加可选字段
|
||||
if top_garment_url:
|
||||
request_data["input"]["top_garment_url"] = top_garment_url
|
||||
if bottom_garment_url:
|
||||
request_data["input"]["bottom_garment_url"] = bottom_garment_url
|
||||
|
||||
# 构建请求头
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.image_synthesis_url,
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
logger.info(f"试穿请求发送成功,任务ID: {response_data.get('output', {}).get('task_id')}")
|
||||
return response_data
|
||||
else:
|
||||
error_msg = f"试穿请求失败: {response.status_code} - {response.text}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope试穿API调用出错: {str(e)}")
|
||||
raise e
|
||||
|
||||
async def check_tryon_status(self, task_id: str):
|
||||
"""
|
||||
检查试穿任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
Dict: 任务状态信息
|
||||
"""
|
||||
try:
|
||||
# 构建请求头
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 构建请求URL
|
||||
status_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
|
||||
|
||||
# 发送请求
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
status_url,
|
||||
headers=headers,
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
logger.info(f"试穿任务状态查询成功: {response_data.get('output', {}).get('task_status')}")
|
||||
return response_data
|
||||
else:
|
||||
error_msg = f"试穿任务状态查询失败: {response.status_code} - {response.text}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
except Exception as e:
|
||||
logger.error(f"查询试穿任务状态出错: {str(e)}")
|
||||
raise e
|
||||
259
app/services/qcloud_service.py
Normal file
259
app/services/qcloud_service.py
Normal file
@ -0,0 +1,259 @@
|
||||
import os
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
import base64
|
||||
import hashlib
|
||||
from typing import Optional, List, Dict, Any, BinaryIO, Union
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from qcloud_cos import CosConfig, CosS3Client
|
||||
from qcloud_cos.cos_exception import CosServiceError, CosClientError
|
||||
import sts.sts
|
||||
|
||||
from app.utils.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class QCloudCOSService:
|
||||
"""腾讯云对象存储(COS)服务类"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化腾讯云COS服务"""
|
||||
settings = get_settings()
|
||||
self.secret_id = settings.qcloud_secret_id
|
||||
self.secret_key = settings.qcloud_secret_key
|
||||
self.region = settings.qcloud_cos_region
|
||||
self.bucket = settings.qcloud_cos_bucket
|
||||
self.domain = settings.qcloud_cos_domain
|
||||
|
||||
# 创建COS客户端
|
||||
config = CosConfig(
|
||||
Region=self.region,
|
||||
SecretId=self.secret_id,
|
||||
SecretKey=self.secret_key
|
||||
)
|
||||
self.client = CosS3Client(config)
|
||||
|
||||
async def upload_file(
|
||||
self,
|
||||
file_content: Union[bytes, BinaryIO],
|
||||
file_name: Optional[str] = None,
|
||||
directory: str = "uploads",
|
||||
content_type: Optional[str] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
上传文件到腾讯云COS
|
||||
|
||||
Args:
|
||||
file_content: 文件内容(bytes或文件对象)
|
||||
file_name: 文件名,如果不提供则生成随机文件名
|
||||
directory: 存储目录
|
||||
content_type: 文件MIME类型
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: 包含文件URL等信息的字典
|
||||
"""
|
||||
try:
|
||||
# 生成文件名(如果未提供)
|
||||
if not file_name:
|
||||
# 使用UUID生成随机文件名
|
||||
random_filename = str(uuid.uuid4())
|
||||
if isinstance(file_content, bytes):
|
||||
# 对于字节流,我们无法确定文件扩展名,默认用.bin
|
||||
file_name = f"{random_filename}.bin"
|
||||
else:
|
||||
# 假设file_content是文件对象,尝试从其name属性获取扩展名
|
||||
if hasattr(file_content, 'name'):
|
||||
original_name = os.path.basename(file_content.name)
|
||||
ext = os.path.splitext(original_name)[1]
|
||||
file_name = f"{random_filename}{ext}"
|
||||
else:
|
||||
file_name = f"{random_filename}.bin"
|
||||
|
||||
# 构建对象键(文件在COS中的完整路径)
|
||||
key = f"{directory}/{file_name}"
|
||||
|
||||
# 上传文件
|
||||
response = self.client.put_object(
|
||||
Bucket=self.bucket,
|
||||
Body=file_content,
|
||||
Key=key,
|
||||
ContentType=content_type
|
||||
)
|
||||
|
||||
# 构建文件URL
|
||||
file_url = urljoin(self.domain, key)
|
||||
|
||||
return {
|
||||
"url": file_url,
|
||||
"key": key,
|
||||
"file_name": file_name,
|
||||
"content_type": content_type,
|
||||
"etag": response["ETag"] if "ETag" in response else None
|
||||
}
|
||||
|
||||
except (CosServiceError, CosClientError) as e:
|
||||
logger.error(f"腾讯云COS上传文件失败: {str(e)}")
|
||||
raise Exception(f"文件上传失败: {str(e)}")
|
||||
|
||||
async def delete_file(self, key: str) -> bool:
|
||||
"""
|
||||
从腾讯云COS删除文件
|
||||
|
||||
Args:
|
||||
key: 文件的对象键(COS路径)
|
||||
|
||||
Returns:
|
||||
bool: 删除是否成功
|
||||
"""
|
||||
try:
|
||||
self.client.delete_object(
|
||||
Bucket=self.bucket,
|
||||
Key=key
|
||||
)
|
||||
return True
|
||||
except (CosServiceError, CosClientError) as e:
|
||||
logger.error(f"腾讯云COS删除文件失败: {str(e)}")
|
||||
return False
|
||||
|
||||
async def get_file_url(self, key: str, expires: int = 3600) -> str:
|
||||
"""
|
||||
获取COS文件的临时访问URL
|
||||
|
||||
Args:
|
||||
key: 文件的对象键(COS路径)
|
||||
expires: URL的有效期(秒)
|
||||
|
||||
Returns:
|
||||
str: 临时访问URL
|
||||
"""
|
||||
try:
|
||||
url = self.client.get_presigned_url(
|
||||
Method='GET',
|
||||
Bucket=self.bucket,
|
||||
Key=key,
|
||||
Expired=expires
|
||||
)
|
||||
return url
|
||||
except (CosServiceError, CosClientError) as e:
|
||||
logger.error(f"获取腾讯云COS文件URL失败: {str(e)}")
|
||||
raise Exception(f"获取文件URL失败: {str(e)}")
|
||||
|
||||
async def generate_cos_sts_token(
|
||||
self,
|
||||
allow_actions: Optional[List[str]] = None,
|
||||
allow_prefix: str = "*",
|
||||
duration_seconds: int = 1800
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成COS的临时安全凭证(STS),用于前端直传
|
||||
|
||||
Args:
|
||||
allow_actions: 允许的COS操作列表
|
||||
allow_prefix: 允许操作的对象前缀
|
||||
duration_seconds: 凭证有效期(秒)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: STS凭证信息
|
||||
"""
|
||||
try:
|
||||
if allow_actions is None:
|
||||
# 默认只允许上传操作
|
||||
allow_actions = [
|
||||
'name/cos:PutObject',
|
||||
'name/cos:PostObject',
|
||||
'name/cos:InitiateMultipartUpload',
|
||||
'name/cos:ListMultipartUploads',
|
||||
'name/cos:ListParts',
|
||||
'name/cos:UploadPart',
|
||||
'name/cos:CompleteMultipartUpload'
|
||||
]
|
||||
|
||||
# 配置STS
|
||||
config = {
|
||||
'url': 'https://sts.tencentcloudapi.com/',
|
||||
'domain': 'sts.tencentcloudapi.com',
|
||||
'duration_seconds': duration_seconds,
|
||||
'secret_id': self.secret_id,
|
||||
'secret_key': self.secret_key,
|
||||
'region': self.region,
|
||||
'policy': {
|
||||
'version': '2.0',
|
||||
'statement': [
|
||||
{
|
||||
'action': allow_actions,
|
||||
'effect': 'allow',
|
||||
'resource': [
|
||||
f'qcs::cos:{self.region}:uid/:{self.bucket}/{allow_prefix}'
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
sts_client = sts.sts.Sts(config)
|
||||
response = sts_client.get_credential()
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成腾讯云COS STS凭证失败: {str(e)}")
|
||||
raise Exception(f"生成COS临时凭证失败: {str(e)}")
|
||||
|
||||
async def list_files(
|
||||
self,
|
||||
directory: str = "",
|
||||
limit: int = 100,
|
||||
marker: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
列出COS中的文件
|
||||
|
||||
Args:
|
||||
directory: 目录前缀
|
||||
limit: 返回的最大文件数
|
||||
marker: 分页标记
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 文件列表信息
|
||||
"""
|
||||
try:
|
||||
# 处理目录前缀
|
||||
prefix = directory
|
||||
if prefix and not prefix.endswith('/'):
|
||||
prefix += '/'
|
||||
|
||||
# 调用COS API列出对象
|
||||
response = self.client.list_objects(
|
||||
Bucket=self.bucket,
|
||||
Prefix=prefix,
|
||||
Marker=marker,
|
||||
MaxKeys=limit
|
||||
)
|
||||
|
||||
# 处理响应
|
||||
files = []
|
||||
if 'Contents' in response:
|
||||
for item in response['Contents']:
|
||||
key = item.get('Key', '')
|
||||
# 过滤出文件(忽略目录)
|
||||
if not key.endswith('/'):
|
||||
file_url = urljoin(self.domain, key)
|
||||
files.append({
|
||||
'key': key,
|
||||
'url': file_url,
|
||||
'size': item.get('Size', 0),
|
||||
'last_modified': item.get('LastModified', ''),
|
||||
'etag': item.get('ETag', '').strip('"')
|
||||
})
|
||||
|
||||
return {
|
||||
'files': files,
|
||||
'is_truncated': response.get('IsTruncated', False),
|
||||
'next_marker': response.get('NextMarker', ''),
|
||||
'common_prefixes': response.get('CommonPrefixes', [])
|
||||
}
|
||||
|
||||
except (CosServiceError, CosClientError) as e:
|
||||
logger.error(f"列出腾讯云COS文件失败: {str(e)}")
|
||||
raise Exception(f"列出文件失败: {str(e)}")
|
||||
1
app/utils/__init__.py
Normal file
1
app/utils/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
|
||||
83
app/utils/config.py
Normal file
83
app/utils/config.py
Normal file
@ -0,0 +1,83 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量(如果直接从main导入,这里可能是冗余的,但为了安全起见)
|
||||
load_dotenv()
|
||||
|
||||
class Settings(BaseModel):
|
||||
"""应用程序配置类"""
|
||||
# DashScope配置
|
||||
dashscope_api_key: str = os.getenv("DASHSCOPE_API_KEY", "sk-caa199589f1c451aaac471fad2986e28")
|
||||
|
||||
# 服务器配置
|
||||
host: str = os.getenv("HOST", "0.0.0.0")
|
||||
port: int = int(os.getenv("PORT", "9001"))
|
||||
debug: bool = os.getenv("DEBUG", "False").lower() in ["true", "1", "t", "yes"]
|
||||
|
||||
# 腾讯云配置
|
||||
qcloud_secret_id: str = os.getenv("QCLOUD_SECRET_ID", "")
|
||||
qcloud_secret_key: str = os.getenv("QCLOUD_SECRET_KEY", "")
|
||||
qcloud_cos_region: str = os.getenv("QCLOUD_COS_REGION", "ap-guangzhou")
|
||||
qcloud_cos_bucket: str = os.getenv("QCLOUD_COS_BUCKET", "")
|
||||
qcloud_cos_domain: str = os.getenv("QCLOUD_COS_DOMAIN", "")
|
||||
|
||||
# 数据库配置
|
||||
db_host: str = os.getenv("DB_HOST", "localhost")
|
||||
db_port: int = int(os.getenv("DB_PORT", "3306"))
|
||||
db_user: str = os.getenv("DB_USER", "root")
|
||||
db_password: str = os.getenv("DB_PASSWORD", "password")
|
||||
db_name: str = os.getenv("DB_NAME", "ai_dressing")
|
||||
|
||||
# 数据库URL
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
"""
|
||||
构建SQLAlchemy数据库连接URL
|
||||
|
||||
Returns:
|
||||
str: 数据库连接URL
|
||||
"""
|
||||
return f"mysql+pymysql://{self.db_user}:{self.db_password}@{self.db_host}:{self.db_port}/{self.db_name}"
|
||||
|
||||
# 其他配置可以根据需要添加
|
||||
log_level: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
"""
|
||||
获取应用程序配置
|
||||
使用lru_cache装饰器避免重复读取环境变量,提高性能
|
||||
|
||||
Returns:
|
||||
Settings: 配置对象
|
||||
"""
|
||||
return Settings()
|
||||
|
||||
def validate_api_key():
|
||||
"""
|
||||
验证API密钥是否已配置
|
||||
|
||||
Raises:
|
||||
ValueError: 如果API密钥未配置
|
||||
"""
|
||||
settings = get_settings()
|
||||
if not settings.dashscope_api_key:
|
||||
raise ValueError("DASHSCOPE_API_KEY未设置,请在.env文件中配置")
|
||||
return True
|
||||
|
||||
def validate_qcloud_config():
|
||||
"""
|
||||
验证腾讯云配置是否已设置
|
||||
|
||||
Raises:
|
||||
ValueError: 如果腾讯云配置未正确设置
|
||||
"""
|
||||
settings = get_settings()
|
||||
if not settings.qcloud_secret_id or not settings.qcloud_secret_key:
|
||||
raise ValueError("腾讯云SecretId和SecretKey未设置,请在.env文件中配置")
|
||||
if not settings.qcloud_cos_bucket:
|
||||
raise ValueError("腾讯云COS存储桶名称未设置,请在.env文件中配置")
|
||||
return True
|
||||
60
create_migration.py
Normal file
60
create_migration.py
Normal file
@ -0,0 +1,60 @@
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
from alembic.config import Config
|
||||
from alembic import command
|
||||
|
||||
def create_migration(message=None):
|
||||
"""
|
||||
创建数据库迁移
|
||||
|
||||
Args:
|
||||
message: 迁移消息,描述此次迁移的内容
|
||||
"""
|
||||
# 获取Alembic配置
|
||||
alembic_cfg = Config("alembic.ini")
|
||||
|
||||
# 创建迁移
|
||||
if message:
|
||||
command.revision(alembic_cfg, message=message, autogenerate=True)
|
||||
else:
|
||||
command.revision(alembic_cfg, message="自动生成的迁移", autogenerate=True)
|
||||
|
||||
def upgrade_database():
|
||||
"""更新数据库到最新版本"""
|
||||
alembic_cfg = Config("alembic.ini")
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
def downgrade_database(revision):
|
||||
"""回滚数据库到指定版本"""
|
||||
alembic_cfg = Config("alembic.ini")
|
||||
command.downgrade(alembic_cfg, revision)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="数据库迁移工具")
|
||||
subparsers = parser.add_subparsers(dest="command", help="命令")
|
||||
|
||||
# 创建迁移命令
|
||||
create_parser = subparsers.add_parser("create", help="创建数据库迁移")
|
||||
create_parser.add_argument("-m", "--message", help="迁移描述消息")
|
||||
|
||||
# 更新数据库命令
|
||||
upgrade_parser = subparsers.add_parser("upgrade", help="更新数据库到最新版本")
|
||||
|
||||
# 回滚数据库命令
|
||||
downgrade_parser = subparsers.add_parser("downgrade", help="回滚数据库到指定版本")
|
||||
downgrade_parser.add_argument("revision", help="目标版本,例如:-1表示回滚一个版本")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.command == "create":
|
||||
create_migration(args.message)
|
||||
print("迁移文件已创建,请查看 alembic/versions 目录")
|
||||
elif args.command == "upgrade":
|
||||
upgrade_database()
|
||||
print("数据库已更新到最新版本")
|
||||
elif args.command == "downgrade":
|
||||
downgrade_database(args.revision)
|
||||
print(f"数据库已回滚到版本: {args.revision}")
|
||||
else:
|
||||
parser.print_help()
|
||||
12
requirements.txt
Normal file
12
requirements.txt
Normal file
@ -0,0 +1,12 @@
|
||||
fastapi==0.104.0
|
||||
uvicorn==0.23.2
|
||||
dashscope>=1.13.0
|
||||
python-dotenv==1.0.0
|
||||
pydantic==2.4.2
|
||||
httpx==0.25.0
|
||||
cos-python-sdk-v5==1.9.26
|
||||
qcloud-python-sts==3.1.4
|
||||
sqlalchemy==2.0.23
|
||||
pymysql==1.1.0
|
||||
cryptography==41.0.5
|
||||
alembic==1.12.1
|
||||
105
run.py
Normal file
105
run.py
Normal file
@ -0,0 +1,105 @@
|
||||
import uvicorn
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
from app.utils.config import get_settings, validate_api_key, validate_qcloud_config
|
||||
from app.database import engine
|
||||
|
||||
def check_dependencies():
|
||||
"""检查依赖项是否正确安装"""
|
||||
try:
|
||||
import dashscope
|
||||
print(f"已加载 DashScope 版本: {getattr(dashscope, '__version__', '未知')}")
|
||||
except ImportError as e:
|
||||
print(f"错误: 无法导入 dashscope 模块: {str(e)}")
|
||||
print("请确保已正确安装 dashscope: pip install -U dashscope")
|
||||
return False
|
||||
|
||||
try:
|
||||
from qcloud_cos import CosConfig
|
||||
print("已加载腾讯云 COS SDK")
|
||||
except ImportError as e:
|
||||
print(f"警告: 无法导入腾讯云 COS SDK: {str(e)}")
|
||||
print("腾讯云 COS 相关功能可能无法使用")
|
||||
print("请确保已正确安装: pip install -U cos-python-sdk-v5 qcloud-python-sts")
|
||||
|
||||
try:
|
||||
import sqlalchemy
|
||||
print(f"已加载 SQLAlchemy 版本: {sqlalchemy.__version__}")
|
||||
import pymysql
|
||||
print(f"已加载 PyMySQL 版本: {pymysql.__version__}")
|
||||
except ImportError as e:
|
||||
print(f"错误: 无法导入数据库模块: {str(e)}")
|
||||
print("请确保已正确安装: pip install -U sqlalchemy pymysql")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def check_database_connection():
|
||||
"""检查数据库连接"""
|
||||
try:
|
||||
settings = get_settings()
|
||||
|
||||
# 尝试连接数据库
|
||||
with engine.connect() as connection:
|
||||
print(f"成功连接到数据库: {settings.db_name}")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"警告: 无法连接到数据库: {str(e)}")
|
||||
print("请确保MySQL服务已启动,并且数据库配置正确")
|
||||
print("您可以使用以下命令创建数据库:")
|
||||
settings = get_settings()
|
||||
print(f"CREATE DATABASE {settings.db_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;")
|
||||
print(f"GRANT ALL PRIVILEGES ON {settings.db_name}.* TO '{settings.db_user}'@'%' IDENTIFIED BY '{settings.db_password}';")
|
||||
print("FLUSH PRIVILEGES;")
|
||||
|
||||
# 询问是否继续
|
||||
if input("是否继续启动应用程序?(y/n): ").lower() not in ["y", "yes"]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
# 检查依赖项
|
||||
if not check_dependencies():
|
||||
sys.exit(1)
|
||||
|
||||
# 验证API密钥是否配置
|
||||
validate_api_key()
|
||||
|
||||
# 验证腾讯云配置
|
||||
try:
|
||||
validate_qcloud_config()
|
||||
print("腾讯云COS配置已验证")
|
||||
except ValueError as e:
|
||||
print(f"警告: {str(e)}")
|
||||
print("腾讯云COS相关功能可能无法正常使用,但应用程序将继续启动")
|
||||
|
||||
# 检查数据库连接
|
||||
if not check_database_connection():
|
||||
sys.exit(1)
|
||||
|
||||
# 获取配置
|
||||
settings = get_settings()
|
||||
|
||||
# 配置日志级别
|
||||
log_level = getattr(logging, settings.log_level.upper(), logging.INFO)
|
||||
logging.basicConfig(level=log_level)
|
||||
|
||||
# 启动服务器
|
||||
print(f"启动服务,访问地址: http://{settings.host}:{settings.port}")
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
reload=settings.debug,
|
||||
log_level=settings.log_level.lower()
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"启动失败: {str(e)}")
|
||||
if hasattr(e, "__traceback__"):
|
||||
import traceback
|
||||
traceback.print_exception(type(e), e, e.__traceback__)
|
||||
exit(1)
|
||||
Loading…
Reference in New Issue
Block a user