From 7bd63eefc53c7115373c926c8c71c17457490379 Mon Sep 17 00:00:00 2001 From: aaron <> Date: Fri, 21 Mar 2025 17:06:54 +0800 Subject: [PATCH] first commit --- .gitignore | 42 +++ README.md | 259 +++++++++++++++++ alembic.ini | 99 +++++++ alembic/env.py | 87 ++++++ alembic/script.py.mako | 24 ++ .../versions/f166e37d7b53_添加试穿记录表.py | 42 +++ app/__init__.py | 1 + app/database/__init__.py | 28 ++ app/main.py | 68 +++++ app/models/__init__.py | 1 + app/models/dress.py | 27 ++ app/models/tryon.py | 23 ++ app/routers/__init__.py | 1 + app/routers/dashscope_router.py | 262 ++++++++++++++++++ app/routers/dress_router.py | 162 +++++++++++ app/routers/qcloud_router.py | 155 +++++++++++ app/routers/tryon_router.py | 225 +++++++++++++++ app/schemas/__init__.py | 1 + app/schemas/dress.py | 36 +++ app/schemas/tryon.py | 54 ++++ app/services/__init__.py | 1 + app/services/dashscope_service.py | 230 +++++++++++++++ app/services/qcloud_service.py | 259 +++++++++++++++++ app/utils/__init__.py | 1 + app/utils/config.py | 83 ++++++ create_migration.py | 60 ++++ requirements.txt | 12 + run.py | 105 +++++++ 28 files changed, 2348 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 alembic.ini create mode 100644 alembic/env.py create mode 100644 alembic/script.py.mako create mode 100644 alembic/versions/f166e37d7b53_添加试穿记录表.py create mode 100644 app/__init__.py create mode 100644 app/database/__init__.py create mode 100644 app/main.py create mode 100644 app/models/__init__.py create mode 100644 app/models/dress.py create mode 100644 app/models/tryon.py create mode 100644 app/routers/__init__.py create mode 100644 app/routers/dashscope_router.py create mode 100644 app/routers/dress_router.py create mode 100644 app/routers/qcloud_router.py create mode 100644 app/routers/tryon_router.py create mode 100644 app/schemas/__init__.py create mode 100644 app/schemas/dress.py create mode 100644 app/schemas/tryon.py create mode 100644 app/services/__init__.py create mode 100644 app/services/dashscope_service.py create mode 100644 app/services/qcloud_service.py create mode 100644 app/utils/__init__.py create mode 100644 app/utils/config.py create mode 100644 create_migration.py create mode 100644 requirements.txt create mode 100644 run.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ca721b8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,42 @@ +# 环境变量 +.env + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# 虚拟环境 +venv/ +env/ +ENV/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo + +# 日志 +*.log + +# 操作系统 +.DS_Store +Thumbs.db \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..65e60cc --- /dev/null +++ b/README.md @@ -0,0 +1,259 @@ +# AI-Dressing API + +基于 FastAPI 框架和阿里云 DashScope 接口的 AI 服务 API。 + +## 功能特点 + +- 集成阿里云 DashScope 的大模型 API,支持通义千问系列模型 +- 提供图像生成 API,支持 Stable Diffusion XL、万相等模型 +- 集成腾讯云 COS 对象存储服务,支持文件上传、下载等功能 +- 使用 MySQL 数据库存储服装数据 +- RESTful API 风格,支持异步处理 +- 易于扩展的模块化设计 + +## 环境要求 + +- Python 3.8+ +- 阿里云 DashScope API 密钥 +- 腾讯云 API 密钥(用于 COS 对象存储) +- MySQL 5.7+ 数据库 + +## 数据库设置 + +在使用应用程序之前,您需要创建 MySQL 数据库。登录到 MySQL 并执行以下命令: + +```sql +CREATE DATABASE ai_dressing CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; +CREATE USER 'ai_user'@'%' IDENTIFIED BY 'your_password'; +GRANT ALL PRIVILEGES ON ai_dressing.* TO 'ai_user'@'%'; +FLUSH PRIVILEGES; +``` + +## 快速开始 + +### 1. 安装依赖 + +```bash +pip install -r requirements.txt +``` + +### 2. 配置环境变量 + +复制 `.env.example` 文件为 `.env`,并填写您的 API 密钥: + +```bash +cp .env.example .env +``` + +然后编辑 `.env` 文件: + +``` +# DashScope API密钥 +DASHSCOPE_API_KEY=your_api_key_here + +# 腾讯云配置 +QCLOUD_SECRET_ID=your_secret_id_here +QCLOUD_SECRET_KEY=your_secret_key_here +QCLOUD_COS_REGION=ap-guangzhou +QCLOUD_COS_BUCKET=your-bucket-name +QCLOUD_COS_DOMAIN=https://your-bucket-name.cos.ap-guangzhou.myqcloud.com + +# 数据库配置 +DB_HOST=localhost +DB_PORT=3306 +DB_USER=ai_user +DB_PASSWORD=your_password +DB_NAME=ai_dressing +``` + +### 3. 数据库迁移 + +首次运行或模型变更后,需要创建或更新数据库表: + +```bash +# 创建迁移(仅在模型变更时需要) +python create_migration.py create -m "初始化数据库" + +# 应用迁移 +python create_migration.py upgrade +``` + +### 4. 启动服务 + +```bash +python run.py +``` + +服务将在 http://localhost:9001 启动,您可以访问 http://localhost:9001/docs 查看 API 文档。 + +## API 文档 + +启动服务后,访问以下地址查看自动生成的 API 文档: + +- Swagger UI: http://localhost:9001/docs +- ReDoc: http://localhost:9001/redoc + +## 主要 API 端点 + +### 大模型对话 + +``` +POST /api/dashscope/chat +``` + +### 图像生成 + +``` +POST /api/dashscope/generate-image +``` + +### 获取支持的模型列表 + +``` +GET /api/dashscope/models +``` + +### 文件上传到腾讯云 COS + +``` +POST /api/qcloud/upload +``` + +### 从腾讯云 COS 获取文件列表 + +``` +GET /api/qcloud/files +``` + +### 生成 COS 临时上传凭证 + +``` +POST /api/qcloud/sts-token +``` + +### 服装管理 API + +``` +# 创建服装 +POST /api/dresses/ + +# 获取服装列表 +GET /api/dresses/ + +# 获取单个服装 +GET /api/dresses/{dress_id} + +# 更新服装 +PUT /api/dresses/{dress_id} + +# 删除服装 +DELETE /api/dresses/{dress_id} +``` + +## 项目结构 + +``` +. +├── alembic/ # 数据库迁移相关文件 +│ ├── versions/ # 迁移版本文件 +│ ├── env.py # 迁移环境配置 +│ └── script.py.mako # 迁移脚本模板 +├── app/ # 应用主目录 +│ ├── __init__.py +│ ├── main.py # FastAPI 应用程序 +│ ├── database/ # 数据库相关 +│ │ └── __init__.py +│ ├── models/ # 数据库模型 +│ │ ├── __init__.py +│ │ └── dress.py +│ ├── routers/ # API 路由 +│ │ ├── __init__.py +│ │ ├── dashscope_router.py +│ │ ├── qcloud_router.py +│ │ └── dress_router.py +│ ├── schemas/ # Pydantic 模型 +│ │ ├── __init__.py +│ │ └── dress.py +│ ├── services/ # 业务逻辑服务 +│ │ ├── __init__.py +│ │ ├── dashscope_service.py +│ │ └── qcloud_service.py +│ └── utils/ # 工具类 +│ ├── __init__.py +│ └── config.py +├── .env.example # 环境变量示例 +├── alembic.ini # Alembic 配置 +├── create_migration.py # 数据库迁移工具 +├── requirements.txt # 项目依赖 +├── README.md # 项目文档 +└── run.py # 应用入口 +``` + +## 开发指南 + +### 添加新的路由 + +1. 在 `app/routers/` 目录下创建新的路由文件 +2. 在 `app/main.py` 中注册新的路由 + +### 添加新的数据库模型 + +1. 在 `app/models/` 目录下创建新的模型文件 +2. 在 `app/schemas/` 目录下创建对应的 Pydantic 模型 +3. 创建数据库迁移: `python create_migration.py create -m "添加新模型"` +4. 应用迁移: `python create_migration.py upgrade` + +### 添加新的服务 + +在 `app/services/` 目录下创建新的服务类。 + +## 腾讯云 COS 使用指南 + +### 前端直传文件 + +1. 首先从服务端获取临时上传凭证: +```javascript +const response = await fetch('/api/qcloud/sts-token', { + method: 'POST', + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + }, + body: 'allow_prefix=uploads/&duration_seconds=1800' +}); +const stsData = await response.json(); +``` + +2. 使用临时凭证进行文件上传: +```javascript +// 使用腾讯云COS SDK上传 +const cos = new COS({ + getAuthorization: function(options, callback) { + callback({ + TmpSecretId: stsData.credentials.tmpSecretId, + TmpSecretKey: stsData.credentials.tmpSecretKey, + SecurityToken: stsData.credentials.sessionToken, + ExpiredTime: stsData.expiration + }); + } +}); + +cos.putObject({ + Bucket: stsData.bucket, + Region: stsData.region, + Key: 'uploads/example.jpg', + Body: file, + onProgress: function(progressData) { + console.log(progressData); + } +}, function(err, data) { + if (err) { + console.error('上传失败', err); + } else { + console.log('上传成功', data); + } +}); +``` + +## 许可证 + +MIT \ No newline at end of file diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..8792f9f --- /dev/null +++ b/alembic.ini @@ -0,0 +1,99 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `python-dateutil` to the requirements. +# For example: timezone = UTC +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. Valid values are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # default: use os.pathsep + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..b0374b8 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,87 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config +from sqlalchemy import pool + +from alembic import context + +# 导入应用的数据库配置和模型 +from app.database import Base +from app.utils.config import get_settings +from app.models.dress import Dress # 导入所有模型 +from app.models.tryon import TryOn + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# 从应用程序配置中获取数据库URL +settings = get_settings() +config.set_main_option("sqlalchemy.url", settings.database_url) + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline(): + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online(): + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() \ No newline at end of file diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..a2b869c --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"} \ No newline at end of file diff --git a/alembic/versions/f166e37d7b53_添加试穿记录表.py b/alembic/versions/f166e37d7b53_添加试穿记录表.py new file mode 100644 index 0000000..db9f614 --- /dev/null +++ b/alembic/versions/f166e37d7b53_添加试穿记录表.py @@ -0,0 +1,42 @@ +"""添加试穿记录表 + +Revision ID: f166e37d7b53 +Revises: +Create Date: 2025-03-21 15:58:04.889147 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'f166e37d7b53' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tryons', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('top_garment_url', sa.String(length=1024), nullable=True, comment='上衣图片URL'), + sa.Column('bottom_garment_url', sa.String(length=1024), nullable=True, comment='下衣图片URL'), + sa.Column('person_image_url', sa.String(length=1024), nullable=True, comment='人物图片URL'), + sa.Column('request_id', sa.String(length=255), nullable=True, comment='请求ID'), + sa.Column('task_id', sa.String(length=255), nullable=True, comment='任务ID'), + sa.Column('task_status', sa.String(length=50), nullable=True, comment='任务状态'), + sa.Column('completion_url', sa.String(length=1024), nullable=True, comment='生成图片URL'), + sa.Column('created_at', sa.DateTime(), nullable=True, comment='创建时间'), + sa.Column('updated_at', sa.DateTime(), nullable=True, comment='更新时间'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_tryons_id'), 'tryons', ['id'], unique=False) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_tryons_id'), table_name='tryons') + op.drop_table('tryons') + # ### end Alembic commands ### \ No newline at end of file diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/app/__init__.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/app/database/__init__.py b/app/database/__init__.py new file mode 100644 index 0000000..d55ad07 --- /dev/null +++ b/app/database/__init__.py @@ -0,0 +1,28 @@ +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +from app.utils.config import get_settings + +settings = get_settings() + +# 创建数据库引擎 +engine = create_engine( + settings.database_url, + pool_pre_ping=True, # 自动检测断开的连接并重新连接 + pool_recycle=3600, # 每小时回收连接 + echo=settings.debug, # 在调试模式下打印SQL语句 +) + +# 创建会话工厂 +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# 创建Base类,所有模型都将继承这个类 +Base = declarative_base() + +# 获取数据库会话的依赖函数 +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() \ No newline at end of file diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..4014b2e --- /dev/null +++ b/app/main.py @@ -0,0 +1,68 @@ +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +import os +import logging +from dotenv import load_dotenv + +from app.routers import dashscope_router, qcloud_router, dress_router, tryon_router +from app.utils.config import get_settings +from app.database import Base, engine + +# 创建数据库表 +Base.metadata.create_all(bind=engine) + +# 加载环境变量 +load_dotenv() + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +app = FastAPI( + title="AI-Dressing API", + description="基于 DashScope 的 AI 服务 API", + version="0.1.0", +) + +# 添加 CORS 中间件 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 在生产环境中,应该指定确切的域名 + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +# 注册路由 +app.include_router(dashscope_router.router, prefix="/api/dashscope", tags=["DashScope"]) +app.include_router(qcloud_router.router, prefix="/api/qcloud", tags=["腾讯云"]) +app.include_router(dress_router.router, prefix="/api/dresses", tags=["服装"]) +app.include_router(tryon_router.router, prefix="/api/tryons", tags=["试穿"]) + +@app.get("/", tags=["健康检查"]) +async def health_check(): + """API 健康检查端点""" + return {"status": "正常", "message": "服务运行中"} + +@app.get("/info", tags=["服务信息"]) +async def get_info(): + """获取服务基本信息""" + settings = get_settings() + return { + "app_name": "AI-Dressing API", + "version": "0.1.0", + "debug_mode": settings.debug, + } + +if __name__ == "__main__": + import uvicorn + settings = get_settings() + uvicorn.run( + "app.main:app", + host=settings.host, + port=settings.port, + reload=settings.debug, + ) \ No newline at end of file diff --git a/app/models/__init__.py b/app/models/__init__.py new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/app/models/__init__.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/app/models/dress.py b/app/models/dress.py new file mode 100644 index 0000000..5f2242b --- /dev/null +++ b/app/models/dress.py @@ -0,0 +1,27 @@ +from sqlalchemy import Column, Integer, String, DateTime, Text, func, Enum +from sqlalchemy.ext.declarative import declarative_base +from app.database import Base +import datetime +import enum + +class GarmentType(enum.Enum): + """服装类型枚举""" + TOP_GARMENT = "TOP_GARMENT" + BOTTOM_GARMENT = "BOTTOM_GARMENT" + +class Dress(Base): + """服装模型类""" + __tablename__ = "dresses" + + id = Column(Integer, primary_key=True, index=True, autoincrement=True) + name = Column(String(255), nullable=False, comment="服装名称") + image_url = Column(String(1024), nullable=True, comment="服装图片URL") + garment_type = Column(Enum(GarmentType), nullable=True, comment="服装类型(上衣/下衣)") + + # 以下是一些可选的附加字段,根据需要可以去掉或添加 + description = Column(Text, nullable=True, comment="服装描述") + created_at = Column(DateTime, default=datetime.datetime.now(), comment="创建时间") + updated_at = Column(DateTime, default=datetime.datetime.now(), onupdate=datetime.datetime.now, comment="更新时间") + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/app/models/tryon.py b/app/models/tryon.py new file mode 100644 index 0000000..35f1d27 --- /dev/null +++ b/app/models/tryon.py @@ -0,0 +1,23 @@ +from sqlalchemy import Column, Integer, String, DateTime, Text, func +from app.database import Base +import datetime + +class TryOn(Base): + """试穿记录模型类""" + __tablename__ = "tryons" + + id = Column(Integer, primary_key=True, index=True, autoincrement=True) + top_garment_url = Column(String(1024), nullable=True, comment="上衣图片URL") + bottom_garment_url = Column(String(1024), nullable=True, comment="下衣图片URL") + person_image_url = Column(String(1024), nullable=True, comment="人物图片URL") + + request_id = Column(String(255), nullable=True, comment="请求ID") + task_id = Column(String(255), nullable=True, comment="任务ID") + task_status = Column(String(50), nullable=True, comment="任务状态") + completion_url = Column(String(1024), nullable=True, comment="生成图片URL") + + created_at = Column(DateTime, default=datetime.datetime.now(), comment="创建时间") + updated_at = Column(DateTime, default=datetime.datetime.now(), onupdate=datetime.datetime.now, comment="更新时间") + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/app/routers/__init__.py b/app/routers/__init__.py new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/app/routers/__init__.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/app/routers/dashscope_router.py b/app/routers/dashscope_router.py new file mode 100644 index 0000000..dd3c6d1 --- /dev/null +++ b/app/routers/dashscope_router.py @@ -0,0 +1,262 @@ +from fastapi import APIRouter, HTTPException, Depends +from pydantic import BaseModel +import logging +from typing import List, Optional, Dict, Any +import dashscope +from app.services.dashscope_service import DashScopeService +from app.utils.config import get_settings + +logger = logging.getLogger(__name__) +router = APIRouter() + +class ChatMessage(BaseModel): + role: str # 'user' 或 'assistant' + content: str + +class ChatRequest(BaseModel): + messages: List[ChatMessage] + model: Optional[str] = "qwen-max" # 默认使用通义千问MAX模型 + max_tokens: Optional[int] = 2048 + temperature: Optional[float] = 0.7 + stream: Optional[bool] = False + +class ChatResponse(BaseModel): + response: str + usage: Dict[str, Any] + request_id: str + +@router.post("/chat", response_model=ChatResponse) +async def chat_completion( + request: ChatRequest, + dashscope_service: DashScopeService = Depends(lambda: DashScopeService()) +): + """ + 调用DashScope的大模型进行对话 + """ + try: + messages = [{"role": msg.role, "content": msg.content} for msg in request.messages] + response = await dashscope_service.chat_completion( + messages=messages, + model=request.model, + max_tokens=request.max_tokens, + temperature=request.temperature, + stream=request.stream + ) + + # 确保我们能正确访问响应字段,根据dashscope的最新版本调整 + response_text = "" + if hasattr(response, 'output') and hasattr(response.output, 'text'): + response_text = response.output.text + elif hasattr(response, 'output') and isinstance(response.output, dict) and 'text' in response.output: + response_text = response.output['text'] + elif hasattr(response, 'choices') and len(response.choices) > 0: + response_text = response.choices[0].message.content + else: + logger.warning(f"无法解析DashScope响应: {response}") + # 尝试保守提取 + try: + if hasattr(response, 'output'): + if isinstance(response.output, dict): + response_text = str(response.output.get('text', '')) + else: + response_text = str(response.output) + else: + response_text = str(response) + except Exception as e: + logger.error(f"提取响应文本失败: {str(e)}") + response_text = "无法获取模型响应文本" + + # 同样确保我们能正确访问usage和request_id + usage = {} + if hasattr(response, 'usage'): + usage = response.usage if isinstance(response.usage, dict) else vars(response.usage) + + request_id = "" + if hasattr(response, 'request_id'): + request_id = response.request_id + + return { + "response": response_text, + "usage": usage, + "request_id": request_id + } + except Exception as e: + logger.error(f"DashScope API 调用错误: {str(e)}") + raise HTTPException(status_code=500, detail=f"模型调用失败: {str(e)}") + +class ImageGenerationRequest(BaseModel): + prompt: str + negative_prompt: Optional[str] = None + model: Optional[str] = "stable-diffusion-xl" + n: Optional[int] = 1 + size: Optional[str] = "1024*1024" + +class ImageGenerationResponse(BaseModel): + images: List[str] + request_id: str + +@router.post("/generate-image", response_model=ImageGenerationResponse) +async def generate_image( + request: ImageGenerationRequest, + dashscope_service: DashScopeService = Depends(lambda: DashScopeService()) +): + """ + 调用DashScope的图像生成API + """ + try: + response = await dashscope_service.generate_image( + prompt=request.prompt, + negative_prompt=request.negative_prompt, + model=request.model, + n=request.n, + size=request.size + ) + + # 确保我们能正确访问images和request_id + images = [] + if hasattr(response, 'output') and hasattr(response.output, 'images'): + images = response.output.images + elif hasattr(response, 'output') and isinstance(response.output, dict) and 'images' in response.output: + images = response.output['images'] + else: + logger.warning(f"无法解析DashScope图像生成响应: {response}") + try: + if hasattr(response, 'output'): + if isinstance(response.output, dict): + images = response.output.get('images', []) + else: + images = [] + else: + images = [] + except Exception as e: + logger.error(f"提取图像URL失败: {str(e)}") + images = [] + + request_id = "" + if hasattr(response, 'request_id'): + request_id = response.request_id + + return { + "images": images, + "request_id": request_id + } + except Exception as e: + logger.error(f"DashScope 图像生成API调用错误: {str(e)}") + raise HTTPException(status_code=500, detail=f"图像生成失败: {str(e)}") + +class TryOnRequest(BaseModel): + """试穿请求模型""" + person_image_url: str + top_garment_url: Optional[str] = None + bottom_garment_url: Optional[str] = None + resolution: Optional[int] = -1 + restore_face: Optional[bool] = True + +class TaskInfo(BaseModel): + """任务信息模型""" + task_id: str + task_status: str + +class TryOnResponse(BaseModel): + """试穿响应模型""" + output: TaskInfo + request_id: str + +@router.post("/try-on", response_model=TryOnResponse) +async def generate_tryon( + request: TryOnRequest, + dashscope_service: DashScopeService = Depends(lambda: DashScopeService()) +): + """ + 调用DashScope的人物试衣API + + 该API用于给人物图片试穿上衣和/或下衣,返回合成图片的任务ID + + - **person_image_url**: 人物图片URL(必填) + - **top_garment_url**: 上衣图片URL(可选) + - **bottom_garment_url**: 下衣图片URL(可选) + - **resolution**: 分辨率,-1表示自动(可选) + - **restore_face**: 是否修复面部(可选) + + 注意:top_garment_url和bottom_garment_url至少需要提供一个 + """ + try: + # 验证参数 + if not request.top_garment_url and not request.bottom_garment_url: + raise HTTPException(status_code=400, detail="上衣和下衣图片至少需要提供一个") + + response = await dashscope_service.generate_tryon( + person_image_url=request.person_image_url, + top_garment_url=request.top_garment_url, + bottom_garment_url=request.bottom_garment_url, + resolution=request.resolution, + restore_face=request.restore_face + ) + + # 构建响应 + return { + "output": { + "task_id": response.get("output", {}).get("task_id", ""), + "task_status": response.get("output", {}).get("task_status", "") + }, + "request_id": response.get("request_id", "") + } + except Exception as e: + logger.error(f"DashScope 试穿API调用错误: {str(e)}") + raise HTTPException(status_code=500, detail=f"试穿请求失败: {str(e)}") + +class TaskStatusResponse(BaseModel): + """任务状态响应模型""" + task_id: str + task_status: str + completion_url: Optional[str] = None + request_id: str + +@router.get("/try-on/{task_id}", response_model=TaskStatusResponse) +async def check_tryon_status( + task_id: str, + dashscope_service: DashScopeService = Depends(lambda: DashScopeService()) +): + """ + 查询试穿任务状态 + + - **task_id**: 任务ID + """ + try: + response = await dashscope_service.check_tryon_status(task_id) + + # 构建响应 + return { + "task_id": task_id, + "task_status": response.get("output", {}).get("task_status", ""), + "completion_url": response.get("output", {}).get("url", ""), + "request_id": response.get("request_id", "") + } + except Exception as e: + logger.error(f"查询试穿任务状态错误: {str(e)}") + raise HTTPException(status_code=500, detail=f"查询任务状态失败: {str(e)}") + +@router.get("/models") +async def list_available_models(): + """ + 列出DashScope上可用的模型 + """ + try: + models = { + "chat_models": [ + {"id": "qwen-max", "name": "通义千问MAX"}, + {"id": "qwen-plus", "name": "通义千问Plus"}, + {"id": "qwen-turbo", "name": "通义千问Turbo"} + ], + "image_models": [ + {"id": "stable-diffusion-xl", "name": "Stable Diffusion XL"}, + {"id": "wanx-v1", "name": "万相"} + ], + "tryon_models": [ + {"id": "aitryon", "name": "AI试穿"} + ] + } + return models + except Exception as e: + logger.error(f"获取模型列表错误: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") \ No newline at end of file diff --git a/app/routers/dress_router.py b/app/routers/dress_router.py new file mode 100644 index 0000000..4eea999 --- /dev/null +++ b/app/routers/dress_router.py @@ -0,0 +1,162 @@ +from fastapi import APIRouter, HTTPException, Depends, Query, Path +from sqlalchemy.orm import Session +from typing import List, Optional +import logging + +from app.database import get_db +from app.models.dress import Dress, GarmentType +from app.schemas.dress import DressCreate, DressUpdate, DressResponse + +logger = logging.getLogger(__name__) +router = APIRouter() + +@router.post("/", response_model=DressResponse, status_code=201) +async def create_dress( + dress: DressCreate, + db: Session = Depends(get_db) +): + """ + 创建一个新的服装记录 + + - **name**: 服装名称(必填) + - **image_url**: 服装图片URL(可选) + - **garment_type**: 服装类型(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)(可选) + - **description**: 服装描述(可选) + """ + try: + db_dress = Dress( + name=dress.name, + image_url=dress.image_url, + garment_type=dress.garment_type, + description=dress.description + ) + db.add(db_dress) + db.commit() + db.refresh(db_dress) + return db_dress + except Exception as e: + logger.error(f"创建服装记录失败: {str(e)}") + db.rollback() + raise HTTPException(status_code=500, detail=f"创建服装记录失败: {str(e)}") + +@router.get("/", response_model=List[DressResponse]) +async def get_dresses( + skip: int = Query(0, description="跳过的记录数量"), + limit: int = Query(100, description="返回的最大记录数量"), + name: Optional[str] = Query(None, description="按名称过滤"), + garment_type: Optional[str] = Query(None, description="按服装类型过滤(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)"), + db: Session = Depends(get_db) +): + """ + 获取服装列表,支持分页和过滤 + + - **skip**: 跳过的记录数量,用于分页 + - **limit**: 返回的最大记录数量,用于分页 + - **name**: 按名称过滤(可选) + - **garment_type**: 按服装类型过滤(可选) + """ + try: + query = db.query(Dress) + + if name: + query = query.filter(Dress.name.ilike(f"%{name}%")) + + if garment_type: + try: + garment_type_enum = GarmentType[garment_type] + query = query.filter(Dress.garment_type == garment_type_enum) + except KeyError: + logger.warning(f"无效的服装类型: {garment_type}") + # 继续查询,但不应用无效的过滤条件 + + dresses = query.offset(skip).limit(limit).all() + return dresses + except Exception as e: + logger.error(f"获取服装列表失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取服装列表失败: {str(e)}") + +@router.get("/{dress_id}", response_model=DressResponse) +async def get_dress( + dress_id: int = Path(..., description="服装ID"), + db: Session = Depends(get_db) +): + """ + 根据ID获取服装详情 + + - **dress_id**: 服装ID + """ + try: + dress = db.query(Dress).filter(Dress.id == dress_id).first() + if not dress: + raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装") + return dress + except HTTPException: + raise + except Exception as e: + logger.error(f"获取服装详情失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取服装详情失败: {str(e)}") + +@router.put("/{dress_id}", response_model=DressResponse) +async def update_dress( + dress_id: int = Path(..., description="服装ID"), + dress: DressUpdate = None, + db: Session = Depends(get_db) +): + """ + 更新服装信息 + + - **dress_id**: 服装ID + - **name**: 服装名称(可选) + - **image_url**: 服装图片URL(可选) + - **garment_type**: 服装类型(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)(可选) + - **description**: 服装描述(可选) + """ + try: + db_dress = db.query(Dress).filter(Dress.id == dress_id).first() + if not db_dress: + raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装") + + # 更新提供的字段 + if dress.name is not None: + db_dress.name = dress.name + if dress.image_url is not None: + db_dress.image_url = dress.image_url + if dress.garment_type is not None: + db_dress.garment_type = dress.garment_type + if dress.description is not None: + db_dress.description = dress.description + + db.commit() + db.refresh(db_dress) + return db_dress + except HTTPException: + raise + except Exception as e: + logger.error(f"更新服装信息失败: {str(e)}") + db.rollback() + raise HTTPException(status_code=500, detail=f"更新服装信息失败: {str(e)}") + +@router.delete("/{dress_id}", status_code=204) +async def delete_dress( + dress_id: int = Path(..., description="服装ID"), + db: Session = Depends(get_db) +): + """ + 删除服装 + + - **dress_id**: 服装ID + """ + try: + db_dress = db.query(Dress).filter(Dress.id == dress_id).first() + if not db_dress: + raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装") + + db.delete(db_dress) + db.commit() + return None + except HTTPException: + raise + except Exception as e: + logger.error(f"删除服装失败: {str(e)}") + db.rollback() + raise HTTPException(status_code=500, detail=f"删除服装失败: {str(e)}") \ No newline at end of file diff --git a/app/routers/qcloud_router.py b/app/routers/qcloud_router.py new file mode 100644 index 0000000..2e91868 --- /dev/null +++ b/app/routers/qcloud_router.py @@ -0,0 +1,155 @@ +from fastapi import APIRouter, HTTPException, Depends, UploadFile, File, Form, Query +from fastapi.responses import JSONResponse +from pydantic import BaseModel +import logging +from typing import List, Optional, Dict, Any + +from app.services.qcloud_service import QCloudCOSService + +logger = logging.getLogger(__name__) +router = APIRouter() + +class STSTokenResponse(BaseModel): + """STS临时凭证响应""" + credentials: Dict[str, str] + expiration: str + request_id: str + bucket: str + region: str + +class FileInfo(BaseModel): + """文件信息""" + key: str + url: str + size: int + last_modified: str + etag: str + +class ListFilesResponse(BaseModel): + """列出文件响应""" + files: List[FileInfo] + is_truncated: bool + next_marker: Optional[str] = None + common_prefixes: Optional[List[Dict[str, Any]]] = None + +@router.post("/upload", tags=["腾讯云COS"]) +async def upload_file( + file: UploadFile = File(...), + directory: str = Form("uploads"), + qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService()) +): + """ + 上传文件到腾讯云COS + + - **file**: 要上传的文件 + - **directory**: 存储目录(默认为"uploads") + """ + try: + # 读取文件内容 + file_content = await file.read() + + # 上传文件 + result = await qcloud_service.upload_file( + file_content=file_content, + file_name=file.filename, + directory=directory, + content_type=file.content_type + ) + + return result + except Exception as e: + logger.error(f"文件上传失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"文件上传失败: {str(e)}") + finally: + # 关闭文件 + await file.close() + +@router.delete("/files/{key:path}", tags=["腾讯云COS"]) +async def delete_file( + key: str, + qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService()) +): + """ + 从腾讯云COS删除文件 + + - **key**: 文件的对象键(COS路径) + """ + try: + success = await qcloud_service.delete_file(key=key) + if success: + return {"message": "文件删除成功", "key": key} + else: + raise HTTPException(status_code=404, detail=f"文件删除失败,文件可能不存在") + except Exception as e: + logger.error(f"文件删除失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"文件删除失败: {str(e)}") + +@router.get("/files/url/{key:path}", tags=["腾讯云COS"]) +async def get_file_url( + key: str, + expires: int = Query(3600, description="URL有效期(秒)"), + qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService()) +): + """ + 获取COS文件的临时访问URL + + - **key**: 文件的对象键(COS路径) + - **expires**: URL的有效期(秒),默认3600秒 + """ + try: + url = await qcloud_service.get_file_url(key=key, expires=expires) + return {"url": url, "key": key, "expires": expires} + except Exception as e: + logger.error(f"获取文件URL失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取文件URL失败: {str(e)}") + +@router.post("/sts-token", response_model=STSTokenResponse, tags=["腾讯云COS"]) +async def generate_sts_token( + allow_prefix: str = Form("*", description="允许操作的对象前缀"), + duration_seconds: int = Form(1800, description="凭证有效期(秒)"), + qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService()) +): + """ + 生成腾讯云COS临时访问凭证(STS),用于前端直传 + + - **allow_prefix**: 允许操作的对象前缀,默认为'*'表示所有对象 + - **duration_seconds**: 凭证有效期(秒),默认1800秒 + """ + try: + sts_result = await qcloud_service.generate_cos_sts_token( + allow_prefix=allow_prefix, + duration_seconds=duration_seconds + ) + # 添加桶和区域信息,方便前端使用 + sts_result['bucket'] = qcloud_service.bucket + sts_result['region'] = qcloud_service.region + + return sts_result + except Exception as e: + logger.error(f"生成STS临时凭证失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"生成STS临时凭证失败: {str(e)}") + +@router.get("/files", response_model=ListFilesResponse, tags=["腾讯云COS"]) +async def list_files( + directory: str = Query("", description="目录前缀"), + limit: int = Query(100, description="返回的最大文件数"), + marker: str = Query("", description="分页标记"), + qcloud_service: QCloudCOSService = Depends(lambda: QCloudCOSService()) +): + """ + 列出COS中的文件 + + - **directory**: 目录前缀 + - **limit**: 返回的最大文件数 + - **marker**: 分页标记 + """ + try: + result = await qcloud_service.list_files( + directory=directory, + limit=limit, + marker=marker + ) + return result + except Exception as e: + logger.error(f"列出文件失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"列出文件失败: {str(e)}") \ No newline at end of file diff --git a/app/routers/tryon_router.py b/app/routers/tryon_router.py new file mode 100644 index 0000000..b40e466 --- /dev/null +++ b/app/routers/tryon_router.py @@ -0,0 +1,225 @@ +from fastapi import APIRouter, HTTPException, Depends, Query, Path, BackgroundTasks +from sqlalchemy.orm import Session +from typing import List, Optional +import httpx +import logging +import os +from dotenv import load_dotenv + +from app.database import get_db +from app.models.tryon import TryOn +from app.schemas.tryon import ( + TryOnCreate, TryOnUpdate, TryOnResponse, + AiTryonRequest, AiTryonResponse, TaskInfo +) +from app.services.dashscope_service import DashScopeService + +# 加载环境变量 +load_dotenv() + +logger = logging.getLogger(__name__) +router = APIRouter() + +# 从环境变量获取API密钥 +DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY") +DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis" + +@router.post("/", response_model=TryOnResponse, status_code=201) +async def create_tryon( + tryon_data: TryOnCreate, + background_tasks: BackgroundTasks, + db: Session = Depends(get_db) +): + """ + 创建一个试穿记录并发送到阿里百炼平台 + + - **top_garment_url**: 上衣图片URL(可选) + - **bottom_garment_url**: 下衣图片URL(可选) + - **person_image_url**: 人物图片URL(必填) + + 注意:上衣和下衣至少需要提供一个 + """ + if not tryon_data.top_garment_url and not tryon_data.bottom_garment_url: + raise HTTPException(status_code=400, detail="上衣和下衣图片至少需要提供一个") + + try: + # 创建试穿记录 + db_tryon = TryOn( + top_garment_url=tryon_data.top_garment_url, + bottom_garment_url=tryon_data.bottom_garment_url, + person_image_url=tryon_data.person_image_url + ) + db.add(db_tryon) + db.commit() + db.refresh(db_tryon) + + # 在后台发送请求到阿里百炼平台 + background_tasks.add_task( + send_tryon_request, + db=db, + tryon_id=db_tryon.id, + top_garment_url=tryon_data.top_garment_url, + bottom_garment_url=tryon_data.bottom_garment_url, + person_image_url=tryon_data.person_image_url + ) + + return db_tryon + except Exception as e: + logger.error(f"创建试穿记录失败: {str(e)}") + db.rollback() + raise HTTPException(status_code=500, detail=f"创建试穿记录失败: {str(e)}") + +async def send_tryon_request( + db: Session, + tryon_id: int, + top_garment_url: Optional[str], + bottom_garment_url: Optional[str], + person_image_url: str +): + """发送试穿请求到阿里百炼平台""" + try: + # 创建DashScopeService实例 + dashscope_service = DashScopeService() + + # 调用服务发送试穿请求 + response = await dashscope_service.generate_tryon( + person_image_url=person_image_url, + top_garment_url=top_garment_url, + bottom_garment_url=bottom_garment_url + ) + + # 更新数据库记录 + db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first() + if db_tryon: + db_tryon.request_id = response.get("request_id") + db_tryon.task_id = response.get("output", {}).get("task_id") + db_tryon.task_status = response.get("output", {}).get("task_status") + db.commit() + + logger.info(f"试穿请求发送成功,任务ID: {db_tryon.task_id}") + except Exception as e: + logger.error(f"发送试穿请求异常: {str(e)}") + + # 更新数据库记录状态 + db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first() + if db_tryon: + db_tryon.task_status = "ERROR" + db.commit() + +@router.get("/", response_model=List[TryOnResponse]) +async def get_tryons( + skip: int = Query(0, description="跳过的记录数量"), + limit: int = Query(100, description="返回的最大记录数量"), + db: Session = Depends(get_db) +): + """ + 获取试穿记录列表,支持分页 + + - **skip**: 跳过的记录数量,用于分页 + - **limit**: 返回的最大记录数量,用于分页 + """ + try: + tryons = db.query(TryOn).order_by(TryOn.created_at.desc()).offset(skip).limit(limit).all() + return tryons + except Exception as e: + logger.error(f"获取试穿记录列表失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取试穿记录列表失败: {str(e)}") + +@router.get("/{tryon_id}", response_model=TryOnResponse) +async def get_tryon( + tryon_id: int = Path(..., description="试穿记录ID"), + db: Session = Depends(get_db) +): + """ + 根据ID获取试穿记录详情 + + - **tryon_id**: 试穿记录ID + """ + try: + tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first() + if not tryon: + raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录") + return tryon + except HTTPException: + raise + except Exception as e: + logger.error(f"获取试穿记录详情失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"获取试穿记录详情失败: {str(e)}") + +@router.put("/{tryon_id}", response_model=TryOnResponse) +async def update_tryon( + tryon_id: int = Path(..., description="试穿记录ID"), + tryon_data: TryOnUpdate = None, + db: Session = Depends(get_db) +): + """ + 更新试穿记录信息 + + - **tryon_id**: 试穿记录ID + - **task_status**: 任务状态(可选) + - **completion_url**: 生成图片URL(可选) + """ + try: + db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first() + if not db_tryon: + raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录") + + # 更新提供的字段 + if tryon_data.task_status is not None: + db_tryon.task_status = tryon_data.task_status + if tryon_data.completion_url is not None: + db_tryon.completion_url = tryon_data.completion_url + + db.commit() + db.refresh(db_tryon) + return db_tryon + except HTTPException: + raise + except Exception as e: + logger.error(f"更新试穿记录失败: {str(e)}") + db.rollback() + raise HTTPException(status_code=500, detail=f"更新试穿记录失败: {str(e)}") + +@router.post("/{tryon_id}/check", response_model=TryOnResponse) +async def check_tryon_status( + tryon_id: int = Path(..., description="试穿记录ID"), + db: Session = Depends(get_db), + dashscope_service: DashScopeService = Depends(lambda: DashScopeService()) +): + """ + 检查试穿任务状态 + + - **tryon_id**: 试穿记录ID + """ + try: + db_tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first() + if not db_tryon: + raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录") + + if not db_tryon.task_id: + raise HTTPException(status_code=400, detail=f"试穿记录未包含任务ID") + + # 调用DashScopeService检查任务状态 + try: + status_response = await dashscope_service.check_tryon_status(db_tryon.task_id) + + # 更新数据库记录 + db_tryon.task_status = status_response.get("output", {}).get("task_status") + + # 如果任务完成,保存结果URL + if db_tryon.task_status == "SUCCEEDED": + db_tryon.completion_url = status_response.get("output", {}).get("url") + + db.commit() + db.refresh(db_tryon) + + logger.info(f"试穿任务状态更新: {db_tryon.task_status}") + except Exception as e: + logger.error(f"调用DashScope API检查任务状态失败: {str(e)}") + + return db_tryon + except HTTPException: + raise + except Exception as e: + logger.error(f"检查试穿任务状态失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"检查试穿任务状态失败: {str(e)}") \ No newline at end of file diff --git a/app/schemas/__init__.py b/app/schemas/__init__.py new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/app/schemas/__init__.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/app/schemas/dress.py b/app/schemas/dress.py new file mode 100644 index 0000000..53b02b7 --- /dev/null +++ b/app/schemas/dress.py @@ -0,0 +1,36 @@ +from pydantic import BaseModel, Field, HttpUrl +from typing import Optional +from datetime import datetime +from enum import Enum +from app.models.dress import GarmentType + +class DressBase(BaseModel): + """服装基础模型""" + name: str = Field(..., description="服装名称", example="夏季连衣裙") + image_url: Optional[str] = Field(None, description="服装图片URL", example="https://example.com/dress1.jpg") + garment_type: Optional[GarmentType] = Field(None, description="服装类型(上衣/下衣)", example="TOP_GARMENT") + description: Optional[str] = Field(None, description="服装描述", example="一款适合夏季穿着的轻薄连衣裙") + +class DressCreate(DressBase): + """创建服装的请求模型""" + pass + +class DressUpdate(BaseModel): + """更新服装的请求模型""" + name: Optional[str] = Field(None, description="服装名称", example="夏季连衣裙") + image_url: Optional[str] = Field(None, description="服装图片URL", example="https://example.com/dress1.jpg") + garment_type: Optional[GarmentType] = Field(None, description="服装类型(上衣/下衣)", example="TOP_GARMENT") + description: Optional[str] = Field(None, description="服装描述", example="一款适合夏季穿着的轻薄连衣裙") + +class DressInDB(DressBase): + """数据库中的服装模型""" + id: int + created_at: datetime + updated_at: datetime + + class Config: + orm_mode = True + +class DressResponse(DressInDB): + """服装API响应模型""" + pass \ No newline at end of file diff --git a/app/schemas/tryon.py b/app/schemas/tryon.py new file mode 100644 index 0000000..f77e588 --- /dev/null +++ b/app/schemas/tryon.py @@ -0,0 +1,54 @@ +from pydantic import BaseModel, Field +from typing import Optional +from datetime import datetime + +class TryOnBase(BaseModel): + """试穿基础模型""" + top_garment_url: Optional[str] = Field(None, description="上衣图片URL", example="https://example.com/top.jpg") + bottom_garment_url: Optional[str] = Field(None, description="下衣图片URL", example="https://example.com/bottom.jpg") + person_image_url: str = Field(..., description="人物图片URL", example="https://example.com/person.jpg") + +class TryOnCreate(TryOnBase): + """创建试穿记录的请求模型""" + pass + +class TryOnUpdate(BaseModel): + """更新试穿记录的请求模型""" + task_status: Optional[str] = Field(None, description="任务状态", example="SUCCEEDED") + completion_url: Optional[str] = Field(None, description="生成图片URL", example="https://example.com/result.jpg") + +class TryOnInDB(TryOnBase): + """数据库中的试穿记录模型""" + id: int + request_id: Optional[str] = None + task_id: Optional[str] = None + task_status: Optional[str] = None + completion_url: Optional[str] = None + created_at: datetime + updated_at: datetime + + class Config: + orm_mode = True + +class TryOnResponse(TryOnInDB): + """试穿记录API响应模型""" + pass + +class AiTryonRequest(BaseModel): + """阿里百炼平台试穿请求模型""" + model: str = "aitryon" + input: TryOnBase + parameters: dict = { + "resolution": -1, + "restore_face": True + } + +class TaskInfo(BaseModel): + """任务信息模型""" + task_id: str + task_status: str + +class AiTryonResponse(BaseModel): + """阿里百炼平台试穿响应模型""" + output: TaskInfo + request_id: str \ No newline at end of file diff --git a/app/services/__init__.py b/app/services/__init__.py new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/app/services/__init__.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/app/services/dashscope_service.py b/app/services/dashscope_service.py new file mode 100644 index 0000000..07f7e71 --- /dev/null +++ b/app/services/dashscope_service.py @@ -0,0 +1,230 @@ +import os +import logging +import dashscope +from dashscope import Generation +# 修改导入语句,dashscope的API响应可能改变了结构 +from typing import List, Dict, Any, Optional +import asyncio +import httpx +from app.utils.config import get_settings + +logger = logging.getLogger(__name__) + +class DashScopeService: + """DashScope服务类,提供对DashScope API的调用封装""" + + def __init__(self): + settings = get_settings() + self.api_key = settings.dashscope_api_key + # 配置DashScope + dashscope.api_key = self.api_key + # 配置API URL + self.image_synthesis_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis" + + async def chat_completion( + self, + messages: List[Dict[str, str]], + model: str = "qwen-max", + max_tokens: int = 2048, + temperature: float = 0.7, + stream: bool = False + ): + """ + 调用DashScope的大模型API进行对话 + + Args: + messages: 对话历史记录 + model: 模型名称 + max_tokens: 最大生成token数 + temperature: 温度参数,控制随机性 + stream: 是否流式输出 + + Returns: + ApiResponse: DashScope的API响应 + """ + try: + # 为了不阻塞FastAPI的异步性能,我们使用run_in_executor运行同步API + loop = asyncio.get_event_loop() + response = await loop.run_in_executor( + None, + lambda: Generation.call( + model=model, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + result_format='message', + stream=stream, + ) + ) + + if response.status_code != 200: + logger.error(f"DashScope API请求失败,状态码:{response.status_code}, 错误信息:{response.message}") + raise Exception(f"API调用失败: {response.message}") + + return response + except Exception as e: + logger.error(f"DashScope聊天API调用出错: {str(e)}") + raise e + + async def generate_image( + self, + prompt: str, + negative_prompt: Optional[str] = None, + model: str = "stable-diffusion-xl", + n: int = 1, + size: str = "1024*1024" + ): + """ + 调用DashScope的图像生成API + + Args: + prompt: 生成图像的文本描述 + negative_prompt: 负面提示词 + model: 模型名称 + n: 生成图像数量 + size: 图像尺寸 + + Returns: + ApiResponse: DashScope的API响应 + """ + try: + # 构建请求参数 + params = { + "model": model, + "prompt": prompt, + "n": n, + "size": size, + } + + if negative_prompt: + params["negative_prompt"] = negative_prompt + + # 异步调用图像生成API + loop = asyncio.get_event_loop() + response = await loop.run_in_executor( + None, + lambda: dashscope.ImageSynthesis.call(**params) + ) + + if response.status_code != 200: + logger.error(f"DashScope 图像生成API请求失败,状态码:{response.status_code}, 错误信息:{response.message}") + raise Exception(f"图像生成API调用失败: {response.message}") + + return response + except Exception as e: + logger.error(f"DashScope图像生成API调用出错: {str(e)}") + raise e + + async def generate_tryon( + self, + person_image_url: str, + top_garment_url: Optional[str] = None, + bottom_garment_url: Optional[str] = None, + resolution: int = -1, + restore_face: bool = True + ): + """ + 调用阿里百炼平台的试衣服务 + + Args: + person_image_url: 人物图片URL + top_garment_url: 上衣图片URL + bottom_garment_url: 下衣图片URL + resolution: 分辨率,-1表示自动 + restore_face: 是否修复面部 + + Returns: + Dict: 包含任务ID和请求ID的响应 + """ + try: + # 验证参数 + if not person_image_url: + raise ValueError("人物图片URL不能为空") + + if not top_garment_url and not bottom_garment_url: + raise ValueError("上衣和下衣图片至少需要提供一个") + + # 构建请求数据 + request_data = { + "model": "aitryon", + "input": { + "person_image_url": person_image_url + }, + "parameters": { + "resolution": resolution, + "restore_face": restore_face + } + } + + # 添加可选字段 + if top_garment_url: + request_data["input"]["top_garment_url"] = top_garment_url + if bottom_garment_url: + request_data["input"]["bottom_garment_url"] = bottom_garment_url + + # 构建请求头 + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + # 发送请求 + async with httpx.AsyncClient() as client: + response = await client.post( + self.image_synthesis_url, + json=request_data, + headers=headers, + timeout=30.0 + ) + + if response.status_code == 200: + response_data = response.json() + logger.info(f"试穿请求发送成功,任务ID: {response_data.get('output', {}).get('task_id')}") + return response_data + else: + error_msg = f"试穿请求失败: {response.status_code} - {response.text}" + logger.error(error_msg) + raise Exception(error_msg) + except Exception as e: + logger.error(f"DashScope试穿API调用出错: {str(e)}") + raise e + + async def check_tryon_status(self, task_id: str): + """ + 检查试穿任务状态 + + Args: + task_id: 任务ID + + Returns: + Dict: 任务状态信息 + """ + try: + # 构建请求头 + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + # 构建请求URL + status_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}" + + # 发送请求 + async with httpx.AsyncClient() as client: + response = await client.get( + status_url, + headers=headers, + timeout=30.0 + ) + + if response.status_code == 200: + response_data = response.json() + logger.info(f"试穿任务状态查询成功: {response_data.get('output', {}).get('task_status')}") + return response_data + else: + error_msg = f"试穿任务状态查询失败: {response.status_code} - {response.text}" + logger.error(error_msg) + raise Exception(error_msg) + except Exception as e: + logger.error(f"查询试穿任务状态出错: {str(e)}") + raise e \ No newline at end of file diff --git a/app/services/qcloud_service.py b/app/services/qcloud_service.py new file mode 100644 index 0000000..4399e9d --- /dev/null +++ b/app/services/qcloud_service.py @@ -0,0 +1,259 @@ +import os +import logging +import time +import uuid +import base64 +import hashlib +from typing import Optional, List, Dict, Any, BinaryIO, Union +from urllib.parse import urljoin + +from qcloud_cos import CosConfig, CosS3Client +from qcloud_cos.cos_exception import CosServiceError, CosClientError +import sts.sts + +from app.utils.config import get_settings + +logger = logging.getLogger(__name__) + +class QCloudCOSService: + """腾讯云对象存储(COS)服务类""" + + def __init__(self): + """初始化腾讯云COS服务""" + settings = get_settings() + self.secret_id = settings.qcloud_secret_id + self.secret_key = settings.qcloud_secret_key + self.region = settings.qcloud_cos_region + self.bucket = settings.qcloud_cos_bucket + self.domain = settings.qcloud_cos_domain + + # 创建COS客户端 + config = CosConfig( + Region=self.region, + SecretId=self.secret_id, + SecretKey=self.secret_key + ) + self.client = CosS3Client(config) + + async def upload_file( + self, + file_content: Union[bytes, BinaryIO], + file_name: Optional[str] = None, + directory: str = "uploads", + content_type: Optional[str] = None + ) -> Dict[str, str]: + """ + 上传文件到腾讯云COS + + Args: + file_content: 文件内容(bytes或文件对象) + file_name: 文件名,如果不提供则生成随机文件名 + directory: 存储目录 + content_type: 文件MIME类型 + + Returns: + Dict[str, str]: 包含文件URL等信息的字典 + """ + try: + # 生成文件名(如果未提供) + if not file_name: + # 使用UUID生成随机文件名 + random_filename = str(uuid.uuid4()) + if isinstance(file_content, bytes): + # 对于字节流,我们无法确定文件扩展名,默认用.bin + file_name = f"{random_filename}.bin" + else: + # 假设file_content是文件对象,尝试从其name属性获取扩展名 + if hasattr(file_content, 'name'): + original_name = os.path.basename(file_content.name) + ext = os.path.splitext(original_name)[1] + file_name = f"{random_filename}{ext}" + else: + file_name = f"{random_filename}.bin" + + # 构建对象键(文件在COS中的完整路径) + key = f"{directory}/{file_name}" + + # 上传文件 + response = self.client.put_object( + Bucket=self.bucket, + Body=file_content, + Key=key, + ContentType=content_type + ) + + # 构建文件URL + file_url = urljoin(self.domain, key) + + return { + "url": file_url, + "key": key, + "file_name": file_name, + "content_type": content_type, + "etag": response["ETag"] if "ETag" in response else None + } + + except (CosServiceError, CosClientError) as e: + logger.error(f"腾讯云COS上传文件失败: {str(e)}") + raise Exception(f"文件上传失败: {str(e)}") + + async def delete_file(self, key: str) -> bool: + """ + 从腾讯云COS删除文件 + + Args: + key: 文件的对象键(COS路径) + + Returns: + bool: 删除是否成功 + """ + try: + self.client.delete_object( + Bucket=self.bucket, + Key=key + ) + return True + except (CosServiceError, CosClientError) as e: + logger.error(f"腾讯云COS删除文件失败: {str(e)}") + return False + + async def get_file_url(self, key: str, expires: int = 3600) -> str: + """ + 获取COS文件的临时访问URL + + Args: + key: 文件的对象键(COS路径) + expires: URL的有效期(秒) + + Returns: + str: 临时访问URL + """ + try: + url = self.client.get_presigned_url( + Method='GET', + Bucket=self.bucket, + Key=key, + Expired=expires + ) + return url + except (CosServiceError, CosClientError) as e: + logger.error(f"获取腾讯云COS文件URL失败: {str(e)}") + raise Exception(f"获取文件URL失败: {str(e)}") + + async def generate_cos_sts_token( + self, + allow_actions: Optional[List[str]] = None, + allow_prefix: str = "*", + duration_seconds: int = 1800 + ) -> Dict[str, Any]: + """ + 生成COS的临时安全凭证(STS),用于前端直传 + + Args: + allow_actions: 允许的COS操作列表 + allow_prefix: 允许操作的对象前缀 + duration_seconds: 凭证有效期(秒) + + Returns: + Dict[str, Any]: STS凭证信息 + """ + try: + if allow_actions is None: + # 默认只允许上传操作 + allow_actions = [ + 'name/cos:PutObject', + 'name/cos:PostObject', + 'name/cos:InitiateMultipartUpload', + 'name/cos:ListMultipartUploads', + 'name/cos:ListParts', + 'name/cos:UploadPart', + 'name/cos:CompleteMultipartUpload' + ] + + # 配置STS + config = { + 'url': 'https://sts.tencentcloudapi.com/', + 'domain': 'sts.tencentcloudapi.com', + 'duration_seconds': duration_seconds, + 'secret_id': self.secret_id, + 'secret_key': self.secret_key, + 'region': self.region, + 'policy': { + 'version': '2.0', + 'statement': [ + { + 'action': allow_actions, + 'effect': 'allow', + 'resource': [ + f'qcs::cos:{self.region}:uid/:{self.bucket}/{allow_prefix}' + ] + } + ] + } + } + + sts_client = sts.sts.Sts(config) + response = sts_client.get_credential() + return response + + except Exception as e: + logger.error(f"生成腾讯云COS STS凭证失败: {str(e)}") + raise Exception(f"生成COS临时凭证失败: {str(e)}") + + async def list_files( + self, + directory: str = "", + limit: int = 100, + marker: str = "" + ) -> Dict[str, Any]: + """ + 列出COS中的文件 + + Args: + directory: 目录前缀 + limit: 返回的最大文件数 + marker: 分页标记 + + Returns: + Dict[str, Any]: 文件列表信息 + """ + try: + # 处理目录前缀 + prefix = directory + if prefix and not prefix.endswith('/'): + prefix += '/' + + # 调用COS API列出对象 + response = self.client.list_objects( + Bucket=self.bucket, + Prefix=prefix, + Marker=marker, + MaxKeys=limit + ) + + # 处理响应 + files = [] + if 'Contents' in response: + for item in response['Contents']: + key = item.get('Key', '') + # 过滤出文件(忽略目录) + if not key.endswith('/'): + file_url = urljoin(self.domain, key) + files.append({ + 'key': key, + 'url': file_url, + 'size': item.get('Size', 0), + 'last_modified': item.get('LastModified', ''), + 'etag': item.get('ETag', '').strip('"') + }) + + return { + 'files': files, + 'is_truncated': response.get('IsTruncated', False), + 'next_marker': response.get('NextMarker', ''), + 'common_prefixes': response.get('CommonPrefixes', []) + } + + except (CosServiceError, CosClientError) as e: + logger.error(f"列出腾讯云COS文件失败: {str(e)}") + raise Exception(f"列出文件失败: {str(e)}") \ No newline at end of file diff --git a/app/utils/__init__.py b/app/utils/__init__.py new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/app/utils/__init__.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/app/utils/config.py b/app/utils/config.py new file mode 100644 index 0000000..819f8ae --- /dev/null +++ b/app/utils/config.py @@ -0,0 +1,83 @@ +import os +from functools import lru_cache +from typing import Optional +from pydantic import BaseModel +from dotenv import load_dotenv + +# 加载环境变量(如果直接从main导入,这里可能是冗余的,但为了安全起见) +load_dotenv() + +class Settings(BaseModel): + """应用程序配置类""" + # DashScope配置 + dashscope_api_key: str = os.getenv("DASHSCOPE_API_KEY", "sk-caa199589f1c451aaac471fad2986e28") + + # 服务器配置 + host: str = os.getenv("HOST", "0.0.0.0") + port: int = int(os.getenv("PORT", "9001")) + debug: bool = os.getenv("DEBUG", "False").lower() in ["true", "1", "t", "yes"] + + # 腾讯云配置 + qcloud_secret_id: str = os.getenv("QCLOUD_SECRET_ID", "") + qcloud_secret_key: str = os.getenv("QCLOUD_SECRET_KEY", "") + qcloud_cos_region: str = os.getenv("QCLOUD_COS_REGION", "ap-guangzhou") + qcloud_cos_bucket: str = os.getenv("QCLOUD_COS_BUCKET", "") + qcloud_cos_domain: str = os.getenv("QCLOUD_COS_DOMAIN", "") + + # 数据库配置 + db_host: str = os.getenv("DB_HOST", "localhost") + db_port: int = int(os.getenv("DB_PORT", "3306")) + db_user: str = os.getenv("DB_USER", "root") + db_password: str = os.getenv("DB_PASSWORD", "password") + db_name: str = os.getenv("DB_NAME", "ai_dressing") + + # 数据库URL + @property + def database_url(self) -> str: + """ + 构建SQLAlchemy数据库连接URL + + Returns: + str: 数据库连接URL + """ + return f"mysql+pymysql://{self.db_user}:{self.db_password}@{self.db_host}:{self.db_port}/{self.db_name}" + + # 其他配置可以根据需要添加 + log_level: str = os.getenv("LOG_LEVEL", "INFO") + +@lru_cache() +def get_settings() -> Settings: + """ + 获取应用程序配置 + 使用lru_cache装饰器避免重复读取环境变量,提高性能 + + Returns: + Settings: 配置对象 + """ + return Settings() + +def validate_api_key(): + """ + 验证API密钥是否已配置 + + Raises: + ValueError: 如果API密钥未配置 + """ + settings = get_settings() + if not settings.dashscope_api_key: + raise ValueError("DASHSCOPE_API_KEY未设置,请在.env文件中配置") + return True + +def validate_qcloud_config(): + """ + 验证腾讯云配置是否已设置 + + Raises: + ValueError: 如果腾讯云配置未正确设置 + """ + settings = get_settings() + if not settings.qcloud_secret_id or not settings.qcloud_secret_key: + raise ValueError("腾讯云SecretId和SecretKey未设置,请在.env文件中配置") + if not settings.qcloud_cos_bucket: + raise ValueError("腾讯云COS存储桶名称未设置,请在.env文件中配置") + return True \ No newline at end of file diff --git a/create_migration.py b/create_migration.py new file mode 100644 index 0000000..b0ac96e --- /dev/null +++ b/create_migration.py @@ -0,0 +1,60 @@ +import os +import sys +import argparse +from alembic.config import Config +from alembic import command + +def create_migration(message=None): + """ + 创建数据库迁移 + + Args: + message: 迁移消息,描述此次迁移的内容 + """ + # 获取Alembic配置 + alembic_cfg = Config("alembic.ini") + + # 创建迁移 + if message: + command.revision(alembic_cfg, message=message, autogenerate=True) + else: + command.revision(alembic_cfg, message="自动生成的迁移", autogenerate=True) + +def upgrade_database(): + """更新数据库到最新版本""" + alembic_cfg = Config("alembic.ini") + command.upgrade(alembic_cfg, "head") + +def downgrade_database(revision): + """回滚数据库到指定版本""" + alembic_cfg = Config("alembic.ini") + command.downgrade(alembic_cfg, revision) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="数据库迁移工具") + subparsers = parser.add_subparsers(dest="command", help="命令") + + # 创建迁移命令 + create_parser = subparsers.add_parser("create", help="创建数据库迁移") + create_parser.add_argument("-m", "--message", help="迁移描述消息") + + # 更新数据库命令 + upgrade_parser = subparsers.add_parser("upgrade", help="更新数据库到最新版本") + + # 回滚数据库命令 + downgrade_parser = subparsers.add_parser("downgrade", help="回滚数据库到指定版本") + downgrade_parser.add_argument("revision", help="目标版本,例如:-1表示回滚一个版本") + + args = parser.parse_args() + + if args.command == "create": + create_migration(args.message) + print("迁移文件已创建,请查看 alembic/versions 目录") + elif args.command == "upgrade": + upgrade_database() + print("数据库已更新到最新版本") + elif args.command == "downgrade": + downgrade_database(args.revision) + print(f"数据库已回滚到版本: {args.revision}") + else: + parser.print_help() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..bbb9e70 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,12 @@ +fastapi==0.104.0 +uvicorn==0.23.2 +dashscope>=1.13.0 +python-dotenv==1.0.0 +pydantic==2.4.2 +httpx==0.25.0 +cos-python-sdk-v5==1.9.26 +qcloud-python-sts==3.1.4 +sqlalchemy==2.0.23 +pymysql==1.1.0 +cryptography==41.0.5 +alembic==1.12.1 \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 0000000..f5008f7 --- /dev/null +++ b/run.py @@ -0,0 +1,105 @@ +import uvicorn +import logging +import sys +import os +from app.utils.config import get_settings, validate_api_key, validate_qcloud_config +from app.database import engine + +def check_dependencies(): + """检查依赖项是否正确安装""" + try: + import dashscope + print(f"已加载 DashScope 版本: {getattr(dashscope, '__version__', '未知')}") + except ImportError as e: + print(f"错误: 无法导入 dashscope 模块: {str(e)}") + print("请确保已正确安装 dashscope: pip install -U dashscope") + return False + + try: + from qcloud_cos import CosConfig + print("已加载腾讯云 COS SDK") + except ImportError as e: + print(f"警告: 无法导入腾讯云 COS SDK: {str(e)}") + print("腾讯云 COS 相关功能可能无法使用") + print("请确保已正确安装: pip install -U cos-python-sdk-v5 qcloud-python-sts") + + try: + import sqlalchemy + print(f"已加载 SQLAlchemy 版本: {sqlalchemy.__version__}") + import pymysql + print(f"已加载 PyMySQL 版本: {pymysql.__version__}") + except ImportError as e: + print(f"错误: 无法导入数据库模块: {str(e)}") + print("请确保已正确安装: pip install -U sqlalchemy pymysql") + return False + + return True + +def check_database_connection(): + """检查数据库连接""" + try: + settings = get_settings() + + # 尝试连接数据库 + with engine.connect() as connection: + print(f"成功连接到数据库: {settings.db_name}") + + return True + except Exception as e: + print(f"警告: 无法连接到数据库: {str(e)}") + print("请确保MySQL服务已启动,并且数据库配置正确") + print("您可以使用以下命令创建数据库:") + settings = get_settings() + print(f"CREATE DATABASE {settings.db_name} CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;") + print(f"GRANT ALL PRIVILEGES ON {settings.db_name}.* TO '{settings.db_user}'@'%' IDENTIFIED BY '{settings.db_password}';") + print("FLUSH PRIVILEGES;") + + # 询问是否继续 + if input("是否继续启动应用程序?(y/n): ").lower() not in ["y", "yes"]: + return False + + return True + +if __name__ == "__main__": + try: + # 检查依赖项 + if not check_dependencies(): + sys.exit(1) + + # 验证API密钥是否配置 + validate_api_key() + + # 验证腾讯云配置 + try: + validate_qcloud_config() + print("腾讯云COS配置已验证") + except ValueError as e: + print(f"警告: {str(e)}") + print("腾讯云COS相关功能可能无法正常使用,但应用程序将继续启动") + + # 检查数据库连接 + if not check_database_connection(): + sys.exit(1) + + # 获取配置 + settings = get_settings() + + # 配置日志级别 + log_level = getattr(logging, settings.log_level.upper(), logging.INFO) + logging.basicConfig(level=log_level) + + # 启动服务器 + print(f"启动服务,访问地址: http://{settings.host}:{settings.port}") + uvicorn.run( + "app.main:app", + host=settings.host, + port=settings.port, + reload=settings.debug, + log_level=settings.log_level.lower() + ) + except Exception as e: + logging.error(f"启动失败: {str(e)}") + if hasattr(e, "__traceback__"): + import traceback + traceback.print_exception(type(e), e, e.__traceback__) + exit(1) \ No newline at end of file