first commit
This commit is contained in:
commit
7bd63eefc5
42
.gitignore
vendored
Normal file
42
.gitignore
vendored
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
# 环境变量
|
||||||
|
.env
|
||||||
|
|
||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# 虚拟环境
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
ENV/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
|
||||||
|
# 日志
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# 操作系统
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
259
README.md
Normal file
259
README.md
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
# AI-Dressing API
|
||||||
|
|
||||||
|
基于 FastAPI 框架和阿里云 DashScope 接口的 AI 服务 API。
|
||||||
|
|
||||||
|
## 功能特点
|
||||||
|
|
||||||
|
- 集成阿里云 DashScope 的大模型 API,支持通义千问系列模型
|
||||||
|
- 提供图像生成 API,支持 Stable Diffusion XL、万相等模型
|
||||||
|
- 集成腾讯云 COS 对象存储服务,支持文件上传、下载等功能
|
||||||
|
- 使用 MySQL 数据库存储服装数据
|
||||||
|
- RESTful API 风格,支持异步处理
|
||||||
|
- 易于扩展的模块化设计
|
||||||
|
|
||||||
|
## 环境要求
|
||||||
|
|
||||||
|
- Python 3.8+
|
||||||
|
- 阿里云 DashScope API 密钥
|
||||||
|
- 腾讯云 API 密钥(用于 COS 对象存储)
|
||||||
|
- MySQL 5.7+ 数据库
|
||||||
|
|
||||||
|
## 数据库设置
|
||||||
|
|
||||||
|
在使用应用程序之前,您需要创建 MySQL 数据库。登录到 MySQL 并执行以下命令:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
CREATE DATABASE ai_dressing CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||||
|
CREATE USER 'ai_user'@'%' IDENTIFIED BY 'your_password';
|
||||||
|
GRANT ALL PRIVILEGES ON ai_dressing.* TO 'ai_user'@'%';
|
||||||
|
FLUSH PRIVILEGES;
|
||||||
|
```
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 1. 安装依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 配置环境变量
|
||||||
|
|
||||||
|
复制 `.env.example` 文件为 `.env`,并填写您的 API 密钥:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
```
|
||||||
|
|
||||||
|
然后编辑 `.env` 文件:
|
||||||
|
|
||||||
|
```
|
||||||
|
# DashScope API密钥
|
||||||
|
DASHSCOPE_API_KEY=your_api_key_here
|
||||||
|
|
||||||
|
# 腾讯云配置
|
||||||
|
QCLOUD_SECRET_ID=your_secret_id_here
|
||||||
|
QCLOUD_SECRET_KEY=your_secret_key_here
|
||||||
|
QCLOUD_COS_REGION=ap-guangzhou
|
||||||
|
QCLOUD_COS_BUCKET=your-bucket-name
|
||||||
|
QCLOUD_COS_DOMAIN=https://your-bucket-name.cos.ap-guangzhou.myqcloud.com
|
||||||
|
|
||||||
|
# 数据库配置
|
||||||
|
DB_HOST=localhost
|
||||||
|
DB_PORT=3306
|
||||||
|
DB_USER=ai_user
|
||||||
|
DB_PASSWORD=your_password
|
||||||
|
DB_NAME=ai_dressing
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 数据库迁移
|
||||||
|
|
||||||
|
首次运行或模型变更后,需要创建或更新数据库表:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 创建迁移(仅在模型变更时需要)
|
||||||
|
python create_migration.py create -m "初始化数据库"
|
||||||
|
|
||||||
|
# 应用迁移
|
||||||
|
python create_migration.py upgrade
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 启动服务
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python run.py
|
||||||
|
```
|
||||||
|
|
||||||
|
服务将在 http://localhost:9001 启动,您可以访问 http://localhost:9001/docs 查看 API 文档。
|
||||||
|
|
||||||
|
## API 文档
|
||||||
|
|
||||||
|
启动服务后,访问以下地址查看自动生成的 API 文档:
|
||||||
|
|
||||||
|
- Swagger UI: http://localhost:9001/docs
|
||||||
|
- ReDoc: http://localhost:9001/redoc
|
||||||
|
|
||||||
|
## 主要 API 端点
|
||||||
|
|
||||||
|
### 大模型对话
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /api/dashscope/chat
|
||||||
|
```
|
||||||
|
|
||||||
|
### 图像生成
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /api/dashscope/generate-image
|
||||||
|
```
|
||||||
|
|
||||||
|
### 获取支持的模型列表
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /api/dashscope/models
|
||||||
|
```
|
||||||
|
|
||||||
|
### 文件上传到腾讯云 COS
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /api/qcloud/upload
|
||||||
|
```
|
||||||
|
|
||||||
|
### 从腾讯云 COS 获取文件列表
|
||||||
|
|
||||||
|
```
|
||||||
|
GET /api/qcloud/files
|
||||||
|
```
|
||||||
|
|
||||||
|
### 生成 COS 临时上传凭证
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /api/qcloud/sts-token
|
||||||
|
```
|
||||||
|
|
||||||
|
### 服装管理 API
|
||||||
|
|
||||||
|
```
|
||||||
|
# 创建服装
|
||||||
|
POST /api/dresses/
|
||||||
|
|
||||||
|
# 获取服装列表
|
||||||
|
GET /api/dresses/
|
||||||
|
|
||||||
|
# 获取单个服装
|
||||||
|
GET /api/dresses/{dress_id}
|
||||||
|
|
||||||
|
# 更新服装
|
||||||
|
PUT /api/dresses/{dress_id}
|
||||||
|
|
||||||
|
# 删除服装
|
||||||
|
DELETE /api/dresses/{dress_id}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
.
|
||||||
|
├── alembic/ # 数据库迁移相关文件
|
||||||
|
│ ├── versions/ # 迁移版本文件
|
||||||
|
│ ├── env.py # 迁移环境配置
|
||||||
|
│ └── script.py.mako # 迁移脚本模板
|
||||||
|
├── app/ # 应用主目录
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── main.py # FastAPI 应用程序
|
||||||
|
│ ├── database/ # 数据库相关
|
||||||
|
│ │ └── __init__.py
|
||||||
|
│ ├── models/ # 数据库模型
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ └── dress.py
|
||||||
|
│ ├── routers/ # API 路由
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── dashscope_router.py
|
||||||
|
│ │ ├── qcloud_router.py
|
||||||
|
│ │ └── dress_router.py
|
||||||
|
│ ├── schemas/ # Pydantic 模型
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ └── dress.py
|
||||||
|
│ ├── services/ # 业务逻辑服务
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── dashscope_service.py
|
||||||
|
│ │ └── qcloud_service.py
|
||||||
|
│ └── utils/ # 工具类
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ └── config.py
|
||||||
|
├── .env.example # 环境变量示例
|
||||||
|
├── alembic.ini # Alembic 配置
|
||||||
|
├── create_migration.py # 数据库迁移工具
|
||||||
|
├── requirements.txt # 项目依赖
|
||||||
|
├── README.md # 项目文档
|
||||||
|
└── run.py # 应用入口
|
||||||
|
```
|
||||||
|
|
||||||
|
## 开发指南
|
||||||
|
|
||||||
|
### 添加新的路由
|
||||||
|
|
||||||
|
1. 在 `app/routers/` 目录下创建新的路由文件
|
||||||
|
2. 在 `app/main.py` 中注册新的路由
|
||||||
|
|
||||||
|
### 添加新的数据库模型
|
||||||
|
|
||||||
|
1. 在 `app/models/` 目录下创建新的模型文件
|
||||||
|
2. 在 `app/schemas/` 目录下创建对应的 Pydantic 模型
|
||||||
|
3. 创建数据库迁移: `python create_migration.py create -m "添加新模型"`
|
||||||
|
4. 应用迁移: `python create_migration.py upgrade`
|
||||||
|
|
||||||
|
### 添加新的服务
|
||||||
|
|
||||||
|
在 `app/services/` 目录下创建新的服务类。
|
||||||
|
|
||||||
|
## 腾讯云 COS 使用指南
|
||||||
|
|
||||||
|
### 前端直传文件
|
||||||
|
|
||||||
|
1. 首先从服务端获取临时上传凭证:
|
||||||
|
```javascript
|
||||||
|
const response = await fetch('/api/qcloud/sts-token', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/x-www-form-urlencoded',
|
||||||
|
},
|
||||||
|
body: 'allow_prefix=uploads/&duration_seconds=1800'
|
||||||
|
});
|
||||||
|
const stsData = await response.json();
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 使用临时凭证进行文件上传:
|
||||||
|
```javascript
|
||||||
|
// 使用腾讯云COS SDK上传
|
||||||
|
const cos = new COS({
|
||||||
|
getAuthorization: function(options, callback) {
|
||||||
|
callback({
|
||||||
|
TmpSecretId: stsData.credentials.tmpSecretId,
|
||||||
|
TmpSecretKey: stsData.credentials.tmpSecretKey,
|
||||||
|
SecurityToken: stsData.credentials.sessionToken,
|
||||||
|
ExpiredTime: stsData.expiration
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
cos.putObject({
|
||||||
|
Bucket: stsData.bucket,
|
||||||
|
Region: stsData.region,
|
||||||
|
Key: 'uploads/example.jpg',
|
||||||
|
Body: file,
|
||||||
|
onProgress: function(progressData) {
|
||||||
|
console.log(progressData);
|
||||||
|
}
|
||||||
|
}, function(err, data) {
|
||||||
|
if (err) {
|
||||||
|
console.error('上传失败', err);
|
||||||
|
} else {
|
||||||
|
console.log('上传成功', data);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
```
|
||||||
|
|
||||||
|
## 许可证
|
||||||
|
|
||||||
|
MIT
|
||||||
99
alembic.ini
Normal file
99
alembic.ini
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
# A generic, single database configuration.
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
# path to migration scripts
|
||||||
|
script_location = alembic
|
||||||
|
|
||||||
|
# template used to generate migration files
|
||||||
|
# file_template = %%(rev)s_%%(slug)s
|
||||||
|
|
||||||
|
# sys.path path, will be prepended to sys.path if present.
|
||||||
|
# defaults to the current working directory.
|
||||||
|
prepend_sys_path = .
|
||||||
|
|
||||||
|
# timezone to use when rendering the date within the migration file
|
||||||
|
# as well as the filename.
|
||||||
|
# If specified, requires the python-dateutil library that can be
|
||||||
|
# installed by adding `python-dateutil` to the requirements.
|
||||||
|
# For example: timezone = UTC
|
||||||
|
# timezone =
|
||||||
|
|
||||||
|
# max length of characters to apply to the
|
||||||
|
# "slug" field
|
||||||
|
# truncate_slug_length = 40
|
||||||
|
|
||||||
|
# set to 'true' to run the environment during
|
||||||
|
# the 'revision' command, regardless of autogenerate
|
||||||
|
# revision_environment = false
|
||||||
|
|
||||||
|
# set to 'true' to allow .pyc and .pyo files without
|
||||||
|
# a source .py file to be detected as revisions in the
|
||||||
|
# versions/ directory
|
||||||
|
# sourceless = false
|
||||||
|
|
||||||
|
# version location specification; This defaults
|
||||||
|
# to alembic/versions. When using multiple version
|
||||||
|
# directories, initial revisions must be specified with --version-path.
|
||||||
|
# The path separator used here should be the separator specified by "version_path_separator"
|
||||||
|
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
|
||||||
|
|
||||||
|
# version path separator; As mentioned above, this is the character used to split
|
||||||
|
# version_locations. Valid values are:
|
||||||
|
#
|
||||||
|
# version_path_separator = :
|
||||||
|
# version_path_separator = ;
|
||||||
|
# version_path_separator = space
|
||||||
|
version_path_separator = os # default: use os.pathsep
|
||||||
|
|
||||||
|
# the output encoding used when revision files
|
||||||
|
# are written from script.py.mako
|
||||||
|
# output_encoding = utf-8
|
||||||
|
|
||||||
|
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||||
|
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
# post_write_hooks defines scripts or Python functions that are run
|
||||||
|
# on newly generated revision scripts. See the documentation for further
|
||||||
|
# detail and examples
|
||||||
|
|
||||||
|
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||||
|
# hooks = black
|
||||||
|
# black.type = console_scripts
|
||||||
|
# black.entrypoint = black
|
||||||
|
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||||
|
|
||||||
|
# Logging configuration
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
87
alembic/env.py
Normal file
87
alembic/env.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from sqlalchemy import engine_from_config
|
||||||
|
from sqlalchemy import pool
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
|
||||||
|
# 导入应用的数据库配置和模型
|
||||||
|
from app.database import Base
|
||||||
|
from app.utils.config import get_settings
|
||||||
|
from app.models.dress import Dress # 导入所有模型
|
||||||
|
from app.models.tryon import TryOn
|
||||||
|
|
||||||
|
# this is the Alembic Config object, which provides
|
||||||
|
# access to the values within the .ini file in use.
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# 从应用程序配置中获取数据库URL
|
||||||
|
settings = get_settings()
|
||||||
|
config.set_main_option("sqlalchemy.url", settings.database_url)
|
||||||
|
|
||||||
|
# Interpret the config file for Python logging.
|
||||||
|
# This line sets up loggers basically.
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# add your model's MetaData object here
|
||||||
|
# for 'autogenerate' support
|
||||||
|
# from myapp import mymodel
|
||||||
|
# target_metadata = mymodel.Base.metadata
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
# other values from the config, defined by the needs of env.py,
|
||||||
|
# can be acquired:
|
||||||
|
# my_important_option = config.get_main_option("my_important_option")
|
||||||
|
# ... etc.
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline():
|
||||||
|
"""Run migrations in 'offline' mode.
|
||||||
|
|
||||||
|
This configures the context with just a URL
|
||||||
|
and not an Engine, though an Engine is acceptable
|
||||||
|
here as well. By skipping the Engine creation
|
||||||
|
we don't even need a DBAPI to be available.
|
||||||
|
|
||||||
|
Calls to context.execute() here emit the given string to the
|
||||||
|
script output.
|
||||||
|
|
||||||
|
"""
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online():
|
||||||
|
"""Run migrations in 'online' mode.
|
||||||
|
|
||||||
|
In this scenario we need to create an Engine
|
||||||
|
and associate a connection with the context.
|
||||||
|
|
||||||
|
"""
|
||||||
|
connectable = engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
context.configure(
|
||||||
|
connection=connection, target_metadata=target_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
24
alembic/script.py.mako
Normal file
24
alembic/script.py.mako
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = ${repr(up_revision)}
|
||||||
|
down_revision = ${repr(down_revision)}
|
||||||
|
branch_labels = ${repr(branch_labels)}
|
||||||
|
depends_on = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
42
alembic/versions/f166e37d7b53_添加试穿记录表.py
Normal file
42
alembic/versions/f166e37d7b53_添加试穿记录表.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
"""添加试穿记录表
|
||||||
|
|
||||||
|
Revision ID: f166e37d7b53
|
||||||
|
Revises:
|
||||||
|
Create Date: 2025-03-21 15:58:04.889147
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'f166e37d7b53'
|
||||||
|
down_revision = None
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('tryons',
|
||||||
|
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column('top_garment_url', sa.String(length=1024), nullable=True, comment='上衣图片URL'),
|
||||||
|
sa.Column('bottom_garment_url', sa.String(length=1024), nullable=True, comment='下衣图片URL'),
|
||||||
|
sa.Column('person_image_url', sa.String(length=1024), nullable=True, comment='人物图片URL'),
|
||||||
|
sa.Column('request_id', sa.String(length=255), nullable=True, comment='请求ID'),
|
||||||
|
sa.Column('task_id', sa.String(length=255), nullable=True, comment='任务ID'),
|
||||||
|
sa.Column('task_status', sa.String(length=50), nullable=True, comment='任务状态'),
|
||||||
|
sa.Column('completion_url', sa.String(length=1024), nullable=True, comment='生成图片URL'),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=True, comment='创建时间'),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=True, comment='更新时间'),
|
||||||
|
sa.PrimaryKeyConstraint('id')
|
||||||
|
)
|
||||||
|
op.create_index(op.f('ix_tryons_id'), 'tryons', ['id'], unique=False)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(op.f('ix_tryons_id'), table_name='tryons')
|
||||||
|
op.drop_table('tryons')
|
||||||
|
# ### end Alembic commands ###
|
||||||
1
app/__init__.py
Normal file
1
app/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
||||||
28
app/database/__init__.py
Normal file
28
app/database/__init__.py
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from app.utils.config import get_settings
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# 创建数据库引擎
|
||||||
|
engine = create_engine(
|
||||||
|
settings.database_url,
|
||||||
|
pool_pre_ping=True, # 自动检测断开的连接并重新连接
|
||||||
|
pool_recycle=3600, # 每小时回收连接
|
||||||
|
echo=settings.debug, # 在调试模式下打印SQL语句
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建会话工厂
|
||||||
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
|
# 创建Base类,所有模型都将继承这个类
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
# 获取数据库会话的依赖函数
|
||||||
|
def get_db():
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
68
app/main.py
Normal file
68
app/main.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from app.routers import dashscope_router, qcloud_router, dress_router, tryon_router
|
||||||
|
from app.utils.config import get_settings
|
||||||
|
from app.database import Base, engine
|
||||||
|
|
||||||
|
# 创建数据库表
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="AI-Dressing API",
|
||||||
|
description="基于 DashScope 的 AI 服务 API",
|
||||||
|
version="0.1.0",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加 CORS 中间件
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # 在生产环境中,应该指定确切的域名
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
app.include_router(dashscope_router.router, prefix="/api/dashscope", tags=["DashScope"])
|
||||||
|
app.include_router(qcloud_router.router, prefix="/api/qcloud", tags=["腾讯云"])
|
||||||
|
app.include_router(dress_router.router, prefix="/api/dresses", tags=["服装"])
|
||||||
|
app.include_router(tryon_router.router, prefix="/api/tryons", tags=["试穿"])
|
||||||
|
|
||||||
|
@app.get("/", tags=["健康检查"])
|
||||||
|
async def health_check():
|
||||||
|
"""API 健康检查端点"""
|
||||||
|
return {"status": "正常", "message": "服务运行中"}
|
||||||
|
|
||||||
|
@app.get("/info", tags=["服务信息"])
|
||||||
|
async def get_info():
|
||||||
|
"""获取服务基本信息"""
|
||||||
|
settings = get_settings()
|
||||||
|
return {
|
||||||
|
"app_name": "AI-Dressing API",
|
||||||
|
"version": "0.1.0",
|
||||||
|
"debug_mode": settings.debug,
|
||||||
|
}
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
settings = get_settings()
|
||||||
|
uvicorn.run(
|
||||||
|
"app.main:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=settings.debug,
|
||||||
|
)
|
||||||
1
app/models/__init__.py
Normal file
1
app/models/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
||||||
27
app/models/dress.py
Normal file
27
app/models/dress.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
from sqlalchemy import Column, Integer, String, DateTime, Text, func, Enum
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from app.database import Base
|
||||||
|
import datetime
|
||||||
|
import enum
|
||||||
|
|
||||||
|
class GarmentType(enum.Enum):
|
||||||
|
"""服装类型枚举"""
|
||||||
|
TOP_GARMENT = "TOP_GARMENT"
|
||||||
|
BOTTOM_GARMENT = "BOTTOM_GARMENT"
|
||||||
|
|
||||||
|
class Dress(Base):
|
||||||
|
"""服装模型类"""
|
||||||
|
__tablename__ = "dresses"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||||
|
name = Column(String(255), nullable=False, comment="服装名称")
|
||||||
|
image_url = Column(String(1024), nullable=True, comment="服装图片URL")
|
||||||
|
garment_type = Column(Enum(GarmentType), nullable=True, comment="服装类型(上衣/下衣)")
|
||||||
|
|
||||||
|
# 以下是一些可选的附加字段,根据需要可以去掉或添加
|
||||||
|
description = Column(Text, nullable=True, comment="服装描述")
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.now(), comment="创建时间")
|
||||||
|
updated_at = Column(DateTime, default=datetime.datetime.now(), onupdate=datetime.datetime.now, comment="更新时间")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Dress(id={self.id}, name='{self.name}')>"
|
||||||
23
app/models/tryon.py
Normal file
23
app/models/tryon.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from sqlalchemy import Column, Integer, String, DateTime, Text, func
|
||||||
|
from app.database import Base
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
class TryOn(Base):
|
||||||
|
"""试穿记录模型类"""
|
||||||
|
__tablename__ = "tryons"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
|
||||||
|
top_garment_url = Column(String(1024), nullable=True, comment="上衣图片URL")
|
||||||
|
bottom_garment_url = Column(String(1024), nullable=True, comment="下衣图片URL")
|
||||||
|
person_image_url = Column(String(1024), nullable=True, comment="人物图片URL")
|
||||||
|
|
||||||
|
request_id = Column(String(255), nullable=True, comment="请求ID")
|
||||||
|
task_id = Column(String(255), nullable=True, comment="任务ID")
|
||||||
|
task_status = Column(String(50), nullable=True, comment="任务状态")
|
||||||
|
completion_url = Column(String(1024), nullable=True, comment="生成图片URL")
|
||||||
|
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.now(), comment="创建时间")
|
||||||
|
updated_at = Column(DateTime, default=datetime.datetime.now(), onupdate=datetime.datetime.now, comment="更新时间")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<TryOn(id={self.id}, request_id='{self.request_id}')>"
|
||||||
1
app/routers/__init__.py
Normal file
1
app/routers/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
||||||
262
app/routers/dashscope_router.py
Normal file
262
app/routers/dashscope_router.py
Normal file
@ -0,0 +1,262 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException, Depends
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
import dashscope
|
||||||
|
from app.services.dashscope_service import DashScopeService
|
||||||
|
from app.utils.config import get_settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
class ChatMessage(BaseModel):
|
||||||
|
role: str # 'user' 或 'assistant'
|
||||||
|
content: str
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
messages: List[ChatMessage]
|
||||||
|
model: Optional[str] = "qwen-max" # 默认使用通义千问MAX模型
|
||||||
|
max_tokens: Optional[int] = 2048
|
||||||
|
temperature: Optional[float] = 0.7
|
||||||
|
stream: Optional[bool] = False
|
||||||
|
|
||||||
|
class ChatResponse(BaseModel):
|
||||||
|
response: str
|
||||||
|
usage: Dict[str, Any]
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
@router.post("/chat", response_model=ChatResponse)
|
||||||
|
async def chat_completion(
|
||||||
|
request: ChatRequest,
|
||||||
|
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
调用DashScope的大模型进行对话
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]
|
||||||
|
response = await dashscope_service.chat_completion(
|
||||||
|
messages=messages,
|
||||||
|
model=request.model,
|
||||||
|
max_tokens=request.max_tokens,
|
||||||
|
temperature=request.temperature,
|
||||||
|
stream=request.stream
|
||||||
|
)
|
||||||
|
|
||||||
|
# 确保我们能正确访问响应字段,根据dashscope的最新版本调整
|
||||||
|
response_text = ""
|
||||||
|
if hasattr(response, 'output') and hasattr(response.output, 'text'):
|
||||||
|
response_text = response.output.text
|
||||||
|
elif hasattr(response, 'output') and isinstance(response.output, dict) and 'text' in response.output:
|
||||||
|
response_text = response.output['text']
|
||||||
|
elif hasattr(response, 'choices') and len(response.choices) > 0:
|
||||||
|
response_text = response.choices[0].message.content
|
||||||
|
else:
|
||||||
|
logger.warning(f"无法解析DashScope响应: {response}")
|
||||||
|
# 尝试保守提取
|
||||||
|
try:
|
||||||
|
if hasattr(response, 'output'):
|
||||||
|
if isinstance(response.output, dict):
|
||||||
|
response_text = str(response.output.get('text', ''))
|
||||||
|
else:
|
||||||
|
response_text = str(response.output)
|
||||||
|
else:
|
||||||
|
response_text = str(response)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取响应文本失败: {str(e)}")
|
||||||
|
response_text = "无法获取模型响应文本"
|
||||||
|
|
||||||
|
# 同样确保我们能正确访问usage和request_id
|
||||||
|
usage = {}
|
||||||
|
if hasattr(response, 'usage'):
|
||||||
|
usage = response.usage if isinstance(response.usage, dict) else vars(response.usage)
|
||||||
|
|
||||||
|
request_id = ""
|
||||||
|
if hasattr(response, 'request_id'):
|
||||||
|
request_id = response.request_id
|
||||||
|
|
||||||
|
return {
|
||||||
|
"response": response_text,
|
||||||
|
"usage": usage,
|
||||||
|
"request_id": request_id
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DashScope API 调用错误: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"模型调用失败: {str(e)}")
|
||||||
|
|
||||||
|
class ImageGenerationRequest(BaseModel):
|
||||||
|
prompt: str
|
||||||
|
negative_prompt: Optional[str] = None
|
||||||
|
model: Optional[str] = "stable-diffusion-xl"
|
||||||
|
n: Optional[int] = 1
|
||||||
|
size: Optional[str] = "1024*1024"
|
||||||
|
|
||||||
|
class ImageGenerationResponse(BaseModel):
|
||||||
|
images: List[str]
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
@router.post("/generate-image", response_model=ImageGenerationResponse)
|
||||||
|
async def generate_image(
|
||||||
|
request: ImageGenerationRequest,
|
||||||
|
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
调用DashScope的图像生成API
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await dashscope_service.generate_image(
|
||||||
|
prompt=request.prompt,
|
||||||
|
negative_prompt=request.negative_prompt,
|
||||||
|
model=request.model,
|
||||||
|
n=request.n,
|
||||||
|
size=request.size
|
||||||
|
)
|
||||||
|
|
||||||
|
# 确保我们能正确访问images和request_id
|
||||||
|
images = []
|
||||||
|
if hasattr(response, 'output') and hasattr(response.output, 'images'):
|
||||||
|
images = response.output.images
|
||||||
|
elif hasattr(response, 'output') and isinstance(response.output, dict) and 'images' in response.output:
|
||||||
|
images = response.output['images']
|
||||||
|
else:
|
||||||
|
logger.warning(f"无法解析DashScope图像生成响应: {response}")
|
||||||
|
try:
|
||||||
|
if hasattr(response, 'output'):
|
||||||
|
if isinstance(response.output, dict):
|
||||||
|
images = response.output.get('images', [])
|
||||||
|
else:
|
||||||
|
images = []
|
||||||
|
else:
|
||||||
|
images = []
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取图像URL失败: {str(e)}")
|
||||||
|
images = []
|
||||||
|
|
||||||
|
request_id = ""
|
||||||
|
if hasattr(response, 'request_id'):
|
||||||
|
request_id = response.request_id
|
||||||
|
|
||||||
|
return {
|
||||||
|
"images": images,
|
||||||
|
"request_id": request_id
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DashScope 图像生成API调用错误: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"图像生成失败: {str(e)}")
|
||||||
|
|
||||||
|
class TryOnRequest(BaseModel):
|
||||||
|
"""试穿请求模型"""
|
||||||
|
person_image_url: str
|
||||||
|
top_garment_url: Optional[str] = None
|
||||||
|
bottom_garment_url: Optional[str] = None
|
||||||
|
resolution: Optional[int] = -1
|
||||||
|
restore_face: Optional[bool] = True
|
||||||
|
|
||||||
|
class TaskInfo(BaseModel):
|
||||||
|
"""任务信息模型"""
|
||||||
|
task_id: str
|
||||||
|
task_status: str
|
||||||
|
|
||||||
|
class TryOnResponse(BaseModel):
|
||||||
|
"""试穿响应模型"""
|
||||||
|
output: TaskInfo
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
@router.post("/try-on", response_model=TryOnResponse)
|
||||||
|
async def generate_tryon(
|
||||||
|
request: TryOnRequest,
|
||||||
|
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
调用DashScope的人物试衣API
|
||||||
|
|
||||||
|
该API用于给人物图片试穿上衣和/或下衣,返回合成图片的任务ID
|
||||||
|
|
||||||
|
- **person_image_url**: 人物图片URL(必填)
|
||||||
|
- **top_garment_url**: 上衣图片URL(可选)
|
||||||
|
- **bottom_garment_url**: 下衣图片URL(可选)
|
||||||
|
- **resolution**: 分辨率,-1表示自动(可选)
|
||||||
|
- **restore_face**: 是否修复面部(可选)
|
||||||
|
|
||||||
|
注意:top_garment_url和bottom_garment_url至少需要提供一个
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证参数
|
||||||
|
if not request.top_garment_url and not request.bottom_garment_url:
|
||||||
|
raise HTTPException(status_code=400, detail="上衣和下衣图片至少需要提供一个")
|
||||||
|
|
||||||
|
response = await dashscope_service.generate_tryon(
|
||||||
|
person_image_url=request.person_image_url,
|
||||||
|
top_garment_url=request.top_garment_url,
|
||||||
|
bottom_garment_url=request.bottom_garment_url,
|
||||||
|
resolution=request.resolution,
|
||||||
|
restore_face=request.restore_face
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
return {
|
||||||
|
"output": {
|
||||||
|
"task_id": response.get("output", {}).get("task_id", ""),
|
||||||
|
"task_status": response.get("output", {}).get("task_status", "")
|
||||||
|
},
|
||||||
|
"request_id": response.get("request_id", "")
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DashScope 试穿API调用错误: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"试穿请求失败: {str(e)}")
|
||||||
|
|
||||||
|
class TaskStatusResponse(BaseModel):
|
||||||
|
"""任务状态响应模型"""
|
||||||
|
task_id: str
|
||||||
|
task_status: str
|
||||||
|
completion_url: Optional[str] = None
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
@router.get("/try-on/{task_id}", response_model=TaskStatusResponse)
|
||||||
|
async def check_tryon_status(
|
||||||
|
task_id: str,
|
||||||
|
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
查询试穿任务状态
|
||||||
|
|
||||||
|
- **task_id**: 任务ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
response = await dashscope_service.check_tryon_status(task_id)
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
return {
|
||||||
|
"task_id": task_id,
|
||||||
|
"task_status": response.get("output", {}).get("task_status", ""),
|
||||||
|
"completion_url": response.get("output", {}).get("url", ""),
|
||||||
|
"request_id": response.get("request_id", "")
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询试穿任务状态错误: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"查询任务状态失败: {str(e)}")
|
||||||
|
|
||||||
|
@router.get("/models")
|
||||||
|
async def list_available_models():
|
||||||
|
"""
|
||||||
|
列出DashScope上可用的模型
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
models = {
|
||||||
|
"chat_models": [
|
||||||
|
{"id": "qwen-max", "name": "通义千问MAX"},
|
||||||
|
{"id": "qwen-plus", "name": "通义千问Plus"},
|
||||||
|
{"id": "qwen-turbo", "name": "通义千问Turbo"}
|
||||||
|
],
|
||||||
|
"image_models": [
|
||||||
|
{"id": "stable-diffusion-xl", "name": "Stable Diffusion XL"},
|
||||||
|
{"id": "wanx-v1", "name": "万相"}
|
||||||
|
],
|
||||||
|
"tryon_models": [
|
||||||
|
{"id": "aitryon", "name": "AI试穿"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取模型列表错误: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}")
|
||||||
162
app/routers/dress_router.py
Normal file
162
app/routers/dress_router.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException, Depends, Query, Path
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from app.database import get_db
|
||||||
|
from app.models.dress import Dress, GarmentType
|
||||||
|
from app.schemas.dress import DressCreate, DressUpdate, DressResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.post("/", response_model=DressResponse, status_code=201)
|
||||||
|
async def create_dress(
|
||||||
|
dress: DressCreate,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
创建一个新的服装记录
|
||||||
|
|
||||||
|
- **name**: 服装名称(必填)
|
||||||
|
- **image_url**: 服装图片URL(可选)
|
||||||
|
- **garment_type**: 服装类型(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)(可选)
|
||||||
|
- **description**: 服装描述(可选)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
db_dress = Dress(
|
||||||
|
name=dress.name,
|
||||||
|
image_url=dress.image_url,
|
||||||
|
garment_type=dress.garment_type,
|
||||||
|
description=dress.description
|
||||||
|
)
|
||||||
|
db.add(db_dress)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_dress)
|
||||||
|
return db_dress
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建服装记录失败: {str(e)}")
|
||||||
|
db.rollback()
|
||||||
|
raise HTTPException(status_code=500, detail=f"创建服装记录失败: {str(e)}")
|
||||||
|
|
||||||
|
@router.get("/", response_model=List[DressResponse])
|
||||||
|
async def get_dresses(
|
||||||
|
skip: int = Query(0, description="跳过的记录数量"),
|
||||||
|
limit: int = Query(100, description="返回的最大记录数量"),
|
||||||
|
name: Optional[str] = Query(None, description="按名称过滤"),
|
||||||
|
garment_type: Optional[str] = Query(None, description="按服装类型过滤(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)"),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取服装列表,支持分页和过滤
|
||||||
|
|
||||||
|
- **skip**: 跳过的记录数量,用于分页
|
||||||
|
- **limit**: 返回的最大记录数量,用于分页
|
||||||
|
- **name**: 按名称过滤(可选)
|
||||||
|
- **garment_type**: 按服装类型过滤(可选)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
query = db.query(Dress)
|
||||||
|
|
||||||
|
if name:
|
||||||
|
query = query.filter(Dress.name.ilike(f"%{name}%"))
|
||||||
|
|
||||||
|
if garment_type:
|
||||||
|
try:
|
||||||
|
garment_type_enum = GarmentType[garment_type]
|
||||||
|
query = query.filter(Dress.garment_type == garment_type_enum)
|
||||||
|
except KeyError:
|
||||||
|
logger.warning(f"无效的服装类型: {garment_type}")
|
||||||
|
# 继续查询,但不应用无效的过滤条件
|
||||||
|
|
||||||
|
dresses = query.offset(skip).limit(limit).all()
|
||||||
|
return dresses
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取服装列表失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取服装列表失败: {str(e)}")
|
||||||
|
|
||||||
|
@router.get("/{dress_id}", response_model=DressResponse)
|
||||||
|
async def get_dress(
|
||||||
|
dress_id: int = Path(..., description="服装ID"),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
根据ID获取服装详情
|
||||||
|
|
||||||
|
- **dress_id**: 服装ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
dress = db.query(Dress).filter(Dress.id == dress_id).first()
|
||||||
|
if not dress:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装")
|
||||||
|
return dress
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取服装详情失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取服装详情失败: {str(e)}")
|
||||||
|
|
||||||
|
@router.put("/{dress_id}", response_model=DressResponse)
|
||||||
|
async def update_dress(
|
||||||
|
dress_id: int = Path(..., description="服装ID"),
|
||||||
|
dress: DressUpdate = None,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
更新服装信息
|
||||||
|
|
||||||
|
- **dress_id**: 服装ID
|
||||||
|
- **name**: 服装名称(可选)
|
||||||
|
- **image_url**: 服装图片URL(可选)
|
||||||
|
- **garment_type**: 服装类型(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)(可选)
|
||||||
|
- **description**: 服装描述(可选)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
db_dress = db.query(Dress).filter(Dress.id == dress_id).first()
|
||||||
|
if not db_dress:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装")
|
||||||
|
|
||||||
|
# 更新提供的字段
|
||||||
|
if dress.name is not None:
|
||||||
|
db_dress.name = dress.name
|
||||||
|
if dress.image_url is not None:
|
||||||
|
db_dress.image_url = dress.image_url
|
||||||
|
if dress.garment_type is not None:
|
||||||
|
db_dress.garment_type = dress.garment_type
|
||||||
|
if dress.description is not None:
|
||||||
|
db_dress.description = dress.description
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_dress)
|
||||||
|
return db_dress
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新服装信息失败: {str(e)}")
|
||||||
|
db.rollback()
|
||||||
|
raise HTTPException(status_code=500, detail=f"更新服装信息失败: {str(e)}")
|
||||||
|
|
||||||
|
@router.delete("/{dress_id}", status_code=204)
|
||||||
|
async def delete_dress(
|
||||||
|
dress_id: int = Path(..., description="服装ID"),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
删除服装
|
||||||
|
|
||||||
|
- **dress_id**: 服装ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
db_dress = db.query(Dress).filter(Dress.id == dress_id).first()
|
||||||
|
if not db_dress:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装")
|
||||||
|
|
||||||
|
db.delete(db_dress)
|
||||||
|
db.commit()
|
||||||
|
return None
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除服装失败: {str(e)}")
|
||||||
|
db.rollback()
|
||||||
|
raise HTTPException(status_code=500, detail=f"删除服装失败: {str(e)}")
|
||||||
155
app/routers/qcloud_router.py
Normal file
155
app/routers/qcloud_router.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException, Depends, UploadFile, File, Form, Query
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
import logging
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
|
||||||
|
from app.services.qcloud_service import QCloudCOSService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
class STSTokenResponse(BaseModel):
|
||||||
|
"""STS临时凭证响应"""
|
||||||
|
credentials: Dict[str, str]
|
||||||
|
expiration: str
|
||||||
|
request_id: str
|
||||||
|
bucket: str
|
||||||
|
region: str
|
||||||
|
|
||||||
|
class FileInfo(BaseModel):
|
||||||
|
"""文件信息"""
|
||||||
|
key: str
|
||||||
|
url: str
|
||||||
|
size: int
|
||||||
|
last_modified: str
|
||||||
|
etag: str
|
||||||
|
|
||||||
|
class ListFilesResponse(BaseModel):
|
||||||
|
"""列出文件响应"""
|
||||||
|
files: List[FileInfo]
|
||||||
|
is_truncated: bool
|
||||||
|
next_marker: Optional[str] = None
|
||||||
|
common_prefixes: Optional[List[Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
@router.post("/upload", tags=["腾讯云COS"])
|
||||||
|
async def upload_file(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
directory: str = Form("uploads"),
|
||||||
|
qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService())
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
上传文件到腾讯云COS
|
||||||
|
|
||||||
|
- **file**: 要上传的文件
|
||||||
|
- **directory**: 存储目录(默认为"uploads")
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 读取文件内容
|
||||||
|
file_content = await file.read()
|
||||||
|
|
||||||
|
# 上传文件
|
||||||
|
result = await qcloud_service.upload_file(
|
||||||
|
file_content=file_content,
|
||||||
|
file_name=file.filename,
|
||||||
|
directory=directory,
|
||||||
|
content_type=file.content_type
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"文件上传失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"文件上传失败: {str(e)}")
|
||||||
|
finally:
|
||||||
|
# 关闭文件
|
||||||
|
await file.close()
|
||||||
|
|
||||||
|
@router.delete("/files/{key:path}", tags=["腾讯云COS"])
|
||||||
|
async def delete_file(
|
||||||
|
key: str,
|
||||||
|
qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService())
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
从腾讯云COS删除文件
|
||||||
|
|
||||||
|
- **key**: 文件的对象键(COS路径)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
success = await qcloud_service.delete_file(key=key)
|
||||||
|
if success:
|
||||||
|
return {"message": "文件删除成功", "key": key}
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=404, detail=f"文件删除失败,文件可能不存在")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"文件删除失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"文件删除失败: {str(e)}")
|
||||||
|
|
||||||
|
@router.get("/files/url/{key:path}", tags=["腾讯云COS"])
|
||||||
|
async def get_file_url(
|
||||||
|
key: str,
|
||||||
|
expires: int = Query(3600, description="URL有效期(秒)"),
|
||||||
|
qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService())
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取COS文件的临时访问URL
|
||||||
|
|
||||||
|
- **key**: 文件的对象键(COS路径)
|
||||||
|
- **expires**: URL的有效期(秒),默认3600秒
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
url = await qcloud_service.get_file_url(key=key, expires=expires)
|
||||||
|
return {"url": url, "key": key, "expires": expires}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取文件URL失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取文件URL失败: {str(e)}")
|
||||||
|
|
||||||
|
@router.post("/sts-token", response_model=STSTokenResponse, tags=["腾讯云COS"])
|
||||||
|
async def generate_sts_token(
|
||||||
|
allow_prefix: str = Form("*", description="允许操作的对象前缀"),
|
||||||
|
duration_seconds: int = Form(1800, description="凭证有效期(秒)"),
|
||||||
|
qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService())
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
生成腾讯云COS临时访问凭证(STS),用于前端直传
|
||||||
|
|
||||||
|
- **allow_prefix**: 允许操作的对象前缀,默认为'*'表示所有对象
|
||||||
|
- **duration_seconds**: 凭证有效期(秒),默认1800秒
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
sts_result = await qcloud_service.generate_cos_sts_token(
|
||||||
|
allow_prefix=allow_prefix,
|
||||||
|
duration_seconds=duration_seconds
|
||||||
|
)
|
||||||
|
# 添加桶和区域信息,方便前端使用
|
||||||
|
sts_result['bucket'] = qcloud_service.bucket
|
||||||
|
sts_result['region'] = qcloud_service.region
|
||||||
|
|
||||||
|
return sts_result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"生成STS临时凭证失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"生成STS临时凭证失败: {str(e)}")
|
||||||
|
|
||||||
|
@router.get("/files", response_model=ListFilesResponse, tags=["腾讯云COS"])
|
||||||
|
async def list_files(
|
||||||
|
directory: str = Query("", description="目录前缀"),
|
||||||
|
limit: int = Query(100, description="返回的最大文件数"),
|
||||||
|
marker: str = Query("", description="分页标记"),
|
||||||
|
qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService())
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
列出COS中的文件
|
||||||
|
|
||||||
|
- **directory**: 目录前缀
|
||||||
|
- **limit**: 返回的最大文件数
|
||||||
|
- **marker**: 分页标记
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = await qcloud_service.list_files(
|
||||||
|
directory=directory,
|
||||||
|
limit=limit,
|
||||||
|
marker=marker
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"列出文件失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"列出文件失败: {str(e)}")
|
||||||
225
app/routers/tryon_router.py
Normal file
225
app/routers/tryon_router.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException, Depends, Query, Path, BackgroundTasks
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import List, Optional
|
||||||
|
import httpx
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from app.database import get_db
|
||||||
|
from app.models.tryon import TryOn
|
||||||
|
from app.schemas.tryon import (
|
||||||
|
TryOnCreate, TryOnUpdate, TryOnResponse,
|
||||||
|
AiTryonRequest, AiTryonResponse, TaskInfo
|
||||||
|
)
|
||||||
|
from app.services.dashscope_service import DashScopeService
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# 从环境变量获取API密钥
|
||||||
|
DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")
|
||||||
|
DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis"
|
||||||
|
|
||||||
|
@router.post("/", response_model=TryOnResponse, status_code=201)
|
||||||
|
async def create_tryon(
|
||||||
|
tryon_data: TryOnCreate,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
创建一个试穿记录并发送到阿里百炼平台
|
||||||
|
|
||||||
|
- **top_garment_url**: 上衣图片URL(可选)
|
||||||
|
- **bottom_garment_url**: 下衣图片URL(可选)
|
||||||
|
- **person_image_url**: 人物图片URL(必填)
|
||||||
|
|
||||||
|
注意:上衣和下衣至少需要提供一个
|
||||||
|
"""
|
||||||
|
if not tryon_data.top_garment_url and not tryon_data.bottom_garment_url:
|
||||||
|
raise HTTPException(status_code=400, detail="上衣和下衣图片至少需要提供一个")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 创建试穿记录
|
||||||
|
db_tryon = TryOn(
|
||||||
|
top_garment_url=tryon_data.top_garment_url,
|
||||||
|
bottom_garment_url=tryon_data.bottom_garment_url,
|
||||||
|
person_image_url=tryon_data.person_image_url
|
||||||
|
)
|
||||||
|
db.add(db_tryon)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_tryon)
|
||||||
|
|
||||||
|
# 在后台发送请求到阿里百炼平台
|
||||||
|
background_tasks.add_task(
|
||||||
|
send_tryon_request,
|
||||||
|
db=db,
|
||||||
|
tryon_id=db_tryon.id,
|
||||||
|
top_garment_url=tryon_data.top_garment_url,
|
||||||
|
bottom_garment_url=tryon_data.bottom_garment_url,
|
||||||
|
person_image_url=tryon_data.person_image_url
|
||||||
|
)
|
||||||
|
|
||||||
|
return db_tryon
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建试穿记录失败: {str(e)}")
|
||||||
|
db.rollback()
|
||||||
|
raise HTTPException(status_code=500, detail=f"创建试穿记录失败: {str(e)}")
|
||||||
|
|
||||||
|
async def send_tryon_request(
|
||||||
|
db: Session,
|
||||||
|
tryon_id: int,
|
||||||
|
top_garment_url: Optional[str],
|
||||||
|
bottom_garment_url: Optional[str],
|
||||||
|
person_image_url: str
|
||||||
|
):
|
||||||
|
"""发送试穿请求到阿里百炼平台"""
|
||||||
|
try:
|
||||||
|
# 创建DashScopeService实例
|
||||||
|
dashscope_service = DashScopeService()
|
||||||
|
|
||||||
|
# 调用服务发送试穿请求
|
||||||
|
response = await dashscope_service.generate_tryon(
|
||||||
|
person_image_url=person_image_url,
|
||||||
|
top_garment_url=top_garment_url,
|
||||||
|
bottom_garment_url=bottom_garment_url
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新数据库记录
|
||||||
|
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
|
||||||
|
if db_tryon:
|
||||||
|
db_tryon.request_id = response.get("request_id")
|
||||||
|
db_tryon.task_id = response.get("output", {}).get("task_id")
|
||||||
|
db_tryon.task_status = response.get("output", {}).get("task_status")
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(f"试穿请求发送成功,任务ID: {db_tryon.task_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送试穿请求异常: {str(e)}")
|
||||||
|
|
||||||
|
# 更新数据库记录状态
|
||||||
|
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
|
||||||
|
if db_tryon:
|
||||||
|
db_tryon.task_status = "ERROR"
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
@router.get("/", response_model=List[TryOnResponse])
|
||||||
|
async def get_tryons(
|
||||||
|
skip: int = Query(0, description="跳过的记录数量"),
|
||||||
|
limit: int = Query(100, description="返回的最大记录数量"),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取试穿记录列表,支持分页
|
||||||
|
|
||||||
|
- **skip**: 跳过的记录数量,用于分页
|
||||||
|
- **limit**: 返回的最大记录数量,用于分页
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tryons = db.query(TryOn).order_by(TryOn.created_at.desc()).offset(skip).limit(limit).all()
|
||||||
|
return tryons
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取试穿记录列表失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取试穿记录列表失败: {str(e)}")
|
||||||
|
|
||||||
|
@router.get("/{tryon_id}", response_model=TryOnResponse)
|
||||||
|
async def get_tryon(
|
||||||
|
tryon_id: int = Path(..., description="试穿记录ID"),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
根据ID获取试穿记录详情
|
||||||
|
|
||||||
|
- **tryon_id**: 试穿记录ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
|
||||||
|
if not tryon:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录")
|
||||||
|
return tryon
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取试穿记录详情失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"获取试穿记录详情失败: {str(e)}")
|
||||||
|
|
||||||
|
@router.put("/{tryon_id}", response_model=TryOnResponse)
|
||||||
|
async def update_tryon(
|
||||||
|
tryon_id: int = Path(..., description="试穿记录ID"),
|
||||||
|
tryon_data: TryOnUpdate = None,
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
更新试穿记录信息
|
||||||
|
|
||||||
|
- **tryon_id**: 试穿记录ID
|
||||||
|
- **task_status**: 任务状态(可选)
|
||||||
|
- **completion_url**: 生成图片URL(可选)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
|
||||||
|
if not db_tryon:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录")
|
||||||
|
|
||||||
|
# 更新提供的字段
|
||||||
|
if tryon_data.task_status is not None:
|
||||||
|
db_tryon.task_status = tryon_data.task_status
|
||||||
|
if tryon_data.completion_url is not None:
|
||||||
|
db_tryon.completion_url = tryon_data.completion_url
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_tryon)
|
||||||
|
return db_tryon
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新试穿记录失败: {str(e)}")
|
||||||
|
db.rollback()
|
||||||
|
raise HTTPException(status_code=500, detail=f"更新试穿记录失败: {str(e)}")
|
||||||
|
|
||||||
|
@router.post("/{tryon_id}/check", response_model=TryOnResponse)
|
||||||
|
async def check_tryon_status(
|
||||||
|
tryon_id: int = Path(..., description="试穿记录ID"),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
dashscope_service: DashScopeService = Depends(lambda: DashScopeService())
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
检查试穿任务状态
|
||||||
|
|
||||||
|
- **tryon_id**: 试穿记录ID
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
|
||||||
|
if not db_tryon:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录")
|
||||||
|
|
||||||
|
if not db_tryon.task_id:
|
||||||
|
raise HTTPException(status_code=400, detail=f"试穿记录未包含任务ID")
|
||||||
|
|
||||||
|
# 调用DashScopeService检查任务状态
|
||||||
|
try:
|
||||||
|
status_response = await dashscope_service.check_tryon_status(db_tryon.task_id)
|
||||||
|
|
||||||
|
# 更新数据库记录
|
||||||
|
db_tryon.task_status = status_response.get("output", {}).get("task_status")
|
||||||
|
|
||||||
|
# 如果任务完成,保存结果URL
|
||||||
|
if db_tryon.task_status == "SUCCEEDED":
|
||||||
|
db_tryon.completion_url = status_response.get("output", {}).get("url")
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_tryon)
|
||||||
|
|
||||||
|
logger.info(f"试穿任务状态更新: {db_tryon.task_status}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"调用DashScope API检查任务状态失败: {str(e)}")
|
||||||
|
|
||||||
|
return db_tryon
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"检查试穿任务状态失败: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"检查试穿任务状态失败: {str(e)}")
|
||||||
1
app/schemas/__init__.py
Normal file
1
app/schemas/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
||||||
36
app/schemas/dress.py
Normal file
36
app/schemas/dress.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from app.models.dress import GarmentType
|
||||||
|
|
||||||
|
class DressBase(BaseModel):
|
||||||
|
"""服装基础模型"""
|
||||||
|
name: str = Field(..., description="服装名称", example="夏季连衣裙")
|
||||||
|
image_url: Optional[str] = Field(None, description="服装图片URL", example="https://example.com/dress1.jpg")
|
||||||
|
garment_type: Optional[GarmentType] = Field(None, description="服装类型(上衣/下衣)", example="TOP_GARMENT")
|
||||||
|
description: Optional[str] = Field(None, description="服装描述", example="一款适合夏季穿着的轻薄连衣裙")
|
||||||
|
|
||||||
|
class DressCreate(DressBase):
|
||||||
|
"""创建服装的请求模型"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class DressUpdate(BaseModel):
|
||||||
|
"""更新服装的请求模型"""
|
||||||
|
name: Optional[str] = Field(None, description="服装名称", example="夏季连衣裙")
|
||||||
|
image_url: Optional[str] = Field(None, description="服装图片URL", example="https://example.com/dress1.jpg")
|
||||||
|
garment_type: Optional[GarmentType] = Field(None, description="服装类型(上衣/下衣)", example="TOP_GARMENT")
|
||||||
|
description: Optional[str] = Field(None, description="服装描述", example="一款适合夏季穿着的轻薄连衣裙")
|
||||||
|
|
||||||
|
class DressInDB(DressBase):
|
||||||
|
"""数据库中的服装模型"""
|
||||||
|
id: int
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
||||||
|
|
||||||
|
class DressResponse(DressInDB):
|
||||||
|
"""服装API响应模型"""
|
||||||
|
pass
|
||||||
54
app/schemas/tryon.py
Normal file
54
app/schemas/tryon.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
class TryOnBase(BaseModel):
|
||||||
|
"""试穿基础模型"""
|
||||||
|
top_garment_url: Optional[str] = Field(None, description="上衣图片URL", example="https://example.com/top.jpg")
|
||||||
|
bottom_garment_url: Optional[str] = Field(None, description="下衣图片URL", example="https://example.com/bottom.jpg")
|
||||||
|
person_image_url: str = Field(..., description="人物图片URL", example="https://example.com/person.jpg")
|
||||||
|
|
||||||
|
class TryOnCreate(TryOnBase):
|
||||||
|
"""创建试穿记录的请求模型"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TryOnUpdate(BaseModel):
|
||||||
|
"""更新试穿记录的请求模型"""
|
||||||
|
task_status: Optional[str] = Field(None, description="任务状态", example="SUCCEEDED")
|
||||||
|
completion_url: Optional[str] = Field(None, description="生成图片URL", example="https://example.com/result.jpg")
|
||||||
|
|
||||||
|
class TryOnInDB(TryOnBase):
|
||||||
|
"""数据库中的试穿记录模型"""
|
||||||
|
id: int
|
||||||
|
request_id: Optional[str] = None
|
||||||
|
task_id: Optional[str] = None
|
||||||
|
task_status: Optional[str] = None
|
||||||
|
completion_url: Optional[str] = None
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
orm_mode = True
|
||||||
|
|
||||||
|
class TryOnResponse(TryOnInDB):
|
||||||
|
"""试穿记录API响应模型"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class AiTryonRequest(BaseModel):
|
||||||
|
"""阿里百炼平台试穿请求模型"""
|
||||||
|
model: str = "aitryon"
|
||||||
|
input: TryOnBase
|
||||||
|
parameters: dict = {
|
||||||
|
"resolution": -1,
|
||||||
|
"restore_face": True
|
||||||
|
}
|
||||||
|
|
||||||
|
class TaskInfo(BaseModel):
|
||||||
|
"""任务信息模型"""
|
||||||
|
task_id: str
|
||||||
|
task_status: str
|
||||||
|
|
||||||
|
class AiTryonResponse(BaseModel):
|
||||||
|
"""阿里百炼平台试穿响应模型"""
|
||||||
|
output: TaskInfo
|
||||||
|
request_id: str
|
||||||
1
app/services/__init__.py
Normal file
1
app/services/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
||||||
230
app/services/dashscope_service.py
Normal file
230
app/services/dashscope_service.py
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import dashscope
|
||||||
|
from dashscope import Generation
|
||||||
|
# 修改导入语句,dashscope的API响应可能改变了结构
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
import asyncio
|
||||||
|
import httpx
|
||||||
|
from app.utils.config import get_settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class DashScopeService:
|
||||||
|
"""DashScope服务类,提供对DashScope API的调用封装"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
settings = get_settings()
|
||||||
|
self.api_key = settings.dashscope_api_key
|
||||||
|
# 配置DashScope
|
||||||
|
dashscope.api_key = self.api_key
|
||||||
|
# 配置API URL
|
||||||
|
self.image_synthesis_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis"
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
model: str = "qwen-max",
|
||||||
|
max_tokens: int = 2048,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
stream: bool = False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
调用DashScope的大模型API进行对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 对话历史记录
|
||||||
|
model: 模型名称
|
||||||
|
max_tokens: 最大生成token数
|
||||||
|
temperature: 温度参数,控制随机性
|
||||||
|
stream: 是否流式输出
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: DashScope的API响应
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 为了不阻塞FastAPI的异步性能,我们使用run_in_executor运行同步API
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
response = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: Generation.call(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
result_format='message',
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"DashScope API请求失败,状态码:{response.status_code}, 错误信息:{response.message}")
|
||||||
|
raise Exception(f"API调用失败: {response.message}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DashScope聊天API调用出错: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def generate_image(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: Optional[str] = None,
|
||||||
|
model: str = "stable-diffusion-xl",
|
||||||
|
n: int = 1,
|
||||||
|
size: str = "1024*1024"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
调用DashScope的图像生成API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 生成图像的文本描述
|
||||||
|
negative_prompt: 负面提示词
|
||||||
|
model: 模型名称
|
||||||
|
n: 生成图像数量
|
||||||
|
size: 图像尺寸
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: DashScope的API响应
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建请求参数
|
||||||
|
params = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"n": n,
|
||||||
|
"size": size,
|
||||||
|
}
|
||||||
|
|
||||||
|
if negative_prompt:
|
||||||
|
params["negative_prompt"] = negative_prompt
|
||||||
|
|
||||||
|
# 异步调用图像生成API
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
response = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: dashscope.ImageSynthesis.call(**params)
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"DashScope 图像生成API请求失败,状态码:{response.status_code}, 错误信息:{response.message}")
|
||||||
|
raise Exception(f"图像生成API调用失败: {response.message}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DashScope图像生成API调用出错: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def generate_tryon(
|
||||||
|
self,
|
||||||
|
person_image_url: str,
|
||||||
|
top_garment_url: Optional[str] = None,
|
||||||
|
bottom_garment_url: Optional[str] = None,
|
||||||
|
resolution: int = -1,
|
||||||
|
restore_face: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
调用阿里百炼平台的试衣服务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
person_image_url: 人物图片URL
|
||||||
|
top_garment_url: 上衣图片URL
|
||||||
|
bottom_garment_url: 下衣图片URL
|
||||||
|
resolution: 分辨率,-1表示自动
|
||||||
|
restore_face: 是否修复面部
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 包含任务ID和请求ID的响应
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证参数
|
||||||
|
if not person_image_url:
|
||||||
|
raise ValueError("人物图片URL不能为空")
|
||||||
|
|
||||||
|
if not top_garment_url and not bottom_garment_url:
|
||||||
|
raise ValueError("上衣和下衣图片至少需要提供一个")
|
||||||
|
|
||||||
|
# 构建请求数据
|
||||||
|
request_data = {
|
||||||
|
"model": "aitryon",
|
||||||
|
"input": {
|
||||||
|
"person_image_url": person_image_url
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"resolution": resolution,
|
||||||
|
"restore_face": restore_face
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加可选字段
|
||||||
|
if top_garment_url:
|
||||||
|
request_data["input"]["top_garment_url"] = top_garment_url
|
||||||
|
if bottom_garment_url:
|
||||||
|
request_data["input"]["bottom_garment_url"] = bottom_garment_url
|
||||||
|
|
||||||
|
# 构建请求头
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
self.image_synthesis_url,
|
||||||
|
json=request_data,
|
||||||
|
headers=headers,
|
||||||
|
timeout=30.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
response_data = response.json()
|
||||||
|
logger.info(f"试穿请求发送成功,任务ID: {response_data.get('output', {}).get('task_id')}")
|
||||||
|
return response_data
|
||||||
|
else:
|
||||||
|
error_msg = f"试穿请求失败: {response.status_code} - {response.text}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise Exception(error_msg)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DashScope试穿API调用出错: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def check_tryon_status(self, task_id: str):
|
||||||
|
"""
|
||||||
|
检查试穿任务状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 任务状态信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建请求头
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求URL
|
||||||
|
status_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
status_url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=30.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
response_data = response.json()
|
||||||
|
logger.info(f"试穿任务状态查询成功: {response_data.get('output', {}).get('task_status')}")
|
||||||
|
return response_data
|
||||||
|
else:
|
||||||
|
error_msg = f"试穿任务状态查询失败: {response.status_code} - {response.text}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise Exception(error_msg)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询试穿任务状态出错: {str(e)}")
|
||||||
|
raise e
|
||||||
259
app/services/qcloud_service.py
Normal file
259
app/services/qcloud_service.py
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
from typing import Optional, List, Dict, Any, BinaryIO, Union
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
from qcloud_cos import CosConfig, CosS3Client
|
||||||
|
from qcloud_cos.cos_exception import CosServiceError, CosClientError
|
||||||
|
import sts.sts
|
||||||
|
|
||||||
|
from app.utils.config import get_settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class QCloudCOSService:
|
||||||
|
"""腾讯云对象存储(COS)服务类"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化腾讯云COS服务"""
|
||||||
|
settings = get_settings()
|
||||||
|
self.secret_id = settings.qcloud_secret_id
|
||||||
|
self.secret_key = settings.qcloud_secret_key
|
||||||
|
self.region = settings.qcloud_cos_region
|
||||||
|
self.bucket = settings.qcloud_cos_bucket
|
||||||
|
self.domain = settings.qcloud_cos_domain
|
||||||
|
|
||||||
|
# 创建COS客户端
|
||||||
|
config = CosConfig(
|
||||||
|
Region=self.region,
|
||||||
|
SecretId=self.secret_id,
|
||||||
|
SecretKey=self.secret_key
|
||||||
|
)
|
||||||
|
self.client = CosS3Client(config)
|
||||||
|
|
||||||
|
async def upload_file(
|
||||||
|
self,
|
||||||
|
file_content: Union[bytes, BinaryIO],
|
||||||
|
file_name: Optional[str] = None,
|
||||||
|
directory: str = "uploads",
|
||||||
|
content_type: Optional[str] = None
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
上传文件到腾讯云COS
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_content: 文件内容(bytes或文件对象)
|
||||||
|
file_name: 文件名,如果不提供则生成随机文件名
|
||||||
|
directory: 存储目录
|
||||||
|
content_type: 文件MIME类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, str]: 包含文件URL等信息的字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 生成文件名(如果未提供)
|
||||||
|
if not file_name:
|
||||||
|
# 使用UUID生成随机文件名
|
||||||
|
random_filename = str(uuid.uuid4())
|
||||||
|
if isinstance(file_content, bytes):
|
||||||
|
# 对于字节流,我们无法确定文件扩展名,默认用.bin
|
||||||
|
file_name = f"{random_filename}.bin"
|
||||||
|
else:
|
||||||
|
# 假设file_content是文件对象,尝试从其name属性获取扩展名
|
||||||
|
if hasattr(file_content, 'name'):
|
||||||
|
original_name = os.path.basename(file_content.name)
|
||||||
|
ext = os.path.splitext(original_name)[1]
|
||||||
|
file_name = f"{random_filename}{ext}"
|
||||||
|
else:
|
||||||
|
file_name = f"{random_filename}.bin"
|
||||||
|
|
||||||
|
# 构建对象键(文件在COS中的完整路径)
|
||||||
|
key = f"{directory}/{file_name}"
|
||||||
|
|
||||||
|
# 上传文件
|
||||||
|
response = self.client.put_object(
|
||||||
|
Bucket=self.bucket,
|
||||||
|
Body=file_content,
|
||||||
|
Key=key,
|
||||||
|
ContentType=content_type
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建文件URL
|
||||||
|
file_url = urljoin(self.domain, key)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"url": file_url,
|
||||||
|
"key": key,
|
||||||
|
"file_name": file_name,
|
||||||
|
"content_type": content_type,
|
||||||
|
"etag": response["ETag"] if "ETag" in response else None
|
||||||
|
}
|
||||||
|
|
||||||
|
except (CosServiceError, CosClientError) as e:
|
||||||
|
logger.error(f"腾讯云COS上传文件失败: {str(e)}")
|
||||||
|
raise Exception(f"文件上传失败: {str(e)}")
|
||||||
|
|
||||||
|
async def delete_file(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
从腾讯云COS删除文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 文件的对象键(COS路径)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 删除是否成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.client.delete_object(
|
||||||
|
Bucket=self.bucket,
|
||||||
|
Key=key
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except (CosServiceError, CosClientError) as e:
|
||||||
|
logger.error(f"腾讯云COS删除文件失败: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_file_url(self, key: str, expires: int = 3600) -> str:
|
||||||
|
"""
|
||||||
|
获取COS文件的临时访问URL
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 文件的对象键(COS路径)
|
||||||
|
expires: URL的有效期(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 临时访问URL
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
url = self.client.get_presigned_url(
|
||||||
|
Method='GET',
|
||||||
|
Bucket=self.bucket,
|
||||||
|
Key=key,
|
||||||
|
Expired=expires
|
||||||
|
)
|
||||||
|
return url
|
||||||
|
except (CosServiceError, CosClientError) as e:
|
||||||
|
logger.error(f"获取腾讯云COS文件URL失败: {str(e)}")
|
||||||
|
raise Exception(f"获取文件URL失败: {str(e)}")
|
||||||
|
|
||||||
|
async def generate_cos_sts_token(
|
||||||
|
self,
|
||||||
|
allow_actions: Optional[List[str]] = None,
|
||||||
|
allow_prefix: str = "*",
|
||||||
|
duration_seconds: int = 1800
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
生成COS的临时安全凭证(STS),用于前端直传
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allow_actions: 允许的COS操作列表
|
||||||
|
allow_prefix: 允许操作的对象前缀
|
||||||
|
duration_seconds: 凭证有效期(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: STS凭证信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if allow_actions is None:
|
||||||
|
# 默认只允许上传操作
|
||||||
|
allow_actions = [
|
||||||
|
'name/cos:PutObject',
|
||||||
|
'name/cos:PostObject',
|
||||||
|
'name/cos:InitiateMultipartUpload',
|
||||||
|
'name/cos:ListMultipartUploads',
|
||||||
|
'name/cos:ListParts',
|
||||||
|
'name/cos:UploadPart',
|
||||||
|
'name/cos:CompleteMultipartUpload'
|
||||||
|
]
|
||||||
|
|
||||||
|
# 配置STS
|
||||||
|
config = {
|
||||||
|
'url': 'https://sts.tencentcloudapi.com/',
|
||||||
|
'domain': 'sts.tencentcloudapi.com',
|
||||||
|
'duration_seconds': duration_seconds,
|
||||||
|
'secret_id': self.secret_id,
|
||||||
|
'secret_key': self.secret_key,
|
||||||
|
'region': self.region,
|
||||||
|
'policy': {
|
||||||
|
'version': '2.0',
|
||||||
|
'statement': [
|
||||||
|
{
|
||||||
|
'action': allow_actions,
|
||||||
|
'effect': 'allow',
|
||||||
|
'resource': [
|
||||||
|
f'qcs::cos:{self.region}:uid/:{self.bucket}/{allow_prefix}'
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sts_client = sts.sts.Sts(config)
|
||||||
|
response = sts_client.get_credential()
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"生成腾讯云COS STS凭证失败: {str(e)}")
|
||||||
|
raise Exception(f"生成COS临时凭证失败: {str(e)}")
|
||||||
|
|
||||||
|
async def list_files(
|
||||||
|
self,
|
||||||
|
directory: str = "",
|
||||||
|
limit: int = 100,
|
||||||
|
marker: str = ""
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
列出COS中的文件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory: 目录前缀
|
||||||
|
limit: 返回的最大文件数
|
||||||
|
marker: 分页标记
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: 文件列表信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 处理目录前缀
|
||||||
|
prefix = directory
|
||||||
|
if prefix and not prefix.endswith('/'):
|
||||||
|
prefix += '/'
|
||||||
|
|
||||||
|
# 调用COS API列出对象
|
||||||
|
response = self.client.list_objects(
|
||||||
|
Bucket=self.bucket,
|
||||||
|
Prefix=prefix,
|
||||||
|
Marker=marker,
|
||||||
|
MaxKeys=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理响应
|
||||||
|
files = []
|
||||||
|
if 'Contents' in response:
|
||||||
|
for item in response['Contents']:
|
||||||
|
key = item.get('Key', '')
|
||||||
|
# 过滤出文件(忽略目录)
|
||||||
|
if not key.endswith('/'):
|
||||||
|
file_url = urljoin(self.domain, key)
|
||||||
|
files.append({
|
||||||
|
'key': key,
|
||||||
|
'url': file_url,
|
||||||
|
'size': item.get('Size', 0),
|
||||||
|
'last_modified': item.get('LastModified', ''),
|
||||||
|
'etag': item.get('ETag', '').strip('"')
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
'files': files,
|
||||||
|
'is_truncated': response.get('IsTruncated', False),
|
||||||
|
'next_marker': response.get('NextMarker', ''),
|
||||||
|
'common_prefixes': response.get('CommonPrefixes', [])
|
||||||
|
}
|
||||||
|
|
||||||
|
except (CosServiceError, CosClientError) as e:
|
||||||
|
logger.error(f"列出腾讯云COS文件失败: {str(e)}")
|
||||||
|
raise Exception(f"列出文件失败: {str(e)}")
|
||||||
1
app/utils/__init__.py
Normal file
1
app/utils/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
|
||||||
83
app/utils/config.py
Normal file
83
app/utils/config.py
Normal file
@ -0,0 +1,83 @@
|
|||||||
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Optional
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# 加载环境变量(如果直接从main导入,这里可能是冗余的,但为了安全起见)
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
class Settings(BaseModel):
|
||||||
|
"""应用程序配置类"""
|
||||||
|
# DashScope配置
|
||||||
|
dashscope_api_key: str = os.getenv("DASHSCOPE_API_KEY", "sk-caa199589f1c451aaac471fad2986e28")
|
||||||
|
|
||||||
|
# 服务器配置
|
||||||
|
host: str = os.getenv("HOST", "0.0.0.0")
|
||||||
|
port: int = int(os.getenv("PORT", "9001"))
|
||||||
|
debug: bool = os.getenv("DEBUG", "False").lower() in ["true", "1", "t", "yes"]
|
||||||
|
|
||||||
|
# 腾讯云配置
|
||||||
|
qcloud_secret_id: str = os.getenv("QCLOUD_SECRET_ID", "")
|
||||||
|
qcloud_secret_key: str = os.getenv("QCLOUD_SECRET_KEY", "")
|
||||||
|
qcloud_cos_region: str = os.getenv("QCLOUD_COS_REGION", "ap-guangzhou")
|
||||||
|
qcloud_cos_bucket: str = os.getenv("QCLOUD_COS_BUCKET", "")
|
||||||
|
qcloud_cos_domain: str = os.getenv("QCLOUD_COS_DOMAIN", "")
|
||||||
|
|
||||||
|
# 数据库配置
|
||||||
|
db_host: str = os.getenv("DB_HOST", "localhost")
|
||||||
|
db_port: int = int(os.getenv("DB_PORT", "3306"))
|
||||||
|
db_user: str = os.getenv("DB_USER", "root")
|
||||||
|
db_password: str = os.getenv("DB_PASSWORD", "password")
|
||||||
|
db_name: str = os.getenv("DB_NAME", "ai_dressing")
|
||||||
|
|
||||||
|
# 数据库URL
|
||||||
|
@property
|
||||||
|
def database_url(self) -> str:
|
||||||
|
"""
|
||||||
|
构建SQLAlchemy数据库连接URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 数据库连接URL
|
||||||
|
"""
|
||||||
|
return f"mysql+pymysql://{self.db_user}:{self.db_password}@{self.db_host}:{self.db_port}/{self.db_name}"
|
||||||
|
|
||||||
|
# 其他配置可以根据需要添加
|
||||||
|
log_level: str = os.getenv("LOG_LEVEL", "INFO")
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
"""
|
||||||
|
获取应用程序配置
|
||||||
|
使用lru_cache装饰器避免重复读取环境变量,提高性能
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Settings: 配置对象
|
||||||
|
"""
|
||||||
|
return Settings()
|
||||||
|
|
||||||
|
def validate_api_key():
|
||||||
|
"""
|
||||||
|
验证API密钥是否已配置
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果API密钥未配置
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
if not settings.dashscope_api_key:
|
||||||
|
raise ValueError("DASHSCOPE_API_KEY未设置,请在.env文件中配置")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def validate_qcloud_config():
|
||||||
|
"""
|
||||||
|
验证腾讯云配置是否已设置
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果腾讯云配置未正确设置
|
||||||
|
"""
|
||||||
|
settings = get_settings()
|
||||||
|
if not settings.qcloud_secret_id or not settings.qcloud_secret_key:
|
||||||
|
raise ValueError("腾讯云SecretId和SecretKey未设置,请在.env文件中配置")
|
||||||
|
if not settings.qcloud_cos_bucket:
|
||||||
|
raise ValueError("腾讯云COS存储桶名称未设置,请在.env文件中配置")
|
||||||
|
return True
|
||||||
60
create_migration.py
Normal file
60
create_migration.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import argparse
|
||||||
|
from alembic.config import Config
|
||||||
|
from alembic import command
|
||||||
|
|
||||||
|
def create_migration(message=None):
|
||||||
|
"""
|
||||||
|
创建数据库迁移
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 迁移消息,描述此次迁移的内容
|
||||||
|
"""
|
||||||
|
# 获取Alembic配置
|
||||||
|
alembic_cfg = Config("alembic.ini")
|
||||||
|
|
||||||
|
# 创建迁移
|
||||||
|
if message:
|
||||||
|
command.revision(alembic_cfg, message=message, autogenerate=True)
|
||||||
|
else:
|
||||||
|
command.revision(alembic_cfg, message="自动生成的迁移", autogenerate=True)
|
||||||
|
|
||||||
|
def upgrade_database():
|
||||||
|
"""更新数据库到最新版本"""
|
||||||
|
alembic_cfg = Config("alembic.ini")
|
||||||
|
command.upgrade(alembic_cfg, "head")
|
||||||
|
|
||||||
|
def downgrade_database(revision):
|
||||||
|
"""回滚数据库到指定版本"""
|
||||||
|
alembic_cfg = Config("alembic.ini")
|
||||||
|
command.downgrade(alembic_cfg, revision)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="数据库迁移工具")
|
||||||
|
subparsers = parser.add_subparsers(dest="command", help="命令")
|
||||||
|
|
||||||
|
# 创建迁移命令
|
||||||
|
create_parser = subparsers.add_parser("create", help="创建数据库迁移")
|
||||||
|
create_parser.add_argument("-m", "--message", help="迁移描述消息")
|
||||||
|
|
||||||
|
# 更新数据库命令
|
||||||
|
upgrade_parser = subparsers.add_parser("upgrade", help="更新数据库到最新版本")
|
||||||
|
|
||||||
|
# 回滚数据库命令
|
||||||
|
downgrade_parser = subparsers.add_parser("downgrade", help="回滚数据库到指定版本")
|
||||||
|
downgrade_parser.add_argument("revision", help="目标版本,例如:-1表示回滚一个版本")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.command == "create":
|
||||||
|
create_migration(args.message)
|
||||||
|
print("迁移文件已创建,请查看 alembic/versions 目录")
|
||||||
|
elif args.command == "upgrade":
|
||||||
|
upgrade_database()
|
||||||
|
print("数据库已更新到最新版本")
|
||||||
|
elif args.command == "downgrade":
|
||||||
|
downgrade_database(args.revision)
|
||||||
|
print(f"数据库已回滚到版本: {args.revision}")
|
||||||
|
else:
|
||||||
|
parser.print_help()
|
||||||
12
requirements.txt
Normal file
12
requirements.txt
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
fastapi==0.104.0
|
||||||
|
uvicorn==0.23.2
|
||||||
|
dashscope>=1.13.0
|
||||||
|
python-dotenv==1.0.0
|
||||||
|
pydantic==2.4.2
|
||||||
|
httpx==0.25.0
|
||||||
|
cos-python-sdk-v5==1.9.26
|
||||||
|
qcloud-python-sts==3.1.4
|
||||||
|
sqlalchemy==2.0.23
|
||||||
|
pymysql==1.1.0
|
||||||
|
cryptography==41.0.5
|
||||||
|
alembic==1.12.1
|
||||||
105
run.py
Normal file
105
run.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
import uvicorn
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from app.utils.config import get_settings, validate_api_key, validate_qcloud_config
|
||||||
|
from app.database import engine
|
||||||
|
|
||||||
|
def check_dependencies():
|
||||||
|
"""检查依赖项是否正确安装"""
|
||||||
|
try:
|
||||||
|
import dashscope
|
||||||
|
print(f"已加载 DashScope 版本: {getattr(dashscope, '__version__', '未知')}")
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"错误: 无法导入 dashscope 模块: {str(e)}")
|
||||||
|
print("请确保已正确安装 dashscope: pip install -U dashscope")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from qcloud_cos import CosConfig
|
||||||
|
print("已加载腾讯云 COS SDK")
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"警告: 无法导入腾讯云 COS SDK: {str(e)}")
|
||||||
|
print("腾讯云 COS 相关功能可能无法使用")
|
||||||
|
print("请确保已正确安装: pip install -U cos-python-sdk-v5 qcloud-python-sts")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import sqlalchemy
|
||||||
|
print(f"已加载 SQLAlchemy 版本: {sqlalchemy.__version__}")
|
||||||
|
import pymysql
|
||||||
|
print(f"已加载 PyMySQL 版本: {pymysql.__version__}")
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"错误: 无法导入数据库模块: {str(e)}")
|
||||||
|
print("请确保已正确安装: pip install -U sqlalchemy pymysql")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def check_database_connection():
|
||||||
|
"""检查数据库连接"""
|
||||||
|
try:
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# 尝试连接数据库
|
||||||
|
with engine.connect() as connection:
|
||||||
|
print(f"成功连接到数据库: {settings.db_name}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"警告: 无法连接到数据库: {str(e)}")
|
||||||
|
print("请确保MySQL服务已启动,并且数据库配置正确")
|
||||||
|
print("您可以使用以下命令创建数据库:")
|
||||||
|
settings = get_settings()
|
||||||
|
print(f"CREATE DATABASE {settings.db_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;")
|
||||||
|
print(f"GRANT ALL PRIVILEGES ON {settings.db_name}.* TO '{settings.db_user}'@'%' IDENTIFIED BY '{settings.db_password}';")
|
||||||
|
print("FLUSH PRIVILEGES;")
|
||||||
|
|
||||||
|
# 询问是否继续
|
||||||
|
if input("是否继续启动应用程序?(y/n): ").lower() not in ["y", "yes"]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
# 检查依赖项
|
||||||
|
if not check_dependencies():
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 验证API密钥是否配置
|
||||||
|
validate_api_key()
|
||||||
|
|
||||||
|
# 验证腾讯云配置
|
||||||
|
try:
|
||||||
|
validate_qcloud_config()
|
||||||
|
print("腾讯云COS配置已验证")
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"警告: {str(e)}")
|
||||||
|
print("腾讯云COS相关功能可能无法正常使用,但应用程序将继续启动")
|
||||||
|
|
||||||
|
# 检查数据库连接
|
||||||
|
if not check_database_connection():
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# 获取配置
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
# 配置日志级别
|
||||||
|
log_level = getattr(logging, settings.log_level.upper(), logging.INFO)
|
||||||
|
logging.basicConfig(level=log_level)
|
||||||
|
|
||||||
|
# 启动服务器
|
||||||
|
print(f"启动服务,访问地址: http://{settings.host}:{settings.port}")
|
||||||
|
uvicorn.run(
|
||||||
|
"app.main:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
reload=settings.debug,
|
||||||
|
log_level=settings.log_level.lower()
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"启动失败: {str(e)}")
|
||||||
|
if hasattr(e, "__traceback__"):
|
||||||
|
import traceback
|
||||||
|
traceback.print_exception(type(e), e, e.__traceback__)
|
||||||
|
exit(1)
|
||||||
Loading…
Reference in New Issue
Block a user