This commit is contained in:
aaron 2025-03-21 22:49:03 +08:00
parent 012b0bbde3
commit 94c741b024
24 changed files with 1168 additions and 305 deletions

View File

@ -1,21 +0,0 @@
# DashScope API密钥
DASHSCOPE_API_KEY=sk-caa199589f1c451aaac471fad2986e28
# 服务器配置
HOST=127.0.0.1
PORT=9001
DEBUG=True
# 腾讯云配置
QCLOUD_SECRET_ID=AKIDxnbGj281iHtKallqqzvlV5YxBCrPltnS
QCLOUD_SECRET_KEY=ta6PXTMBsX7dzA7IN6uYUFn8F9uTovoU
QCLOUD_COS_REGION=ap-chengdu
QCLOUD_COS_BUCKET=aidress-1311994147
QCLOUD_COS_DOMAIN=https://aidress-1311994147.cos.ap-chengdu.myqcloud.com
# 数据库配置
DB_HOST=gz-cynosdbmysql-grp-2j1cnopr.sql.tencentcdb.com
DB_PORT=27469
DB_USER=root
DB_PASSWORD=Aa#223388
DB_NAME=aidress

View File

@ -1,22 +1,30 @@
# 基本配置
ENV=development
HOST=0.0.0.0
PORT=9001
DEBUG=true
# 数据库配置 # 数据库配置
DB_HOST=localhost DB_HOST=localhost
DB_PORT=3306 DB_PORT=3306
DB_USER=ai_user DB_USER=ai_user
DB_PASSWORD=your_password DB_PASSWORD=yourpassword
DB_NAME=ai_dressing DB_NAME=ai_dressing
# 阿里云大模型API配置 # 阿里云DashScope配置
DASHSCOPE_API_KEY=your_dashscope_api_key DASHSCOPE_API_KEY=your_dashscope_api_key
DASHSCOPE_MODEL_NAME=qwen-vl-plus
# 腾讯云配置 # 腾讯云配置
QCLOUD_SECRET_ID=your_qcloud_secret_id QCLOUD_SECRET_ID=your_qcloud_secret_id
QCLOUD_SECRET_KEY=your_qcloud_secret_key QCLOUD_SECRET_KEY=your_qcloud_secret_key
QCLOUD_COS_REGION=ap-chengdu QCLOUD_COS_REGION=ap-chengdu
QCLOUD_COS_BUCKET=your-bucket-name QCLOUD_COS_BUCKET=your-bucket-name
QCLOUD_COS_DOMAIN=https://your-bucket-name.cos.ap-chengdu.myqcloud.com QCLOUD_COS_DOMAIN=https://your-bucket-domain.com
# 应用程序配置 # 应用特定配置
HOST=0.0.0.0 UPLOAD_DIR=uploads
PORT=9001
DEBUG=False
LOG_LEVEL=INFO LOG_LEVEL=INFO
# Docker配置
DOCKER_MYSQL_ROOT_PASSWORD=rootpassword

23
.gitignore vendored
View File

