This commit is contained in:
aaron 2025-04-09 15:00:38 +08:00
parent d5d7f9b89a
commit 16c8c11172
17 changed files with 1043 additions and 70 deletions

View File

@ -2,8 +2,12 @@ from fastapi import APIRouter
from app.api.v1.endpoints import router as endpoints_router
from app.api.v1.users import router as users_router
from app.api.v1.auth import router as auth_router
from app.api.v1.user_images import router as user_images_router
from app.api.v1.clothing import router as clothing_router
api_router = APIRouter()
api_router.include_router(endpoints_router, prefix="")
api_router.include_router(auth_router, prefix="/auth")
api_router.include_router(users_router, prefix="/users")
api_router.include_router(auth_router, prefix="/auth")
api_router.include_router(user_images_router, prefix="/user-images")
api_router.include_router(clothing_router, prefix="/clothing")

View File

@ -1,5 +1,6 @@
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException, status, Response
from sqlalchemy.ext.asyncio import AsyncSession
import logging
from app.schemas.auth import WechatLogin, LoginResponse
from app.schemas.user import UserCreate
@ -9,6 +10,9 @@ from app.core.security import create_access_token
from app.db.database import get_db
from app.core.exceptions import BusinessError
# 创建日志记录器
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/login/wechat", response_model=LoginResponse, tags=["auth"])
@ -23,29 +27,48 @@ async def wechat_login(
- 如果用户不存在则创建新用户
- 生成JWT令牌
"""
# 调用微信API获取openid和unionid
openid, unionid = await wechat_service.code2session(login_data.code)
# 检查用户是否存在
existing_user = await user_service.get_user_by_openid(db, openid=openid)
is_new_user = existing_user is None
if is_new_user:
# 创建新用户
user_create = UserCreate(
openid=openid,
unionid=unionid
)
user = await user_service.create_user(db, user=user_create)
else:
user = existing_user
try:
logger.info(f"尝试微信登录: {login_data.code[:5]}...")
# 创建访问令牌 - 使用openid作为标识
access_token = create_access_token(subject=openid)
# 返回登录响应
return LoginResponse(
access_token=access_token,
is_new_user=is_new_user,
openid=openid
)
# 调用微信API获取openid和unionid
openid, unionid = await wechat_service.code2session(login_data.code)
logger.info(f"成功获取openid: {openid[:5]}...")
# 检查用户是否存在
existing_user = await user_service.get_user_by_openid(db, openid=openid)
is_new_user = existing_user is None
logger.info(f"用户状态: 新用户={is_new_user}")
if is_new_user:
# 创建新用户
user_create = UserCreate(
openid=openid,
unionid=unionid
)
user = await user_service.create_user(db, user=user_create)
logger.info(f"创建新用户: id={user.id}")
else:
user = existing_user
logger.info(f"现有用户登录: id={user.id}")
# 创建访问令牌 - 使用openid作为标识
access_token = create_access_token(subject=openid)
# 创建响应对象
response_data = LoginResponse(
access_token=access_token,
is_new_user=is_new_user,
openid=openid
)
# 返回标准响应
return response_data
except Exception as e:
# 记录异常
logger.error(f"登录异常: {str(e)}", exc_info=True)
# 重新抛出业务异常或转换为业务异常
if isinstance(e, BusinessError):
raise
raise BusinessError(f"登录失败: {str(e)}", code=500)

187
app/api/v1/clothing.py Normal file
View File

@ -0,0 +1,187 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
from app.schemas.clothing import (
Clothing, ClothingCreate, ClothingUpdate, ClothingWithCategory,
ClothingCategory, ClothingCategoryCreate, ClothingCategoryUpdate
)
from app.services import clothing as clothing_service
from app.db.database import get_db
from app.api.deps import get_current_user
from app.models.users import User as UserModel
router = APIRouter()
# 衣服分类API
@router.post("/categories", response_model=ClothingCategory, tags=["clothing-categories"])
async def create_category(
category: ClothingCategoryCreate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user)
):
"""
创建衣服分类
需要JWT令牌认证
"""
return await clothing_service.create_category(db=db, category=category)
@router.get("/categories", response_model=List[ClothingCategory], tags=["clothing-categories"])
async def read_categories(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=100),
db: AsyncSession = Depends(get_db)
):
"""获取所有衣服分类"""
return await clothing_service.get_categories(db=db, skip=skip, limit=limit)
@router.get("/categories/{category_id}", response_model=ClothingCategory, tags=["clothing-categories"])
async def read_category(
category_id: int,
db: AsyncSession = Depends(get_db)
):
"""获取单个衣服分类"""
db_category = await clothing_service.get_category(db, category_id=category_id)
if db_category is None:
raise HTTPException(status_code=404, detail="分类不存在")
return db_category
@router.put("/categories/{category_id}", response_model=ClothingCategory, tags=["clothing-categories"])
async def update_category(
category_id: int,
category: ClothingCategoryUpdate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user)
):
"""
更新衣服分类
需要JWT令牌认证
"""
db_category = await clothing_service.update_category(
db=db,
category_id=category_id,
category_update=category
)
if db_category is None:
raise HTTPException(status_code=404, detail="分类不存在")
return db_category
@router.delete("/categories/{category_id}", response_model=ClothingCategory, tags=["clothing-categories"])
async def delete_category(
category_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user)
):
"""
删除衣服分类会级联删除相关衣服
需要JWT令牌认证
"""
db_category = await clothing_service.delete_category(db=db, category_id=category_id)
if db_category is None:
raise HTTPException(status_code=404, detail="分类不存在")
return db_category
# 衣服API
@router.post("/", response_model=Clothing, tags=["clothing"])
async def create_clothing(
clothing: ClothingCreate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user)
):
"""
创建衣服
需要JWT令牌认证
"""
# 检查分类是否存在
category = await clothing_service.get_category(db, category_id=clothing.clothing_category_id)
if category is None:
raise HTTPException(status_code=404, detail="指定的分类不存在")
return await clothing_service.create_clothing(db=db, clothing=clothing)
@router.get("/", response_model=List[Clothing], tags=["clothing"])
async def read_clothes(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=100),
db: AsyncSession = Depends(get_db)
):
"""获取所有衣服"""
return await clothing_service.get_clothes(db=db, skip=skip, limit=limit)
@router.get("/by-category/{category_id}", response_model=List[Clothing], tags=["clothing"])
async def read_clothes_by_category(
category_id: int,
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=100),
db: AsyncSession = Depends(get_db)
):
"""根据分类获取衣服"""
# 检查分类是否存在
category = await clothing_service.get_category(db, category_id=category_id)
if category is None:
raise HTTPException(status_code=404, detail="分类不存在")
return await clothing_service.get_clothes_by_category(
db=db,
category_id=category_id,
skip=skip,
limit=limit
)
@router.get("/{clothing_id}", response_model=Clothing, tags=["clothing"])
async def read_clothing(
clothing_id: int,
db: AsyncSession = Depends(get_db)
):
"""获取单个衣服"""
db_clothing = await clothing_service.get_clothing(db, clothing_id=clothing_id)
if db_clothing is None:
raise HTTPException(status_code=404, detail="衣服不存在")
return db_clothing
@router.put("/{clothing_id}", response_model=Clothing, tags=["clothing"])
async def update_clothing(
clothing_id: int,
clothing: ClothingUpdate,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user)
):
"""
更新衣服
需要JWT令牌认证
"""
# 检查分类是否存在
if clothing.clothing_category_id is not None:
category = await clothing_service.get_category(db, category_id=clothing.clothing_category_id)
if category is None:
raise HTTPException(status_code=404, detail="指定的分类不存在")
db_clothing = await clothing_service.update_clothing(
db=db,
clothing_id=clothing_id,
clothing_update=clothing
)
if db_clothing is None:
raise HTTPException(status_code=404, detail="衣服不存在")
return db_clothing
@router.delete("/{clothing_id}", response_model=Clothing, tags=["clothing"])
async def delete_clothing(
clothing_id: int,
db: AsyncSession = Depends(get_db),
current_user: UserModel = Depends(get_current_user)
):
"""
删除衣服
需要JWT令牌认证
"""
db_clothing = await clothing_service.delete_clothing(db=db, clothing_id=clothing_id)
if db_clothing is None:
raise HTTPException(status_code=404, detail="衣服不存在")
return db_clothing

114
app/api/v1/user_images.py Normal file
View File

@ -0,0 +1,114 @@
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List
from app.schemas.user_image import UserImage, UserImageCreate, UserImageUpdate
from app.services import user_image as user_image_service
from app.db.database import get_db
from app.api.deps import get_current_user
from app.models.users import User as UserModel
from app.core.exceptions import BusinessError
router = APIRouter()
@router.post("/", response_model=UserImage, tags=["user-images"])
async def create_image(
image: UserImageCreate,
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
创建用户形象图片
需要JWT令牌认证
"""
return await user_image_service.create_user_image(
db=db,
user_id=current_user.id,
image=image
)
@router.get("/", response_model=List[UserImage], tags=["user-images"])
async def read_images(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=100),
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
获取当前用户的所有形象图片
需要JWT令牌认证
"""
return await user_image_service.get_user_images(
db=db,
user_id=current_user.id,
skip=skip,
limit=limit
)
@router.get("/{image_id}", response_model=UserImage, tags=["user-images"])
async def read_image(
image_id: int,
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
获取指定形象图片
需要JWT令牌认证
"""
image = await user_image_service.get_user_image(db, image_id=image_id)
if image is None:
raise HTTPException(status_code=404, detail="图片不存在")
if image.user_id != current_user.id:
raise HTTPException(status_code=403, detail="没有权限访问此图片")
return image
@router.put("/{image_id}", response_model=UserImage, tags=["user-images"])
async def update_image(
image_id: int,
image_update: UserImageUpdate,
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
更新形象图片
需要JWT令牌认证
"""
# 先检查图片是否存在
db_image = await user_image_service.get_user_image(db, image_id=image_id)
if db_image is None:
raise HTTPException(status_code=404, detail="图片不存在")
if db_image.user_id != current_user.id:
raise HTTPException(status_code=403, detail="没有权限更新此图片")
# 执行更新
updated_image = await user_image_service.update_user_image(
db=db,
image_id=image_id,
image_update=image_update
)
return updated_image
@router.delete("/{image_id}", response_model=UserImage, tags=["user-images"])
async def delete_image(
image_id: int,
current_user: UserModel = Depends(get_current_user),
db: AsyncSession = Depends(get_db)
):
"""
删除形象图片
需要JWT令牌认证
"""
# 先检查图片是否存在
db_image = await user_image_service.get_user_image(db, image_id=image_id)
if db_image is None:
raise HTTPException(status_code=404, detail="图片不存在")
if db_image.user_id != current_user.id:
raise HTTPException(status_code=403, detail="没有权限删除此图片")
# 执行删除
return await user_image_service.delete_user_image(db=db, image_id=image_id)

View File

@ -23,7 +23,8 @@ class Settings(BaseSettings):
DB_ECHO: bool = os.getenv("DB_ECHO", "False").lower() == "true"
# DashScope API密钥
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "sk-caa199589f1c451aaac471fad2986e28")
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "")
DASHSCOPE_MODEL: str = os.getenv("DASHSCOPE_MODEL", "qwen-max")
# 安全设置
SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-for-jwt-please-change-in-production")
@ -34,6 +35,15 @@ class Settings(BaseSettings):
WECHAT_APP_ID: str = os.getenv("WECHAT_APP_ID", "")
WECHAT_APP_SECRET: str = os.getenv("WECHAT_APP_SECRET", "")
@property
def cors_origins(self):
"""获取CORS来源列表"""
if isinstance(self.BACKEND_CORS_ORIGINS, str) and self.BACKEND_CORS_ORIGINS == "*":
return ["*"]
elif isinstance(self.BACKEND_CORS_ORIGINS, list):
return self.BACKEND_CORS_ORIGINS
return []
class Config:
env_file = ".env"
case_sensitive = True

View File

@ -19,49 +19,42 @@ class ResponseMiddleware(BaseHTTPMiddleware):
exclude_paths = ["/docs", "/redoc", "/openapi.json"]
if any(request.url.path.startswith(path) for path in exclude_paths):
return await call_next(request)
# 获取原始响应
response = await call_next(request)
# 如果是HTTPException直接返回不进行封装
if response.status_code >= 400:
# 如果不是200系列响应或不是JSON响应直接返回
if response.status_code >= 300 or response.headers.get("content-type") != "application/json":
return response
# 正常响应进行封装
if response.status_code < 300 and response.headers.get("content-type") == "application/json":
response_body = [chunk async for chunk in response.body_iterator]
response_body = b"".join(response_body)
if response_body:
try:
data = json.loads(response_body)
# 已经是统一格式的响应,不再封装
if isinstance(data, dict) and "code" in data and ("data" in data or "message" in data):
return Response(
content=response_body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type
)
# 封装为标准响应格式
result = StandardResponse(code=200, data=data).model_dump()
return JSONResponse(
content=result,
status_code=response.status_code,
headers=dict(response.headers)
)
except json.JSONDecodeError:
# 非JSON响应直接返回
return Response(
content=response_body,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type
)
return response
# 读取响应内容
try:
# 使用JSONResponse的_render方法获取内容
if isinstance(response, JSONResponse):
# 获取原始数据
raw_data = response.body.decode("utf-8")
data = json.loads(raw_data)
# 已经是标准格式,不再封装
if isinstance(data, dict) and "code" in data and ("data" in data or "message" in data):
return response
# 创建新的标准响应
std_response = StandardResponse(code=200, data=data)
# 创建新的JSONResponse
return JSONResponse(
content=std_response.model_dump(),
status_code=response.status_code,
headers=dict(response.headers)
)
# 处理流式响应或其他类型响应
else:
return response
except Exception as e:
# 出现异常,返回原始响应
return response
def add_response_middleware(app: FastAPI) -> None:
"""添加响应处理中间件到FastAPI应用"""

View File

@ -1,6 +1,8 @@
import asyncio
from app.db.database import Base, engine
from app.models.users import User
from app.models.user_images import UserImage
from app.models.clothing import ClothingCategory, Clothing
# 创建所有表格
async def init_db():

View File

@ -1,19 +1,29 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
import logging
from app.core.config import settings
from app.api.v1.api import api_router
from app.core.middleware import add_response_middleware
from app.core.exceptions import add_exception_handlers
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
# 在应用启动时执行
logger.info("应用启动,初始化数据库...")
from app.db.init_db import init_db
await init_db()
logger.info("数据库初始化完成")
yield
# 在应用关闭时执行
# 清理代码可以放在这里
logger.info("应用关闭")
app = FastAPI(
title=settings.PROJECT_NAME,
@ -25,7 +35,7 @@ app = FastAPI(
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.BACKEND_CORS_ORIGINS,
allow_origins=settings.cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],

33
app/models/clothing.py Normal file
View File

@ -0,0 +1,33 @@
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from app.db.database import Base
class ClothingCategory(Base):
"""衣服分类表"""
__tablename__ = "clothing_category"
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
name = Column(String(50), nullable=False, comment="分类名称")
create_time = Column(DateTime, default=func.now(), comment="创建时间")
# 关系
clothes = relationship("Clothing", back_populates="category", cascade="all, delete-orphan")
def __repr__(self):
return f"<ClothingCategory(id={self.id}, name={self.name})>"
class Clothing(Base):
"""衣服表"""
__tablename__ = "clothing"
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
clothing_category_id = Column(Integer, ForeignKey("clothing_category.id"), nullable=False, index=True, comment="分类ID")
image_url = Column(String(500), nullable=False, comment="图片URL")
create_time = Column(DateTime, default=func.now(), comment="创建时间")
# 关系
category = relationship("ClothingCategory", back_populates="clothes")
def __repr__(self):
return f"<Clothing(id={self.id}, category_id={self.clothing_category_id})>"

20
app/models/user_images.py Normal file
View File

@ -0,0 +1,20 @@
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Boolean
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from app.db.database import Base
class UserImage(Base):
"""用户个人形象库模型"""
__tablename__ = "user_images"
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True, comment="用户ID")
image_url = Column(String(500), nullable=False, comment="图片URL")
is_default = Column(Boolean, default=False, nullable=False, comment="是否为默认形象")
create_time = Column(DateTime, default=func.now(), comment="创建时间")
# 关系
user = relationship("User", back_populates="images")
def __repr__(self):
return f"<UserImage(id={self.id}, user_id={self.user_id})>"

View File

@ -1,5 +1,6 @@
from sqlalchemy import Column, Integer, String, DateTime
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from app.db.database import Base
class User(Base):
@ -12,5 +13,8 @@ class User(Base):
nickname = Column(String(50), nullable=True, comment="昵称")
create_time = Column(DateTime, default=func.now(), comment="创建时间")
# 关系
images = relationship("UserImage", back_populates="user", cascade="all, delete-orphan")
def __repr__(self):
return f"<User(id={self.id}, nickname={self.nickname})>"

63
app/schemas/clothing.py Normal file
View File

@ -0,0 +1,63 @@
from pydantic import BaseModel, Field
from typing import Optional, List
from datetime import datetime
# 分类模式
class ClothingCategoryBase(BaseModel):
"""衣服分类基础模式"""
name: str = Field(..., description="分类名称")
class ClothingCategoryCreate(ClothingCategoryBase):
"""创建衣服分类请求"""
pass
class ClothingCategoryInDB(ClothingCategoryBase):
"""数据库中的衣服分类数据"""
id: int
create_time: datetime
class Config:
from_attributes = True
class ClothingCategory(ClothingCategoryInDB):
"""衣服分类响应模式"""
pass
class ClothingCategoryUpdate(BaseModel):
"""更新衣服分类请求"""
name: Optional[str] = Field(None, description="分类名称")
# 衣服模式
class ClothingBase(BaseModel):
"""衣服基础模式"""
clothing_category_id: int = Field(..., description="分类ID")
image_url: str = Field(..., description="图片URL")
class ClothingCreate(ClothingBase):
"""创建衣服请求"""
pass
class ClothingInDB(ClothingBase):
"""数据库中的衣服数据"""
id: int
create_time: datetime
class Config:
from_attributes = True
class Clothing(ClothingInDB):
"""衣服响应模式"""
pass
class ClothingUpdate(BaseModel):
"""更新衣服请求"""
clothing_category_id: Optional[int] = Field(None, description="分类ID")
image_url: Optional[str] = Field(None, description="图片URL")
# 带有分类信息的衣服响应
class ClothingWithCategory(Clothing):
"""包含分类信息的衣服响应"""
category: ClothingCategory
class Config:
from_attributes = True

30
app/schemas/user_image.py Normal file
View File

@ -0,0 +1,30 @@
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
class UserImageBase(BaseModel):
"""用户形象基础模式"""
image_url: str = Field(..., description="图片URL")
is_default: bool = Field(False, description="是否为默认形象")
class UserImageCreate(UserImageBase):
"""创建用户形象请求"""
pass
class UserImageInDB(UserImageBase):
"""数据库中的用户形象数据"""
id: int
user_id: int
create_time: datetime
class Config:
from_attributes = True
class UserImage(UserImageInDB):
"""用户形象响应模式"""
pass
class UserImageUpdate(BaseModel):
"""更新用户形象请求"""
image_url: Optional[str] = Field(None, description="图片URL")
is_default: Optional[bool] = Field(None, description="是否为默认形象")

116
app/services/clothing.py Normal file
View File

@ -0,0 +1,116 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import update, delete
from typing import List, Optional
from app.models.clothing import Clothing, ClothingCategory
from app.schemas.clothing import ClothingCreate, ClothingUpdate
from app.schemas.clothing import ClothingCategoryCreate, ClothingCategoryUpdate
# 衣服分类服务函数
async def get_category(db: AsyncSession, category_id: int):
"""获取单个衣服分类"""
result = await db.execute(select(ClothingCategory).filter(ClothingCategory.id == category_id))
return result.scalars().first()
async def get_categories(db: AsyncSession, skip: int = 0, limit: int = 100):
"""获取所有衣服分类"""
result = await db.execute(
select(ClothingCategory)
.offset(skip)
.limit(limit)
)
return result.scalars().all()
async def create_category(db: AsyncSession, category: ClothingCategoryCreate):
"""创建衣服分类"""
db_category = ClothingCategory(
name=category.name
)
db.add(db_category)
await db.commit()
await db.refresh(db_category)
return db_category
async def update_category(db: AsyncSession, category_id: int, category_update: ClothingCategoryUpdate):
"""更新衣服分类"""
db_category = await get_category(db, category_id)
if not db_category:
return None
if category_update.name:
db_category.name = category_update.name
await db.commit()
await db.refresh(db_category)
return db_category
async def delete_category(db: AsyncSession, category_id: int):
"""删除衣服分类(会级联删除相关衣服)"""
db_category = await get_category(db, category_id)
if db_category:
await db.delete(db_category)
await db.commit()
return db_category
# 衣服服务函数
async def get_clothing(db: AsyncSession, clothing_id: int):
"""获取单个衣服"""
result = await db.execute(select(Clothing).filter(Clothing.id == clothing_id))
return result.scalars().first()
async def get_clothes(db: AsyncSession, skip: int = 0, limit: int = 100):
"""获取所有衣服"""
result = await db.execute(
select(Clothing)
.order_by(Clothing.create_time.desc())
.offset(skip)
.limit(limit)
)
return result.scalars().all()
async def get_clothes_by_category(db: AsyncSession, category_id: int, skip: int = 0, limit: int = 100):
"""根据分类获取衣服"""
result = await db.execute(
select(Clothing)
.filter(Clothing.clothing_category_id == category_id)
.order_by(Clothing.create_time.desc())
.offset(skip)
.limit(limit)
)
return result.scalars().all()
async def create_clothing(db: AsyncSession, clothing: ClothingCreate):
"""创建衣服"""
db_clothing = Clothing(
clothing_category_id=clothing.clothing_category_id,
image_url=clothing.image_url
)
db.add(db_clothing)
await db.commit()
await db.refresh(db_clothing)
return db_clothing
async def update_clothing(db: AsyncSession, clothing_id: int, clothing_update: ClothingUpdate):
"""更新衣服"""
db_clothing = await get_clothing(db, clothing_id)
if not db_clothing:
return None
if clothing_update.clothing_category_id is not None:
db_clothing.clothing_category_id = clothing_update.clothing_category_id
if clothing_update.image_url:
db_clothing.image_url = clothing_update.image_url
await db.commit()
await db.refresh(db_clothing)
return db_clothing
async def delete_clothing(db: AsyncSession, clothing_id: int):
"""删除衣服"""
db_clothing = await get_clothing(db, clothing_id)
if db_clothing:
await db.delete(db_clothing)
await db.commit()
return db_clothing

View File

@ -0,0 +1,272 @@
import os
import logging
import dashscope
from dashscope import Generation
# 修改导入语句dashscope的API响应可能改变了结构
from typing import List, Dict, Any, Optional
import asyncio
import httpx
from app.utils.config import get_settings
logger = logging.getLogger(__name__)
class DashScopeService:
"""DashScope服务类提供对DashScope API的调用封装"""
def __init__(self):
settings = get_settings()
self.api_key = settings.dashscope_api_key
# 配置DashScope
dashscope.api_key = self.api_key
# 配置API URL
self.image_synthesis_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis"
async def chat_completion(
self,
messages: List[Dict[str, str]],
model: str = "qwen-max",
max_tokens: int = 2048,
temperature: float = 0.7,
stream: bool = False
):
"""
调用DashScope的大模型API进行对话
Args:
messages: 对话历史记录
model: 模型名称
max_tokens: 最大生成token数
temperature: 温度参数控制随机性
stream: 是否流式输出
Returns:
ApiResponse: DashScope的API响应
"""
try:
# 为了不阻塞FastAPI的异步性能我们使用run_in_executor运行同步API
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: Generation.call(
model=model,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
result_format='message',
stream=stream,
)
)
if response.status_code != 200:
logger.error(f"DashScope API请求失败状态码{response.status_code}, 错误信息:{response.message}")
raise Exception(f"API调用失败: {response.message}")
return response
except Exception as e:
logger.error(f"DashScope聊天API调用出错: {str(e)}")
raise e
async def generate_image(
self,
prompt: str,
negative_prompt: Optional[str] = None,
model: str = "stable-diffusion-xl",
n: int = 1,
size: str = "1024*1024"
):
"""
调用DashScope的图像生成API
Args:
prompt: 生成图像的文本描述
negative_prompt: 负面提示词
model: 模型名称
n: 生成图像数量
size: 图像尺寸
Returns:
ApiResponse: DashScope的API响应
"""
try:
# 构建请求参数
params = {
"model": model,
"prompt": prompt,
"n": n,
"size": size,
}
if negative_prompt:
params["negative_prompt"] = negative_prompt
# 异步调用图像生成API
loop = asyncio.get_event_loop()
response = await loop.run_in_executor(
None,
lambda: dashscope.ImageSynthesis.call(**params)
)
if response.status_code != 200:
logger.error(f"DashScope 图像生成API请求失败状态码{response.status_code}, 错误信息:{response.message}")
raise Exception(f"图像生成API调用失败: {response.message}")
return response
except Exception as e:
logger.error(f"DashScope图像生成API调用出错: {str(e)}")
raise e
async def generate_tryon(
self,
person_image_url: str,
top_garment_url: Optional[str] = None,
bottom_garment_url: Optional[str] = None,
resolution: int = -1,
restore_face: bool = True
):
"""
调用阿里百炼平台的试衣服务
Args:
person_image_url: 人物图片URL
top_garment_url: 上衣图片URL
bottom_garment_url: 下衣图片URL
resolution: 分辨率-1表示自动
restore_face: 是否修复面部
Returns:
Dict: 包含任务ID和请求ID的响应
"""
try:
# 验证参数
if not person_image_url:
raise ValueError("人物图片URL不能为空")
if not top_garment_url and not bottom_garment_url:
raise ValueError("上衣和下衣图片至少需要提供一个")
# 构建请求数据
request_data = {
"model": "aitryon",
"input": {
"person_image_url": person_image_url
},
"parameters": {
"resolution": resolution,
"restore_face": restore_face
}
}
# 添加可选字段
if top_garment_url:
request_data["input"]["top_garment_url"] = top_garment_url
if bottom_garment_url:
request_data["input"]["bottom_garment_url"] = bottom_garment_url
# 构建请求头
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"X-DashScope-Async": "enable"
}
# 发送请求
async with httpx.AsyncClient() as client:
logger.info(f"发送试穿请求: {request_data}")
response = await client.post(
self.image_synthesis_url,
json=request_data,
headers=headers,
timeout=30.0
)
response_data = response.json()
logger.info(f"试穿API响应: {response_data}")
if response.status_code == 200 or response.status_code == 202: # 202表示异步任务已接受
# 提取任务ID适应不同的API响应格式
task_id = None
if 'output' in response_data and 'task_id' in response_data['output']:
task_id = response_data['output']['task_id']
elif 'task_id' in response_data:
task_id = response_data['task_id']
if task_id:
logger.info(f"试穿请求发送成功任务ID: {task_id}")
return {
"task_id": task_id,
"request_id": response_data.get('request_id'),
"status": "processing"
}
else:
# 如果没有任务ID这可能是同步响应
logger.info("收到同步响应没有任务ID")
return {
"status": "completed",
"result": response_data.get('output', {}),
"request_id": response_data.get('request_id')
}
else:
error_msg = f"试穿请求失败: {response.status_code} - {response.text}"
logger.error(error_msg)
raise Exception(error_msg)
except Exception as e:
logger.error(f"DashScope试穿API调用出错: {str(e)}")
raise e
async def check_tryon_status(self, task_id: str):
"""
检查试穿任务状态
Args:
task_id: 任务ID
Returns:
Dict: 任务状态信息
"""
try:
# 构建请求头
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求URL
status_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
# 发送请求
async with httpx.AsyncClient() as client:
response = await client.get(
status_url,
headers=headers,
timeout=30.0
)
if response.status_code == 200:
response_data = response.json()
print(response_data)
status = response_data.get('output', {}).get('task_status', '')
logger.info(f"试穿任务状态查询成功: {status}")
# 检查是否完成并返回结果
if status.lower() == 'succeeded':
image_url = response_data.get('output', {}).get('image_url')
if image_url:
logger.info(f"试穿任务完成结果URL: {image_url}")
return {
"status": status,
"task_id": task_id,
"image_url": image_url,
"result": response_data.get('output', {})
}
return {
"status": status,
"task_id": task_id
}
else:
error_msg = f"试穿任务状态查询失败: {response.status_code} - {response.text}"
logger.error(error_msg)
raise Exception(error_msg)
except Exception as e:
logger.error(f"查询试穿任务状态出错: {str(e)}")
raise e

View File

@ -0,0 +1,91 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy import delete, update
from app.models.user_images import UserImage
from app.schemas.user_image import UserImageCreate, UserImageUpdate
from typing import List, Optional
async def get_user_image(db: AsyncSession, image_id: int):
"""获取单个用户形象"""
result = await db.execute(select(UserImage).filter(UserImage.id == image_id))
return result.scalars().first()
async def get_user_images(db: AsyncSession, user_id: int, skip: int = 0, limit: int = 100):
"""获取用户的所有形象图片"""
result = await db.execute(
select(UserImage)
.filter(UserImage.user_id == user_id)
.order_by(UserImage.create_time.desc())
.offset(skip)
.limit(limit)
)
return result.scalars().all()
async def create_user_image(db: AsyncSession, user_id: int, image: UserImageCreate):
"""创建用户形象"""
# 如果设置为默认形象,先重置用户的所有形象为非默认
if image.is_default:
await reset_default_images(db, user_id)
db_image = UserImage(
user_id=user_id,
image_url=image.image_url,
is_default=image.is_default
)
db.add(db_image)
await db.commit()
await db.refresh(db_image)
return db_image
async def update_user_image(db: AsyncSession, image_id: int, image_update: UserImageUpdate):
"""更新用户形象"""
db_image = await get_user_image(db, image_id)
if not db_image:
return None
# 处理默认形象逻辑
if image_update.is_default is not None and image_update.is_default and not db_image.is_default:
# 如果设置为默认且之前不是默认,重置其他形象
await reset_default_images(db, db_image.user_id)
db_image.is_default = True
elif image_update.is_default is not None:
db_image.is_default = image_update.is_default
# 更新图片URL
if image_update.image_url:
db_image.image_url = image_update.image_url
await db.commit()
await db.refresh(db_image)
return db_image
async def delete_user_image(db: AsyncSession, image_id: int):
"""删除用户形象"""
db_image = await get_user_image(db, image_id)
if db_image:
await db.delete(db_image)
await db.commit()
return db_image
async def delete_user_images(db: AsyncSession, user_id: int):
"""删除用户所有形象图片"""
stmt = delete(UserImage).where(UserImage.user_id == user_id)
await db.execute(stmt)
await db.commit()
return True
async def reset_default_images(db: AsyncSession, user_id: int):
"""重置用户所有形象为非默认"""
stmt = update(UserImage).where(
UserImage.user_id == user_id,
UserImage.is_default == True
).values(is_default=False)
await db.execute(stmt)
async def get_default_image(db: AsyncSession, user_id: int):
"""获取用户的默认形象"""
result = await db.execute(
select(UserImage)
.filter(UserImage.user_id == user_id, UserImage.is_default == True)
)
return result.scalars().first()

View File

@ -8,4 +8,5 @@ aiomysql==0.2.0
greenlet==2.0.2
python-jose[cryptography]==3.3.0
passlib==1.7.4
httpx==0.24.1
httpx==0.24.1
dashscope==1.10.0