update
This commit is contained in:
parent
d5d7f9b89a
commit
16c8c11172
@ -2,8 +2,12 @@ from fastapi import APIRouter
|
||||
from app.api.v1.endpoints import router as endpoints_router
|
||||
from app.api.v1.users import router as users_router
|
||||
from app.api.v1.auth import router as auth_router
|
||||
from app.api.v1.user_images import router as user_images_router
|
||||
from app.api.v1.clothing import router as clothing_router
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(endpoints_router, prefix="")
|
||||
api_router.include_router(auth_router, prefix="/auth")
|
||||
api_router.include_router(users_router, prefix="/users")
|
||||
api_router.include_router(auth_router, prefix="/auth")
|
||||
api_router.include_router(user_images_router, prefix="/user-images")
|
||||
api_router.include_router(clothing_router, prefix="/clothing")
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Response
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import logging
|
||||
|
||||
from app.schemas.auth import WechatLogin, LoginResponse
|
||||
from app.schemas.user import UserCreate
|
||||
@ -9,6 +10,9 @@ from app.core.security import create_access_token
|
||||
from app.db.database import get_db
|
||||
from app.core.exceptions import BusinessError
|
||||
|
||||
# 创建日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/login/wechat", response_model=LoginResponse, tags=["auth"])
|
||||
@ -23,29 +27,48 @@ async def wechat_login(
|
||||
- 如果用户不存在,则创建新用户
|
||||
- 生成JWT令牌
|
||||
"""
|
||||
# 调用微信API获取openid和unionid
|
||||
openid, unionid = await wechat_service.code2session(login_data.code)
|
||||
|
||||
# 检查用户是否存在
|
||||
existing_user = await user_service.get_user_by_openid(db, openid=openid)
|
||||
is_new_user = existing_user is None
|
||||
|
||||
if is_new_user:
|
||||
# 创建新用户
|
||||
user_create = UserCreate(
|
||||
openid=openid,
|
||||
unionid=unionid
|
||||
)
|
||||
user = await user_service.create_user(db, user=user_create)
|
||||
else:
|
||||
user = existing_user
|
||||
try:
|
||||
logger.info(f"尝试微信登录: {login_data.code[:5]}...")
|
||||
|
||||
# 创建访问令牌 - 使用openid作为标识
|
||||
access_token = create_access_token(subject=openid)
|
||||
|
||||
# 返回登录响应
|
||||
return LoginResponse(
|
||||
access_token=access_token,
|
||||
is_new_user=is_new_user,
|
||||
openid=openid
|
||||
)
|
||||
# 调用微信API获取openid和unionid
|
||||
openid, unionid = await wechat_service.code2session(login_data.code)
|
||||
logger.info(f"成功获取openid: {openid[:5]}...")
|
||||
|
||||
# 检查用户是否存在
|
||||
existing_user = await user_service.get_user_by_openid(db, openid=openid)
|
||||
is_new_user = existing_user is None
|
||||
logger.info(f"用户状态: 新用户={is_new_user}")
|
||||
|
||||
if is_new_user:
|
||||
# 创建新用户
|
||||
user_create = UserCreate(
|
||||
openid=openid,
|
||||
unionid=unionid
|
||||
)
|
||||
user = await user_service.create_user(db, user=user_create)
|
||||
logger.info(f"创建新用户: id={user.id}")
|
||||
else:
|
||||
user = existing_user
|
||||
logger.info(f"现有用户登录: id={user.id}")
|
||||
|
||||
# 创建访问令牌 - 使用openid作为标识
|
||||
access_token = create_access_token(subject=openid)
|
||||
|
||||
# 创建响应对象
|
||||
response_data = LoginResponse(
|
||||
access_token=access_token,
|
||||
is_new_user=is_new_user,
|
||||
openid=openid
|
||||
)
|
||||
|
||||
# 返回标准响应
|
||||
return response_data
|
||||
|
||||
except Exception as e:
|
||||
# 记录异常
|
||||
logger.error(f"登录异常: {str(e)}", exc_info=True)
|
||||
|
||||
# 重新抛出业务异常或转换为业务异常
|
||||
if isinstance(e, BusinessError):
|
||||
raise
|
||||
raise BusinessError(f"登录失败: {str(e)}", code=500)
|
||||
187
app/api/v1/clothing.py
Normal file
187
app/api/v1/clothing.py
Normal file
@ -0,0 +1,187 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List
|
||||
|
||||
from app.schemas.clothing import (
|
||||
Clothing, ClothingCreate, ClothingUpdate, ClothingWithCategory,
|
||||
ClothingCategory, ClothingCategoryCreate, ClothingCategoryUpdate
|
||||
)
|
||||
from app.services import clothing as clothing_service
|
||||
from app.db.database import get_db
|
||||
from app.api.deps import get_current_user
|
||||
from app.models.users import User as UserModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 衣服分类API
|
||||
@router.post("/categories", response_model=ClothingCategory, tags=["clothing-categories"])
|
||||
async def create_category(
|
||||
category: ClothingCategoryCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
创建衣服分类
|
||||
|
||||
需要JWT令牌认证
|
||||
"""
|
||||
return await clothing_service.create_category(db=db, category=category)
|
||||
|
||||
@router.get("/categories", response_model=List[ClothingCategory], tags=["clothing-categories"])
|
||||
async def read_categories(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取所有衣服分类"""
|
||||
return await clothing_service.get_categories(db=db, skip=skip, limit=limit)
|
||||
|
||||
@router.get("/categories/{category_id}", response_model=ClothingCategory, tags=["clothing-categories"])
|
||||
async def read_category(
|
||||
category_id: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取单个衣服分类"""
|
||||
db_category = await clothing_service.get_category(db, category_id=category_id)
|
||||
if db_category is None:
|
||||
raise HTTPException(status_code=404, detail="分类不存在")
|
||||
return db_category
|
||||
|
||||
@router.put("/categories/{category_id}", response_model=ClothingCategory, tags=["clothing-categories"])
|
||||
async def update_category(
|
||||
category_id: int,
|
||||
category: ClothingCategoryUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
更新衣服分类
|
||||
|
||||
需要JWT令牌认证
|
||||
"""
|
||||
db_category = await clothing_service.update_category(
|
||||
db=db,
|
||||
category_id=category_id,
|
||||
category_update=category
|
||||
)
|
||||
if db_category is None:
|
||||
raise HTTPException(status_code=404, detail="分类不存在")
|
||||
return db_category
|
||||
|
||||
@router.delete("/categories/{category_id}", response_model=ClothingCategory, tags=["clothing-categories"])
|
||||
async def delete_category(
|
||||
category_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
删除衣服分类(会级联删除相关衣服)
|
||||
|
||||
需要JWT令牌认证
|
||||
"""
|
||||
db_category = await clothing_service.delete_category(db=db, category_id=category_id)
|
||||
if db_category is None:
|
||||
raise HTTPException(status_code=404, detail="分类不存在")
|
||||
return db_category
|
||||
|
||||
# 衣服API
|
||||
@router.post("/", response_model=Clothing, tags=["clothing"])
|
||||
async def create_clothing(
|
||||
clothing: ClothingCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
创建衣服
|
||||
|
||||
需要JWT令牌认证
|
||||
"""
|
||||
# 检查分类是否存在
|
||||
category = await clothing_service.get_category(db, category_id=clothing.clothing_category_id)
|
||||
if category is None:
|
||||
raise HTTPException(status_code=404, detail="指定的分类不存在")
|
||||
|
||||
return await clothing_service.create_clothing(db=db, clothing=clothing)
|
||||
|
||||
@router.get("/", response_model=List[Clothing], tags=["clothing"])
|
||||
async def read_clothes(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取所有衣服"""
|
||||
return await clothing_service.get_clothes(db=db, skip=skip, limit=limit)
|
||||
|
||||
@router.get("/by-category/{category_id}", response_model=List[Clothing], tags=["clothing"])
|
||||
async def read_clothes_by_category(
|
||||
category_id: int,
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""根据分类获取衣服"""
|
||||
# 检查分类是否存在
|
||||
category = await clothing_service.get_category(db, category_id=category_id)
|
||||
if category is None:
|
||||
raise HTTPException(status_code=404, detail="分类不存在")
|
||||
|
||||
return await clothing_service.get_clothes_by_category(
|
||||
db=db,
|
||||
category_id=category_id,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
@router.get("/{clothing_id}", response_model=Clothing, tags=["clothing"])
|
||||
async def read_clothing(
|
||||
clothing_id: int,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""获取单个衣服"""
|
||||
db_clothing = await clothing_service.get_clothing(db, clothing_id=clothing_id)
|
||||
if db_clothing is None:
|
||||
raise HTTPException(status_code=404, detail="衣服不存在")
|
||||
return db_clothing
|
||||
|
||||
@router.put("/{clothing_id}", response_model=Clothing, tags=["clothing"])
|
||||
async def update_clothing(
|
||||
clothing_id: int,
|
||||
clothing: ClothingUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
更新衣服
|
||||
|
||||
需要JWT令牌认证
|
||||
"""
|
||||
# 检查分类是否存在
|
||||
if clothing.clothing_category_id is not None:
|
||||
category = await clothing_service.get_category(db, category_id=clothing.clothing_category_id)
|
||||
if category is None:
|
||||
raise HTTPException(status_code=404, detail="指定的分类不存在")
|
||||
|
||||
db_clothing = await clothing_service.update_clothing(
|
||||
db=db,
|
||||
clothing_id=clothing_id,
|
||||
clothing_update=clothing
|
||||
)
|
||||
if db_clothing is None:
|
||||
raise HTTPException(status_code=404, detail="衣服不存在")
|
||||
return db_clothing
|
||||
|
||||
@router.delete("/{clothing_id}", response_model=Clothing, tags=["clothing"])
|
||||
async def delete_clothing(
|
||||
clothing_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: UserModel = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
删除衣服
|
||||
|
||||
需要JWT令牌认证
|
||||
"""
|
||||
db_clothing = await clothing_service.delete_clothing(db=db, clothing_id=clothing_id)
|
||||
if db_clothing is None:
|
||||
raise HTTPException(status_code=404, detail="衣服不存在")
|
||||
return db_clothing
|
||||
114
app/api/v1/user_images.py
Normal file
114
app/api/v1/user_images.py
Normal file
@ -0,0 +1,114 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List
|
||||
|
||||
from app.schemas.user_image import UserImage, UserImageCreate, UserImageUpdate
|
||||
from app.services import user_image as user_image_service
|
||||
from app.db.database import get_db
|
||||
from app.api.deps import get_current_user
|
||||
from app.models.users import User as UserModel
|
||||
from app.core.exceptions import BusinessError
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/", response_model=UserImage, tags=["user-images"])
|
||||
async def create_image(
|
||||
image: UserImageCreate,
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
创建用户形象图片
|
||||
|
||||
需要JWT令牌认证
|
||||
"""
|
||||
return await user_image_service.create_user_image(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
image=image
|
||||
)
|
||||
|
||||
@router.get("/", response_model=List[UserImage], tags=["user-images"])
|
||||
async def read_images(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=100),
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取当前用户的所有形象图片
|
||||
|
||||
需要JWT令牌认证
|
||||
"""
|
||||
return await user_image_service.get_user_images(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
@router.get("/{image_id}", response_model=UserImage, tags=["user-images"])
|
||||
async def read_image(
|
||||
image_id: int,
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取指定形象图片
|
||||
|
||||
需要JWT令牌认证
|
||||
"""
|
||||
image = await user_image_service.get_user_image(db, image_id=image_id)
|
||||
if image is None:
|
||||
raise HTTPException(status_code=404, detail="图片不存在")
|
||||
if image.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="没有权限访问此图片")
|
||||
return image
|
||||
|
||||
@router.put("/{image_id}", response_model=UserImage, tags=["user-images"])
|
||||
async def update_image(
|
||||
image_id: int,
|
||||
image_update: UserImageUpdate,
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
更新形象图片
|
||||
|
||||
需要JWT令牌认证
|
||||
"""
|
||||
# 先检查图片是否存在
|
||||
db_image = await user_image_service.get_user_image(db, image_id=image_id)
|
||||
if db_image is None:
|
||||
raise HTTPException(status_code=404, detail="图片不存在")
|
||||
if db_image.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="没有权限更新此图片")
|
||||
|
||||
# 执行更新
|
||||
updated_image = await user_image_service.update_user_image(
|
||||
db=db,
|
||||
image_id=image_id,
|
||||
image_update=image_update
|
||||
)
|
||||
return updated_image
|
||||
|
||||
@router.delete("/{image_id}", response_model=UserImage, tags=["user-images"])
|
||||
async def delete_image(
|
||||
image_id: int,
|
||||
current_user: UserModel = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
删除形象图片
|
||||
|
||||
需要JWT令牌认证
|
||||
"""
|
||||
# 先检查图片是否存在
|
||||
db_image = await user_image_service.get_user_image(db, image_id=image_id)
|
||||
if db_image is None:
|
||||
raise HTTPException(status_code=404, detail="图片不存在")
|
||||
if db_image.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="没有权限删除此图片")
|
||||
|
||||
# 执行删除
|
||||
return await user_image_service.delete_user_image(db=db, image_id=image_id)
|
||||
@ -23,7 +23,8 @@ class Settings(BaseSettings):
|
||||
DB_ECHO: bool = os.getenv("DB_ECHO", "False").lower() == "true"
|
||||
|
||||
# DashScope API密钥
|
||||
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "sk-caa199589f1c451aaac471fad2986e28")
|
||||
DASHSCOPE_API_KEY: str = os.getenv("DASHSCOPE_API_KEY", "")
|
||||
DASHSCOPE_MODEL: str = os.getenv("DASHSCOPE_MODEL", "qwen-max")
|
||||
|
||||
# 安全设置
|
||||
SECRET_KEY: str = os.getenv("SECRET_KEY", "your-secret-key-for-jwt-please-change-in-production")
|
||||
@ -34,6 +35,15 @@ class Settings(BaseSettings):
|
||||
WECHAT_APP_ID: str = os.getenv("WECHAT_APP_ID", "")
|
||||
WECHAT_APP_SECRET: str = os.getenv("WECHAT_APP_SECRET", "")
|
||||
|
||||
@property
|
||||
def cors_origins(self):
|
||||
"""获取CORS来源列表"""
|
||||
if isinstance(self.BACKEND_CORS_ORIGINS, str) and self.BACKEND_CORS_ORIGINS == "*":
|
||||
return ["*"]
|
||||
elif isinstance(self.BACKEND_CORS_ORIGINS, list):
|
||||
return self.BACKEND_CORS_ORIGINS
|
||||
return []
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
|
||||
@ -19,49 +19,42 @@ class ResponseMiddleware(BaseHTTPMiddleware):
|
||||
exclude_paths = ["/docs", "/redoc", "/openapi.json"]
|
||||
if any(request.url.path.startswith(path) for path in exclude_paths):
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
# 获取原始响应
|
||||
response = await call_next(request)
|
||||
|
||||
# 如果是HTTPException,直接返回,不进行封装
|
||||
if response.status_code >= 400:
|
||||
# 如果不是200系列响应或不是JSON响应,直接返回
|
||||
if response.status_code >= 300 or response.headers.get("content-type") != "application/json":
|
||||
return response
|
||||
|
||||
# 正常响应进行封装
|
||||
if response.status_code < 300 and response.headers.get("content-type") == "application/json":
|
||||
response_body = [chunk async for chunk in response.body_iterator]
|
||||
response_body = b"".join(response_body)
|
||||
|
||||
if response_body:
|
||||
try:
|
||||
data = json.loads(response_body)
|
||||
|
||||
# 已经是统一格式的响应,不再封装
|
||||
if isinstance(data, dict) and "code" in data and ("data" in data or "message" in data):
|
||||
return Response(
|
||||
content=response_body,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type
|
||||
)
|
||||
|
||||
# 封装为标准响应格式
|
||||
result = StandardResponse(code=200, data=data).model_dump()
|
||||
return JSONResponse(
|
||||
content=result,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# 非JSON响应,直接返回
|
||||
return Response(
|
||||
content=response_body,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type
|
||||
)
|
||||
|
||||
return response
|
||||
# 读取响应内容
|
||||
try:
|
||||
# 使用JSONResponse的_render方法获取内容
|
||||
if isinstance(response, JSONResponse):
|
||||
# 获取原始数据
|
||||
raw_data = response.body.decode("utf-8")
|
||||
data = json.loads(raw_data)
|
||||
|
||||
# 已经是标准格式,不再封装
|
||||
if isinstance(data, dict) and "code" in data and ("data" in data or "message" in data):
|
||||
return response
|
||||
|
||||
# 创建新的标准响应
|
||||
std_response = StandardResponse(code=200, data=data)
|
||||
|
||||
# 创建新的JSONResponse
|
||||
return JSONResponse(
|
||||
content=std_response.model_dump(),
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers)
|
||||
)
|
||||
# 处理流式响应或其他类型响应
|
||||
else:
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# 出现异常,返回原始响应
|
||||
return response
|
||||
|
||||
def add_response_middleware(app: FastAPI) -> None:
|
||||
"""添加响应处理中间件到FastAPI应用"""
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
from app.db.database import Base, engine
|
||||
from app.models.users import User
|
||||
from app.models.user_images import UserImage
|
||||
from app.models.clothing import ClothingCategory, Clothing
|
||||
|
||||
# 创建所有表格
|
||||
async def init_db():
|
||||
|
||||
14
app/main.py
14
app/main.py
@ -1,19 +1,29 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
from app.core.config import settings
|
||||
from app.api.v1.api import api_router
|
||||
from app.core.middleware import add_response_middleware
|
||||
from app.core.exceptions import add_exception_handlers
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# 在应用启动时执行
|
||||
logger.info("应用启动,初始化数据库...")
|
||||
from app.db.init_db import init_db
|
||||
await init_db()
|
||||
logger.info("数据库初始化完成")
|
||||
yield
|
||||
# 在应用关闭时执行
|
||||
# 清理代码可以放在这里
|
||||
logger.info("应用关闭")
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
@ -25,7 +35,7 @@ app = FastAPI(
|
||||
# 配置CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.BACKEND_CORS_ORIGINS,
|
||||
allow_origins=settings.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
33
app/models/clothing.py
Normal file
33
app/models/clothing.py
Normal file
@ -0,0 +1,33 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
|
||||
class ClothingCategory(Base):
|
||||
"""衣服分类表"""
|
||||
__tablename__ = "clothing_category"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
|
||||
name = Column(String(50), nullable=False, comment="分类名称")
|
||||
create_time = Column(DateTime, default=func.now(), comment="创建时间")
|
||||
|
||||
# 关系
|
||||
clothes = relationship("Clothing", back_populates="category", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ClothingCategory(id={self.id}, name={self.name})>"
|
||||
|
||||
class Clothing(Base):
|
||||
"""衣服表"""
|
||||
__tablename__ = "clothing"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
|
||||
clothing_category_id = Column(Integer, ForeignKey("clothing_category.id"), nullable=False, index=True, comment="分类ID")
|
||||
image_url = Column(String(500), nullable=False, comment="图片URL")
|
||||
create_time = Column(DateTime, default=func.now(), comment="创建时间")
|
||||
|
||||
# 关系
|
||||
category = relationship("ClothingCategory", back_populates="clothes")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Clothing(id={self.id}, category_id={self.clothing_category_id})>"
|
||||
20
app/models/user_images.py
Normal file
20
app/models/user_images.py
Normal file
@ -0,0 +1,20 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Boolean
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
|
||||
class UserImage(Base):
|
||||
"""用户个人形象库模型"""
|
||||
__tablename__ = "user_images"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
|
||||
user_id = Column(Integer, ForeignKey("users.id"), nullable=False, index=True, comment="用户ID")
|
||||
image_url = Column(String(500), nullable=False, comment="图片URL")
|
||||
is_default = Column(Boolean, default=False, nullable=False, comment="是否为默认形象")
|
||||
create_time = Column(DateTime, default=func.now(), comment="创建时间")
|
||||
|
||||
# 关系
|
||||
user = relationship("User", back_populates="images")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<UserImage(id={self.id}, user_id={self.user_id})>"
|
||||
@ -1,5 +1,6 @@
|
||||
from sqlalchemy import Column, Integer, String, DateTime
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db.database import Base
|
||||
|
||||
class User(Base):
|
||||
@ -12,5 +13,8 @@ class User(Base):
|
||||
nickname = Column(String(50), nullable=True, comment="昵称")
|
||||
create_time = Column(DateTime, default=func.now(), comment="创建时间")
|
||||
|
||||
# 关系
|
||||
images = relationship("UserImage", back_populates="user", cascade="all, delete-orphan")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<User(id={self.id}, nickname={self.nickname})>"
|
||||
63
app/schemas/clothing.py
Normal file
63
app/schemas/clothing.py
Normal file
@ -0,0 +1,63 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List
|
||||
from datetime import datetime
|
||||
|
||||
# 分类模式
|
||||
class ClothingCategoryBase(BaseModel):
|
||||
"""衣服分类基础模式"""
|
||||
name: str = Field(..., description="分类名称")
|
||||
|
||||
class ClothingCategoryCreate(ClothingCategoryBase):
|
||||
"""创建衣服分类请求"""
|
||||
pass
|
||||
|
||||
class ClothingCategoryInDB(ClothingCategoryBase):
|
||||
"""数据库中的衣服分类数据"""
|
||||
id: int
|
||||
create_time: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class ClothingCategory(ClothingCategoryInDB):
|
||||
"""衣服分类响应模式"""
|
||||
pass
|
||||
|
||||
class ClothingCategoryUpdate(BaseModel):
|
||||
"""更新衣服分类请求"""
|
||||
name: Optional[str] = Field(None, description="分类名称")
|
||||
|
||||
# 衣服模式
|
||||
class ClothingBase(BaseModel):
|
||||
"""衣服基础模式"""
|
||||
clothing_category_id: int = Field(..., description="分类ID")
|
||||
image_url: str = Field(..., description="图片URL")
|
||||
|
||||
class ClothingCreate(ClothingBase):
|
||||
"""创建衣服请求"""
|
||||
pass
|
||||
|
||||
class ClothingInDB(ClothingBase):
|
||||
"""数据库中的衣服数据"""
|
||||
id: int
|
||||
create_time: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class Clothing(ClothingInDB):
|
||||
"""衣服响应模式"""
|
||||
pass
|
||||
|
||||
class ClothingUpdate(BaseModel):
|
||||
"""更新衣服请求"""
|
||||
clothing_category_id: Optional[int] = Field(None, description="分类ID")
|
||||
image_url: Optional[str] = Field(None, description="图片URL")
|
||||
|
||||
# 带有分类信息的衣服响应
|
||||
class ClothingWithCategory(Clothing):
|
||||
"""包含分类信息的衣服响应"""
|
||||
category: ClothingCategory
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
30
app/schemas/user_image.py
Normal file
30
app/schemas/user_image.py
Normal file
@ -0,0 +1,30 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
class UserImageBase(BaseModel):
|
||||
"""用户形象基础模式"""
|
||||
image_url: str = Field(..., description="图片URL")
|
||||
is_default: bool = Field(False, description="是否为默认形象")
|
||||
|
||||
class UserImageCreate(UserImageBase):
|
||||
"""创建用户形象请求"""
|
||||
pass
|
||||
|
||||
class UserImageInDB(UserImageBase):
|
||||
"""数据库中的用户形象数据"""
|
||||
id: int
|
||||
user_id: int
|
||||
create_time: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class UserImage(UserImageInDB):
|
||||
"""用户形象响应模式"""
|
||||
pass
|
||||
|
||||
class UserImageUpdate(BaseModel):
|
||||
"""更新用户形象请求"""
|
||||
image_url: Optional[str] = Field(None, description="图片URL")
|
||||
is_default: Optional[bool] = Field(None, description="是否为默认形象")
|
||||
116
app/services/clothing.py
Normal file
116
app/services/clothing.py
Normal file
@ -0,0 +1,116 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import update, delete
|
||||
from typing import List, Optional
|
||||
|
||||
from app.models.clothing import Clothing, ClothingCategory
|
||||
from app.schemas.clothing import ClothingCreate, ClothingUpdate
|
||||
from app.schemas.clothing import ClothingCategoryCreate, ClothingCategoryUpdate
|
||||
|
||||
# 衣服分类服务函数
|
||||
async def get_category(db: AsyncSession, category_id: int):
|
||||
"""获取单个衣服分类"""
|
||||
result = await db.execute(select(ClothingCategory).filter(ClothingCategory.id == category_id))
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_categories(db: AsyncSession, skip: int = 0, limit: int = 100):
|
||||
"""获取所有衣服分类"""
|
||||
result = await db.execute(
|
||||
select(ClothingCategory)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def create_category(db: AsyncSession, category: ClothingCategoryCreate):
|
||||
"""创建衣服分类"""
|
||||
db_category = ClothingCategory(
|
||||
name=category.name
|
||||
)
|
||||
db.add(db_category)
|
||||
await db.commit()
|
||||
await db.refresh(db_category)
|
||||
return db_category
|
||||
|
||||
async def update_category(db: AsyncSession, category_id: int, category_update: ClothingCategoryUpdate):
|
||||
"""更新衣服分类"""
|
||||
db_category = await get_category(db, category_id)
|
||||
if not db_category:
|
||||
return None
|
||||
|
||||
if category_update.name:
|
||||
db_category.name = category_update.name
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_category)
|
||||
return db_category
|
||||
|
||||
async def delete_category(db: AsyncSession, category_id: int):
|
||||
"""删除衣服分类(会级联删除相关衣服)"""
|
||||
db_category = await get_category(db, category_id)
|
||||
if db_category:
|
||||
await db.delete(db_category)
|
||||
await db.commit()
|
||||
return db_category
|
||||
|
||||
# 衣服服务函数
|
||||
async def get_clothing(db: AsyncSession, clothing_id: int):
|
||||
"""获取单个衣服"""
|
||||
result = await db.execute(select(Clothing).filter(Clothing.id == clothing_id))
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_clothes(db: AsyncSession, skip: int = 0, limit: int = 100):
|
||||
"""获取所有衣服"""
|
||||
result = await db.execute(
|
||||
select(Clothing)
|
||||
.order_by(Clothing.create_time.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_clothes_by_category(db: AsyncSession, category_id: int, skip: int = 0, limit: int = 100):
|
||||
"""根据分类获取衣服"""
|
||||
result = await db.execute(
|
||||
select(Clothing)
|
||||
.filter(Clothing.clothing_category_id == category_id)
|
||||
.order_by(Clothing.create_time.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def create_clothing(db: AsyncSession, clothing: ClothingCreate):
|
||||
"""创建衣服"""
|
||||
db_clothing = Clothing(
|
||||
clothing_category_id=clothing.clothing_category_id,
|
||||
image_url=clothing.image_url
|
||||
)
|
||||
db.add(db_clothing)
|
||||
await db.commit()
|
||||
await db.refresh(db_clothing)
|
||||
return db_clothing
|
||||
|
||||
async def update_clothing(db: AsyncSession, clothing_id: int, clothing_update: ClothingUpdate):
|
||||
"""更新衣服"""
|
||||
db_clothing = await get_clothing(db, clothing_id)
|
||||
if not db_clothing:
|
||||
return None
|
||||
|
||||
if clothing_update.clothing_category_id is not None:
|
||||
db_clothing.clothing_category_id = clothing_update.clothing_category_id
|
||||
|
||||
if clothing_update.image_url:
|
||||
db_clothing.image_url = clothing_update.image_url
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_clothing)
|
||||
return db_clothing
|
||||
|
||||
async def delete_clothing(db: AsyncSession, clothing_id: int):
|
||||
"""删除衣服"""
|
||||
db_clothing = await get_clothing(db, clothing_id)
|
||||
if db_clothing:
|
||||
await db.delete(db_clothing)
|
||||
await db.commit()
|
||||
return db_clothing
|
||||
272
app/services/dashscope_service.py
Normal file
272
app/services/dashscope_service.py
Normal file
@ -0,0 +1,272 @@
|
||||
import os
|
||||
import logging
|
||||
import dashscope
|
||||
from dashscope import Generation
|
||||
# 修改导入语句,dashscope的API响应可能改变了结构
|
||||
from typing import List, Dict, Any, Optional
|
||||
import asyncio
|
||||
import httpx
|
||||
from app.utils.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DashScopeService:
|
||||
"""DashScope服务类,提供对DashScope API的调用封装"""
|
||||
|
||||
def __init__(self):
|
||||
settings = get_settings()
|
||||
self.api_key = settings.dashscope_api_key
|
||||
# 配置DashScope
|
||||
dashscope.api_key = self.api_key
|
||||
# 配置API URL
|
||||
self.image_synthesis_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image2image/image-synthesis"
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
model: str = "qwen-max",
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.7,
|
||||
stream: bool = False
|
||||
):
|
||||
"""
|
||||
调用DashScope的大模型API进行对话
|
||||
|
||||
Args:
|
||||
messages: 对话历史记录
|
||||
model: 模型名称
|
||||
max_tokens: 最大生成token数
|
||||
temperature: 温度参数,控制随机性
|
||||
stream: 是否流式输出
|
||||
|
||||
Returns:
|
||||
ApiResponse: DashScope的API响应
|
||||
"""
|
||||
try:
|
||||
# 为了不阻塞FastAPI的异步性能,我们使用run_in_executor运行同步API
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: Generation.call(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
result_format='message',
|
||||
stream=stream,
|
||||
)
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"DashScope API请求失败,状态码:{response.status_code}, 错误信息:{response.message}")
|
||||
raise Exception(f"API调用失败: {response.message}")
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope聊天API调用出错: {str(e)}")
|
||||
raise e
|
||||
|
||||
async def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
model: str = "stable-diffusion-xl",
|
||||
n: int = 1,
|
||||
size: str = "1024*1024"
|
||||
):
|
||||
"""
|
||||
调用DashScope的图像生成API
|
||||
|
||||
Args:
|
||||
prompt: 生成图像的文本描述
|
||||
negative_prompt: 负面提示词
|
||||
model: 模型名称
|
||||
n: 生成图像数量
|
||||
size: 图像尺寸
|
||||
|
||||
Returns:
|
||||
ApiResponse: DashScope的API响应
|
||||
"""
|
||||
try:
|
||||
# 构建请求参数
|
||||
params = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"n": n,
|
||||
"size": size,
|
||||
}
|
||||
|
||||
if negative_prompt:
|
||||
params["negative_prompt"] = negative_prompt
|
||||
|
||||
# 异步调用图像生成API
|
||||
loop = asyncio.get_event_loop()
|
||||
response = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: dashscope.ImageSynthesis.call(**params)
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"DashScope 图像生成API请求失败,状态码:{response.status_code}, 错误信息:{response.message}")
|
||||
raise Exception(f"图像生成API调用失败: {response.message}")
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope图像生成API调用出错: {str(e)}")
|
||||
raise e
|
||||
|
||||
async def generate_tryon(
|
||||
self,
|
||||
person_image_url: str,
|
||||
top_garment_url: Optional[str] = None,
|
||||
bottom_garment_url: Optional[str] = None,
|
||||
resolution: int = -1,
|
||||
restore_face: bool = True
|
||||
):
|
||||
"""
|
||||
调用阿里百炼平台的试衣服务
|
||||
|
||||
Args:
|
||||
person_image_url: 人物图片URL
|
||||
top_garment_url: 上衣图片URL
|
||||
bottom_garment_url: 下衣图片URL
|
||||
resolution: 分辨率,-1表示自动
|
||||
restore_face: 是否修复面部
|
||||
|
||||
Returns:
|
||||
Dict: 包含任务ID和请求ID的响应
|
||||
"""
|
||||
try:
|
||||
# 验证参数
|
||||
if not person_image_url:
|
||||
raise ValueError("人物图片URL不能为空")
|
||||
|
||||
if not top_garment_url and not bottom_garment_url:
|
||||
raise ValueError("上衣和下衣图片至少需要提供一个")
|
||||
|
||||
# 构建请求数据
|
||||
request_data = {
|
||||
"model": "aitryon",
|
||||
"input": {
|
||||
"person_image_url": person_image_url
|
||||
},
|
||||
"parameters": {
|
||||
"resolution": resolution,
|
||||
"restore_face": restore_face
|
||||
}
|
||||
}
|
||||
|
||||
# 添加可选字段
|
||||
if top_garment_url:
|
||||
request_data["input"]["top_garment_url"] = top_garment_url
|
||||
if bottom_garment_url:
|
||||
request_data["input"]["bottom_garment_url"] = bottom_garment_url
|
||||
|
||||
# 构建请求头
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"X-DashScope-Async": "enable"
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
async with httpx.AsyncClient() as client:
|
||||
logger.info(f"发送试穿请求: {request_data}")
|
||||
response = await client.post(
|
||||
self.image_synthesis_url,
|
||||
json=request_data,
|
||||
headers=headers,
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
logger.info(f"试穿API响应: {response_data}")
|
||||
|
||||
if response.status_code == 200 or response.status_code == 202: # 202表示异步任务已接受
|
||||
# 提取任务ID,适应不同的API响应格式
|
||||
task_id = None
|
||||
if 'output' in response_data and 'task_id' in response_data['output']:
|
||||
task_id = response_data['output']['task_id']
|
||||
elif 'task_id' in response_data:
|
||||
task_id = response_data['task_id']
|
||||
|
||||
if task_id:
|
||||
logger.info(f"试穿请求发送成功,任务ID: {task_id}")
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"request_id": response_data.get('request_id'),
|
||||
"status": "processing"
|
||||
}
|
||||
else:
|
||||
# 如果没有任务ID,这可能是同步响应
|
||||
logger.info("收到同步响应,没有任务ID")
|
||||
return {
|
||||
"status": "completed",
|
||||
"result": response_data.get('output', {}),
|
||||
"request_id": response_data.get('request_id')
|
||||
}
|
||||
else:
|
||||
error_msg = f"试穿请求失败: {response.status_code} - {response.text}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope试穿API调用出错: {str(e)}")
|
||||
raise e
|
||||
|
||||
async def check_tryon_status(self, task_id: str):
|
||||
"""
|
||||
检查试穿任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
Dict: 任务状态信息
|
||||
"""
|
||||
try:
|
||||
# 构建请求头
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 构建请求URL
|
||||
status_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
|
||||
|
||||
# 发送请求
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
status_url,
|
||||
headers=headers,
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
print(response_data)
|
||||
status = response_data.get('output', {}).get('task_status', '')
|
||||
logger.info(f"试穿任务状态查询成功: {status}")
|
||||
|
||||
# 检查是否完成并返回结果
|
||||
if status.lower() == 'succeeded':
|
||||
image_url = response_data.get('output', {}).get('image_url')
|
||||
if image_url:
|
||||
logger.info(f"试穿任务完成,结果URL: {image_url}")
|
||||
return {
|
||||
"status": status,
|
||||
"task_id": task_id,
|
||||
"image_url": image_url,
|
||||
"result": response_data.get('output', {})
|
||||
}
|
||||
|
||||
return {
|
||||
"status": status,
|
||||
"task_id": task_id
|
||||
}
|
||||
else:
|
||||
error_msg = f"试穿任务状态查询失败: {response.status_code} - {response.text}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
except Exception as e:
|
||||
logger.error(f"查询试穿任务状态出错: {str(e)}")
|
||||
raise e
|
||||
91
app/services/user_image.py
Normal file
91
app/services/user_image.py
Normal file
@ -0,0 +1,91 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy import delete, update
|
||||
from app.models.user_images import UserImage
|
||||
from app.schemas.user_image import UserImageCreate, UserImageUpdate
|
||||
from typing import List, Optional
|
||||
|
||||
async def get_user_image(db: AsyncSession, image_id: int):
|
||||
"""获取单个用户形象"""
|
||||
result = await db.execute(select(UserImage).filter(UserImage.id == image_id))
|
||||
return result.scalars().first()
|
||||
|
||||
async def get_user_images(db: AsyncSession, user_id: int, skip: int = 0, limit: int = 100):
|
||||
"""获取用户的所有形象图片"""
|
||||
result = await db.execute(
|
||||
select(UserImage)
|
||||
.filter(UserImage.user_id == user_id)
|
||||
.order_by(UserImage.create_time.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def create_user_image(db: AsyncSession, user_id: int, image: UserImageCreate):
|
||||
"""创建用户形象"""
|
||||
# 如果设置为默认形象,先重置用户的所有形象为非默认
|
||||
if image.is_default:
|
||||
await reset_default_images(db, user_id)
|
||||
|
||||
db_image = UserImage(
|
||||
user_id=user_id,
|
||||
image_url=image.image_url,
|
||||
is_default=image.is_default
|
||||
)
|
||||
db.add(db_image)
|
||||
await db.commit()
|
||||
await db.refresh(db_image)
|
||||
return db_image
|
||||
|
||||
async def update_user_image(db: AsyncSession, image_id: int, image_update: UserImageUpdate):
|
||||
"""更新用户形象"""
|
||||
db_image = await get_user_image(db, image_id)
|
||||
if not db_image:
|
||||
return None
|
||||
|
||||
# 处理默认形象逻辑
|
||||
if image_update.is_default is not None and image_update.is_default and not db_image.is_default:
|
||||
# 如果设置为默认且之前不是默认,重置其他形象
|
||||
await reset_default_images(db, db_image.user_id)
|
||||
db_image.is_default = True
|
||||
elif image_update.is_default is not None:
|
||||
db_image.is_default = image_update.is_default
|
||||
|
||||
# 更新图片URL
|
||||
if image_update.image_url:
|
||||
db_image.image_url = image_update.image_url
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(db_image)
|
||||
return db_image
|
||||
|
||||
async def delete_user_image(db: AsyncSession, image_id: int):
|
||||
"""删除用户形象"""
|
||||
db_image = await get_user_image(db, image_id)
|
||||
if db_image:
|
||||
await db.delete(db_image)
|
||||
await db.commit()
|
||||
return db_image
|
||||
|
||||
async def delete_user_images(db: AsyncSession, user_id: int):
|
||||
"""删除用户所有形象图片"""
|
||||
stmt = delete(UserImage).where(UserImage.user_id == user_id)
|
||||
await db.execute(stmt)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
async def reset_default_images(db: AsyncSession, user_id: int):
|
||||
"""重置用户所有形象为非默认"""
|
||||
stmt = update(UserImage).where(
|
||||
UserImage.user_id == user_id,
|
||||
UserImage.is_default == True
|
||||
).values(is_default=False)
|
||||
await db.execute(stmt)
|
||||
|
||||
async def get_default_image(db: AsyncSession, user_id: int):
|
||||
"""获取用户的默认形象"""
|
||||
result = await db.execute(
|
||||
select(UserImage)
|
||||
.filter(UserImage.user_id == user_id, UserImage.is_default == True)
|
||||
)
|
||||
return result.scalars().first()
|
||||
@ -8,4 +8,5 @@ aiomysql==0.2.0
|
||||
greenlet==2.0.2
|
||||
python-jose[cryptography]==3.3.0
|
||||
passlib==1.7.4
|
||||
httpx==0.24.1
|
||||
httpx==0.24.1
|
||||
dashscope==1.10.0
|
||||
Loading…
Reference in New Issue
Block a user