update
This commit is contained in:
parent
d5d7f9b89a
commit
16c8c11172
@ -2,8 +2,12 @@ from fastapi import APIRouter
|
|||||||
from app.api.v1.endpoints import router as endpoints_router
|
from app.api.v1.endpoints import router as endpoints_router
|
||||||
from app.api.v1.users import router as users_router
|
from app.api.v1.users import router as users_router
|
||||||
from app.api.v1.auth import router as auth_router
|
from app.api.v1.auth import router as auth_router
|
||||||
|
from app.api.v1.user_images import router as user_images_router
|
||||||
|
from app.api.v1.clothing import router as clothing_router
|
||||||
|
|
||||||
api_router = APIRouter()
|
api_router = APIRouter()
|
||||||
api_router.include_router(endpoints_router, prefix="")
|
api_router.include_router(endpoints_router, prefix="")
|
||||||
api_router.include_router(auth_router, prefix="/auth")
|
|
||||||
api_router.include_router(users_router, prefix="/users")
|
api_router.include_router(users_router, prefix="/users")
|
||||||
|
api_router.include_router(auth_router, prefix="/auth")
|
||||||
|
api_router.include_router(user_images_router, prefix="/user-images")
|
||||||
|
api_router.include_router(clothing_router, prefix="/clothing")
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status, Response
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
import logging
|
||||||
|
|
||||||
from app.schemas.auth import WechatLogin, LoginResponse
|
from app.schemas.auth import WechatLogin, LoginResponse
|
||||||
from app.schemas.user import UserCreate
|
from app.schemas.user import UserCreate
|
||||||
@ -9,6 +10,9 @@ from app.core.security import create_access_token
|
|||||||
from app.db.database import get_db
|
from app.db.database import get_db
|
||||||
from app.core.exceptions import BusinessError
|
from app.core.exceptions import BusinessError
|
||||||
|
|
||||||
|
# 创建日志记录器
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@router.post("/login/wechat", response_model=LoginResponse, tags=["auth"])
|
@router.post("/login/wechat", response_model=LoginResponse, tags=["auth"])
|
||||||
@ -23,12 +27,17 @@ async def wechat_login(
|
|||||||
- 如果用户不存在,则创建新用户
|
- 如果用户不存在,则创建新用户
|
||||||
- 生成JWT令牌
|
- 生成JWT令牌
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"尝试微信登录: {login_data.code[:5]}...")
|
||||||
|
|
||||||
# 调用微信API获取openid和unionid
|
# 调用微信API获取openid和unionid
|
||||||
openid, unionid = await wechat_service.code2session(login_data.code)
|
openid, unionid = await wechat_service.code2session(login_data.code)
|
||||||
|
logger.info(f"成功获取openid: {openid[:5]}...")
|
||||||
|
|
||||||
# 检查用户是否存在
|
# 检查用户是否存在
|
||||||
existing_user = await user_service.get_user_by_openid(db, openid=openid)
|
existing_user = await user_service.get_user_by_openid(db, openid=openid)
|
||||||
is_new_user = existing_user is None
|
is_new_user = existing_user is None
|
||||||
|
logger.info(f"用户状态: 新用户={is_new_user}")
|
||||||
|
|
||||||
if is_new_user:
|
if is_new_user:
|
||||||
# 创建新用户
|
# 创建新用户
|
||||||
@ -37,15 +46,29 @@ async def wechat_login(
|
|||||||
unionid=unionid
|
unionid=unionid
|
||||||
)
|
)
|
||||||
user = await user_service.create_user(db, user=user_create)
|
user = await user_service.create_user(db, user=user_create)
|
||||||
|
logger.info(f"创建新用户: id={user.id}")
|
||||||
else:
|
else:
|
||||||
user = existing_user
|
user = existing_user
|
||||||
|
logger.info(f"现有用户登录: id={user.id}")
|
||||||
|
|
||||||
# 创建访问令牌 - 使用openid作为标识
|
# 创建访问令牌 - 使用openid作为标识
|
||||||
access_token = create_access_token(subject=openid)
|
access_token = create_access_token(subject=openid)
|
||||||
|
|
||||||
# 返回登录响应
|
# 创建响应对象
|
||||||
return LoginResponse(
|
response_data = LoginResponse(
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
is_new_user=is_new_user,
|
is_new_user=is_new_user,
|
||||||
openid=openid
|
openid=openid
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 返回标准响应
|
||||||
|
return response_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 记录异常
|
||||||
|
logger.error(f"登录异常: {str(e)}", exc_info=True)
|
||||||
|
|
||||||
|
# 重新抛出业务异常或转换为业务异常
|
||||||
|
if isinstance(e, BusinessError):
|
||||||
|
raise
|
||||||
|
raise BusinessError(f"登录失败: {str(e)}", code=500)
|
||||||
187
app/api/v1/clothing.py
Normal file
187
app/api/v1/clothing.py
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from app.schemas.clothing import (
|
||||||
|
Clothing, ClothingCreate, ClothingUpdate, ClothingWithCategory,
|
||||||
|
ClothingCategory, ClothingCategoryCreate, ClothingCategoryUpdate
|
||||||
|
)
|
||||||
|
from app.services import clothing as clothing_service
|
||||||
|
from app.db.database import get_db
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.models.users import User as UserModel
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# 衣服分类API
|
||||||
|
@router.post("/categories", response_model=ClothingCategory, tags=["clothing-categories"])
|
||||||
|
async def create_category(
|
||||||
|
category: ClothingCategoryCreate,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
创建衣服分类
|
||||||
|
|
||||||
|
需要JWT令牌认证
|
||||||
|
"""
|
||||||
|
return await clothing_service.create_category(db=db, category=category)
|
||||||
|
|
||||||
|
@router.get("/categories", response_model=List[ClothingCategory], tags=["clothing-categories"])
|
||||||
|
async def read_categories(
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(100, ge=1, le=100),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""获取所有衣服分类"""
|
||||||
|
return await clothing_service.get_categories(db=db, skip=skip, limit=limit)
|
||||||
|
|
||||||
|
@router.get("/categories/{category_id}", response_model=ClothingCategory, tags=["clothing-categories"])
|
||||||
|
async def read_category(
|
||||||
|
category_id: int,
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""获取单个衣服分类"""
|
||||||
|
db_category = await clothing_service.get_category(db, category_id=category_id)
|
||||||
|
if db_category is None:
|
||||||
|
raise HTTPException(status_code=404, detail="分类不存在")
|
||||||
|
return db_category
|
||||||
|
|
||||||
|
@router.put("/categories/{category_id}", response_model=ClothingCategory, tags=["clothing-categories"])
|
||||||
|
async def update_category(
|
||||||
|
category_id: int,
|
||||||
|
category: ClothingCategoryUpdate,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
更新衣服分类
|
||||||
|
|
||||||
|
需要JWT令牌认证
|
||||||
|
"""
|
||||||
|
db_category = await clothing_service.update_category(
|
||||||
|
db=db,
|
||||||
|
category_id=category_id,
|
||||||
|
category_update=category
|
||||||
|
)
|
||||||
|
if db_category is None:
|
||||||
|
raise HTTPException(status_code=404, detail="分类不存在")
|
||||||
|
return db_category
|
||||||
|
|
||||||
|
@router.delete("/categories/{category_id}", response_model=ClothingCategory, tags=["clothing-categories"])
|
||||||
|
async def delete_category(
|
||||||
|
category_id: int,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
删除衣服分类(会级联删除相关衣服)
|
||||||
|
|
||||||
|
需要JWT令牌认证
|
||||||
|
"""
|
||||||
|
db_category = await clothing_service.delete_category(db=db, category_id=category_id)
|
||||||
|
if db_category is None:
|
||||||
|
raise HTTPException(status_code=404, detail="分类不存在")
|
||||||
|
return db_category
|
||||||
|
|
||||||
|
# 衣服API
|
||||||
|
@router.post("/", response_model=Clothing, tags=["clothing"])
|
||||||
|
async def create_clothing(
|
||||||
|
clothing: ClothingCreate,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
创建衣服
|
||||||
|
|
||||||
|
需要JWT令牌认证
|
||||||
|
"""
|
||||||
|
# 检查分类是否存在
|
||||||
|
category = await clothing_service.get_category(db, category_id=clothing.clothing_category_id)
|
||||||
|
if category is None:
|
||||||
|
raise HTTPException(status_code=404, detail="指定的分类不存在")
|
||||||
|
|
||||||
|
return await clothing_service.create_clothing(db=db, clothing=clothing)
|
||||||
|
|
||||||
|
@router.get("/", response_model=List[Clothing], tags=["clothing"])
|
||||||
|
async def read_clothes(
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(100, ge=1, le=100),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""获取所有衣服"""
|
||||||
|
return await clothing_service.get_clothes(db=db, skip=skip, limit=limit)
|
||||||
|
|
||||||
|
@router.get("/by-category/{category_id}", response_model=List[Clothing], tags=["clothing"])
|
||||||
|
async def read_clothes_by_category(
|
||||||
|
category_id: int,
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(100, ge=1, le=100),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""根据分类获取衣服"""
|
||||||
|
# 检查分类是否存在
|
||||||
|
category = await clothing_service.get_category(db, category_id=category_id)
|
||||||
|
if category is None:
|
||||||
|
raise HTTPException(status_code=404, detail="分类不存在")
|
||||||
|
|
||||||
|
return await clothing_service.get_clothes_by_category(
|
||||||
|
db=db,
|
||||||
|
category_id=category_id,
|
||||||
|
skip=skip,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/{clothing_id}", response_model=Clothing, tags=["clothing"])
|
||||||
|
async def read_clothing(
|
||||||
|
clothing_id: int,
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""获取单个衣服"""
|
||||||
|
db_clothing = await clothing_service.get_clothing(db, clothing_id=clothing_id)
|
||||||
|
if db_clothing is None:
|
||||||
|
raise HTTPException(status_code=404, detail="衣服不存在")
|
||||||
|
return db_clothing
|
||||||
|
|
||||||
|
@router.put("/{clothing_id}", response_model=Clothing, tags=["clothing"])
|
||||||
|
async def update_clothing(
|
||||||
|
clothing_id: int,
|
||||||
|
clothing: ClothingUpdate,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
更新衣服
|
||||||
|
|
||||||
|
需要JWT令牌认证
|
||||||
|
"""
|
||||||
|
# 检查分类是否存在
|
||||||
|
if clothing.clothing_category_id is not None:
|
||||||
|
category = await clothing_service.get_category(db, category_id=clothing.clothing_category_id)
|
||||||
|
if category is None:
|
||||||
|
raise HTTPException(status_code=404, detail="指定的分类不存在")
|
||||||
|
|
||||||
|
db_clothing = await clothing_service.update_clothing(
|
||||||
|
db=db,
|
||||||
|
clothing_id=clothing_id,
|
||||||
|
clothing_update=clothing
|
||||||
|
)
|
||||||
|
if db_clothing is None:
|
||||||
|
raise HTTPException(status_code=404, detail="衣服不存在")
|
||||||
|
return db_clothing
|
||||||
|
|
||||||
|
@router.delete("/{clothing_id}", response_model=Clothing, tags=["clothing"])
|
||||||
|
async def delete_clothing(
|
||||||
|
clothing_id: int,
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
删除衣服
|
||||||
|
|
||||||
|
需要JWT令牌认证
|
||||||
|
"""
|
||||||
|
db_clothing = await clothing_service.delete_clothing(db=db, clothing_id=clothing_id)
|
||||||
|
if db_clothing is None:
|
||||||
|
raise HTTPException(status_code=404, detail="衣服不存在")
|
||||||
|
return db_clothing
|
||||||
114
app/api/v1/user_images.py
Normal file
114
app/api/v1/user_images.py
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from app.schemas.user_image import UserImage, UserImageCreate, UserImageUpdate
|
||||||
|
from app.services import user_image as user_image_service
|
||||||
|
from app.db.database import get_db
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.models.users import User as UserModel
|
||||||
|
from app.core.exceptions import BusinessError
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.post("/", response_model=UserImage, tags=["user-images"])
|
||||||
|
async def create_image(
|
||||||
|
image: UserImageCreate,
|
||||||
|
current_user: UserModel = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
创建用户形象图片
|
||||||
|
|
||||||
|
需要JWT令牌认证
|
||||||
|
"""
|
||||||
|
return await user_image_service.create_user_image(
|
||||||
|
db=db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
image=image
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/", response_model=List[UserImage], tags=["user-images"])
|
||||||
|
async def read_images(
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(100, ge=1, le=100),
|
||||||
|
current_user: UserModel = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取当前用户的所有形象图片
|
||||||
|
|
||||||
|
需要JWT令牌认证
|
||||||
|
"""
|
||||||
|
return await user_image_service.get_user_images(
|
||||||
|
db=db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
skip=skip,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.get("/{image_id}", response_model=UserImage, tags=["user-images"])
|
||||||
|
async def read_image(
|
||||||
|
image_id: int,
|
||||||
|
current_user: UserModel = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取指定形象图片
|
||||||
|
|
||||||
|
需要JWT令牌认证
|
||||||
|
"""
|
||||||
|
image = await user_image_service.get_user_image(db, image_id=image_id)
|
||||||
|
if image is None:
|
||||||
|
raise HTTPException(status_code=404, detail="图片不存在")
|
||||||
|
if image.user_id != current_user.id:
|
||||||
|
raise HTTPException(status_code=403, detail="没有权限访问此图片")
|
||||||
|
return image
|
||||||
|
|
||||||
|
@router.put("/{image_id}", response_model=UserImage, tags=["user-images"])
|
||||||
|
async def update_image(
|
||||||
|
image_id: int,
|
||||||
|
image_update: UserImageUpdate,
|
||||||
|
current_user: UserModel = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
更新形象图片
|
||||||
|
|
||||||
|
需要JWT令牌认证
|
||||||
|
"""
|
||||||
|
# 先检查图片是否存在
|
||||||
|
db_image = await user_image_service.get_user_image(db, image_id=image_id)
|
||||||
|
if db_image is None:
|
||||||
|
raise HTTPException(status_code=404, detail="图片不存在")
|
||||||
|
if db_image.user_id != current_user.id:
|
||||||
|
raise HTTPException(status_code=403, detail="没有权限更新此图片")
|
||||||
|
|
||||||
|
# 执行更新
|
||||||
|
updated_image = await user_image_service.update_user_image(
|
||||||
|
db=db,
|
||||||
|
image_id=image_id,
|
||||||
|
image_update=image_update
|
||||||
|
)
|
||||||
|
return updated_image
|
||||||
|
|
||||||
|
@router.delete("/{image_id}", response_model=UserImage, tags=["user-images"])
|
||||||
|
async def delete_image(
|
||||||
|
image_id: int,
|
||||||
|
current_user: UserModel = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
删除形象图片
|
||||||
|
|
||||||
|
需要JWT令牌认证
|
||||||
|
"""
|
||||||
|
# 先检查图片是否存在
|
||||||
|
db_image = await user_image_service.get_user_image(db, image_id=image_id)
|
||||||
|
if db_image is None:
|
||||||
|
raise HTTPException(status_code=404, detail="图片不存在")
|
||||||
|
if db_image.user_id != current_user.id:
|
||||||
|
raise HTTPException(status_code=403, detail="没有权限删除此图片")
|
||||||
|
|
||||||
|
# 执行删除
|
||||||
|
return await user_image_service.delete_user_image(db=db, image_id=image_id)
|
||||||
@ -23,7 +23,8 @@ class Settings(BaseSettings):
|
|||||||
DB_ECHO: bool = os.getenv("DB_ECHO", "False").lower() == "true"
|
DB_ECHO: bool = os.getenv("DB_ECHO", "False").lower() == "true"
|
||||||
|
|
||||||
# DashScope API密钥
|
# DashScope API密钥
|
||||||
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "sk-caa199589f1c451aaac471fad2986e28")
|
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "")
|
||||||
|
DASHSCOPE_MODEL: str = os.getenv("DASHSCOPE_MODEL", "qwen-max")
|
||||||
|
|
||||||
# 安全设置
|
# 安全设置
|
||||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-for-jwt-please-change-in-production")
|
SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-for-jwt-please-change-in-production")
|
||||||
@ -34,6 +35,15 @@ class Settings(BaseSettings):
|
|||||||
WECHAT_APP_ID: str = os.getenv("WECHAT_APP_ID", "")
|
WECHAT_APP_ID: str = os.getenv("WECHAT_APP_ID", "")
|
||||||
WECHAT_APP_SECRET: str = os.getenv("WECHAT_APP_SECRET", "")
|
WECHAT_APP_SECRET: str = os.getenv("WECHAT_APP_SECRET", "")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cors_origins(self):
|
||||||
|
"""获取CORS来源列表"""
|
||||||
|
if isinstance(self.BACKEND_CORS_ORIGINS, str) and self.BACKEND_CORS_ORIGINS == "*":
|
||||||
|
return ["*"]
|
||||||
|
elif isinstance(self.BACKEND_CORS_ORIGINS, list):
|
||||||
|
return self.BACKEND_CORS_ORIGINS
|
||||||
|
return []
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
case_sensitive = True
|
case_sensitive = True
|
||||||
|
|||||||
@ -20,47 +20,40 @@ class ResponseMiddleware(BaseHTTPMiddleware):
|
|||||||
if any(request.url.path.startswith(path) for path in exclude_paths):
|
if any(request.url.path.startswith(path) for path in exclude_paths):
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
|
# 获取原始响应
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
|
|
||||||
# 如果是HTTPException,直接返回,不进行封装
|
# 如果不是200系列响应或不是JSON响应,直接返回
|
||||||
if response.status_code >= 400:
|
if response.status_code >= 300 or response.headers.get("content-type") != "application/json":
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# 正常响应进行封装
|
# 读取响应内容
|
||||||
if response.status_code < 300 and response.headers.get("content-type") == "application/json":
|
|
||||||
response_body = [chunk async for chunk in response.body_iterator]
|
|
||||||
response_body = b"".join(response_body)
|
|
||||||
|
|
||||||
if response_body:
|
|
||||||
try:
|
try:
|
||||||
data = json.loads(response_body)
|
# 使用JSONResponse的_render方法获取内容
|
||||||
|
if isinstance(response, JSONResponse):
|
||||||
|
# 获取原始数据
|
||||||
|
raw_data = response.body.decode("utf-8")
|
||||||
|
data = json.loads(raw_data)
|
||||||
|
|
||||||
# 已经是统一格式的响应,不再封装
|
# 已经是标准格式,不再封装
|
||||||
if isinstance(data, dict) and "code" in data and ("data" in data or "message" in data):
|
if isinstance(data, dict) and "code" in data and ("data" in data or "message" in data):
|
||||||
return Response(
|
return response
|
||||||
content=response_body,
|
|
||||||
status_code=response.status_code,
|
|
||||||
headers=dict(response.headers),
|
|
||||||
media_type=response.media_type
|
|
||||||
)
|
|
||||||
|
|
||||||
# 封装为标准响应格式
|
# 创建新的标准响应
|
||||||
result = StandardResponse(code=200, data=data).model_dump()
|
std_response = StandardResponse(code=200, data=data)
|
||||||
|
|
||||||
|
# 创建新的JSONResponse
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content=result,
|
content=std_response.model_dump(),
|
||||||
status_code=response.status_code,
|
status_code=response.status_code,
|
||||||
headers=dict(response.headers)
|
headers=dict(response.headers)
|
||||||
)
|
)
|
||||||
|
# 处理流式响应或其他类型响应
|
||||||
|
else:
|
||||||
|
return response
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except Exception as e:
|
||||||
# 非JSON响应,直接返回
|
# 出现异常,返回原始响应
|
||||||
return Response(
|
|
||||||
content=response_body,
|
|
||||||
status_code=response.status_code,
|
|
||||||
headers=dict(response.headers),
|
|
||||||
media_type=response.media_type
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def add_response_middleware(app: FastAPI) -> None:
|
def add_response_middleware(app: FastAPI) -> None:
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from app.db.database import Base, engine
|
from app.db.database import Base, engine
|
||||||
from app.models.users import User
|
from app.models.users import User
|
||||||
|
from app.models.user_images import UserImage
|
||||||
|
from app.models.clothing import ClothingCategory, Clothing
|
||||||
|
|
||||||
# 创建所有表格
|
# 创建所有表格
|
||||||
async def init_db():
|
async def init_db():
|
||||||
|
|||||||
14
app/main.py
14
app/main.py
@ -1,19 +1,29 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
import logging
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.api.v1.api import api_router
|
from app.api.v1.api import api_router
|
||||||
from app.core.middleware import add_response_middleware
|
from app.core.middleware import add_response_middleware
|
||||||
from app.core.exceptions import add_exception_handlers
|
from app.core.exceptions import add_exception_handlers
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# 在应用启动时执行
|
# 在应用启动时执行
|
||||||
|
logger.info("应用启动,初始化数据库...")
|
||||||
from app.db.init_db import init_db
|
from app.db.init_db import init_db
|
||||||
await init_db()
|
await init_db()
|
||||||
|
logger.info("数据库初始化完成")
|
||||||
yield
|
yield
|
||||||
# 在应用关闭时执行
|
# 在应用关闭时执行
|
||||||
# 清理代码可以放在这里
|
logger.info("应用关闭")
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title=settings.PROJECT_NAME,
|
title=settings.PROJECT_NAME,
|
||||||
@ -25,7 +35,7 @@ app = FastAPI(
|
|||||||
# 配置CORS
|
# 配置CORS
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=settings.BACKEND_CORS_ORIGINS,
|
allow_origins=settings.cors_origins,
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
|
|||||||
33
app/models/clothing.py
Normal file
33
app/models/clothing.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
from app.db.database import Base
|
||||||
|
|
||||||
|
class ClothingCategory(Base):
|
||||||
|
"""衣服分类表"""
|
||||||
|
__tablename__ = "clothing_category"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
|
||||||
|
name = Column(String(50), nullable=False, comment="分类名称")
|
||||||
|
create_time = Column(DateTime, default=func.now(), comment="创建时间")
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
clothes = relationship("Clothing", back_populates="category", cascade="all, delete-orphan")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<ClothingCategory(id={self.id}, name={self.name})>"
|
||||||
|
|
||||||
|
class Clothing(Base):
|
||||||
|
"""衣服表"""
|
||||||
|
__tablename__ = "clothing"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
|
||||||
|
clothing_category_id = Column(Integer, ForeignKey("clothing_category.id"), nullable=False, index=True, comment="分类ID")
|
||||||
|
image_url = Column(String(500), nullable=False, comment="图片URL")
|
||||||
|
create_time = Column(DateTime, default=func.now(), comment="创建时间")
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
category = relationship("ClothingCategory", back_populates="clothes")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<Clothing(id={self.id}, category_id={self.clothing_category_id})>"
|
||||||
20
app/models/user_images.py
Normal file
20
app/models/user_images.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Boolean
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
|
from app.db.database import Base
|
||||||
|
|
||||||
|
class UserImage(Base):
|
||||||
|
"""用户个人形象库模型"""
|
||||||
|
__tablename__ = "user_images"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
|
||||||
|
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True, comment="用户ID")
|
||||||
|
image_url = Column(String(500), nullable=False, comment="图片URL")
|
||||||
|
is_default = Column(Boolean, default=False, nullable=False, comment="是否为默认形象")
|
||||||
|
create_time = Column(DateTime, default=func.now(), comment="创建时间")
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
user = relationship("User", back_populates="images")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<UserImage(id={self.id}, user_id={self.user_id})>"
|
||||||
@ -1,5 +1,6 @@
|
|||||||
from sqlalchemy import Column, Integer, String, DateTime
|
from sqlalchemy import Column, Integer, String, DateTime
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
|
from sqlalchemy.orm import relationship
|
||||||
from app.db.database import Base
|
from app.db.database import Base
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
@ -12,5 +13,8 @@ class User(Base):
|
|||||||
nickname = Column(String(50), nullable=True, comment="昵称")
|
nickname = Column(String(50), nullable=True, comment="昵称")
|
||||||
create_time = Column(DateTime, default=func.now(), comment="创建时间")
|
create_time = Column(DateTime, default=func.now(), comment="创建时间")
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
images = relationship("UserImage", back_populates="user", cascade="all, delete-orphan")
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<User(id={self.id}, nickname={self.nickname})>"
|
return f"<User(id={self.id}, nickname={self.nickname})>"
|
||||||
63
app/schemas/clothing.py
Normal file
63
app/schemas/clothing.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional, List
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# 分类模式
|
||||||
|
class ClothingCategoryBase(BaseModel):
|
||||||
|
"""衣服分类基础模式"""
|
||||||
|
name: str = Field(..., description="分类名称")
|
||||||
|
|
||||||
|
class ClothingCategoryCreate(ClothingCategoryBase):
|
||||||
|
"""创建衣服分类请求"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ClothingCategoryInDB(ClothingCategoryBase):
|
||||||
|
"""数据库中的衣服分类数据"""
|
||||||
|
id: int
|
||||||
|
create_time: datetime
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
class ClothingCategory(ClothingCategoryInDB):
|
||||||
|
"""衣服分类响应模式"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ClothingCategoryUpdate(BaseModel):
|
||||||
|
"""更新衣服分类请求"""
|
||||||
|
name: Optional[str] = Field(None, description="分类名称")
|
||||||
|
|
||||||
|
# 衣服模式
|
||||||
|
class ClothingBase(BaseModel):
|
||||||
|
"""衣服基础模式"""
|
||||||
|
clothing_category_id: int = Field(..., description="分类ID")
|
||||||
|
image_url: str = Field(..., description="图片URL")
|
||||||
|
|
||||||
|
class ClothingCreate(ClothingBase):
|
||||||
|
"""创建衣服请求"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ClothingInDB(ClothingBase):
|
||||||
|
"""数据库中的衣服数据"""
|
||||||
|
id: int
|
||||||
|
create_time: datetime
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
class Clothing(ClothingInDB):
|
||||||
|
"""衣服响应模式"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ClothingUpdate(BaseModel):
|
||||||
|
"""更新衣服请求"""
|
||||||
|
clothing_category_id: Optional[int] = Field(None, description="分类ID")
|
||||||
|
image_url: Optional[str] = Field(None, description="图片URL")
|
||||||
|
|
||||||
|
# 带有分类信息的衣服响应
|
||||||
|
class ClothingWithCategory(Clothing):
|
||||||
|
"""包含分类信息的衣服响应"""
|
||||||
|
category: ClothingCategory
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
30
app/schemas/user_image.py
Normal file
30
app/schemas/user_image.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
class UserImageBase(BaseModel):
|
||||||
|
"""用户形象基础模式"""
|
||||||
|
image_url: str = Field(..., description="图片URL")
|
||||||
|
is_default: bool = Field(False, description="是否为默认形象")
|
||||||
|
|
||||||
|
class UserImageCreate(UserImageBase):
|
||||||
|
"""创建用户形象请求"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class UserImageInDB(UserImageBase):
|
||||||
|
"""数据库中的用户形象数据"""
|
||||||
|
id: int
|
||||||
|
user_id: int
|
||||||
|
create_time: datetime
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
class UserImage(UserImageInDB):
|
||||||
|
"""用户形象响应模式"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class UserImageUpdate(BaseModel):
|
||||||
|
"""更新用户形象请求"""
|
||||||
|
image_url: Optional[str] = Field(None, description="图片URL")
|
||||||
|
is_default: Optional[bool] = Field(None, description="是否为默认形象")
|
||||||
116
app/services/clothing.py
Normal file
116
app/services/clothing.py
Normal file
@ -0,0 +1,116 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
from sqlalchemy import update, delete
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from app.models.clothing import Clothing, ClothingCategory
|
||||||
|
from app.schemas.clothing import ClothingCreate, ClothingUpdate
|
||||||
|
from app.schemas.clothing import ClothingCategoryCreate, ClothingCategoryUpdate
|
||||||
|
|
||||||
|
# 衣服分类服务函数
|
||||||
|
async def get_category(db: AsyncSession, category_id: int):
|
||||||
|
"""获取单个衣服分类"""
|
||||||
|
result = await db.execute(select(ClothingCategory).filter(ClothingCategory.id == category_id))
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
async def get_categories(db: AsyncSession, skip: int = 0, limit: int = 100):
|
||||||
|
"""获取所有衣服分类"""
|
||||||
|
result = await db.execute(
|
||||||
|
select(ClothingCategory)
|
||||||
|
.offset(skip)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def create_category(db: AsyncSession, category: ClothingCategoryCreate):
|
||||||
|
"""创建衣服分类"""
|
||||||
|
db_category = ClothingCategory(
|
||||||
|
name=category.name
|
||||||
|
)
|
||||||
|
db.add(db_category)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_category)
|
||||||
|
return db_category
|
||||||
|
|
||||||
|
async def update_category(db: AsyncSession, category_id: int, category_update: ClothingCategoryUpdate):
|
||||||
|
"""更新衣服分类"""
|
||||||
|
db_category = await get_category(db, category_id)
|
||||||
|
if not db_category:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if category_update.name:
|
||||||
|
db_category.name = category_update.name
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_category)
|
||||||
|
return db_category
|
||||||
|
|
||||||
|
async def delete_category(db: AsyncSession, category_id: int):
|
||||||
|
"""删除衣服分类(会级联删除相关衣服)"""
|
||||||
|
db_category = await get_category(db, category_id)
|
||||||
|
if db_category:
|
||||||
|
await db.delete(db_category)
|
||||||
|
await db.commit()
|
||||||
|
return db_category
|
||||||
|
|
||||||
|
# 衣服服务函数
|
||||||
|
async def get_clothing(db: AsyncSession, clothing_id: int):
|
||||||
|
"""获取单个衣服"""
|
||||||
|
result = await db.execute(select(Clothing).filter(Clothing.id == clothing_id))
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
async def get_clothes(db: AsyncSession, skip: int = 0, limit: int = 100):
|
||||||
|
"""获取所有衣服"""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Clothing)
|
||||||
|
.order_by(Clothing.create_time.desc())
|
||||||
|
.offset(skip)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def get_clothes_by_category(db: AsyncSession, category_id: int, skip: int = 0, limit: int = 100):
|
||||||
|
"""根据分类获取衣服"""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Clothing)
|
||||||
|
.filter(Clothing.clothing_category_id == category_id)
|
||||||
|
.order_by(Clothing.create_time.desc())
|
||||||
|
.offset(skip)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def create_clothing(db: AsyncSession, clothing: ClothingCreate):
|
||||||
|
"""创建衣服"""
|
||||||
|
db_clothing = Clothing(
|
||||||
|
clothing_category_id=clothing.clothing_category_id,
|
||||||
|
image_url=clothing.image_url
|
||||||
|
)
|
||||||
|
db.add(db_clothing)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_clothing)
|
||||||
|
return db_clothing
|
||||||
|
|
||||||
|
async def update_clothing(db: AsyncSession, clothing_id: int, clothing_update: ClothingUpdate):
|
||||||
|
"""更新衣服"""
|
||||||
|
db_clothing = await get_clothing(db, clothing_id)
|
||||||
|
if not db_clothing:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if clothing_update.clothing_category_id is not None:
|
||||||
|
db_clothing.clothing_category_id = clothing_update.clothing_category_id
|
||||||
|
|
||||||
|
if clothing_update.image_url:
|
||||||
|
db_clothing.image_url = clothing_update.image_url
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_clothing)
|
||||||
|
return db_clothing
|
||||||
|
|
||||||
|
async def delete_clothing(db: AsyncSession, clothing_id: int):
|
||||||
|
"""删除衣服"""
|
||||||
|
db_clothing = await get_clothing(db, clothing_id)
|
||||||
|
if db_clothing:
|
||||||
|
await db.delete(db_clothing)
|
||||||
|
await db.commit()
|
||||||
|
return db_clothing
|
||||||
272
app/services/dashscope_service.py
Normal file
272
app/services/dashscope_service.py
Normal file
@ -0,0 +1,272 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import dashscope
|
||||||
|
from dashscope import Generation
|
||||||
|
# 修改导入语句,dashscope的API响应可能改变了结构
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
import asyncio
|
||||||
|
import httpx
|
||||||
|
from app.utils.config import get_settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class DashScopeService:
|
||||||
|
"""DashScope服务类,提供对DashScope API的调用封装"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
settings = get_settings()
|
||||||
|
self.api_key = settings.dashscope_api_key
|
||||||
|
# 配置DashScope
|
||||||
|
dashscope.api_key = self.api_key
|
||||||
|
# 配置API URL
|
||||||
|
self.image_synthesis_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis"
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
model: str = "qwen-max",
|
||||||
|
max_tokens: int = 2048,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
stream: bool = False
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
调用DashScope的大模型API进行对话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: 对话历史记录
|
||||||
|
model: 模型名称
|
||||||
|
max_tokens: 最大生成token数
|
||||||
|
temperature: 温度参数,控制随机性
|
||||||
|
stream: 是否流式输出
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: DashScope的API响应
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 为了不阻塞FastAPI的异步性能,我们使用run_in_executor运行同步API
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
response = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: Generation.call(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
result_format='message',
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"DashScope API请求失败,状态码:{response.status_code}, 错误信息:{response.message}")
|
||||||
|
raise Exception(f"API调用失败: {response.message}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DashScope聊天API调用出错: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def generate_image(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: Optional[str] = None,
|
||||||
|
model: str = "stable-diffusion-xl",
|
||||||
|
n: int = 1,
|
||||||
|
size: str = "1024*1024"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
调用DashScope的图像生成API
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 生成图像的文本描述
|
||||||
|
negative_prompt: 负面提示词
|
||||||
|
model: 模型名称
|
||||||
|
n: 生成图像数量
|
||||||
|
size: 图像尺寸
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: DashScope的API响应
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建请求参数
|
||||||
|
params = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"n": n,
|
||||||
|
"size": size,
|
||||||
|
}
|
||||||
|
|
||||||
|
if negative_prompt:
|
||||||
|
params["negative_prompt"] = negative_prompt
|
||||||
|
|
||||||
|
# 异步调用图像生成API
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
response = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: dashscope.ImageSynthesis.call(**params)
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"DashScope 图像生成API请求失败,状态码:{response.status_code}, 错误信息:{response.message}")
|
||||||
|
raise Exception(f"图像生成API调用失败: {response.message}")
|
||||||
|
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DashScope图像生成API调用出错: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def generate_tryon(
|
||||||
|
self,
|
||||||
|
person_image_url: str,
|
||||||
|
top_garment_url: Optional[str] = None,
|
||||||
|
bottom_garment_url: Optional[str] = None,
|
||||||
|
resolution: int = -1,
|
||||||
|
restore_face: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
调用阿里百炼平台的试衣服务
|
||||||
|
|
||||||
|
Args:
|
||||||
|
person_image_url: 人物图片URL
|
||||||
|
top_garment_url: 上衣图片URL
|
||||||
|
bottom_garment_url: 下衣图片URL
|
||||||
|
resolution: 分辨率,-1表示自动
|
||||||
|
restore_face: 是否修复面部
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 包含任务ID和请求ID的响应
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 验证参数
|
||||||
|
if not person_image_url:
|
||||||
|
raise ValueError("人物图片URL不能为空")
|
||||||
|
|
||||||
|
if not top_garment_url and not bottom_garment_url:
|
||||||
|
raise ValueError("上衣和下衣图片至少需要提供一个")
|
||||||
|
|
||||||
|
# 构建请求数据
|
||||||
|
request_data = {
|
||||||
|
"model": "aitryon",
|
||||||
|
"input": {
|
||||||
|
"person_image_url": person_image_url
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"resolution": resolution,
|
||||||
|
"restore_face": restore_face
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加可选字段
|
||||||
|
if top_garment_url:
|
||||||
|
request_data["input"]["top_garment_url"] = top_garment_url
|
||||||
|
if bottom_garment_url:
|
||||||
|
request_data["input"]["bottom_garment_url"] = bottom_garment_url
|
||||||
|
|
||||||
|
# 构建请求头
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"X-DashScope-Async": "enable"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
logger.info(f"发送试穿请求: {request_data}")
|
||||||
|
response = await client.post(
|
||||||
|
self.image_synthesis_url,
|
||||||
|
json=request_data,
|
||||||
|
headers=headers,
|
||||||
|
timeout=30.0
|
||||||
|
)
|
||||||
|
|
||||||
|
response_data = response.json()
|
||||||
|
logger.info(f"试穿API响应: {response_data}")
|
||||||
|
|
||||||
|
if response.status_code == 200 or response.status_code == 202: # 202表示异步任务已接受
|
||||||
|
# 提取任务ID,适应不同的API响应格式
|
||||||
|
task_id = None
|
||||||
|
if 'output' in response_data and 'task_id' in response_data['output']:
|
||||||
|
task_id = response_data['output']['task_id']
|
||||||
|
elif 'task_id' in response_data:
|
||||||
|
task_id = response_data['task_id']
|
||||||
|
|
||||||
|
if task_id:
|
||||||
|
logger.info(f"试穿请求发送成功,任务ID: {task_id}")
|
||||||
|
return {
|
||||||
|
"task_id": task_id,
|
||||||
|
"request_id": response_data.get('request_id'),
|
||||||
|
"status": "processing"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 如果没有任务ID,这可能是同步响应
|
||||||
|
logger.info("收到同步响应,没有任务ID")
|
||||||
|
return {
|
||||||
|
"status": "completed",
|
||||||
|
"result": response_data.get('output', {}),
|
||||||
|
"request_id": response_data.get('request_id')
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
error_msg = f"试穿请求失败: {response.status_code} - {response.text}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise Exception(error_msg)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DashScope试穿API调用出错: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def check_tryon_status(self, task_id: str):
|
||||||
|
"""
|
||||||
|
检查试穿任务状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: 任务ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: 任务状态信息
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建请求头
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求URL
|
||||||
|
status_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
|
||||||
|
|
||||||
|
# 发送请求
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
status_url,
|
||||||
|
headers=headers,
|
||||||
|
timeout=30.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
response_data = response.json()
|
||||||
|
print(response_data)
|
||||||
|
status = response_data.get('output', {}).get('task_status', '')
|
||||||
|
logger.info(f"试穿任务状态查询成功: {status}")
|
||||||
|
|
||||||
|
# 检查是否完成并返回结果
|
||||||
|
if status.lower() == 'succeeded':
|
||||||
|
image_url = response_data.get('output', {}).get('image_url')
|
||||||
|
if image_url:
|
||||||
|
logger.info(f"试穿任务完成,结果URL: {image_url}")
|
||||||
|
return {
|
||||||
|
"status": status,
|
||||||
|
"task_id": task_id,
|
||||||
|
"image_url": image_url,
|
||||||
|
"result": response_data.get('output', {})
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": status,
|
||||||
|
"task_id": task_id
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
error_msg = f"试穿任务状态查询失败: {response.status_code} - {response.text}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise Exception(error_msg)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询试穿任务状态出错: {str(e)}")
|
||||||
|
raise e
|
||||||
91
app/services/user_image.py
Normal file
91
app/services/user_image.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.future import select
|
||||||
|
from sqlalchemy import delete, update
|
||||||
|
from app.models.user_images import UserImage
|
||||||
|
from app.schemas.user_image import UserImageCreate, UserImageUpdate
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
async def get_user_image(db: AsyncSession, image_id: int):
|
||||||
|
"""获取单个用户形象"""
|
||||||
|
result = await db.execute(select(UserImage).filter(UserImage.id == image_id))
|
||||||
|
return result.scalars().first()
|
||||||
|
|
||||||
|
async def get_user_images(db: AsyncSession, user_id: int, skip: int = 0, limit: int = 100):
|
||||||
|
"""获取用户的所有形象图片"""
|
||||||
|
result = await db.execute(
|
||||||
|
select(UserImage)
|
||||||
|
.filter(UserImage.user_id == user_id)
|
||||||
|
.order_by(UserImage.create_time.desc())
|
||||||
|
.offset(skip)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def create_user_image(db: AsyncSession, user_id: int, image: UserImageCreate):
|
||||||
|
"""创建用户形象"""
|
||||||
|
# 如果设置为默认形象,先重置用户的所有形象为非默认
|
||||||
|
if image.is_default:
|
||||||
|
await reset_default_images(db, user_id)
|
||||||
|
|
||||||
|
db_image = UserImage(
|
||||||
|
user_id=user_id,
|
||||||
|
image_url=image.image_url,
|
||||||
|
is_default=image.is_default
|
||||||
|
)
|
||||||
|
db.add(db_image)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_image)
|
||||||
|
return db_image
|
||||||
|
|
||||||
|
async def update_user_image(db: AsyncSession, image_id: int, image_update: UserImageUpdate):
|
||||||
|
"""更新用户形象"""
|
||||||
|
db_image = await get_user_image(db, image_id)
|
||||||
|
if not db_image:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 处理默认形象逻辑
|
||||||
|
if image_update.is_default is not None and image_update.is_default and not db_image.is_default:
|
||||||
|
# 如果设置为默认且之前不是默认,重置其他形象
|
||||||
|
await reset_default_images(db, db_image.user_id)
|
||||||
|
db_image.is_default = True
|
||||||
|
elif image_update.is_default is not None:
|
||||||
|
db_image.is_default = image_update.is_default
|
||||||
|
|
||||||
|
# 更新图片URL
|
||||||
|
if image_update.image_url:
|
||||||
|
db_image.image_url = image_update.image_url
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_image)
|
||||||
|
return db_image
|
||||||
|
|
||||||
|
async def delete_user_image(db: AsyncSession, image_id: int):
|
||||||
|
"""删除用户形象"""
|
||||||
|
db_image = await get_user_image(db, image_id)
|
||||||
|
if db_image:
|
||||||
|
await db.delete(db_image)
|
||||||
|
await db.commit()
|
||||||
|
return db_image
|
||||||
|
|
||||||
|
async def delete_user_images(db: AsyncSession, user_id: int):
|
||||||
|
"""删除用户所有形象图片"""
|
||||||
|
stmt = delete(UserImage).where(UserImage.user_id == user_id)
|
||||||
|
await db.execute(stmt)
|
||||||
|
await db.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def reset_default_images(db: AsyncSession, user_id: int):
|
||||||
|
"""重置用户所有形象为非默认"""
|
||||||
|
stmt = update(UserImage).where(
|
||||||
|
UserImage.user_id == user_id,
|
||||||
|
UserImage.is_default == True
|
||||||
|
).values(is_default=False)
|
||||||
|
await db.execute(stmt)
|
||||||
|
|
||||||
|
async def get_default_image(db: AsyncSession, user_id: int):
|
||||||
|
"""获取用户的默认形象"""
|
||||||
|
result = await db.execute(
|
||||||
|
select(UserImage)
|
||||||
|
.filter(UserImage.user_id == user_id, UserImage.is_default == True)
|
||||||
|
)
|
||||||
|
return result.scalars().first()
|
||||||
@ -9,3 +9,4 @@ greenlet==2.0.2
|
|||||||
python-jose[cryptography]==3.3.0
|
python-jose[cryptography]==3.3.0
|
||||||
passlib==1.7.4
|
passlib==1.7.4
|
||||||
httpx==0.24.1
|
httpx==0.24.1
|
||||||
|
dashscope==1.10.0
|
||||||
Loading…
Reference in New Issue
Block a user