@ -1,5 +1,8 @@
# 环境变量 # 环境变量
# .env .env
.env.*
!.env.example
.venv
# Python # Python
__pycache__/ __pycache__/
@ -36,7 +39,25 @@ ENV/
# 日志 # 日志
*.log *.log
logs/
# 操作系统 # 操作系统
.DS_Store .DS_Store
Thumbs.db Thumbs.db
# 上传文件
uploads/
data/
# 数据库
*.db
*.sqlite
*.sqlite3
# Alembic版本
alembic/versions/*
!alembic/versions/.gitkeep
# Docker相关
.docker/
docker-compose.override.yml

View File

@ -1,78 +1,67 @@
# 使用Python 3.10作为基础镜像 # 使用多阶段构建,先安装构建依赖
FROM python:3.10-slim FROM python:3.9-slim AS builder
# 设置环境变量
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
# 设置时区 # 设置时区
RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime ENV TZ=Asia/Shanghai
# 清空所有默认源 RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
RUN rm -rf /etc/apt/sources.list.d/* && \
rm -f /etc/apt/sources.list
# 替换为阿里云源 # 设置工作目录
RUN echo "\ WORKDIR /build
deb https://mirrors.aliyun.com/debian/ bookworm main non-free-firmware contrib\n\
deb-src https://mirrors.aliyun.com/debian/ bookworm main non-free-firmware contrib\n\
deb https://mirrors.aliyun.com/debian/ bookworm-updates main non-free-firmware contrib\n\
deb-src https://mirrors.aliyun.com/debian/ bookworm-updates main non-free-firmware contrib\n\
deb https://mirrors.aliyun.com/debian/ bookworm-backports main non-free-firmware contrib\n\
deb-src https://mirrors.aliyun.com/debian/ bookworm-backports main non-free-firmware contrib\n\
deb https://mirrors.aliyun.com/debian-security bookworm-security main non-free-firmware contrib\n\
deb-src https://mirrors.aliyun.com/debian-security bookworm-security main non-free-firmware contrib\n\
" > /etc/apt/sources.list
# 安装系统依赖 # 安装netcat用于网络连接检查
RUN apt-get update \ RUN apt-get update && apt-get install -y --no-install-recommends \
&& apt-get install -y --no-install-recommends \ curl \
build-essential \ build-essential \
default-libmysqlclient-dev \ default-libmysqlclient-dev \
pkg-config \ netcat-openbsd \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# 复制项目依赖文件
COPY requirements.txt .
# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt
# 最终阶段,使用轻量镜像
FROM python:3.9-slim
# 设置时区
ENV TZ=Asia/Shanghai
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
# 安装运行时依赖
RUN apt-get update && apt-get install -y --no-install-recommends \
curl \ curl \
nano \ default-libmysqlclient-dev \
netcat-openbsd \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
# 设置工作目录 # 设置工作目录
WORKDIR /app WORKDIR /app
# 安装Python依赖先于文件复制利用缓存 # 从构建阶段复制已安装的依赖
COPY requirements.txt . COPY --from=builder /usr/local/lib/python3.9/site-packages /usr/local/lib/python3.9/site-packages
RUN pip install -i https://mirrors.aliyun.com/pypi/simple/ -r requirements.txt \ COPY --from=builder /usr/local/bin /usr/local/bin
&& pip install -i https://mirrors.aliyun.com/pypi/simple/ uvicorn python-multipart python-dotenv
# 复制项目文件 # 复制应用代码
COPY app app/ COPY . .
COPY *.py ./
COPY entrypoint.sh .
# 创建一个默认的.env.example文件 # 确保脚本可执行
RUN echo "# 数据库配置\n\ RUN chmod +x entrypoint.sh
DB_HOST=db\n\
DB_PORT=3306\n\
DB_USER=ai_user\n\
DB_PASSWORD=yourpassword\n\
DB_NAME=ai_dressing\n\
\n\
# 阿里云DashScope配置\n\
DASHSCOPE_API_KEY=your_dashscope_api_key\n\
\n\
# 腾讯云配置\n\
QCLOUD_SECRET_ID=your_qcloud_secret_id\n\
QCLOUD_SECRET_KEY=your_qcloud_secret_key\n\
QCLOUD_COS_REGION=ap-chengdu\n\
QCLOUD_COS_BUCKET=your-bucket-name\n\
QCLOUD_COS_DOMAIN=https://your-bucket-domain.com\n\
" > /app/.env.example
# 确保entrypoint.sh可执行 # 设置环境变量
RUN chmod +x /app/entrypoint.sh ENV PYTHONPATH=/app:$PYTHONPATH
ENV PORT=8000
ENV HOST=0.0.0.0
# 暴露端口 # 暴露端口
EXPOSE 8000 EXPOSE 8000
# 设置入口点 # 添加健康检查
ENTRYPOINT ["/app/entrypoint.sh"] HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# 启动命令 # 设置入口点
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--log-level", "info"] ENTRYPOINT ["./entrypoint.sh"]

View File

@ -156,6 +156,85 @@ python run.py
docker-compose down -v # 谨慎使用,这将删除数据库中的所有数据 docker-compose down -v # 谨慎使用,这将删除数据库中的所有数据
``` ```
## Docker构建和环境变量配置指南
## 使用.env文件构建Docker镜像
本项目提供了多种方式将`.env`文件中的环境变量传递到Docker容器中。
### 方法一使用build.sh脚本推荐
这种方法会自动读取`.env`文件中的所有环境变量并传递给Docker构建过程
```bash
# 赋予执行权限
chmod +x build.sh
# 执行构建脚本
./build.sh
```
构建完成后,可以直接运行容器:
```bash
docker run -p 9001:8000 ai-dressing:latest
```
### 方法二使用Docker Compose
使用专门的docker-compose文件进行构建该文件会读取.env中的环境变量
```bash
# 先加载.env文件到当前shell
export $(grep -v '^#' .env | xargs)
# 使用专用的构建配置文件
docker-compose -f docker-compose.build.yml build
# 启动服务
docker-compose -f docker-compose.build.yml up -d
```
### 方法三:手动构建
如果你希望手动指定部分环境变量,可以使用以下命令:
```bash
docker build \
--build-arg DASHSCOPE_API_KEY=your_key \
--build-arg DB_HOST=your_db_host \
--build-arg DB_USER=your_db_user \
--build-arg DB_PASSWORD=your_db_password \
-t ai-dressing:latest .
```
## 验证环境变量
构建完成后,你可以通过以下命令验证环境变量是否正确加载:
```bash
# 查看容器日志
docker logs ai-dressing-app
# 或者进入容器查看环境变量
docker exec -it ai-dressing-app bash -c "env | sort"
```
## 环境变量优先级
环境变量的加载优先级从高到低如下:
1. docker run或docker-compose启动时通过`-e`或`environment`指定的环境变量
2. 通过卷挂载到容器中的`.env`文件
3. 构建时通过`--build-arg`传入并保存到容器内`.env.built`文件的变量
4. 代码中的默认值
## 注意事项
- 生产环境中应避免将敏感信息硬编码在Dockerfile中
- 敏感信息应通过环境变量或Docker secrets进行管理
- 建议将`.env`文件添加到`.gitignore`中,避免意外提交
## API 文档 ## API 文档
启动服务后,访问以下地址查看自动生成的 API 文档: 启动服务后,访问以下地址查看自动生成的 API 文档:

View File

@ -0,0 +1,3 @@
from app.exceptions.http_exception import CustomHTTPException, setup_exception_handlers
__all__ = ["CustomHTTPException", "setup_exception_handlers"]

View File

@ -0,0 +1,90 @@
from fastapi import FastAPI, Request, status
from fastapi.responses import JSONResponse
from fastapi.exceptions import RequestValidationError
from starlette.exceptions import HTTPException as StarletteHTTPException
import logging
from typing import Any, Dict, Optional, Union
from app.utils.response import APIResponse
logger = logging.getLogger(__name__)
class CustomHTTPException(Exception):
"""自定义HTTP异常"""
def __init__(
self,
status_code: int = 400,
code: int = None,
message: str = "操作失败",
data: Any = None
):
self.status_code = status_code
self.code = code or status_code
self.message = message
self.data = data
super().__init__(self.message)
def setup_exception_handlers(app: FastAPI) -> None:
"""设置FastAPI应用的异常处理程序"""
@app.exception_handler(CustomHTTPException)
async def custom_http_exception_handler(request: Request, exc: CustomHTTPException):
"""处理自定义HTTP异常"""
logger.error(f"自定义HTTP异常: {exc.message} (状态码: {exc.status_code}, 业务码: {exc.code})")
return JSONResponse(
status_code=exc.status_code,
content=APIResponse.error(
message=exc.message,
code=exc.code,
data=exc.data
)
)
@app.exception_handler(StarletteHTTPException)
async def http_exception_handler(request: Request, exc: StarletteHTTPException):
"""处理标准HTTP异常"""
logger.error(f"HTTP异常: {exc.detail} (状态码: {exc.status_code})")
return JSONResponse(
status_code=exc.status_code,
content=APIResponse.error(
message=str(exc.detail),
code=exc.status_code
)
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""处理请求验证异常"""
error_details = exc.errors()
error_messages = []
for error in error_details:
location = " -> ".join(str(loc) for loc in error["loc"])
message = f"{location}: {error['msg']}"
error_messages.append(message)
error_message = "; ".join(error_messages)
logger.error(f"验证错误: {error_message}")
return JSONResponse(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
content=APIResponse.error(
message="请求参数验证失败",
code=status.HTTP_422_UNPROCESSABLE_ENTITY,
data={"details": error_details}
)
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception):
"""处理通用异常"""
logger.exception(f"未处理的异常: {str(exc)}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=APIResponse.error(
message="服务器内部错误",
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
data={"detail": str(exc)} if app.debug else None
)
)

View File

@ -6,7 +6,10 @@ from dotenv import load_dotenv
from app.routers import qcloud_router, dress_router, tryon_router from app.routers import qcloud_router, dress_router, tryon_router
from app.utils.config import get_settings from app.utils.config import get_settings
from app.utils.response import APIResponse
from app.database import Base, engine from app.database import Base, engine
from app.middleware import ResponseWrapperMiddleware
from app.exceptions import setup_exception_handlers
# 创建数据库表 # 创建数据库表
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
@ -36,6 +39,12 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
# 添加响应包装中间件
app.add_middleware(ResponseWrapperMiddleware)
# 设置异常处理
setup_exception_handlers(app)
# 注册路由 # 注册路由
app.include_router(qcloud_router.router, prefix="/api/qcloud", tags=["腾讯云"]) app.include_router(qcloud_router.router, prefix="/api/qcloud", tags=["腾讯云"])
app.include_router(dress_router.router, prefix="/api/dresses", tags=["服装"]) app.include_router(dress_router.router, prefix="/api/dresses", tags=["服装"])
@ -44,22 +53,53 @@ app.include_router(tryon_router.router, prefix="/api/tryons", tags=["试穿"])
@app.get("/", tags=["健康检查"]) @app.get("/", tags=["健康检查"])
async def root(): async def root():
"""API 根端点""" """API 根端点"""
return {"status": "正常", "message": "服务运行中"} return APIResponse.ok(message="服务运行中")
@app.get("/health", tags=["健康检查"]) @app.get("/health", tags=["健康检查"])
async def health_check(): async def health_check():
"""健康检查端点用于Docker容器健康监控""" """健康检查端点用于Docker容器健康监控"""
return {"status": "healthy", "message": "服务运行正常"} health_data = {"checks": {}}
health_status = "healthy"
health_message = "服务运行正常"
# 检查数据库连接
try:
from app.database import SessionLocal
db = SessionLocal()
db.execute("SELECT 1")
db.close()
health_data["checks"]["database"] = {"status": "healthy", "message": "数据库连接正常"}
except Exception as e:
health_status = "unhealthy"
health_message = "服务异常"
health_data["checks"]["database"] = {"status": "unhealthy", "message": f"数据库连接失败: {str(e)}"}
# 检查API密钥
if os.getenv("DASHSCOPE_API_KEY"):
health_data["checks"]["dashscope"] = {"status": "configured", "message": "DashScope API密钥已配置"}
else:
health_data["checks"]["dashscope"] = {"status": "missing", "message": "DashScope API密钥未配置"}
if os.getenv("QCLOUD_SECRET_ID") and os.getenv("QCLOUD_SECRET_KEY"):
health_data["checks"]["qcloud"] = {"status": "configured", "message": "腾讯云凭证已配置"}
else:
health_data["checks"]["qcloud"] = {"status": "missing", "message": "腾讯云凭证未配置"}
if health_status == "healthy":
return APIResponse.ok(data=health_data, message=health_message)
else:
return APIResponse.error(message=health_message, code=503, data=health_data)
@app.get("/info", tags=["服务信息"]) @app.get("/info", tags=["服务信息"])
async def get_info(): async def get_info():
"""获取服务基本信息""" """获取服务基本信息"""
settings = get_settings() settings = get_settings()
return { info_data = {
"app_name": "AI-Dressing API", "app_name": "AI-Dressing API",
"version": "0.1.0", "version": "0.1.0",
"debug_mode": settings.debug, "debug_mode": settings.debug,
} }
return APIResponse.ok(data=info_data, message="服务信息获取成功")
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn

View File

@ -0,0 +1,3 @@
from app.middleware.response_wrapper import ResponseWrapperMiddleware
__all__ = ["ResponseWrapperMiddleware"]

View File

@ -0,0 +1,86 @@
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from typing import Any, Dict, Optional, Union
import json
class ResponseWrapperMiddleware(BaseHTTPMiddleware):
"""
响应包装中间件自动将所有API响应包装为标准格式
格式为
{
"success": true/false,
"code": 200,
"message": "操作成功",
"data": 原始响应数据
}
"""
async def dispatch(
self, request: Request, call_next
) -> Response:
# 排除不需要包装的路径
if self._should_skip_path(request.url.path):
return await call_next(request)
# 调用下一个中间件或路由处理函数
response = await call_next(request)
# 如果响应是JSON且未使用标准包装格式则进行包装
if (
isinstance(response, JSONResponse) and
self._should_wrap_response(response)
):
return self._wrap_response(response)
return response
def _should_skip_path(self, path: str) -> bool:
"""判断是否跳过包装处理"""
# 跳过文档相关路径
skip_paths = ["/docs", "/redoc", "/openapi.json"]
for skip_path in skip_paths:
if path.startswith(skip_path):
return True
return False
def _should_wrap_response(self, response: JSONResponse) -> bool:
"""判断是否需要包装响应"""
try:
content = response.body.decode()
data = json.loads(content)
# 已经是标准格式则不需要再包装
if isinstance(data, dict) and "success" in data and "code" in data and "message" in data:
return False
return True
except Exception:
return False
def _wrap_response(self, response: JSONResponse) -> JSONResponse:
"""包装响应为标准格式"""
try:
# 解析原始响应内容
content = response.body.decode()
data = json.loads(content)
# 构造标准格式响应
wrapped_data = {
"success": response.status_code < 400,
"code": response.status_code,
"message": "操作成功" if response.status_code < 400 else "操作失败",
"data": data
}
# 创建新的响应
return JSONResponse(
content=wrapped_data,
status_code=response.status_code,
headers=dict(response.headers),
)
except Exception:
# 出错时返回原始响应
return response

View File

@ -1,162 +1,90 @@
from fastapi import APIRouter, HTTPException, Depends, Query, Path from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List, Optional from typing import List
import logging
from app.database import get_db from app.database import get_db
from app.models.dress import Dress, GarmentType from app.models.dress import Dress
from app.schemas.dress import DressCreate, DressUpdate, DressResponse from app.schemas.dress import DressCreate, DressUpdate, DressResponse, DressListResponse
from app.utils.response import APIResponse
logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@router.post("/", response_model=DressResponse, status_code=201) @router.post("/", response_model=DressResponse)
async def create_dress( def create_dress(dress: DressCreate, db: Session = Depends(get_db)):
dress: DressCreate, """创建服装记录"""
db: Session = Depends(get_db)
):
"""
创建一个新的服装记录
- **name**: 服装名称必填
- **image_url**: 服装图片URL可选
- **garment_type**: 服装类型(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)可选
- **description**: 服装描述可选
"""
try: try:
db_dress = Dress( db_dress = Dress(**dress.dict())
name=dress.name,
image_url=dress.image_url,
garment_type=dress.garment_type,
description=dress.description
)
db.add(db_dress) db.add(db_dress)
db.commit() db.commit()
db.refresh(db_dress) db.refresh(db_dress)
return db_dress return APIResponse.created(data=db_dress, message="服装创建成功")
except Exception as e: except Exception as e:
logger.error(f"创建服装记录失败: {str(e)}")
db.rollback() db.rollback()
raise HTTPException(status_code=500, detail=f"创建服装记录失败: {str(e)}") return APIResponse.error(message=f"服装创建失败: {str(e)}", code=500)
@router.get("/", response_model=List[DressResponse]) @router.get("/", response_model=DressListResponse)
async def get_dresses( def get_all_dresses(
skip: int = Query(0, description="跳过的记录数量"), skip: int = 0,
limit: int = Query(100, description="返回的最大记录数量"), limit: int = 100,
name: Optional[str] = Query(None, description="按名称过滤"),
garment_type: Optional[str] = Query(None, description="按服装类型过滤(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)"),
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
""" """获取所有服装记录"""
获取服装列表支持分页和过滤
- **skip**: 跳过的记录数量用于分页
- **limit**: 返回的最大记录数量用于分页
- **name**: 按名称过滤可选
- **garment_type**: 按服装类型过滤可选
"""
try: try:
query = db.query(Dress) dresses = db.query(Dress).offset(skip).limit(limit).all()
total = db.query(Dress).count()
if name: return APIResponse.ok(
query = query.filter(Dress.name.ilike(f"%{name}%")) data={
"items": dresses,
if garment_type: "total": total,
try: "page": skip // limit + 1 if limit > 0 else 1,
garment_type_enum = GarmentType[garment_type] "size": limit
query = query.filter(Dress.garment_type == garment_type_enum) },
except KeyError: message="服装列表获取成功"
logger.warning(f"无效的服装类型: {garment_type}") )
# 继续查询,但不应用无效的过滤条件
dresses = query.offset(skip).limit(limit).all()
return dresses
except Exception as e: except Exception as e:
logger.error(f"获取服装列表失败: {str(e)}") return APIResponse.error(message=f"获取服装列表失败: {str(e)}", code=500)
raise HTTPException(status_code=500, detail=f"获取服装列表失败: {str(e)}")
@router.get("/{dress_id}", response_model=DressResponse) @router.get("/{dress_id}", response_model=DressResponse)
async def get_dress( def get_dress(dress_id: int, db: Session = Depends(get_db)):
dress_id: int = Path(..., description="服装ID"), """获取单个服装记录"""
db: Session = Depends(get_db)
):
"""
根据ID获取服装详情
- **dress_id**: 服装ID
"""
try: try:
dress = db.query(Dress).filter(Dress.id == dress_id).first() dress = db.query(Dress).filter(Dress.id == dress_id).first()
if not dress: if dress is None:
raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装") return APIResponse.not_found(message=f"未找到ID为{dress_id}的服装")
return dress return APIResponse.ok(data=dress, message="服装获取成功")
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error(f"获取服装详情失败: {str(e)}") return APIResponse.error(message=f"获取服装失败: {str(e)}", code=500)
raise HTTPException(status_code=500, detail=f"获取服装详情失败: {str(e)}")
@router.put("/{dress_id}", response_model=DressResponse) @router.put("/{dress_id}", response_model=DressResponse)
async def update_dress( def update_dress(dress_id: int, dress: DressUpdate, db: Session = Depends(get_db)):
dress_id: int = Path(..., description="服装ID"), """更新服装记录"""
dress: DressUpdate = None,
db: Session = Depends(get_db)
):
"""
更新服装信息
- **dress_id**: 服装ID
- **name**: 服装名称可选
- **image_url**: 服装图片URL可选
- **garment_type**: 服装类型(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)可选
- **description**: 服装描述可选
"""
try: try:
db_dress = db.query(Dress).filter(Dress.id == dress_id).first() db_dress = db.query(Dress).filter(Dress.id == dress_id).first()
if not db_dress: if db_dress is None:
raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装") return APIResponse.not_found(message=f"未找到ID为{dress_id}的服装")
# 更新提供的字段 # 更新服装字段
if dress.name is not None: for field, value in dress.dict(exclude_unset=True).items():
db_dress.name = dress.name setattr(db_dress, field, value)
if dress.image_url is not None:
db_dress.image_url = dress.image_url
if dress.garment_type is not None:
db_dress.garment_type = dress.garment_type
if dress.description is not None:
db_dress.description = dress.description
db.commit() db.commit()
db.refresh(db_dress) db.refresh(db_dress)
return db_dress return APIResponse.ok(data=db_dress, message="服装更新成功")
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error(f"更新服装信息失败: {str(e)}")
db.rollback() db.rollback()
raise HTTPException(status_code=500, detail=f"更新服装信息失败: {str(e)}") return APIResponse.error(message=f"更新服装失败: {str(e)}", code=500)
@router.delete("/{dress_id}", status_code=204) @router.delete("/{dress_id}")
async def delete_dress( def delete_dress(dress_id: int, db: Session = Depends(get_db)):
dress_id: int = Path(..., description="服装ID"), """删除服装记录"""
db: Session = Depends(get_db)
):
"""
删除服装
- **dress_id**: 服装ID
"""
try: try:
db_dress = db.query(Dress).filter(Dress.id == dress_id).first() db_dress = db.query(Dress).filter(Dress.id == dress_id).first()
if not db_dress: if db_dress is None:
raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装") return APIResponse.not_found(message=f"未找到ID为{dress_id}的服装")
db.delete(db_dress) db.delete(db_dress)
db.commit() db.commit()
return None return APIResponse.ok(message="服装删除成功")
except HTTPException:
raise
except Exception as e: except Exception as e:
logger.error(f"删除服装失败: {str(e)}")
db.rollback() db.rollback()
raise HTTPException(status_code=500, detail=f"删除服装失败: {str(e)}") return APIResponse.error(message=f"删除服装失败: {str(e)}", code=500)

View File

@ -13,7 +13,7 @@ from app.schemas.tryon import (
AiTryonRequest, AiTryonResponse, TaskInfo AiTryonRequest, AiTryonResponse, TaskInfo
) )
from app.services.dashscope_service import DashScopeService from app.services.dashscope_service import DashScopeService
from app.utils.response import APIResponse
# 加载环境变量 # 加载环境变量
load_dotenv() load_dotenv()
@ -63,7 +63,7 @@ async def create_tryon(
person_image_url=tryon_data.person_image_url person_image_url=tryon_data.person_image_url
) )
return db_tryon return APIResponse.ok(data=db_tryon, message="试穿记录创建成功")
except Exception as e: except Exception as e:
logger.error(f"创建试穿记录失败: {str(e)}") logger.error(f"创建试穿记录失败: {str(e)}")
db.rollback() db.rollback()
@ -106,7 +106,7 @@ async def send_tryon_request(
db_tryon.task_status = "ERROR" db_tryon.task_status = "ERROR"
db.commit() db.commit()
@router.get("/", response_model=List[TryOnResponse]) @router.get("", response_model=List[TryOnResponse])
async def get_tryons( async def get_tryons(
skip: int = Query(0, description="跳过的记录数量"), skip: int = Query(0, description="跳过的记录数量"),
limit: int = Query(100, description="返回的最大记录数量"), limit: int = Query(100, description="返回的最大记录数量"),
@ -119,8 +119,8 @@ async def get_tryons(
- **limit**: 返回的最大记录数量用于分页 - **limit**: 返回的最大记录数量用于分页
""" """
try: try:
tryons = db.query(TryOn).order_by(TryOn.created_at.desc()).offset(skip).limit(limit).all() tryons = db.query(TryOn).order_by(TryOn.id.desc()).offset(skip).limit(limit).all()
return tryons return APIResponse.ok(data=tryons, message="试穿记录列表获取成功")
except Exception as e: except Exception as e:
logger.error(f"获取试穿记录列表失败: {str(e)}") logger.error(f"获取试穿记录列表失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"获取试穿记录列表失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取试穿记录列表失败: {str(e)}")
@ -139,7 +139,7 @@ async def get_tryon(
tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first() tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first()
if not tryon: if not tryon:
raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录") raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录")
return tryon return APIResponse.ok(data=tryon, message="试穿记录获取成功")
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@ -172,7 +172,7 @@ async def update_tryon(
db.commit() db.commit()
db.refresh(db_tryon) db.refresh(db_tryon)
return db_tryon return APIResponse.ok(data=db_tryon, message="试穿记录更新成功")
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@ -217,7 +217,7 @@ async def check_tryon_status(
except Exception as e: except Exception as e:
logger.error(f"调用DashScope API检查任务状态失败: {str(e)}") logger.error(f"调用DashScope API检查任务状态失败: {str(e)}")
return db_tryon return APIResponse.ok(data=db_tryon, message="试穿任务状态检查成功")
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:

56
app/schemas/common.py Normal file
View File

@ -0,0 +1,56 @@
from pydantic import BaseModel, Field
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
T = TypeVar('T')
class StandardResponse(BaseModel, Generic[T]):
"""标准API响应模型"""
success: bool = Field(True, description="请求是否成功")
code: int = Field(200, description="业务状态码")
message: str = Field("操作成功", description="响应消息")
data: Optional[T] = Field(None, description="响应数据")
class PageInfo(BaseModel):
"""分页信息"""
page: int = Field(..., description="当前页码")
size: int = Field(..., description="每页大小")
total: int = Field(..., description="总记录数")
pages: int = Field(..., description="总页数")
class PagedResponse(BaseModel, Generic[T]):
"""分页响应数据"""
items: List[T] = Field(..., description="数据列表")
page_info: PageInfo = Field(..., description="分页信息")
class ErrorDetail(BaseModel):
"""错误详情"""
loc: List[str] = Field(..., description="错误位置")
msg: str = Field(..., description="错误消息")
type: str = Field(..., description="错误类型")
class ValidationError(BaseModel):
"""验证错误响应"""
detail: List[ErrorDetail]
class HealthCheck(BaseModel):
"""健康检查项目"""
status: str = Field(..., description="状态")
message: str = Field(..., description="消息")
class HealthCheckResponse(BaseModel):
"""健康检查响应"""
checks: Dict[str, HealthCheck] = Field(..., description="检查项目")
class TokenResponse(BaseModel):
"""令牌响应"""
access_token: str = Field(..., description="访问令牌")
token_type: str = Field("bearer", description="令牌类型")
expires_in: int = Field(..., description="过期时间(秒)")
class FileUploadResponse(BaseModel):
"""文件上传响应"""
file_id: str = Field(..., description="文件ID")
url: str = Field(..., description="文件URL")
file_name: str = Field(..., description="文件名")
content_type: str = Field(..., description="内容类型")
size: int = Field(..., description="文件大小(字节)")

View File

@ -1,9 +1,13 @@
from pydantic import BaseModel, Field, HttpUrl from pydantic import BaseModel, Field, HttpUrl
from typing import Optional from typing import Optional, List
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from app.models.dress import GarmentType from app.models.dress import GarmentType
class GarmentType(str, Enum):
TOP_GARMENT = "TOP_GARMENT"
BOTTOM_GARMENT = "BOTTOM_GARMENT"
class DressBase(BaseModel): class DressBase(BaseModel):
"""服装基础模型""" """服装基础模型"""
name: str = Field(..., description="服装名称", example="夏季连衣裙") name: str = Field(..., description="服装名称", example="夏季连衣裙")
@ -34,3 +38,19 @@ class DressInDB(DressBase):
class DressResponse(DressInDB): class DressResponse(DressInDB):
"""服装API响应模型""" """服装API响应模型"""
pass pass
class DressListResponse(BaseModel):
items: List[DressResponse]
total: int
page: int
size: int
class Config:
orm_mode = True
class StandardResponse(BaseModel):
"""标准API响应格式"""
success: bool = True
code: int = 200
message: str = "操作成功"
data: Optional[dict] = None

192
app/utils/client.py Normal file
View File

@ -0,0 +1,192 @@
import httpx
import json
import logging
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urljoin
logger = logging.getLogger(__name__)
class HttpClient:
"""
HTTP客户端工具类
封装了请求和响应处理逻辑
"""
def __init__(
self,
base_url: str = "",
timeout: int = 30,
headers: Optional[Dict[str, str]] = None,
verify_ssl: bool = True
):
"""
初始化HTTP客户端
Args:
base_url: API基础URL
timeout: 请求超时时间
headers: 默认请求头
verify_ssl: 是否验证SSL证书
"""
self.base_url = base_url
self.timeout = timeout
self.headers = headers or {}
self.verify_ssl = verify_ssl
async def request(
self,
method: str,
url: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Any] = None,
json_data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None,
) -> Dict[str, Any]:
"""
发送HTTP请求
Args:
method: 请求方法 (GET, POST, PUT, DELETE等)
url: 请求URL路径
params: URL查询参数
data: 表单数据或二进制数据
json_data: JSON数据
headers: 请求头
timeout: 超时时间
Returns:
响应数据
Raises:
Exception: 请求失败时抛出异常
"""
if self.base_url:
full_url = urljoin(self.base_url, url)
else:
full_url = url
request_headers = {**self.headers}
if headers:
request_headers.update(headers)
timeout_value = timeout or self.timeout
try:
async with httpx.AsyncClient(verify=self.verify_ssl) as client:
response = await client.request(
method=method,
url=full_url,
params=params,
data=data,
json=json_data,
headers=request_headers,
timeout=timeout_value,
)
# 记录请求和响应
logger.debug(f"HTTP请求: {method} {full_url}")
logger.debug(f"状态码: {response.status_code}")
# 尝试解析JSON响应
response_data = None
try:
response_data = response.json()
except json.JSONDecodeError:
response_data = {"content": response.text}
# 检查状态码
response.raise_for_status()
# 验证是否成功
self._verify_success(response_data, response.status_code)
return response_data
except httpx.HTTPStatusError as e:
logger.error(f"HTTP状态错误: {e.response.status_code} - {e.response.text}")
try:
error_data = e.response.json()
msg = error_data.get("message", str(e))
except json.JSONDecodeError:
msg = e.response.text or str(e)
raise Exception(f"请求失败 ({e.response.status_code}): {msg}")
except httpx.RequestError as e:
logger.error(f"请求错误: {str(e)}")
raise Exception(f"请求错误: {str(e)}")
except Exception as e:
logger.error(f"未知错误: {str(e)}")
raise
def _verify_success(self, data: Dict[str, Any], status_code: int) -> None:
"""
验证响应是否成功
Args:
data: 响应数据
status_code: HTTP状态码
Raises:
Exception: 验证失败时抛出异常
"""
# 检查HTTP状态码
if status_code >= 400:
raise Exception(f"请求失败,状态码: {status_code}")
# 检查响应体中的success字段如果存在
if isinstance(data, dict) and "success" in data and data["success"] is False:
msg = data.get("message", "未知错误")
code = data.get("code", status_code)
raise Exception(f"请求失败 (code: {code}): {msg}")
async def get(
self,
url: str,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None,
) -> Dict[str, Any]:
"""发送GET请求"""
return await self.request("GET", url, params=params, headers=headers, timeout=timeout)
async def post(
self,
url: str,
data: Optional[Any] = None,
json_data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None,
) -> Dict[str, Any]:
"""发送POST请求"""
return await self.request(
"POST", url, params=params, data=data, json_data=json_data,
headers=headers, timeout=timeout
)
async def put(
self,
url: str,
data: Optional[Any] = None,
json_data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None,
) -> Dict[str, Any]:
"""发送PUT请求"""
return await self.request(
"PUT", url, params=params, data=data, json_data=json_data,
headers=headers, timeout=timeout
)
async def delete(
self,
url: str,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
timeout: Optional[int] = None,
) -> Dict[str, Any]:
"""发送DELETE请求"""
return await self.request("DELETE", url, params=params, headers=headers, timeout=timeout)

136
app/utils/response.py Normal file
View File

@ -0,0 +1,136 @@
from typing import Any, Dict, List, Optional, Union
from fastapi import status
from fastapi.responses import JSONResponse
class APIResponse:
"""API标准响应格式工具类"""
@staticmethod
def success(
data: Any = None,
message: str = "操作成功",
code: int = 200
) -> Dict[str, Any]:
"""
成功响应
Args:
data: 响应数据
message: 响应消息
code: 状态码
Returns:
标准响应格式的字典
"""
return {
"success": True,
"code": code,
"message": message,
"data": data
}
@staticmethod
def error(
message: str = "操作失败",
code: int = 400,
data: Any = None
) -> Dict[str, Any]:
"""
错误响应
Args:
message: 错误消息
code: 错误状态码
data: 附加错误数据
Returns:
标准响应格式的字典
"""
return {
"success": False,
"code": code,
"message": message,
"data": data
}
@staticmethod
def json_response(
data: Any = None,
message: str = "操作成功",
code: int = 200,
success: bool = True,
status_code: int = status.HTTP_200_OK,
headers: Dict[str, str] = None
) -> JSONResponse:
"""
返回JSONResponse对象
Args:
data: 响应数据
message: 响应消息
code: 业务状态码
success: 是否成功
status_code: HTTP状态码
headers: 自定义响应头
Returns:
JSONResponse对象
"""
content = {
"success": success,
"code": code,
"message": message,
"data": data
}
return JSONResponse(
content=content,
status_code=status_code,
headers=headers
)
# 常用响应码封装
@classmethod
def ok(cls, data: Any = None, message: str = "操作成功") -> Dict[str, Any]:
"""200 成功"""
return cls.success(data, message, 200)
@classmethod
def created(cls, data: Any = None, message: str = "创建成功") -> Dict[str, Any]:
"""201 创建成功"""
return cls.success(data, message, 201)
@classmethod
def accepted(cls, data: Any = None, message: str = "请求已接受") -> Dict[str, Any]:
"""202 已接受"""
return cls.success(data, message, 202)
@classmethod
def no_content(cls) -> Dict[str, Any]:
"""204 无内容"""
return cls.success(None, "无内容", 204)
@classmethod
def bad_request(cls, message: str = "请求参数错误") -> Dict[str, Any]:
"""400 请求错误"""
return cls.error(message, 400)
@classmethod
def unauthorized(cls, message: str = "未授权") -> Dict[str, Any]:
"""401 未授权"""
return cls.error(message, 401)
@classmethod
def forbidden(cls, message: str = "禁止访问") -> Dict[str, Any]:
"""403 禁止"""
return cls.error(message, 403)
@classmethod
def not_found(cls, message: str = "资源不存在") -> Dict[str, Any]:
"""404 不存在"""
return cls.error(message, 404)
@classmethod
def server_error(cls, message: str = "服务器内部错误") -> Dict[str, Any]:
"""500 服务器错误"""
return cls.error(message, 500)

View File

@ -4,12 +4,7 @@ services:
app: app:
build: build:
context: . context: .
args: dockerfile: Dockerfile
- DB_HOST=db
- DB_PORT=3306
- DB_USER=ai_user
- DB_PASSWORD=yourpassword
- DB_NAME=ai_dressing
container_name: ai-dressing-app container_name: ai-dressing-app
restart: always restart: always
ports: ports:
@ -18,6 +13,7 @@ services:
- ./.env - ./.env
# 环境变量可以覆盖.env文件中的值 # 环境变量可以覆盖.env文件中的值
environment: environment:
- ENV=development
- DB_HOST=db - DB_HOST=db
- DB_PORT=3306 - DB_PORT=3306
- DB_USER=ai_user - DB_USER=ai_user
@ -31,43 +27,55 @@ services:
- QCLOUD_COS_BUCKET=${QCLOUD_COS_BUCKET:-your-bucket-name} - QCLOUD_COS_BUCKET=${QCLOUD_COS_BUCKET:-your-bucket-name}
- QCLOUD_COS_DOMAIN=${QCLOUD_COS_DOMAIN:-https://your-bucket-domain.com} - QCLOUD_COS_DOMAIN=${QCLOUD_COS_DOMAIN:-https://your-bucket-domain.com}
- PYTHONPATH=/app - PYTHONPATH=/app
volumes:
- .:/app
- ./data/uploads:/app/uploads
depends_on: depends_on:
- db - db
volumes:
- ./.env:/app/.env:ro # 明确映射.env文件
# - ./app:/app/app # 开发时才使用,生产环境建议去掉
networks:
- ai-dressing-network
healthcheck: healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8000/health"] test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
interval: 30s interval: 30s
timeout: 10s timeout: 10s
retries: 3 retries: 3
start_period: 10s start_period: 10s
networks:
- ai-dressing-network
db: db:
image: mysql:8.0 image: mysql:8.0
container_name: ai-dressing-db container_name: ai-dressing-db
restart: always restart: always
ports:
- "3306:3306"
environment: environment:
- MYSQL_ROOT_PASSWORD=rootpassword - MYSQL_ROOT_PASSWORD=rootpassword
- MYSQL_DATABASE=ai_dressing
- MYSQL_USER=ai_user - MYSQL_USER=ai_user
- MYSQL_PASSWORD=yourpassword - MYSQL_PASSWORD=yourpassword
- MYSQL_DATABASE=ai_dressing ports:
- "3306:3306"
volumes: volumes:
- mysql-data:/var/lib/mysql - mysql-data:/var/lib/mysql
- ./mysql-init:/docker-entrypoint-initdb.d - ./init-scripts:/docker-entrypoint-initdb.d
command: --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci
networks:
- ai-dressing-network
healthcheck: healthcheck:
test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-u", "root", "-p$$MYSQL_ROOT_PASSWORD"] test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-u", "root", "-p$$MYSQL_ROOT_PASSWORD"]
interval: 30s interval: 5s
timeout: 10s timeout: 5s
retries: 3 retries: 10
start_period: 30s networks:
- ai-dressing-network
adminer:
image: adminer
container_name: ai-dressing-adminer
restart: always
ports:
- "8080:8080"
environment:
- ADMINER_DEFAULT_SERVER=db
depends_on:
- db
networks:
- ai-dressing-network
volumes: volumes:
mysql-data: mysql-data:

View File

@ -1,52 +1,93 @@
#!/bin/bash #!/bin/bash
set -e set -e
echo "============= 容器启动 =============" echo "正在启动AI Dressing服务..."
echo "当前工作目录: $(pwd)"
echo "目录内容:"
ls -la
# 优先级1: 尝试从.env文件加载 # 检查是否存在.env文件不存在则使用内置的.env.built
if [ -f .env ]; then if [ ! -f ".env" ]; then
echo "找到.env文件加载环境变量..." echo "未找到.env文件使用构建时创建的.env.built文件"
# 打印文件内容(隐藏敏感信息) if [ -f ".env.built" ]; then
echo "文件内容预览(敏感信息已隐藏):" cp .env.built .env
grep -v "KEY\|PASSWORD\|SECRET" .env | cat -n echo "已复制.env.built到.env"
else
# 从.env文件导出所有环境变量 echo "警告:未找到.env.built文件将使用系统环境变量"
export $(grep -v '^#' .env | xargs) fi
echo "已从.env加载环境变量"
elif [ -f /app/.env ]; then
echo "在/app目录下找到.env文件加载环境变量..."
export $(grep -v '^#' /app/.env | xargs)
echo "已从/app/.env加载环境变量"
elif [ -f .env.docker ]; then
echo "找到.env.docker文件加载环境变量..."
export $(grep -v '^#' .env.docker | xargs)
echo "已从.env.docker加载环境变量"
else else
echo ".env文件不存在环境变量将从Docker环境中读取..." echo "使用已存在的.env文件"
fi fi
# 打印环境变量,确认是否已正确加载 # 输出关键环境变量(隐藏敏感信息)
echo "============= 环境变量检查 =============" echo "检查关键环境变量:"
echo "DB_HOST: $DB_HOST" # 数据库配置
echo "DB_PORT: $DB_PORT" if [ -n "$DB_HOST" ]; then
echo "DB_USER: $DB_USER" echo "- DB_HOST: 已设置 ✓"
echo "DB_NAME: $DB_NAME" else
echo "DASHSCOPE_API_KEY是否存在: $(if [ -n "$DASHSCOPE_API_KEY" ]; then echo "是"; else echo "否"; fi)" echo "- DB_HOST: 未设置 ❌"
echo "QCLOUD_SECRET_ID是否存在: $(if [ -n "$QCLOUD_SECRET_ID" ]; then echo "是"; else echo "否"; fi)" fi
# 确保Python能找到应用 if [ -n "$DB_PORT" ]; then
export PYTHONPATH=/app:$PYTHONPATH echo "- DB_PORT: 已设置 ✓"
else
echo "- DB_PORT: 未设置 ❌"
fi
echo "============= Python环境 =============" if [ -n "$DB_NAME" ]; then
echo "Python版本: $(python --version)" echo "- DB_NAME: 已设置 ✓"
echo "Python路径: $(which python)" else
echo "PYTHONPATH: $PYTHONPATH" echo "- DB_NAME: 未设置 ❌"
fi
echo "============= 启动应用 =============" # API密钥
echo "执行命令: $@" if [ -n "$DASHSCOPE_API_KEY" ]; then
echo "- DASHSCOPE_API_KEY: 已设置(已隐藏) ✓"
else
echo "- DASHSCOPE_API_KEY: 未设置 ❌"
fi
# 执行原始的命令 if [ -n "$QCLOUD_SECRET_ID" ] && [ -n "$QCLOUD_SECRET_KEY" ]; then
exec "$@" echo "- 腾讯云凭证: 已设置(已隐藏) ✓"
else
echo "- 腾讯云凭证: 未设置 ❌"
fi
# 检查依赖服务连接
echo "检查数据库连接..."
MAX_RETRIES=10
COUNT=0
if [ -n "$DB_HOST" ] && [ -n "$DB_PORT" ]; then
while [ $COUNT -lt $MAX_RETRIES ]; do
if nc -z -w3 $DB_HOST $DB_PORT; then
echo "数据库连接成功!"
break
fi
echo "等待数据库连接... ($((COUNT+1))/$MAX_RETRIES)"
COUNT=$((COUNT+1))
sleep 2
done
if [ $COUNT -eq $MAX_RETRIES ]; then
echo "警告:无法连接到数据库。服务将继续启动,但可能无法正常工作。"
fi
else
echo "跳过数据库连接检查未设置DB_HOST或DB_PORT"
fi
# 应用数据库迁移
echo "应用数据库迁移..."
if [ -f "create_migration.py" ]; then
python create_migration.py upgrade || echo "警告:数据库迁移失败,但将继续启动服务"
else
echo "未找到create_migration.py跳过数据库迁移"
fi
# 启动服务
echo "AI Dressing服务启动中..."
# 根据环境变量决定是否使用热重载
if [ "$ENV" = "development" ]; then
echo "以开发模式启动,启用热重载..."
exec uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
else
echo "以生产模式启动..."
exec uvicorn app.main:app --host 0.0.0.0 --port 8000 --workers 4
fi

View File

@ -0,0 +1,46 @@
-- 设置字符集和排序规则
SET NAMES utf8mb4;
SET GLOBAL time_zone = '+8:00';
-- 创建数据库(如果不存在)
CREATE DATABASE IF NOT EXISTS ai_dressing CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
-- 确保使用正确的数据库
USE ai_dressing;
-- 创建用户(如果不存在)
-- 注意在Docker环境中这些通常由MYSQL_USER和MYSQL_PASSWORD环境变量处理
-- 此处作为备用,确保权限正确设置
CREATE USER IF NOT EXISTS 'ai_user'@'%' IDENTIFIED BY 'yourpassword';
GRANT ALL PRIVILEGES ON ai_dressing.* TO 'ai_user'@'%';
FLUSH PRIVILEGES;
-- 设置必要的MySQL配置
SET GLOBAL max_connections = 500;
SET GLOBAL connect_timeout = 60;
SET GLOBAL wait_timeout = 600;
SET GLOBAL interactive_timeout = 600;
SET GLOBAL max_allowed_packet = 16777216; -- 16MB
-- 创建一个测试表,验证初始化是否成功
CREATE TABLE IF NOT EXISTS `system_info` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`key` varchar(50) NOT NULL,
`value` text NOT NULL,
`description` varchar(255) DEFAULT NULL,
`created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP,
`updated_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
PRIMARY KEY (`id`),
UNIQUE KEY `key` (`key`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci;
-- 插入初始数据
INSERT INTO `system_info` (`key`, `value`, `description`)
VALUES
('system_name', 'AI Dressing System', '系统名称'),
('version', '1.0.0', '系统版本'),
('init_date', CURRENT_TIMESTAMP, '系统初始化日期')
ON DUPLICATE KEY UPDATE
`value` = VALUES(`value`),
`description` = VALUES(`description`),
`updated_at` = CURRENT_TIMESTAMP;

View File

@ -1,13 +1,32 @@
fastapi==0.104.0 # REST API依赖
uvicorn==0.23.2 fastapi>=0.68.0,<0.69.0
dashscope>=1.13.0 uvicorn>=0.15.0,<0.16.0
python-dotenv==1.0.0 python-multipart>=0.0.5,<0.1.0
pydantic==2.4.2 email-validator>=1.1.3,<2.0.0
httpx==0.25.0 pydantic>=1.8.0,<2.0.0
cos-python-sdk-v5==1.9.26 httpx>=0.23.0,<0.24.0
qcloud-python-sts==3.1.4
sqlalchemy==2.0.23 # 数据库依赖
pymysql==1.1.0 sqlalchemy>=1.4.0,<1.5.0
cryptography==41.0.5 alembic>=1.7.0,<1.8.0
alembic==1.12.1 pymysql>=1.0.2,<1.1.0
python-multipart==0.0.12 mysqlclient>=2.1.0,<2.2.0
# 阿里云SDK
dashscope>=1.5.0,<1.6.0
# 腾讯云SDK
cos-python-sdk-v5>=1.9.0,<2.0.0
# 工具库
python-dotenv>=0.19.1,<0.20.0
python-jose[cryptography]>=3.3.0,<3.4.0
passlib[bcrypt]>=1.7.4,<1.8.0
tenacity>=8.0.1,<8.1.0
loguru>=0.5.3,<0.6.0
dynaconf>=3.1.7,<3.2.0
# 测试依赖(开发环境)
pytest>=6.2.5,<6.3.0
pytest-asyncio>=0.18.0,<0.19.0
httpx>=0.23.0,<0.24.0

1
tests/__init__.py Normal file
View File

@ -0,0 +1 @@
# 测试包

1
tests/api/__init__.py Normal file
View File

@ -0,0 +1 @@
# API测试包

58
tests/api/test_health.py Normal file
View File

@ -0,0 +1,58 @@
from fastapi.testclient import TestClient
import pytest
from app.main import app
client = TestClient(app)
def test_root_endpoint():
"""测试根端点"""
response = client.get("/")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["code"] == 200
assert "服务运行中" in data["message"]
def test_health_endpoint():
"""测试健康检查端点"""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert "checks" in data["data"]
assert "database" in data["data"]["checks"]
# 验证响应格式
assert isinstance(data, dict)
assert "success" in data
assert "code" in data
assert "message" in data
assert "data" in data
# 验证success字段为布尔类型
assert isinstance(data["success"], bool)
# 验证code字段为整数
assert isinstance(data["code"], int)
# 验证message字段为字符串
assert isinstance(data["message"], str)
def test_info_endpoint():
"""测试服务信息端点"""
response = client.get("/info")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["code"] == 200
assert "服务信息获取成功" in data["message"]
# 验证响应的数据部分
assert "data" in data
assert "app_name" in data["data"]
assert "version" in data["data"]
assert "debug_mode" in data["data"]
if __name__ == "__main__":
pytest.main(["-v", "test_health.py"])

59
tests/conftest.py Normal file
View File

@ -0,0 +1,59 @@
import os
import sys
import pytest
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
# 添加项目根目录到Python路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from app.main import app
from app.database import Base, get_db
# 使用SQLite内存数据库进行测试
SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
@pytest.fixture(scope="session")
def db():
"""获取测试数据库会话"""
# 创建数据库表
Base.metadata.create_all(bind=engine)
# 创建会话
db = TestingSessionLocal()
try:
yield db
finally:
db.close()
# 清理:删除所有表
Base.metadata.drop_all(bind=engine)
# 删除测试数据库文件
if os.path.exists("./test.db"):
os.remove("./test.db")
@pytest.fixture
def client(db):
"""创建FastAPI测试客户端"""
def override_get_db():
try:
yield db
finally:
pass
# 替换依赖项
app.dependency_overrides[get_db] = override_get_db
# 创建测试客户端
with TestClient(app) as client:
yield client
# 清理:恢复原始依赖项
app.dependency_overrides = {}