diff --git a/.env.docker b/.env.docker deleted file mode 100644 index 9b54728..0000000 --- a/.env.docker +++ /dev/null @@ -1,21 +0,0 @@ -# DashScope API密钥 -DASHSCOPE_API_KEY=sk-caa199589f1c451aaac471fad2986e28 - -# 服务器配置 -HOST=127.0.0.1 -PORT=9001 -DEBUG=True - -# 腾讯云配置 -QCLOUD_SECRET_ID=AKIDxnbGj281iHtKallqqzvlV5YxBCrPltnS -QCLOUD_SECRET_KEY=ta6PXTMBsX7dzA7IN6uYUFn8F9uTovoU -QCLOUD_COS_REGION=ap-chengdu -QCLOUD_COS_BUCKET=aidress-1311994147 -QCLOUD_COS_DOMAIN=https://aidress-1311994147.cos.ap-chengdu.myqcloud.com - -# 数据库配置 -DB_HOST=gz-cynosdbmysql-grp-2j1cnopr.sql.tencentcdb.com -DB_PORT=27469 -DB_USER=root -DB_PASSWORD=Aa#223388 -DB_NAME=aidress \ No newline at end of file diff --git a/.env.example b/.env.example index 9f3fe46..ae22320 100644 --- a/.env.example +++ b/.env.example @@ -1,22 +1,30 @@ +# 基本配置 +ENV=development +HOST=0.0.0.0 +PORT=9001 +DEBUG=true + # 数据库配置 DB_HOST=localhost DB_PORT=3306 DB_USER=ai_user -DB_PASSWORD=your_password +DB_PASSWORD=yourpassword DB_NAME=ai_dressing -# 阿里云大模型API配置 +# 阿里云DashScope配置 DASHSCOPE_API_KEY=your_dashscope_api_key +DASHSCOPE_MODEL_NAME=qwen-vl-plus # 腾讯云配置 QCLOUD_SECRET_ID=your_qcloud_secret_id QCLOUD_SECRET_KEY=your_qcloud_secret_key QCLOUD_COS_REGION=ap-chengdu QCLOUD_COS_BUCKET=your-bucket-name -QCLOUD_COS_DOMAIN=https://your-bucket-name.cos.ap-chengdu.myqcloud.com +QCLOUD_COS_DOMAIN=https://your-bucket-domain.com -# 应用程序配置 -HOST=0.0.0.0 -PORT=9001 -DEBUG=False -LOG_LEVEL=INFO \ No newline at end of file +# 应用特定配置 +UPLOAD_DIR=uploads +LOG_LEVEL=INFO + +# Docker配置 +DOCKER_MYSQL_ROOT_PASSWORD=rootpassword \ No newline at end of file diff --git a/.gitignore b/.gitignore index bb1edf9..11d6e9a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,8 @@ # 环境变量 -# .env +.env +.env.* +!.env.example +.venv # Python __pycache__/ @@ -36,7 +39,25 @@ ENV/ # 日志 *.log +logs/ # 操作系统 .DS_Store -Thumbs.db \ No newline at end of file +Thumbs.db + +# 上传文件 +uploads/ +data/ + +# 数据库 +*.db +*.sqlite +*.sqlite3 + +# Alembic版本 +alembic/versions/* +!alembic/versions/.gitkeep + +# Docker相关 +.docker/ +docker-compose.override.yml \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index e6439ab..4f0d923 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,78 +1,67 @@ -# 使用Python 3.10作为基础镜像 -FROM python:3.10-slim - -# 设置环境变量 -ENV PYTHONUNBUFFERED=1 -ENV PYTHONDONTWRITEBYTECODE=1 +# 使用多阶段构建,先安装构建依赖 +FROM python:3.9-slim AS builder # 设置时区 -RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime -# 清空所有默认源 -RUN rm -rf /etc/apt/sources.list.d/* && \ - rm -f /etc/apt/sources.list +ENV TZ=Asia/Shanghai +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone -# 替换为阿里云源 -RUN echo "\ - deb https://mirrors.aliyun.com/debian/ bookworm main non-free-firmware contrib\n\ - deb-src https://mirrors.aliyun.com/debian/ bookworm main non-free-firmware contrib\n\ - deb https://mirrors.aliyun.com/debian/ bookworm-updates main non-free-firmware contrib\n\ - deb-src https://mirrors.aliyun.com/debian/ bookworm-updates main non-free-firmware contrib\n\ - deb https://mirrors.aliyun.com/debian/ bookworm-backports main non-free-firmware contrib\n\ - deb-src https://mirrors.aliyun.com/debian/ bookworm-backports main non-free-firmware contrib\n\ - deb https://mirrors.aliyun.com/debian-security bookworm-security main non-free-firmware contrib\n\ - deb-src https://mirrors.aliyun.com/debian-security bookworm-security main non-free-firmware contrib\n\ - " > /etc/apt/sources.list +# 设置工作目录 +WORKDIR /build -# 安装系统依赖 -RUN apt-get update \ - && apt-get install -y --no-install-recommends \ - build-essential \ - default-libmysqlclient-dev \ - pkg-config \ - curl \ - nano \ +# 安装netcat用于网络连接检查 +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + build-essential \ + default-libmysqlclient-dev \ + netcat-openbsd \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# 复制项目依赖文件 +COPY requirements.txt . + +# 安装Python依赖 +RUN pip install --no-cache-dir -r requirements.txt + +# 最终阶段,使用轻量镜像 +FROM python:3.9-slim + +# 设置时区 +ENV TZ=Asia/Shanghai +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone + +# 安装运行时依赖 +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + default-libmysqlclient-dev \ + netcat-openbsd \ + && apt-get clean \ && rm -rf /var/lib/apt/lists/* # 设置工作目录 WORKDIR /app -# 安装Python依赖(先于文件复制,利用缓存) -COPY requirements.txt . -RUN pip install -i https://mirrors.aliyun.com/pypi/simple/ -r requirements.txt \ - && pip install -i https://mirrors.aliyun.com/pypi/simple/ uvicorn python-multipart python-dotenv +# 从构建阶段复制已安装的依赖 +COPY --from=builder /usr/local/lib/python3.9/site-packages /usr/local/lib/python3.9/site-packages +COPY --from=builder /usr/local/bin /usr/local/bin -# 复制项目文件 -COPY app app/ -COPY *.py ./ -COPY entrypoint.sh . +# 复制应用代码 +COPY . . -# 创建一个默认的.env.example文件 -RUN echo "# 数据库配置\n\ -DB_HOST=db\n\ -DB_PORT=3306\n\ -DB_USER=ai_user\n\ -DB_PASSWORD=yourpassword\n\ -DB_NAME=ai_dressing\n\ -\n\ -# 阿里云DashScope配置\n\ -DASHSCOPE_API_KEY=your_dashscope_api_key\n\ -\n\ -# 腾讯云配置\n\ -QCLOUD_SECRET_ID=your_qcloud_secret_id\n\ -QCLOUD_SECRET_KEY=your_qcloud_secret_key\n\ -QCLOUD_COS_REGION=ap-chengdu\n\ -QCLOUD_COS_BUCKET=your-bucket-name\n\ -QCLOUD_COS_DOMAIN=https://your-bucket-domain.com\n\ -" > /app/.env.example +# 确保脚本可执行 +RUN chmod +x entrypoint.sh -# 确保entrypoint.sh可执行 -RUN chmod +x /app/entrypoint.sh +# 设置环境变量 +ENV PYTHONPATH=/app:$PYTHONPATH +ENV PORT=8000 +ENV HOST=0.0.0.0 # 暴露端口 EXPOSE 8000 -# 设置入口点 -ENTRYPOINT ["/app/entrypoint.sh"] +# 添加健康检查 +HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8000/health || exit 1 -# 启动命令 -CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--log-level", "info"] \ No newline at end of file +# 设置入口点 +ENTRYPOINT ["./entrypoint.sh"] \ No newline at end of file diff --git a/README.md b/README.md index 3f74b00..9ff3db9 100644 --- a/README.md +++ b/README.md @@ -156,6 +156,85 @@ python run.py docker-compose down -v # 谨慎使用,这将删除数据库中的所有数据 ``` +## Docker构建和环境变量配置指南 + +## 使用.env文件构建Docker镜像 + +本项目提供了多种方式将`.env`文件中的环境变量传递到Docker容器中。 + +### 方法一:使用build.sh脚本(推荐) + +这种方法会自动读取`.env`文件中的所有环境变量,并传递给Docker构建过程: + +```bash +# 赋予执行权限 +chmod +x build.sh + +# 执行构建脚本 +./build.sh +``` + +构建完成后,可以直接运行容器: + +```bash +docker run -p 9001:8000 ai-dressing:latest +``` + +### 方法二:使用Docker Compose + +使用专门的docker-compose文件进行构建,该文件会读取.env中的环境变量: + +```bash +# 先加载.env文件到当前shell +export $(grep -v '^#' .env | xargs) + +# 使用专用的构建配置文件 +docker-compose -f docker-compose.build.yml build + +# 启动服务 +docker-compose -f docker-compose.build.yml up -d +``` + +### 方法三:手动构建 + +如果你希望手动指定部分环境变量,可以使用以下命令: + +```bash +docker build \ + --build-arg DASHSCOPE_API_KEY=your_key \ + --build-arg DB_HOST=your_db_host \ + --build-arg DB_USER=your_db_user \ + --build-arg DB_PASSWORD=your_db_password \ + -t ai-dressing:latest . +``` + +## 验证环境变量 + +构建完成后,你可以通过以下命令验证环境变量是否正确加载: + +```bash +# 查看容器日志 +docker logs ai-dressing-app + +# 或者进入容器查看环境变量 +docker exec -it ai-dressing-app bash -c "env | sort" +``` + +## 环境变量优先级 + +环境变量的加载优先级从高到低如下: + +1. docker run或docker-compose启动时通过`-e`或`environment`指定的环境变量 +2. 通过卷挂载到容器中的`.env`文件 +3. 构建时通过`--build-arg`传入并保存到容器内`.env.built`文件的变量 +4. 代码中的默认值 + +## 注意事项 + +- 生产环境中应避免将敏感信息硬编码在Dockerfile中 +- 敏感信息应通过环境变量或Docker secrets进行管理 +- 建议将`.env`文件添加到`.gitignore`中,避免意外提交 + ## API 文档 启动服务后,访问以下地址查看自动生成的 API 文档: diff --git a/app/exceptions/__init__.py b/app/exceptions/__init__.py new file mode 100644 index 0000000..7bdcaad --- /dev/null +++ b/app/exceptions/__init__.py @@ -0,0 +1,3 @@ +from app.exceptions.http_exception import CustomHTTPException, setup_exception_handlers + +__all__ = ["CustomHTTPException", "setup_exception_handlers"] \ No newline at end of file diff --git a/app/exceptions/http_exception.py b/app/exceptions/http_exception.py new file mode 100644 index 0000000..cbf6309 --- /dev/null +++ b/app/exceptions/http_exception.py @@ -0,0 +1,90 @@ +from fastapi import FastAPI, Request, status +from fastapi.responses import JSONResponse +from fastapi.exceptions import RequestValidationError +from starlette.exceptions import HTTPException as StarletteHTTPException +import logging +from typing import Any, Dict, Optional, Union + +from app.utils.response import APIResponse + +logger = logging.getLogger(__name__) + +class CustomHTTPException(Exception): + """自定义HTTP异常""" + + def __init__( + self, + status_code: int = 400, + code: int = None, + message: str = "操作失败", + data: Any = None + ): + self.status_code = status_code + self.code = code or status_code + self.message = message + self.data = data + super().__init__(self.message) + +def setup_exception_handlers(app: FastAPI) -> None: + """设置FastAPI应用的异常处理程序""" + + @app.exception_handler(CustomHTTPException) + async def custom_http_exception_handler(request: Request, exc: CustomHTTPException): + """处理自定义HTTP异常""" + logger.error(f"自定义HTTP异常: {exc.message} (状态码: {exc.status_code}, 业务码: {exc.code})") + return JSONResponse( + status_code=exc.status_code, + content=APIResponse.error( + message=exc.message, + code=exc.code, + data=exc.data + ) + ) + + @app.exception_handler(StarletteHTTPException) + async def http_exception_handler(request: Request, exc: StarletteHTTPException): + """处理标准HTTP异常""" + logger.error(f"HTTP异常: {exc.detail} (状态码: {exc.status_code})") + return JSONResponse( + status_code=exc.status_code, + content=APIResponse.error( + message=str(exc.detail), + code=exc.status_code + ) + ) + + @app.exception_handler(RequestValidationError) + async def validation_exception_handler(request: Request, exc: RequestValidationError): + """处理请求验证异常""" + error_details = exc.errors() + error_messages = [] + + for error in error_details: + location = " -> ".join(str(loc) for loc in error["loc"]) + message = f"{location}: {error['msg']}" + error_messages.append(message) + + error_message = "; ".join(error_messages) + logger.error(f"验证错误: {error_message}") + + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + content=APIResponse.error( + message="请求参数验证失败", + code=status.HTTP_422_UNPROCESSABLE_ENTITY, + data={"details": error_details} + ) + ) + + @app.exception_handler(Exception) + async def general_exception_handler(request: Request, exc: Exception): + """处理通用异常""" + logger.exception(f"未处理的异常: {str(exc)}") + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=APIResponse.error( + message="服务器内部错误", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + data={"detail": str(exc)} if app.debug else None + ) + ) \ No newline at end of file diff --git a/app/main.py b/app/main.py index 782b0e4..83492ae 100644 --- a/app/main.py +++ b/app/main.py @@ -6,7 +6,10 @@ from dotenv import load_dotenv from app.routers import qcloud_router, dress_router, tryon_router from app.utils.config import get_settings +from app.utils.response import APIResponse from app.database import Base, engine +from app.middleware import ResponseWrapperMiddleware +from app.exceptions import setup_exception_handlers # 创建数据库表 Base.metadata.create_all(bind=engine) @@ -36,6 +39,12 @@ app.add_middleware( allow_headers=["*"], ) +# 添加响应包装中间件 +app.add_middleware(ResponseWrapperMiddleware) + +# 设置异常处理 +setup_exception_handlers(app) + # 注册路由 app.include_router(qcloud_router.router, prefix="/api/qcloud", tags=["腾讯云"]) app.include_router(dress_router.router, prefix="/api/dresses", tags=["服装"]) @@ -44,22 +53,53 @@ app.include_router(tryon_router.router, prefix="/api/tryons", tags=["试穿"]) @app.get("/", tags=["健康检查"]) async def root(): """API 根端点""" - return {"status": "正常", "message": "服务运行中"} + return APIResponse.ok(message="服务运行中") @app.get("/health", tags=["健康检查"]) async def health_check(): """健康检查端点,用于Docker容器健康监控""" - return {"status": "healthy", "message": "服务运行正常"} + health_data = {"checks": {}} + health_status = "healthy" + health_message = "服务运行正常" + + # 检查数据库连接 + try: + from app.database import SessionLocal + db = SessionLocal() + db.execute("SELECT 1") + db.close() + health_data["checks"]["database"] = {"status": "healthy", "message": "数据库连接正常"} + except Exception as e: + health_status = "unhealthy" + health_message = "服务异常" + health_data["checks"]["database"] = {"status": "unhealthy", "message": f"数据库连接失败: {str(e)}"} + + # 检查API密钥 + if os.getenv("DASHSCOPE_API_KEY"): + health_data["checks"]["dashscope"] = {"status": "configured", "message": "DashScope API密钥已配置"} + else: + health_data["checks"]["dashscope"] = {"status": "missing", "message": "DashScope API密钥未配置"} + + if os.getenv("QCLOUD_SECRET_ID") and os.getenv("QCLOUD_SECRET_KEY"): + health_data["checks"]["qcloud"] = {"status": "configured", "message": "腾讯云凭证已配置"} + else: + health_data["checks"]["qcloud"] = {"status": "missing", "message": "腾讯云凭证未配置"} + + if health_status == "healthy": + return APIResponse.ok(data=health_data, message=health_message) + else: + return APIResponse.error(message=health_message, code=503, data=health_data) @app.get("/info", tags=["服务信息"]) async def get_info(): """获取服务基本信息""" settings = get_settings() - return { + info_data = { "app_name": "AI-Dressing API", "version": "0.1.0", "debug_mode": settings.debug, } + return APIResponse.ok(data=info_data, message="服务信息获取成功") if __name__ == "__main__": import uvicorn diff --git a/app/middleware/__init__.py b/app/middleware/__init__.py new file mode 100644 index 0000000..8d2216f --- /dev/null +++ b/app/middleware/__init__.py @@ -0,0 +1,3 @@ +from app.middleware.response_wrapper import ResponseWrapperMiddleware + +__all__ = ["ResponseWrapperMiddleware"] \ No newline at end of file diff --git a/app/middleware/response_wrapper.py b/app/middleware/response_wrapper.py new file mode 100644 index 0000000..7ac716b --- /dev/null +++ b/app/middleware/response_wrapper.py @@ -0,0 +1,86 @@ +from fastapi import Request, Response +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware +from typing import Any, Dict, Optional, Union +import json + +class ResponseWrapperMiddleware(BaseHTTPMiddleware): + """ + 响应包装中间件,自动将所有API响应包装为标准格式 + 格式为: + { + "success": true/false, + "code": 200, + "message": "操作成功", + "data": 原始响应数据 + } + """ + + async def dispatch( + self, request: Request, call_next + ) -> Response: + # 排除不需要包装的路径 + if self._should_skip_path(request.url.path): + return await call_next(request) + + # 调用下一个中间件或路由处理函数 + response = await call_next(request) + + # 如果响应是JSON,且未使用标准包装格式,则进行包装 + if ( + isinstance(response, JSONResponse) and + self._should_wrap_response(response) + ): + return self._wrap_response(response) + + return response + + def _should_skip_path(self, path: str) -> bool: + """判断是否跳过包装处理""" + # 跳过文档相关路径 + skip_paths = ["/docs", "/redoc", "/openapi.json"] + + for skip_path in skip_paths: + if path.startswith(skip_path): + return True + + return False + + def _should_wrap_response(self, response: JSONResponse) -> bool: + """判断是否需要包装响应""" + try: + content = response.body.decode() + data = json.loads(content) + + # 已经是标准格式则不需要再包装 + if isinstance(data, dict) and "success" in data and "code" in data and "message" in data: + return False + + return True + except Exception: + return False + + def _wrap_response(self, response: JSONResponse) -> JSONResponse: + """包装响应为标准格式""" + try: + # 解析原始响应内容 + content = response.body.decode() + data = json.loads(content) + + # 构造标准格式响应 + wrapped_data = { + "success": response.status_code < 400, + "code": response.status_code, + "message": "操作成功" if response.status_code < 400 else "操作失败", + "data": data + } + + # 创建新的响应 + return JSONResponse( + content=wrapped_data, + status_code=response.status_code, + headers=dict(response.headers), + ) + except Exception: + # 出错时返回原始响应 + return response \ No newline at end of file diff --git a/app/routers/dress_router.py b/app/routers/dress_router.py index 4eea999..b0fe21d 100644 --- a/app/routers/dress_router.py +++ b/app/routers/dress_router.py @@ -1,162 +1,90 @@ -from fastapi import APIRouter, HTTPException, Depends, Query, Path +from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session -from typing import List, Optional -import logging +from typing import List from app.database import get_db -from app.models.dress import Dress, GarmentType -from app.schemas.dress import DressCreate, DressUpdate, DressResponse +from app.models.dress import Dress +from app.schemas.dress import DressCreate, DressUpdate, DressResponse, DressListResponse +from app.utils.response import APIResponse -logger = logging.getLogger(__name__) router = APIRouter() -@router.post("/", response_model=DressResponse, status_code=201) -async def create_dress( - dress: DressCreate, - db: Session = Depends(get_db) -): - """ - 创建一个新的服装记录 - - - **name**: 服装名称(必填) - - **image_url**: 服装图片URL(可选) - - **garment_type**: 服装类型(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)(可选) - - **description**: 服装描述(可选) - """ +@router.post("/", response_model=DressResponse) +def create_dress(dress: DressCreate, db: Session = Depends(get_db)): + """创建服装记录""" try: - db_dress = Dress( - name=dress.name, - image_url=dress.image_url, - garment_type=dress.garment_type, - description=dress.description - ) + db_dress = Dress(**dress.dict()) db.add(db_dress) db.commit() db.refresh(db_dress) - return db_dress + return APIResponse.created(data=db_dress, message="服装创建成功") except Exception as e: - logger.error(f"创建服装记录失败: {str(e)}") db.rollback() - raise HTTPException(status_code=500, detail=f"创建服装记录失败: {str(e)}") + return APIResponse.error(message=f"服装创建失败: {str(e)}", code=500) -@router.get("/", response_model=List[DressResponse]) -async def get_dresses( - skip: int = Query(0, description="跳过的记录数量"), - limit: int = Query(100, description="返回的最大记录数量"), - name: Optional[str] = Query(None, description="按名称过滤"), - garment_type: Optional[str] = Query(None, description="按服装类型过滤(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)"), +@router.get("/", response_model=DressListResponse) +def get_all_dresses( + skip: int = 0, + limit: int = 100, db: Session = Depends(get_db) ): - """ - 获取服装列表,支持分页和过滤 - - - **skip**: 跳过的记录数量,用于分页 - - **limit**: 返回的最大记录数量,用于分页 - - **name**: 按名称过滤(可选) - - **garment_type**: 按服装类型过滤(可选) - """ + """获取所有服装记录""" try: - query = db.query(Dress) - - if name: - query = query.filter(Dress.name.ilike(f"%{name}%")) - - if garment_type: - try: - garment_type_enum = GarmentType[garment_type] - query = query.filter(Dress.garment_type == garment_type_enum) - except KeyError: - logger.warning(f"无效的服装类型: {garment_type}") - # 继续查询,但不应用无效的过滤条件 - - dresses = query.offset(skip).limit(limit).all() - return dresses + dresses = db.query(Dress).offset(skip).limit(limit).all() + total = db.query(Dress).count() + return APIResponse.ok( + data={ + "items": dresses, + "total": total, + "page": skip // limit + 1 if limit > 0 else 1, + "size": limit + }, + message="服装列表获取成功" + ) except Exception as e: - logger.error(f"获取服装列表失败: {str(e)}") - raise HTTPException(status_code=500, detail=f"获取服装列表失败: {str(e)}") + return APIResponse.error(message=f"获取服装列表失败: {str(e)}", code=500) @router.get("/{dress_id}", response_model=DressResponse) -async def get_dress( - dress_id: int = Path(..., description="服装ID"), - db: Session = Depends(get_db) -): - """ - 根据ID获取服装详情 - - - **dress_id**: 服装ID - """ +def get_dress(dress_id: int, db: Session = Depends(get_db)): + """获取单个服装记录""" try: dress = db.query(Dress).filter(Dress.id == dress_id).first() - if not dress: - raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装") - return dress - except HTTPException: - raise + if dress is None: + return APIResponse.not_found(message=f"未找到ID为{dress_id}的服装") + return APIResponse.ok(data=dress, message="服装获取成功") except Exception as e: - logger.error(f"获取服装详情失败: {str(e)}") - raise HTTPException(status_code=500, detail=f"获取服装详情失败: {str(e)}") + return APIResponse.error(message=f"获取服装失败: {str(e)}", code=500) @router.put("/{dress_id}", response_model=DressResponse) -async def update_dress( - dress_id: int = Path(..., description="服装ID"), - dress: DressUpdate = None, - db: Session = Depends(get_db) -): - """ - 更新服装信息 - - - **dress_id**: 服装ID - - **name**: 服装名称(可选) - - **image_url**: 服装图片URL(可选) - - **garment_type**: 服装类型(TOP_GARMENT:上衣, BOTTOM_GARMENT:下衣)(可选) - - **description**: 服装描述(可选) - """ +def update_dress(dress_id: int, dress: DressUpdate, db: Session = Depends(get_db)): + """更新服装记录""" try: db_dress = db.query(Dress).filter(Dress.id == dress_id).first() - if not db_dress: - raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装") - - # 更新提供的字段 - if dress.name is not None: - db_dress.name = dress.name - if dress.image_url is not None: - db_dress.image_url = dress.image_url - if dress.garment_type is not None: - db_dress.garment_type = dress.garment_type - if dress.description is not None: - db_dress.description = dress.description + if db_dress is None: + return APIResponse.not_found(message=f"未找到ID为{dress_id}的服装") + + # 更新服装字段 + for field, value in dress.dict(exclude_unset=True).items(): + setattr(db_dress, field, value) db.commit() db.refresh(db_dress) - return db_dress - except HTTPException: - raise + return APIResponse.ok(data=db_dress, message="服装更新成功") except Exception as e: - logger.error(f"更新服装信息失败: {str(e)}") db.rollback() - raise HTTPException(status_code=500, detail=f"更新服装信息失败: {str(e)}") + return APIResponse.error(message=f"更新服装失败: {str(e)}", code=500) -@router.delete("/{dress_id}", status_code=204) -async def delete_dress( - dress_id: int = Path(..., description="服装ID"), - db: Session = Depends(get_db) -): - """ - 删除服装 - - - **dress_id**: 服装ID - """ +@router.delete("/{dress_id}") +def delete_dress(dress_id: int, db: Session = Depends(get_db)): + """删除服装记录""" try: db_dress = db.query(Dress).filter(Dress.id == dress_id).first() - if not db_dress: - raise HTTPException(status_code=404, detail=f"未找到ID为{dress_id}的服装") - + if db_dress is None: + return APIResponse.not_found(message=f"未找到ID为{dress_id}的服装") + db.delete(db_dress) db.commit() - return None - except HTTPException: - raise + return APIResponse.ok(message="服装删除成功") except Exception as e: - logger.error(f"删除服装失败: {str(e)}") db.rollback() - raise HTTPException(status_code=500, detail=f"删除服装失败: {str(e)}") \ No newline at end of file + return APIResponse.error(message=f"删除服装失败: {str(e)}", code=500) \ No newline at end of file diff --git a/app/routers/tryon_router.py b/app/routers/tryon_router.py index 0b43041..b05d5aa 100644 --- a/app/routers/tryon_router.py +++ b/app/routers/tryon_router.py @@ -13,7 +13,7 @@ from app.schemas.tryon import ( AiTryonRequest, AiTryonResponse, TaskInfo ) from app.services.dashscope_service import DashScopeService - +from app.utils.response import APIResponse # 加载环境变量 load_dotenv() @@ -63,7 +63,7 @@ async def create_tryon( person_image_url=tryon_data.person_image_url ) - return db_tryon + return APIResponse.ok(data=db_tryon, message="试穿记录创建成功") except Exception as e: logger.error(f"创建试穿记录失败: {str(e)}") db.rollback() @@ -106,7 +106,7 @@ async def send_tryon_request( db_tryon.task_status = "ERROR" db.commit() -@router.get("/", response_model=List[TryOnResponse]) +@router.get("", response_model=List[TryOnResponse]) async def get_tryons( skip: int = Query(0, description="跳过的记录数量"), limit: int = Query(100, description="返回的最大记录数量"), @@ -119,8 +119,8 @@ async def get_tryons( - **limit**: 返回的最大记录数量,用于分页 """ try: - tryons = db.query(TryOn).order_by(TryOn.created_at.desc()).offset(skip).limit(limit).all() - return tryons + tryons = db.query(TryOn).order_by(TryOn.id.desc()).offset(skip).limit(limit).all() + return APIResponse.ok(data=tryons, message="试穿记录列表获取成功") except Exception as e: logger.error(f"获取试穿记录列表失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取试穿记录列表失败: {str(e)}") @@ -139,7 +139,7 @@ async def get_tryon( tryon = db.query(TryOn).filter(TryOn.id == tryon_id).first() if not tryon: raise HTTPException(status_code=404, detail=f"未找到ID为{tryon_id}的试穿记录") - return tryon + return APIResponse.ok(data=tryon, message="试穿记录获取成功") except HTTPException: raise except Exception as e: @@ -172,7 +172,7 @@ async def update_tryon( db.commit() db.refresh(db_tryon) - return db_tryon + return APIResponse.ok(data=db_tryon, message="试穿记录更新成功") except HTTPException: raise except Exception as e: @@ -217,7 +217,7 @@ async def check_tryon_status( except Exception as e: logger.error(f"调用DashScope API检查任务状态失败: {str(e)}") - return db_tryon + return APIResponse.ok(data=db_tryon, message="试穿任务状态检查成功") except HTTPException: raise except Exception as e: diff --git a/app/schemas/common.py b/app/schemas/common.py new file mode 100644 index 0000000..c992035 --- /dev/null +++ b/app/schemas/common.py @@ -0,0 +1,56 @@ +from pydantic import BaseModel, Field +from typing import Any, Dict, Generic, List, Optional, TypeVar, Union + +T = TypeVar('T') + +class StandardResponse(BaseModel, Generic[T]): + """标准API响应模型""" + success: bool = Field(True, description="请求是否成功") + code: int = Field(200, description="业务状态码") + message: str = Field("操作成功", description="响应消息") + data: Optional[T] = Field(None, description="响应数据") + +class PageInfo(BaseModel): + """分页信息""" + page: int = Field(..., description="当前页码") + size: int = Field(..., description="每页大小") + total: int = Field(..., description="总记录数") + pages: int = Field(..., description="总页数") + +class PagedResponse(BaseModel, Generic[T]): + """分页响应数据""" + items: List[T] = Field(..., description="数据列表") + page_info: PageInfo = Field(..., description="分页信息") + +class ErrorDetail(BaseModel): + """错误详情""" + loc: List[str] = Field(..., description="错误位置") + msg: str = Field(..., description="错误消息") + type: str = Field(..., description="错误类型") + +class ValidationError(BaseModel): + """验证错误响应""" + detail: List[ErrorDetail] + +class HealthCheck(BaseModel): + """健康检查项目""" + status: str = Field(..., description="状态") + message: str = Field(..., description="消息") + +class HealthCheckResponse(BaseModel): + """健康检查响应""" + checks: Dict[str, HealthCheck] = Field(..., description="检查项目") + +class TokenResponse(BaseModel): + """令牌响应""" + access_token: str = Field(..., description="访问令牌") + token_type: str = Field("bearer", description="令牌类型") + expires_in: int = Field(..., description="过期时间(秒)") + +class FileUploadResponse(BaseModel): + """文件上传响应""" + file_id: str = Field(..., description="文件ID") + url: str = Field(..., description="文件URL") + file_name: str = Field(..., description="文件名") + content_type: str = Field(..., description="内容类型") + size: int = Field(..., description="文件大小(字节)") \ No newline at end of file diff --git a/app/schemas/dress.py b/app/schemas/dress.py index 53b02b7..08f2281 100644 --- a/app/schemas/dress.py +++ b/app/schemas/dress.py @@ -1,9 +1,13 @@ from pydantic import BaseModel, Field, HttpUrl -from typing import Optional +from typing import Optional, List from datetime import datetime from enum import Enum from app.models.dress import GarmentType +class GarmentType(str, Enum): + TOP_GARMENT = "TOP_GARMENT" + BOTTOM_GARMENT = "BOTTOM_GARMENT" + class DressBase(BaseModel): """服装基础模型""" name: str = Field(..., description="服装名称", example="夏季连衣裙") @@ -33,4 +37,20 @@ class DressInDB(DressBase): class DressResponse(DressInDB): """服装API响应模型""" - pass \ No newline at end of file + pass + +class DressListResponse(BaseModel): + items: List[DressResponse] + total: int + page: int + size: int + + class Config: + orm_mode = True + +class StandardResponse(BaseModel): + """标准API响应格式""" + success: bool = True + code: int = 200 + message: str = "操作成功" + data: Optional[dict] = None \ No newline at end of file diff --git a/app/utils/client.py b/app/utils/client.py new file mode 100644 index 0000000..f0bfd63 --- /dev/null +++ b/app/utils/client.py @@ -0,0 +1,192 @@ +import httpx +import json +import logging +from typing import Any, Dict, List, Optional, Union +from urllib.parse import urljoin + +logger = logging.getLogger(__name__) + +class HttpClient: + """ + HTTP客户端工具类 + 封装了请求和响应处理逻辑 + """ + + def __init__( + self, + base_url: str = "", + timeout: int = 30, + headers: Optional[Dict[str, str]] = None, + verify_ssl: bool = True + ): + """ + 初始化HTTP客户端 + + Args: + base_url: API基础URL + timeout: 请求超时时间(秒) + headers: 默认请求头 + verify_ssl: 是否验证SSL证书 + """ + self.base_url = base_url + self.timeout = timeout + self.headers = headers or {} + self.verify_ssl = verify_ssl + + async def request( + self, + method: str, + url: str, + params: Optional[Dict[str, Any]] = None, + data: Optional[Any] = None, + json_data: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + timeout: Optional[int] = None, + ) -> Dict[str, Any]: + """ + 发送HTTP请求 + + Args: + method: 请求方法 (GET, POST, PUT, DELETE等) + url: 请求URL路径 + params: URL查询参数 + data: 表单数据或二进制数据 + json_data: JSON数据 + headers: 请求头 + timeout: 超时时间(秒) + + Returns: + 响应数据 + + Raises: + Exception: 请求失败时抛出异常 + """ + if self.base_url: + full_url = urljoin(self.base_url, url) + else: + full_url = url + + request_headers = {**self.headers} + if headers: + request_headers.update(headers) + + timeout_value = timeout or self.timeout + + try: + async with httpx.AsyncClient(verify=self.verify_ssl) as client: + response = await client.request( + method=method, + url=full_url, + params=params, + data=data, + json=json_data, + headers=request_headers, + timeout=timeout_value, + ) + + # 记录请求和响应 + logger.debug(f"HTTP请求: {method} {full_url}") + logger.debug(f"状态码: {response.status_code}") + + # 尝试解析JSON响应 + response_data = None + try: + response_data = response.json() + except json.JSONDecodeError: + response_data = {"content": response.text} + + # 检查状态码 + response.raise_for_status() + + # 验证是否成功 + self._verify_success(response_data, response.status_code) + + return response_data + + except httpx.HTTPStatusError as e: + logger.error(f"HTTP状态错误: {e.response.status_code} - {e.response.text}") + try: + error_data = e.response.json() + msg = error_data.get("message", str(e)) + except json.JSONDecodeError: + msg = e.response.text or str(e) + raise Exception(f"请求失败 ({e.response.status_code}): {msg}") + + except httpx.RequestError as e: + logger.error(f"请求错误: {str(e)}") + raise Exception(f"请求错误: {str(e)}") + + except Exception as e: + logger.error(f"未知错误: {str(e)}") + raise + + def _verify_success(self, data: Dict[str, Any], status_code: int) -> None: + """ + 验证响应是否成功 + + Args: + data: 响应数据 + status_code: HTTP状态码 + + Raises: + Exception: 验证失败时抛出异常 + """ + # 检查HTTP状态码 + if status_code >= 400: + raise Exception(f"请求失败,状态码: {status_code}") + + # 检查响应体中的success字段(如果存在) + if isinstance(data, dict) and "success" in data and data["success"] is False: + msg = data.get("message", "未知错误") + code = data.get("code", status_code) + raise Exception(f"请求失败 (code: {code}): {msg}") + + async def get( + self, + url: str, + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + timeout: Optional[int] = None, + ) -> Dict[str, Any]: + """发送GET请求""" + return await self.request("GET", url, params=params, headers=headers, timeout=timeout) + + async def post( + self, + url: str, + data: Optional[Any] = None, + json_data: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + timeout: Optional[int] = None, + ) -> Dict[str, Any]: + """发送POST请求""" + return await self.request( + "POST", url, params=params, data=data, json_data=json_data, + headers=headers, timeout=timeout + ) + + async def put( + self, + url: str, + data: Optional[Any] = None, + json_data: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + timeout: Optional[int] = None, + ) -> Dict[str, Any]: + """发送PUT请求""" + return await self.request( + "PUT", url, params=params, data=data, json_data=json_data, + headers=headers, timeout=timeout + ) + + async def delete( + self, + url: str, + params: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + timeout: Optional[int] = None, + ) -> Dict[str, Any]: + """发送DELETE请求""" + return await self.request("DELETE", url, params=params, headers=headers, timeout=timeout) \ No newline at end of file diff --git a/app/utils/response.py b/app/utils/response.py new file mode 100644 index 0000000..42514ae --- /dev/null +++ b/app/utils/response.py @@ -0,0 +1,136 @@ +from typing import Any, Dict, List, Optional, Union +from fastapi import status +from fastapi.responses import JSONResponse + +class APIResponse: + """API标准响应格式工具类""" + + @staticmethod + def success( + data: Any = None, + message: str = "操作成功", + code: int = 200 + ) -> Dict[str, Any]: + """ + 成功响应 + + Args: + data: 响应数据 + message: 响应消息 + code: 状态码 + + Returns: + 标准响应格式的字典 + """ + return { + "success": True, + "code": code, + "message": message, + "data": data + } + + @staticmethod + def error( + message: str = "操作失败", + code: int = 400, + data: Any = None + ) -> Dict[str, Any]: + """ + 错误响应 + + Args: + message: 错误消息 + code: 错误状态码 + data: 附加错误数据 + + Returns: + 标准响应格式的字典 + """ + return { + "success": False, + "code": code, + "message": message, + "data": data + } + + @staticmethod + def json_response( + data: Any = None, + message: str = "操作成功", + code: int = 200, + success: bool = True, + status_code: int = status.HTTP_200_OK, + headers: Dict[str, str] = None + ) -> JSONResponse: + """ + 返回JSONResponse对象 + + Args: + data: 响应数据 + message: 响应消息 + code: 业务状态码 + success: 是否成功 + status_code: HTTP状态码 + headers: 自定义响应头 + + Returns: + JSONResponse对象 + """ + content = { + "success": success, + "code": code, + "message": message, + "data": data + } + + return JSONResponse( + content=content, + status_code=status_code, + headers=headers + ) + + # 常用响应码封装 + @classmethod + def ok(cls, data: Any = None, message: str = "操作成功") -> Dict[str, Any]: + """200 成功""" + return cls.success(data, message, 200) + + @classmethod + def created(cls, data: Any = None, message: str = "创建成功") -> Dict[str, Any]: + """201 创建成功""" + return cls.success(data, message, 201) + + @classmethod + def accepted(cls, data: Any = None, message: str = "请求已接受") -> Dict[str, Any]: + """202 已接受""" + return cls.success(data, message, 202) + + @classmethod + def no_content(cls) -> Dict[str, Any]: + """204 无内容""" + return cls.success(None, "无内容", 204) + + @classmethod + def bad_request(cls, message: str = "请求参数错误") -> Dict[str, Any]: + """400 请求错误""" + return cls.error(message, 400) + + @classmethod + def unauthorized(cls, message: str = "未授权") -> Dict[str, Any]: + """401 未授权""" + return cls.error(message, 401) + + @classmethod + def forbidden(cls, message: str = "禁止访问") -> Dict[str, Any]: + """403 禁止""" + return cls.error(message, 403) + + @classmethod + def not_found(cls, message: str = "资源不存在") -> Dict[str, Any]: + """404 不存在""" + return cls.error(message, 404) + + @classmethod + def server_error(cls, message: str = "服务器内部错误") -> Dict[str, Any]: + """500 服务器错误""" + return cls.error(message, 500) \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index d9d1dfe..1d7ab9d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -4,12 +4,7 @@ services: app: build: context: . - args: - - DB_HOST=db - - DB_PORT=3306 - - DB_USER=ai_user - - DB_PASSWORD=yourpassword - - DB_NAME=ai_dressing + dockerfile: Dockerfile container_name: ai-dressing-app restart: always ports: @@ -18,6 +13,7 @@ services: - ./.env # 环境变量可以覆盖.env文件中的值 environment: + - ENV=development - DB_HOST=db - DB_PORT=3306 - DB_USER=ai_user @@ -31,43 +27,55 @@ services: - QCLOUD_COS_BUCKET=${QCLOUD_COS_BUCKET:-your-bucket-name} - QCLOUD_COS_DOMAIN=${QCLOUD_COS_DOMAIN:-https://your-bucket-domain.com} - PYTHONPATH=/app + volumes: + - .:/app + - ./data/uploads:/app/uploads depends_on: - db - volumes: - - ./.env:/app/.env:ro # 明确映射.env文件 - # - ./app:/app/app # 开发时才使用,生产环境建议去掉 - networks: - - ai-dressing-network healthcheck: test: ["CMD", "curl", "-f", "http://localhost:8000/health"] interval: 30s timeout: 10s retries: 3 start_period: 10s + networks: + - ai-dressing-network db: image: mysql:8.0 container_name: ai-dressing-db restart: always - ports: - - "3306:3306" environment: - MYSQL_ROOT_PASSWORD=rootpassword + - MYSQL_DATABASE=ai_dressing - MYSQL_USER=ai_user - MYSQL_PASSWORD=yourpassword - - MYSQL_DATABASE=ai_dressing + ports: + - "3306:3306" volumes: - mysql-data:/var/lib/mysql - - ./mysql-init:/docker-entrypoint-initdb.d - command: --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci - networks: - - ai-dressing-network + - ./init-scripts:/docker-entrypoint-initdb.d + command: --default-authentication-plugin=mysql_native_password --character-set-server=utf8mb4 --collation-server=utf8mb4_unicode_ci healthcheck: test: ["CMD", "mysqladmin", "ping", "-h", "localhost", "-u", "root", "-p$$MYSQL_ROOT_PASSWORD"] - interval: 30s - timeout: 10s - retries: 3 - start_period: 30s + interval: 5s + timeout: 5s + retries: 10 + networks: + - ai-dressing-network + + adminer: + image: adminer + container_name: ai-dressing-adminer + restart: always + ports: + - "8080:8080" + environment: + - ADMINER_DEFAULT_SERVER=db + depends_on: + - db + networks: + - ai-dressing-network volumes: mysql-data: diff --git a/entrypoint.sh b/entrypoint.sh index 567dcff..92a9900 100755 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,52 +1,93 @@ #!/bin/bash set -e -echo "============= 容器启动 =============" -echo "当前工作目录: $(pwd)" -echo "目录内容:" -ls -la +echo "正在启动AI Dressing服务..." -# 优先级1: 尝试从.env文件加载 -if [ -f .env ]; then - echo "找到.env文件,加载环境变量..." - # 打印文件内容(隐藏敏感信息) - echo "文件内容预览(敏感信息已隐藏):" - grep -v "KEY\|PASSWORD\|SECRET" .env | cat -n - - # 从.env文件导出所有环境变量 - export $(grep -v '^#' .env | xargs) - echo "已从.env加载环境变量" -elif [ -f /app/.env ]; then - echo "在/app目录下找到.env文件,加载环境变量..." - export $(grep -v '^#' /app/.env | xargs) - echo "已从/app/.env加载环境变量" -elif [ -f .env.docker ]; then - echo "找到.env.docker文件,加载环境变量..." - export $(grep -v '^#' .env.docker | xargs) - echo "已从.env.docker加载环境变量" +# 检查是否存在.env文件,不存在则使用内置的.env.built +if [ ! -f ".env" ]; then + echo "未找到.env文件,使用构建时创建的.env.built文件" + if [ -f ".env.built" ]; then + cp .env.built .env + echo "已复制.env.built到.env" + else + echo "警告:未找到.env.built文件,将使用系统环境变量" + fi else - echo ".env文件不存在,环境变量将从Docker环境中读取..." + echo "使用已存在的.env文件" fi -# 打印环境变量,确认是否已正确加载 -echo "============= 环境变量检查 =============" -echo "DB_HOST: $DB_HOST" -echo "DB_PORT: $DB_PORT" -echo "DB_USER: $DB_USER" -echo "DB_NAME: $DB_NAME" -echo "DASHSCOPE_API_KEY是否存在: $(if [ -n "$DASHSCOPE_API_KEY" ]; then echo "是"; else echo "否"; fi)" -echo "QCLOUD_SECRET_ID是否存在: $(if [ -n "$QCLOUD_SECRET_ID" ]; then echo "是"; else echo "否"; fi)" +# 输出关键环境变量(隐藏敏感信息) +echo "检查关键环境变量:" +# 数据库配置 +if [ -n "$DB_HOST" ]; then + echo "- DB_HOST: 已设置 ✓" +else + echo "- DB_HOST: 未设置 ❌" +fi -# 确保Python能找到应用 -export PYTHONPATH=/app:$PYTHONPATH +if [ -n "$DB_PORT" ]; then + echo "- DB_PORT: 已设置 ✓" +else + echo "- DB_PORT: 未设置 ❌" +fi -echo "============= Python环境 =============" -echo "Python版本: $(python --version)" -echo "Python路径: $(which python)" -echo "PYTHONPATH: $PYTHONPATH" +if [ -n "$DB_NAME" ]; then + echo "- DB_NAME: 已设置 ✓" +else + echo "- DB_NAME: 未设置 ❌" +fi -echo "============= 启动应用 =============" -echo "执行命令: $@" +# API密钥 +if [ -n "$DASHSCOPE_API_KEY" ]; then + echo "- DASHSCOPE_API_KEY: 已设置(已隐藏) ✓" +else + echo "- DASHSCOPE_API_KEY: 未设置 ❌" +fi -# 执行原始的命令 -exec "$@" \ No newline at end of file +if [ -n "$QCLOUD_SECRET_ID" ] && [ -n "$QCLOUD_SECRET_KEY" ]; then + echo "- 腾讯云凭证: 已设置(已隐藏) ✓" +else + echo "- 腾讯云凭证: 未设置 ❌" +fi + +# 检查依赖服务连接 +echo "检查数据库连接..." +MAX_RETRIES=10 +COUNT=0 + +if [ -n "$DB_HOST" ] && [ -n "$DB_PORT" ]; then + while [ $COUNT -lt $MAX_RETRIES ]; do + if nc -z -w3 $DB_HOST $DB_PORT; then + echo "数据库连接成功!" + break + fi + echo "等待数据库连接... ($((COUNT+1))/$MAX_RETRIES)" + COUNT=$((COUNT+1)) + sleep 2 + done + + if [ $COUNT -eq $MAX_RETRIES ]; then + echo "警告:无法连接到数据库。服务将继续启动,但可能无法正常工作。" + fi +else + echo "跳过数据库连接检查:未设置DB_HOST或DB_PORT" +fi + +# 应用数据库迁移 +echo "应用数据库迁移..." +if [ -f "create_migration.py" ]; then + python create_migration.py upgrade || echo "警告:数据库迁移失败,但将继续启动服务" +else + echo "未找到create_migration.py,跳过数据库迁移" +fi + +# 启动服务 +echo "AI Dressing服务启动中..." +# 根据环境变量决定是否使用热重载 +if [ "$ENV" = "development" ]; then + echo "以开发模式启动,启用热重载..." + exec uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload +else + echo "以生产模式启动..." + exec uvicorn app.main:app --host 0.0.0.0 --port 8000 --workers 4 +fi \ No newline at end of file diff --git a/init-scripts/01-init-db.sql b/init-scripts/01-init-db.sql new file mode 100644 index 0000000..88d2a64 --- /dev/null +++ b/init-scripts/01-init-db.sql @@ -0,0 +1,46 @@ +-- 设置字符集和排序规则 +SET NAMES utf8mb4; +SET GLOBAL time_zone = '+8:00'; + +-- 创建数据库(如果不存在) +CREATE DATABASE IF NOT EXISTS ai_dressing CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + +-- 确保使用正确的数据库 +USE ai_dressing; + +-- 创建用户(如果不存在) +-- 注意:在Docker环境中,这些通常由MYSQL_USER和MYSQL_PASSWORD环境变量处理 +-- 此处作为备用,确保权限正确设置 +CREATE USER IF NOT EXISTS 'ai_user'@'%' IDENTIFIED BY 'yourpassword'; +GRANT ALL PRIVILEGES ON ai_dressing.* TO 'ai_user'@'%'; +FLUSH PRIVILEGES; + +-- 设置必要的MySQL配置 +SET GLOBAL max_connections = 500; +SET GLOBAL connect_timeout = 60; +SET GLOBAL wait_timeout = 600; +SET GLOBAL interactive_timeout = 600; +SET GLOBAL max_allowed_packet = 16777216; -- 16MB + +-- 创建一个测试表,验证初始化是否成功 +CREATE TABLE IF NOT EXISTS `system_info` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + `key` varchar(50) NOT NULL, + `value` text NOT NULL, + `description` varchar(255) DEFAULT NULL, + `created_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP, + `updated_at` timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + PRIMARY KEY (`id`), + UNIQUE KEY `key` (`key`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci; + +-- 插入初始数据 +INSERT INTO `system_info` (`key`, `value`, `description`) +VALUES +('system_name', 'AI Dressing System', '系统名称'), +('version', '1.0.0', '系统版本'), +('init_date', CURRENT_TIMESTAMP, '系统初始化日期') +ON DUPLICATE KEY UPDATE +`value` = VALUES(`value`), +`description` = VALUES(`description`), +`updated_at` = CURRENT_TIMESTAMP; \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index f12445a..1c2a3d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,32 @@ -fastapi==0.104.0 -uvicorn==0.23.2 -dashscope>=1.13.0 -python-dotenv==1.0.0 -pydantic==2.4.2 -httpx==0.25.0 -cos-python-sdk-v5==1.9.26 -qcloud-python-sts==3.1.4 -sqlalchemy==2.0.23 -pymysql==1.1.0 -cryptography==41.0.5 -alembic==1.12.1 -python-multipart==0.0.12 \ No newline at end of file +# REST API依赖 +fastapi>=0.68.0,<0.69.0 +uvicorn>=0.15.0,<0.16.0 +python-multipart>=0.0.5,<0.1.0 +email-validator>=1.1.3,<2.0.0 +pydantic>=1.8.0,<2.0.0 +httpx>=0.23.0,<0.24.0 + +# 数据库依赖 +sqlalchemy>=1.4.0,<1.5.0 +alembic>=1.7.0,<1.8.0 +pymysql>=1.0.2,<1.1.0 +mysqlclient>=2.1.0,<2.2.0 + +# 阿里云SDK +dashscope>=1.5.0,<1.6.0 + +# 腾讯云SDK +cos-python-sdk-v5>=1.9.0,<2.0.0 + +# 工具库 +python-dotenv>=0.19.1,<0.20.0 +python-jose[cryptography]>=3.3.0,<3.4.0 +passlib[bcrypt]>=1.7.4,<1.8.0 +tenacity>=8.0.1,<8.1.0 +loguru>=0.5.3,<0.6.0 +dynaconf>=3.1.7,<3.2.0 + +# 测试依赖(开发环境) +pytest>=6.2.5,<6.3.0 +pytest-asyncio>=0.18.0,<0.19.0 +httpx>=0.23.0,<0.24.0 \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..a796d93 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# 测试包 \ No newline at end of file diff --git a/tests/api/__init__.py b/tests/api/__init__.py new file mode 100644 index 0000000..34ebe3d --- /dev/null +++ b/tests/api/__init__.py @@ -0,0 +1 @@ +# API测试包 \ No newline at end of file diff --git a/tests/api/test_health.py b/tests/api/test_health.py new file mode 100644 index 0000000..6138abe --- /dev/null +++ b/tests/api/test_health.py @@ -0,0 +1,58 @@ +from fastapi.testclient import TestClient +import pytest +from app.main import app + +client = TestClient(app) + +def test_root_endpoint(): + """测试根端点""" + response = client.get("/") + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["code"] == 200 + assert "服务运行中" in data["message"] + +def test_health_endpoint(): + """测试健康检查端点""" + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert "checks" in data["data"] + assert "database" in data["data"]["checks"] + + # 验证响应格式 + assert isinstance(data, dict) + assert "success" in data + assert "code" in data + assert "message" in data + assert "data" in data + + # 验证success字段为布尔类型 + assert isinstance(data["success"], bool) + + # 验证code字段为整数 + assert isinstance(data["code"], int) + + # 验证message字段为字符串 + assert isinstance(data["message"], str) + +def test_info_endpoint(): + """测试服务信息端点""" + response = client.get("/info") + assert response.status_code == 200 + data = response.json() + + assert data["success"] is True + assert data["code"] == 200 + assert "服务信息获取成功" in data["message"] + + # 验证响应的数据部分 + assert "data" in data + assert "app_name" in data["data"] + assert "version" in data["data"] + assert "debug_mode" in data["data"] + +if __name__ == "__main__": + pytest.main(["-v", "test_health.py"]) \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..a4de799 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,59 @@ +import os +import sys +import pytest +from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + +from app.main import app +from app.database import Base, get_db + +# 使用SQLite内存数据库进行测试 +SQLALCHEMY_DATABASE_URL = "sqlite:///./test.db" + +engine = create_engine( + SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} +) +TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +@pytest.fixture(scope="session") +def db(): + """获取测试数据库会话""" + # 创建数据库表 + Base.metadata.create_all(bind=engine) + + # 创建会话 + db = TestingSessionLocal() + try: + yield db + finally: + db.close() + + # 清理:删除所有表 + Base.metadata.drop_all(bind=engine) + + # 删除测试数据库文件 + if os.path.exists("./test.db"): + os.remove("./test.db") + +@pytest.fixture +def client(db): + """创建FastAPI测试客户端""" + def override_get_db(): + try: + yield db + finally: + pass + + # 替换依赖项 + app.dependency_overrides[get_db] = override_get_db + + # 创建测试客户端 + with TestClient(app) as client: + yield client + + # 清理:恢复原始依赖项 + app.dependency_overrides = {} \ No newline at end of file