diff --git a/app/api/v1/api.py b/app/api/v1/api.py index dfc451d..0d2f31e 100644 --- a/app/api/v1/api.py +++ b/app/api/v1/api.py @@ -2,8 +2,12 @@ from fastapi import APIRouter from app.api.v1.endpoints import router as endpoints_router from app.api.v1.users import router as users_router from app.api.v1.auth import router as auth_router +from app.api.v1.user_images import router as user_images_router +from app.api.v1.clothing import router as clothing_router api_router = APIRouter() api_router.include_router(endpoints_router, prefix="") -api_router.include_router(auth_router, prefix="/auth") api_router.include_router(users_router, prefix="/users") +api_router.include_router(auth_router, prefix="/auth") +api_router.include_router(user_images_router, prefix="/user-images") +api_router.include_router(clothing_router, prefix="/clothing") diff --git a/app/api/v1/auth.py b/app/api/v1/auth.py index 4574b57..491a5dc 100644 --- a/app/api/v1/auth.py +++ b/app/api/v1/auth.py @@ -1,5 +1,6 @@ -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, status, Response from sqlalchemy.ext.asyncio import AsyncSession +import logging from app.schemas.auth import WechatLogin, LoginResponse from app.schemas.user import UserCreate @@ -9,6 +10,9 @@ from app.core.security import create_access_token from app.db.database import get_db from app.core.exceptions import BusinessError +# 创建日志记录器 +logger = logging.getLogger(__name__) + router = APIRouter() @router.post("/login/wechat", response_model=LoginResponse, tags=["auth"]) @@ -23,29 +27,48 @@ async def wechat_login( - 如果用户不存在,则创建新用户 - 生成JWT令牌 """ - # 调用微信API获取openid和unionid - openid, unionid = await wechat_service.code2session(login_data.code) - - # 检查用户是否存在 - existing_user = await user_service.get_user_by_openid(db, openid=openid) - is_new_user = existing_user is None - - if is_new_user: - # 创建新用户 - user_create = UserCreate( - openid=openid, - unionid=unionid - ) - user = await user_service.create_user(db, user=user_create) - else: - user = existing_user + try: + logger.info(f"尝试微信登录: {login_data.code[:5]}...") - # 创建访问令牌 - 使用openid作为标识 - access_token = create_access_token(subject=openid) - - # 返回登录响应 - return LoginResponse( - access_token=access_token, - is_new_user=is_new_user, - openid=openid - ) \ No newline at end of file + # 调用微信API获取openid和unionid + openid, unionid = await wechat_service.code2session(login_data.code) + logger.info(f"成功获取openid: {openid[:5]}...") + + # 检查用户是否存在 + existing_user = await user_service.get_user_by_openid(db, openid=openid) + is_new_user = existing_user is None + logger.info(f"用户状态: 新用户={is_new_user}") + + if is_new_user: + # 创建新用户 + user_create = UserCreate( + openid=openid, + unionid=unionid + ) + user = await user_service.create_user(db, user=user_create) + logger.info(f"创建新用户: id={user.id}") + else: + user = existing_user + logger.info(f"现有用户登录: id={user.id}") + + # 创建访问令牌 - 使用openid作为标识 + access_token = create_access_token(subject=openid) + + # 创建响应对象 + response_data = LoginResponse( + access_token=access_token, + is_new_user=is_new_user, + openid=openid + ) + + # 返回标准响应 + return response_data + + except Exception as e: + # 记录异常 + logger.error(f"登录异常: {str(e)}", exc_info=True) + + # 重新抛出业务异常或转换为业务异常 + if isinstance(e, BusinessError): + raise + raise BusinessError(f"登录失败: {str(e)}", code=500) \ No newline at end of file diff --git a/app/api/v1/clothing.py b/app/api/v1/clothing.py new file mode 100644 index 0000000..3948698 --- /dev/null +++ b/app/api/v1/clothing.py @@ -0,0 +1,187 @@ +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.ext.asyncio import AsyncSession +from typing import List + +from app.schemas.clothing import ( + Clothing, ClothingCreate, ClothingUpdate, ClothingWithCategory, + ClothingCategory, ClothingCategoryCreate, ClothingCategoryUpdate +) +from app.services import clothing as clothing_service +from app.db.database import get_db +from app.api.deps import get_current_user +from app.models.users import User as UserModel + +router = APIRouter() + +# 衣服分类API +@router.post("/categories", response_model=ClothingCategory, tags=["clothing-categories"]) +async def create_category( + category: ClothingCategoryCreate, + db: AsyncSession = Depends(get_db), + current_user: UserModel = Depends(get_current_user) +): + """ + 创建衣服分类 + + 需要JWT令牌认证 + """ + return await clothing_service.create_category(db=db, category=category) + +@router.get("/categories", response_model=List[ClothingCategory], tags=["clothing-categories"]) +async def read_categories( + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=100), + db: AsyncSession = Depends(get_db) +): + """获取所有衣服分类""" + return await clothing_service.get_categories(db=db, skip=skip, limit=limit) + +@router.get("/categories/{category_id}", response_model=ClothingCategory, tags=["clothing-categories"]) +async def read_category( + category_id: int, + db: AsyncSession = Depends(get_db) +): + """获取单个衣服分类""" + db_category = await clothing_service.get_category(db, category_id=category_id) + if db_category is None: + raise HTTPException(status_code=404, detail="分类不存在") + return db_category + +@router.put("/categories/{category_id}", response_model=ClothingCategory, tags=["clothing-categories"]) +async def update_category( + category_id: int, + category: ClothingCategoryUpdate, + db: AsyncSession = Depends(get_db), + current_user: UserModel = Depends(get_current_user) +): + """ + 更新衣服分类 + + 需要JWT令牌认证 + """ + db_category = await clothing_service.update_category( + db=db, + category_id=category_id, + category_update=category + ) + if db_category is None: + raise HTTPException(status_code=404, detail="分类不存在") + return db_category + +@router.delete("/categories/{category_id}", response_model=ClothingCategory, tags=["clothing-categories"]) +async def delete_category( + category_id: int, + db: AsyncSession = Depends(get_db), + current_user: UserModel = Depends(get_current_user) +): + """ + 删除衣服分类(会级联删除相关衣服) + + 需要JWT令牌认证 + """ + db_category = await clothing_service.delete_category(db=db, category_id=category_id) + if db_category is None: + raise HTTPException(status_code=404, detail="分类不存在") + return db_category + +# 衣服API +@router.post("/", response_model=Clothing, tags=["clothing"]) +async def create_clothing( + clothing: ClothingCreate, + db: AsyncSession = Depends(get_db), + current_user: UserModel = Depends(get_current_user) +): + """ + 创建衣服 + + 需要JWT令牌认证 + """ + # 检查分类是否存在 + category = await clothing_service.get_category(db, category_id=clothing.clothing_category_id) + if category is None: + raise HTTPException(status_code=404, detail="指定的分类不存在") + + return await clothing_service.create_clothing(db=db, clothing=clothing) + +@router.get("/", response_model=List[Clothing], tags=["clothing"]) +async def read_clothes( + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=100), + db: AsyncSession = Depends(get_db) +): + """获取所有衣服""" + return await clothing_service.get_clothes(db=db, skip=skip, limit=limit) + +@router.get("/by-category/{category_id}", response_model=List[Clothing], tags=["clothing"]) +async def read_clothes_by_category( + category_id: int, + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=100), + db: AsyncSession = Depends(get_db) +): + """根据分类获取衣服""" + # 检查分类是否存在 + category = await clothing_service.get_category(db, category_id=category_id) + if category is None: + raise HTTPException(status_code=404, detail="分类不存在") + + return await clothing_service.get_clothes_by_category( + db=db, + category_id=category_id, + skip=skip, + limit=limit + ) + +@router.get("/{clothing_id}", response_model=Clothing, tags=["clothing"]) +async def read_clothing( + clothing_id: int, + db: AsyncSession = Depends(get_db) +): + """获取单个衣服""" + db_clothing = await clothing_service.get_clothing(db, clothing_id=clothing_id) + if db_clothing is None: + raise HTTPException(status_code=404, detail="衣服不存在") + return db_clothing + +@router.put("/{clothing_id}", response_model=Clothing, tags=["clothing"]) +async def update_clothing( + clothing_id: int, + clothing: ClothingUpdate, + db: AsyncSession = Depends(get_db), + current_user: UserModel = Depends(get_current_user) +): + """ + 更新衣服 + + 需要JWT令牌认证 + """ + # 检查分类是否存在 + if clothing.clothing_category_id is not None: + category = await clothing_service.get_category(db, category_id=clothing.clothing_category_id) + if category is None: + raise HTTPException(status_code=404, detail="指定的分类不存在") + + db_clothing = await clothing_service.update_clothing( + db=db, + clothing_id=clothing_id, + clothing_update=clothing + ) + if db_clothing is None: + raise HTTPException(status_code=404, detail="衣服不存在") + return db_clothing + +@router.delete("/{clothing_id}", response_model=Clothing, tags=["clothing"]) +async def delete_clothing( + clothing_id: int, + db: AsyncSession = Depends(get_db), + current_user: UserModel = Depends(get_current_user) +): + """ + 删除衣服 + + 需要JWT令牌认证 + """ + db_clothing = await clothing_service.delete_clothing(db=db, clothing_id=clothing_id) + if db_clothing is None: + raise HTTPException(status_code=404, detail="衣服不存在") + return db_clothing \ No newline at end of file diff --git a/app/api/v1/user_images.py b/app/api/v1/user_images.py new file mode 100644 index 0000000..c3cfb0a --- /dev/null +++ b/app/api/v1/user_images.py @@ -0,0 +1,114 @@ +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlalchemy.ext.asyncio import AsyncSession +from typing import List + +from app.schemas.user_image import UserImage, UserImageCreate, UserImageUpdate +from app.services import user_image as user_image_service +from app.db.database import get_db +from app.api.deps import get_current_user +from app.models.users import User as UserModel +from app.core.exceptions import BusinessError + +router = APIRouter() + +@router.post("/", response_model=UserImage, tags=["user-images"]) +async def create_image( + image: UserImageCreate, + current_user: UserModel = Depends(get_current_user), + db: AsyncSession = Depends(get_db) +): + """ + 创建用户形象图片 + + 需要JWT令牌认证 + """ + return await user_image_service.create_user_image( + db=db, + user_id=current_user.id, + image=image + ) + +@router.get("/", response_model=List[UserImage], tags=["user-images"]) +async def read_images( + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=100), + current_user: UserModel = Depends(get_current_user), + db: AsyncSession = Depends(get_db) +): + """ + 获取当前用户的所有形象图片 + + 需要JWT令牌认证 + """ + return await user_image_service.get_user_images( + db=db, + user_id=current_user.id, + skip=skip, + limit=limit + ) + +@router.get("/{image_id}", response_model=UserImage, tags=["user-images"]) +async def read_image( + image_id: int, + current_user: UserModel = Depends(get_current_user), + db: AsyncSession = Depends(get_db) +): + """ + 获取指定形象图片 + + 需要JWT令牌认证 + """ + image = await user_image_service.get_user_image(db, image_id=image_id) + if image is None: + raise HTTPException(status_code=404, detail="图片不存在") + if image.user_id != current_user.id: + raise HTTPException(status_code=403, detail="没有权限访问此图片") + return image + +@router.put("/{image_id}", response_model=UserImage, tags=["user-images"]) +async def update_image( + image_id: int, + image_update: UserImageUpdate, + current_user: UserModel = Depends(get_current_user), + db: AsyncSession = Depends(get_db) +): + """ + 更新形象图片 + + 需要JWT令牌认证 + """ + # 先检查图片是否存在 + db_image = await user_image_service.get_user_image(db, image_id=image_id) + if db_image is None: + raise HTTPException(status_code=404, detail="图片不存在") + if db_image.user_id != current_user.id: + raise HTTPException(status_code=403, detail="没有权限更新此图片") + + # 执行更新 + updated_image = await user_image_service.update_user_image( + db=db, + image_id=image_id, + image_update=image_update + ) + return updated_image + +@router.delete("/{image_id}", response_model=UserImage, tags=["user-images"]) +async def delete_image( + image_id: int, + current_user: UserModel = Depends(get_current_user), + db: AsyncSession = Depends(get_db) +): + """ + 删除形象图片 + + 需要JWT令牌认证 + """ + # 先检查图片是否存在 + db_image = await user_image_service.get_user_image(db, image_id=image_id) + if db_image is None: + raise HTTPException(status_code=404, detail="图片不存在") + if db_image.user_id != current_user.id: + raise HTTPException(status_code=403, detail="没有权限删除此图片") + + # 执行删除 + return await user_image_service.delete_user_image(db=db, image_id=image_id) \ No newline at end of file diff --git a/app/core/config.py b/app/core/config.py index f55f162..8bec645 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -23,7 +23,8 @@ class Settings(BaseSettings): DB_ECHO: bool = os.getenv("DB_ECHO", "False").lower() == "true" # DashScope API密钥 - DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "sk-caa199589f1c451aaac471fad2986e28") + DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "") + DASHSCOPE_MODEL: str = os.getenv("DASHSCOPE_MODEL", "qwen-max") # 安全设置 SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-for-jwt-please-change-in-production") @@ -34,6 +35,15 @@ class Settings(BaseSettings): WECHAT_APP_ID: str = os.getenv("WECHAT_APP_ID", "") WECHAT_APP_SECRET: str = os.getenv("WECHAT_APP_SECRET", "") + @property + def cors_origins(self): + """获取CORS来源列表""" + if isinstance(self.BACKEND_CORS_ORIGINS, str) and self.BACKEND_CORS_ORIGINS == "*": + return ["*"] + elif isinstance(self.BACKEND_CORS_ORIGINS, list): + return self.BACKEND_CORS_ORIGINS + return [] + class Config: env_file = ".env" case_sensitive = True diff --git a/app/core/middleware.py b/app/core/middleware.py index 73c5610..e56ec18 100644 --- a/app/core/middleware.py +++ b/app/core/middleware.py @@ -19,49 +19,42 @@ class ResponseMiddleware(BaseHTTPMiddleware): exclude_paths = ["/docs", "/redoc", "/openapi.json"] if any(request.url.path.startswith(path) for path in exclude_paths): return await call_next(request) - + + # 获取原始响应 response = await call_next(request) - # 如果是HTTPException,直接返回,不进行封装 - if response.status_code >= 400: + # 如果不是200系列响应或不是JSON响应,直接返回 + if response.status_code >= 300 or response.headers.get("content-type") != "application/json": return response - # 正常响应进行封装 - if response.status_code < 300 and response.headers.get("content-type") == "application/json": - response_body = [chunk async for chunk in response.body_iterator] - response_body = b"".join(response_body) - - if response_body: - try: - data = json.loads(response_body) - - # 已经是统一格式的响应,不再封装 - if isinstance(data, dict) and "code" in data and ("data" in data or "message" in data): - return Response( - content=response_body, - status_code=response.status_code, - headers=dict(response.headers), - media_type=response.media_type - ) - - # 封装为标准响应格式 - result = StandardResponse(code=200, data=data).model_dump() - return JSONResponse( - content=result, - status_code=response.status_code, - headers=dict(response.headers) - ) - - except json.JSONDecodeError: - # 非JSON响应,直接返回 - return Response( - content=response_body, - status_code=response.status_code, - headers=dict(response.headers), - media_type=response.media_type - ) - - return response + # 读取响应内容 + try: + # 使用JSONResponse的_render方法获取内容 + if isinstance(response, JSONResponse): + # 获取原始数据 + raw_data = response.body.decode("utf-8") + data = json.loads(raw_data) + + # 已经是标准格式,不再封装 + if isinstance(data, dict) and "code" in data and ("data" in data or "message" in data): + return response + + # 创建新的标准响应 + std_response = StandardResponse(code=200, data=data) + + # 创建新的JSONResponse + return JSONResponse( + content=std_response.model_dump(), + status_code=response.status_code, + headers=dict(response.headers) + ) + # 处理流式响应或其他类型响应 + else: + return response + + except Exception as e: + # 出现异常,返回原始响应 + return response def add_response_middleware(app: FastAPI) -> None: """添加响应处理中间件到FastAPI应用""" diff --git a/app/db/init_db.py b/app/db/init_db.py index 8e79a25..4cc7cfa 100644 --- a/app/db/init_db.py +++ b/app/db/init_db.py @@ -1,6 +1,8 @@ import asyncio from app.db.database import Base, engine from app.models.users import User +from app.models.user_images import UserImage +from app.models.clothing import ClothingCategory, Clothing # 创建所有表格 async def init_db(): diff --git a/app/main.py b/app/main.py index 01f5c93..d5e4df6 100644 --- a/app/main.py +++ b/app/main.py @@ -1,19 +1,29 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from contextlib import asynccontextmanager +import logging from app.core.config import settings from app.api.v1.api import api_router from app.core.middleware import add_response_middleware from app.core.exceptions import add_exception_handlers +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', +) +logger = logging.getLogger(__name__) + @asynccontextmanager async def lifespan(app: FastAPI): # 在应用启动时执行 + logger.info("应用启动,初始化数据库...") from app.db.init_db import init_db await init_db() + logger.info("数据库初始化完成") yield # 在应用关闭时执行 - # 清理代码可以放在这里 + logger.info("应用关闭") app = FastAPI( title=settings.PROJECT_NAME, @@ -25,7 +35,7 @@ app = FastAPI( # 配置CORS app.add_middleware( CORSMiddleware, - allow_origins=settings.BACKEND_CORS_ORIGINS, + allow_origins=settings.cors_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], diff --git a/app/models/clothing.py b/app/models/clothing.py new file mode 100644 index 0000000..06f3b5c --- /dev/null +++ b/app/models/clothing.py @@ -0,0 +1,33 @@ +from sqlalchemy import Column, Integer, String, DateTime, ForeignKey +from sqlalchemy.sql import func +from sqlalchemy.orm import relationship +from app.db.database import Base + +class ClothingCategory(Base): + """衣服分类表""" + __tablename__ = "clothing_category" + + id = Column(Integer, primary_key=True, autoincrement=True, index=True) + name = Column(String(50), nullable=False, comment="分类名称") + create_time = Column(DateTime, default=func.now(), comment="创建时间") + + # 关系 + clothes = relationship("Clothing", back_populates="category", cascade="all, delete-orphan") + + def __repr__(self): + return f"" + +class Clothing(Base): + """衣服表""" + __tablename__ = "clothing" + + id = Column(Integer, primary_key=True, autoincrement=True, index=True) + clothing_category_id = Column(Integer, ForeignKey("clothing_category.id"), nullable=False, index=True, comment="分类ID") + image_url = Column(String(500), nullable=False, comment="图片URL") + create_time = Column(DateTime, default=func.now(), comment="创建时间") + + # 关系 + category = relationship("ClothingCategory", back_populates="clothes") + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/app/models/user_images.py b/app/models/user_images.py new file mode 100644 index 0000000..98e140e --- /dev/null +++ b/app/models/user_images.py @@ -0,0 +1,20 @@ +from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Boolean +from sqlalchemy.sql import func +from sqlalchemy.orm import relationship +from app.db.database import Base + +class UserImage(Base): + """用户个人形象库模型""" + __tablename__ = "user_images" + + id = Column(Integer, primary_key=True, autoincrement=True, index=True) + user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True, comment="用户ID") + image_url = Column(String(500), nullable=False, comment="图片URL") + is_default = Column(Boolean, default=False, nullable=False, comment="是否为默认形象") + create_time = Column(DateTime, default=func.now(), comment="创建时间") + + # 关系 + user = relationship("User", back_populates="images") + + def __repr__(self): + return f"" \ No newline at end of file diff --git a/app/models/users.py b/app/models/users.py index 00fa385..93ee79e 100644 --- a/app/models/users.py +++ b/app/models/users.py @@ -1,5 +1,6 @@ from sqlalchemy import Column, Integer, String, DateTime from sqlalchemy.sql import func +from sqlalchemy.orm import relationship from app.db.database import Base class User(Base): @@ -12,5 +13,8 @@ class User(Base): nickname = Column(String(50), nullable=True, comment="昵称") create_time = Column(DateTime, default=func.now(), comment="创建时间") + # 关系 + images = relationship("UserImage", back_populates="user", cascade="all, delete-orphan") + def __repr__(self): return f"" \ No newline at end of file diff --git a/app/schemas/clothing.py b/app/schemas/clothing.py new file mode 100644 index 0000000..dd51442 --- /dev/null +++ b/app/schemas/clothing.py @@ -0,0 +1,63 @@ +from pydantic import BaseModel, Field +from typing import Optional, List +from datetime import datetime + +# 分类模式 +class ClothingCategoryBase(BaseModel): + """衣服分类基础模式""" + name: str = Field(..., description="分类名称") + +class ClothingCategoryCreate(ClothingCategoryBase): + """创建衣服分类请求""" + pass + +class ClothingCategoryInDB(ClothingCategoryBase): + """数据库中的衣服分类数据""" + id: int + create_time: datetime + + class Config: + from_attributes = True + +class ClothingCategory(ClothingCategoryInDB): + """衣服分类响应模式""" + pass + +class ClothingCategoryUpdate(BaseModel): + """更新衣服分类请求""" + name: Optional[str] = Field(None, description="分类名称") + +# 衣服模式 +class ClothingBase(BaseModel): + """衣服基础模式""" + clothing_category_id: int = Field(..., description="分类ID") + image_url: str = Field(..., description="图片URL") + +class ClothingCreate(ClothingBase): + """创建衣服请求""" + pass + +class ClothingInDB(ClothingBase): + """数据库中的衣服数据""" + id: int + create_time: datetime + + class Config: + from_attributes = True + +class Clothing(ClothingInDB): + """衣服响应模式""" + pass + +class ClothingUpdate(BaseModel): + """更新衣服请求""" + clothing_category_id: Optional[int] = Field(None, description="分类ID") + image_url: Optional[str] = Field(None, description="图片URL") + +# 带有分类信息的衣服响应 +class ClothingWithCategory(Clothing): + """包含分类信息的衣服响应""" + category: ClothingCategory + + class Config: + from_attributes = True \ No newline at end of file diff --git a/app/schemas/user_image.py b/app/schemas/user_image.py new file mode 100644 index 0000000..680f5a4 --- /dev/null +++ b/app/schemas/user_image.py @@ -0,0 +1,30 @@ +from pydantic import BaseModel, Field +from typing import Optional +from datetime import datetime + +class UserImageBase(BaseModel): + """用户形象基础模式""" + image_url: str = Field(..., description="图片URL") + is_default: bool = Field(False, description="是否为默认形象") + +class UserImageCreate(UserImageBase): + """创建用户形象请求""" + pass + +class UserImageInDB(UserImageBase): + """数据库中的用户形象数据""" + id: int + user_id: int + create_time: datetime + + class Config: + from_attributes = True + +class UserImage(UserImageInDB): + """用户形象响应模式""" + pass + +class UserImageUpdate(BaseModel): + """更新用户形象请求""" + image_url: Optional[str] = Field(None, description="图片URL") + is_default: Optional[bool] = Field(None, description="是否为默认形象") \ No newline at end of file diff --git a/app/services/clothing.py b/app/services/clothing.py new file mode 100644 index 0000000..838692b --- /dev/null +++ b/app/services/clothing.py @@ -0,0 +1,116 @@ +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import update, delete +from typing import List, Optional + +from app.models.clothing import Clothing, ClothingCategory +from app.schemas.clothing import ClothingCreate, ClothingUpdate +from app.schemas.clothing import ClothingCategoryCreate, ClothingCategoryUpdate + +# 衣服分类服务函数 +async def get_category(db: AsyncSession, category_id: int): + """获取单个衣服分类""" + result = await db.execute(select(ClothingCategory).filter(ClothingCategory.id == category_id)) + return result.scalars().first() + +async def get_categories(db: AsyncSession, skip: int = 0, limit: int = 100): + """获取所有衣服分类""" + result = await db.execute( + select(ClothingCategory) + .offset(skip) + .limit(limit) + ) + return result.scalars().all() + +async def create_category(db: AsyncSession, category: ClothingCategoryCreate): + """创建衣服分类""" + db_category = ClothingCategory( + name=category.name + ) + db.add(db_category) + await db.commit() + await db.refresh(db_category) + return db_category + +async def update_category(db: AsyncSession, category_id: int, category_update: ClothingCategoryUpdate): + """更新衣服分类""" + db_category = await get_category(db, category_id) + if not db_category: + return None + + if category_update.name: + db_category.name = category_update.name + + await db.commit() + await db.refresh(db_category) + return db_category + +async def delete_category(db: AsyncSession, category_id: int): + """删除衣服分类(会级联删除相关衣服)""" + db_category = await get_category(db, category_id) + if db_category: + await db.delete(db_category) + await db.commit() + return db_category + +# 衣服服务函数 +async def get_clothing(db: AsyncSession, clothing_id: int): + """获取单个衣服""" + result = await db.execute(select(Clothing).filter(Clothing.id == clothing_id)) + return result.scalars().first() + +async def get_clothes(db: AsyncSession, skip: int = 0, limit: int = 100): + """获取所有衣服""" + result = await db.execute( + select(Clothing) + .order_by(Clothing.create_time.desc()) + .offset(skip) + .limit(limit) + ) + return result.scalars().all() + +async def get_clothes_by_category(db: AsyncSession, category_id: int, skip: int = 0, limit: int = 100): + """根据分类获取衣服""" + result = await db.execute( + select(Clothing) + .filter(Clothing.clothing_category_id == category_id) + .order_by(Clothing.create_time.desc()) + .offset(skip) + .limit(limit) + ) + return result.scalars().all() + +async def create_clothing(db: AsyncSession, clothing: ClothingCreate): + """创建衣服""" + db_clothing = Clothing( + clothing_category_id=clothing.clothing_category_id, + image_url=clothing.image_url + ) + db.add(db_clothing) + await db.commit() + await db.refresh(db_clothing) + return db_clothing + +async def update_clothing(db: AsyncSession, clothing_id: int, clothing_update: ClothingUpdate): + """更新衣服""" + db_clothing = await get_clothing(db, clothing_id) + if not db_clothing: + return None + + if clothing_update.clothing_category_id is not None: + db_clothing.clothing_category_id = clothing_update.clothing_category_id + + if clothing_update.image_url: + db_clothing.image_url = clothing_update.image_url + + await db.commit() + await db.refresh(db_clothing) + return db_clothing + +async def delete_clothing(db: AsyncSession, clothing_id: int): + """删除衣服""" + db_clothing = await get_clothing(db, clothing_id) + if db_clothing: + await db.delete(db_clothing) + await db.commit() + return db_clothing \ No newline at end of file diff --git a/app/services/dashscope_service.py b/app/services/dashscope_service.py new file mode 100644 index 0000000..3d5877c --- /dev/null +++ b/app/services/dashscope_service.py @@ -0,0 +1,272 @@ +import os +import logging +import dashscope +from dashscope import Generation +# 修改导入语句,dashscope的API响应可能改变了结构 +from typing import List, Dict, Any, Optional +import asyncio +import httpx +from app.utils.config import get_settings + +logger = logging.getLogger(__name__) + +class DashScopeService: + """DashScope服务类,提供对DashScope API的调用封装""" + + def __init__(self): + settings = get_settings() + self.api_key = settings.dashscope_api_key + # 配置DashScope + dashscope.api_key = self.api_key + # 配置API URL + self.image_synthesis_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis" + + async def chat_completion( + self, + messages: List[Dict[str, str]], + model: str = "qwen-max", + max_tokens: int = 2048, + temperature: float = 0.7, + stream: bool = False + ): + """ + 调用DashScope的大模型API进行对话 + + Args: + messages: 对话历史记录 + model: 模型名称 + max_tokens: 最大生成token数 + temperature: 温度参数,控制随机性 + stream: 是否流式输出 + + Returns: + ApiResponse: DashScope的API响应 + """ + try: + # 为了不阻塞FastAPI的异步性能,我们使用run_in_executor运行同步API + loop = asyncio.get_event_loop() + response = await loop.run_in_executor( + None, + lambda: Generation.call( + model=model, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + result_format='message', + stream=stream, + ) + ) + + if response.status_code != 200: + logger.error(f"DashScope API请求失败,状态码:{response.status_code}, 错误信息:{response.message}") + raise Exception(f"API调用失败: {response.message}") + + return response + except Exception as e: + logger.error(f"DashScope聊天API调用出错: {str(e)}") + raise e + + async def generate_image( + self, + prompt: str, + negative_prompt: Optional[str] = None, + model: str = "stable-diffusion-xl", + n: int = 1, + size: str = "1024*1024" + ): + """ + 调用DashScope的图像生成API + + Args: + prompt: 生成图像的文本描述 + negative_prompt: 负面提示词 + model: 模型名称 + n: 生成图像数量 + size: 图像尺寸 + + Returns: + ApiResponse: DashScope的API响应 + """ + try: + # 构建请求参数 + params = { + "model": model, + "prompt": prompt, + "n": n, + "size": size, + } + + if negative_prompt: + params["negative_prompt"] = negative_prompt + + # 异步调用图像生成API + loop = asyncio.get_event_loop() + response = await loop.run_in_executor( + None, + lambda: dashscope.ImageSynthesis.call(**params) + ) + + if response.status_code != 200: + logger.error(f"DashScope 图像生成API请求失败,状态码:{response.status_code}, 错误信息:{response.message}") + raise Exception(f"图像生成API调用失败: {response.message}") + + return response + except Exception as e: + logger.error(f"DashScope图像生成API调用出错: {str(e)}") + raise e + + async def generate_tryon( + self, + person_image_url: str, + top_garment_url: Optional[str] = None, + bottom_garment_url: Optional[str] = None, + resolution: int = -1, + restore_face: bool = True + ): + """ + 调用阿里百炼平台的试衣服务 + + Args: + person_image_url: 人物图片URL + top_garment_url: 上衣图片URL + bottom_garment_url: 下衣图片URL + resolution: 分辨率,-1表示自动 + restore_face: 是否修复面部 + + Returns: + Dict: 包含任务ID和请求ID的响应 + """ + try: + # 验证参数 + if not person_image_url: + raise ValueError("人物图片URL不能为空") + + if not top_garment_url and not bottom_garment_url: + raise ValueError("上衣和下衣图片至少需要提供一个") + + # 构建请求数据 + request_data = { + "model": "aitryon", + "input": { + "person_image_url": person_image_url + }, + "parameters": { + "resolution": resolution, + "restore_face": restore_face + } + } + + # 添加可选字段 + if top_garment_url: + request_data["input"]["top_garment_url"] = top_garment_url + if bottom_garment_url: + request_data["input"]["bottom_garment_url"] = bottom_garment_url + + # 构建请求头 + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + "X-DashScope-Async": "enable" + } + + # 发送请求 + async with httpx.AsyncClient() as client: + logger.info(f"发送试穿请求: {request_data}") + response = await client.post( + self.image_synthesis_url, + json=request_data, + headers=headers, + timeout=30.0 + ) + + response_data = response.json() + logger.info(f"试穿API响应: {response_data}") + + if response.status_code == 200 or response.status_code == 202: # 202表示异步任务已接受 + # 提取任务ID,适应不同的API响应格式 + task_id = None + if 'output' in response_data and 'task_id' in response_data['output']: + task_id = response_data['output']['task_id'] + elif 'task_id' in response_data: + task_id = response_data['task_id'] + + if task_id: + logger.info(f"试穿请求发送成功,任务ID: {task_id}") + return { + "task_id": task_id, + "request_id": response_data.get('request_id'), + "status": "processing" + } + else: + # 如果没有任务ID,这可能是同步响应 + logger.info("收到同步响应,没有任务ID") + return { + "status": "completed", + "result": response_data.get('output', {}), + "request_id": response_data.get('request_id') + } + else: + error_msg = f"试穿请求失败: {response.status_code} - {response.text}" + logger.error(error_msg) + raise Exception(error_msg) + except Exception as e: + logger.error(f"DashScope试穿API调用出错: {str(e)}") + raise e + + async def check_tryon_status(self, task_id: str): + """ + 检查试穿任务状态 + + Args: + task_id: 任务ID + + Returns: + Dict: 任务状态信息 + """ + try: + # 构建请求头 + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + # 构建请求URL + status_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}" + + # 发送请求 + async with httpx.AsyncClient() as client: + response = await client.get( + status_url, + headers=headers, + timeout=30.0 + ) + + if response.status_code == 200: + response_data = response.json() + print(response_data) + status = response_data.get('output', {}).get('task_status', '') + logger.info(f"试穿任务状态查询成功: {status}") + + # 检查是否完成并返回结果 + if status.lower() == 'succeeded': + image_url = response_data.get('output', {}).get('image_url') + if image_url: + logger.info(f"试穿任务完成,结果URL: {image_url}") + return { + "status": status, + "task_id": task_id, + "image_url": image_url, + "result": response_data.get('output', {}) + } + + return { + "status": status, + "task_id": task_id + } + else: + error_msg = f"试穿任务状态查询失败: {response.status_code} - {response.text}" + logger.error(error_msg) + raise Exception(error_msg) + except Exception as e: + logger.error(f"查询试穿任务状态出错: {str(e)}") + raise e \ No newline at end of file diff --git a/app/services/user_image.py b/app/services/user_image.py new file mode 100644 index 0000000..fbdd5c4 --- /dev/null +++ b/app/services/user_image.py @@ -0,0 +1,91 @@ +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.future import select +from sqlalchemy import delete, update +from app.models.user_images import UserImage +from app.schemas.user_image import UserImageCreate, UserImageUpdate +from typing import List, Optional + +async def get_user_image(db: AsyncSession, image_id: int): + """获取单个用户形象""" + result = await db.execute(select(UserImage).filter(UserImage.id == image_id)) + return result.scalars().first() + +async def get_user_images(db: AsyncSession, user_id: int, skip: int = 0, limit: int = 100): + """获取用户的所有形象图片""" + result = await db.execute( + select(UserImage) + .filter(UserImage.user_id == user_id) + .order_by(UserImage.create_time.desc()) + .offset(skip) + .limit(limit) + ) + return result.scalars().all() + +async def create_user_image(db: AsyncSession, user_id: int, image: UserImageCreate): + """创建用户形象""" + # 如果设置为默认形象,先重置用户的所有形象为非默认 + if image.is_default: + await reset_default_images(db, user_id) + + db_image = UserImage( + user_id=user_id, + image_url=image.image_url, + is_default=image.is_default + ) + db.add(db_image) + await db.commit() + await db.refresh(db_image) + return db_image + +async def update_user_image(db: AsyncSession, image_id: int, image_update: UserImageUpdate): + """更新用户形象""" + db_image = await get_user_image(db, image_id) + if not db_image: + return None + + # 处理默认形象逻辑 + if image_update.is_default is not None and image_update.is_default and not db_image.is_default: + # 如果设置为默认且之前不是默认,重置其他形象 + await reset_default_images(db, db_image.user_id) + db_image.is_default = True + elif image_update.is_default is not None: + db_image.is_default = image_update.is_default + + # 更新图片URL + if image_update.image_url: + db_image.image_url = image_update.image_url + + await db.commit() + await db.refresh(db_image) + return db_image + +async def delete_user_image(db: AsyncSession, image_id: int): + """删除用户形象""" + db_image = await get_user_image(db, image_id) + if db_image: + await db.delete(db_image) + await db.commit() + return db_image + +async def delete_user_images(db: AsyncSession, user_id: int): + """删除用户所有形象图片""" + stmt = delete(UserImage).where(UserImage.user_id == user_id) + await db.execute(stmt) + await db.commit() + return True + +async def reset_default_images(db: AsyncSession, user_id: int): + """重置用户所有形象为非默认""" + stmt = update(UserImage).where( + UserImage.user_id == user_id, + UserImage.is_default == True + ).values(is_default=False) + await db.execute(stmt) + +async def get_default_image(db: AsyncSession, user_id: int): + """获取用户的默认形象""" + result = await db.execute( + select(UserImage) + .filter(UserImage.user_id == user_id, UserImage.is_default == True) + ) + return result.scalars().first() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 773c678..1a5221a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,5 @@ aiomysql==0.2.0 greenlet==2.0.2 python-jose[cryptography]==3.3.0 passlib==1.7.4 -httpx==0.24.1 \ No newline at end of file +httpx==0.24.1 +dashscope==1.10.0 \ No newline at end of file