This commit is contained in:
aaron 2025-04-09 21:45:42 +08:00
parent 36450d0d6f
commit c79cf5186c
11 changed files with 188 additions and 121 deletions

View File

@ -1,35 +1,61 @@
from fastapi import Depends, HTTPException, status
from typing import Generator, Optional
from fastapi import Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
from jose import jwt
from pydantic import ValidationError
from sqlalchemy.ext.asyncio import AsyncSession
from app.core import security
from app.core.config import settings
from app.models.users import User
from app.schemas.user import User as UserSchema
from app.core.security import verify_token
from app.db.database import get_db
from app.services import user as user_service
from app.db.database import AsyncSessionLocal
from sqlalchemy import select
import logging
# OAuth2密码Bearer - JWT token的位置
oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login")
reusable_oauth2 = OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
async def get_db():
"""获取数据库会话"""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
# 获取当前用户的依赖项
async def get_current_user(
db: AsyncSession = Depends(get_db),
token: str = Depends(oauth2_scheme)
):
credentials_exception = HTTPException(
request: Request,
db: AsyncSession = Depends(get_db)
) -> UserSchema:
"""获取当前登录用户"""
UNAUTHORIZED = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
detail="未提供认证信息",
headers={"WWW-Authenticate": "Bearer"},
)
# 验证token
openid = verify_token(token)
if not openid:
raise credentials_exception
# 根据openid获取用户
user = await user_service.get_user_by_openid(db, openid=openid)
if user is None:
raise credentials_exception
auth_header = request.headers.get("Authorization")
if auth_header:
access_token = auth_header.split(" ")[1]
else:
access_token = request.session.get("access_token")
if not access_token:
raise UNAUTHORIZED
try:
sub = verify_token(access_token)
if sub:
user = (await db.execute(select(User).filter(User.openid == sub))).scalars().first()
if user:
return user
raise UNAUTHORIZED
except Exception as e:
logger.error(f"获取当前登录用户失败: {e}")
raise UNAUTHORIZED

View File

@ -1,25 +1,51 @@
from fastapi import APIRouter, Depends, HTTPException, status, Response
from sqlalchemy.ext.asyncio import AsyncSession
import logging
from fastapi.security import OAuth2PasswordRequestForm
from app.api import deps
from app.core import security
from app.core.config import settings
from app.schemas.response import StandardResponse
from app.schemas.auth import WechatLogin, LoginResponse
from app.schemas.user import UserCreate
from app.services import wechat as wechat_service
from app.services import user as user_service
from app.core.security import create_access_token
from app.db.database import get_db
from app.core.exceptions import BusinessError
from app.schemas.response import StandardResponse
from app.schemas.user import User
from fastapi import Request
# 创建日志记录器
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
router = APIRouter()
@router.get("/login/{user_id}", tags=["auth"])
async def login(
user_id: int,
request: Request,
db: AsyncSession = Depends(deps.get_db)
):
"""
登录接口
"""
user = await user_service.get_user(db, user_id=user_id)
logger.info(f"登录用户: {user.openid}")
if not user:
raise BusinessError("用户不存在", code=404)
# 创建访问令牌 - 使用openid作为标识
access_token = security.create_access_token(subject=user.openid)
# 将access_token存储在session中
request.session["access_token"] = access_token
return StandardResponse(code=200, message="登录成功", data=User.model_validate(user))
@router.post("/login/wechat", tags=["auth"])
async def wechat_login(
login_data: WechatLogin,
db: AsyncSession = Depends(get_db)
request: Request,
db: AsyncSession = Depends(deps.get_db)
):
"""
微信登录接口
@ -53,7 +79,10 @@ async def wechat_login(
logger.info(f"现有用户登录: id={user.id}")
# 创建访问令牌 - 使用openid作为标识
access_token = create_access_token(subject=openid)
access_token = security.create_access_token(subject=openid)
# 将access_token存储在session中
request.session["access_token"] = access_token
# 创建响应对象
response_data = LoginResponse(
@ -68,8 +97,4 @@ async def wechat_login(
except Exception as e:
# 记录异常
logger.error(f"登录异常: {str(e)}", exc_info=True)
# 重新抛出业务异常或转换为业务异常
if isinstance(e, BusinessError):
raise
raise BusinessError(f"登录失败: {str(e)}", code=500)

View File

@ -12,6 +12,7 @@ from app.api.deps import get_current_user
from app.models.users import User as UserModel
from app.core.exceptions import BusinessError
from app.schemas.response import StandardResponse
from app.schemas.clothing import ClothingCategory
import logging
# 创建日志记录器
@ -31,10 +32,8 @@ async def create_category(
需要JWT令牌认证
"""
result = await clothing_service.create_category(db=db, category=category)
logger.info(f"创建分类成功: id={result.id}, name={result.name}")
# 手动返回标准响应格式
return StandardResponse(code=200, data=result)
category = await clothing_service.create_category(db=db, category=category)
return StandardResponse(code=200, data=ClothingCategory.model_validate(category))
@router.get("/categories", tags=["clothing-categories"])
async def read_categories(
@ -43,10 +42,8 @@ async def read_categories(
db: AsyncSession = Depends(get_db)
):
"""获取所有衣服分类"""
result = await clothing_service.get_categories(db=db, skip=skip, limit=limit)
logger.info(f"获取分类列表: 共{len(result)}")
# 手动返回标准响应格式
return StandardResponse(code=200, data=result)
categories = await clothing_service.get_categories(db=db, skip=skip, limit=limit)
return StandardResponse(code=200, data=[ClothingCategory.model_validate(category) for category in categories])
@router.get("/categories/{category_id}", tags=["clothing-categories"])
async def read_category(
@ -54,11 +51,11 @@ async def read_category(
db: AsyncSession = Depends(get_db)
):
"""获取单个衣服分类"""
db_category = await clothing_service.get_category(db, category_id=category_id)
if db_category is None:
category = await clothing_service.get_category(db, category_id=category_id)
if category is None:
raise BusinessError("分类不存在", code=404)
# 手动返回标准响应格式
return StandardResponse(code=200, data=db_category)
return StandardResponse(code=200, data=ClothingCategory.model_validate(category))
@router.put("/categories/{category_id}", tags=["clothing-categories"])
async def update_category(
@ -72,15 +69,15 @@ async def update_category(
需要JWT令牌认证
"""
db_category = await clothing_service.update_category(
category = await clothing_service.update_category(
db=db,
category_id=category_id,
category_update=category
)
if db_category is None:
if category is None:
raise BusinessError("分类不存在", code=404)
# 手动返回标准响应格式
return StandardResponse(code=200, data=db_category)
return StandardResponse(code=200, data=ClothingCategory.model_validate(category))
@router.delete("/categories/{category_id}", tags=["clothing-categories"])
async def delete_category(
@ -93,11 +90,11 @@ async def delete_category(
需要JWT令牌认证
"""
db_category = await clothing_service.delete_category(db=db, category_id=category_id)
if db_category is None:
category = await clothing_service.delete_category(db=db, category_id=category_id)
if category is None:
raise BusinessError("分类不存在", code=404)
# 手动返回标准响应格式
return StandardResponse(code=200, data=db_category)
return StandardResponse(code=200, data=ClothingCategory.model_validate(category))
# 衣服API
@router.post("", tags=["clothing"])
@ -116,10 +113,10 @@ async def create_clothing(
if category is None:
raise BusinessError("指定的分类不存在", code=404)
result = await clothing_service.create_clothing(db=db, clothing=clothing)
logger.info(f"创建衣服成功: id={result.id}")
clothing = await clothing_service.create_clothing(db=db, clothing=clothing)
logger.info(f"创建衣服成功: id={clothing.id}")
# 手动返回标准响应格式
return StandardResponse(code=200, data=result)
return StandardResponse(code=200, data=Clothing.model_validate(clothing))
@router.get("", tags=["clothing"])
async def read_clothes(
@ -128,9 +125,9 @@ async def read_clothes(
db: AsyncSession = Depends(get_db)
):
"""获取所有衣服"""
result = await clothing_service.get_clothes(db=db, skip=skip, limit=limit)
clothes = await clothing_service.get_clothes(db=db, skip=skip, limit=limit)
# 手动返回标准响应格式
return StandardResponse(code=200, data=result)
return StandardResponse(code=200, data=[Clothing.model_validate(clothing) for clothing in clothes])
@router.get("/by-category/{category_id}", tags=["clothing"])
async def read_clothes_by_category(
@ -145,14 +142,14 @@ async def read_clothes_by_category(
if category is None:
raise BusinessError("分类不存在", code=404)
result = await clothing_service.get_clothes_by_category(
clothes = await clothing_service.get_clothes_by_category(
db=db,
category_id=category_id,
skip=skip,
limit=limit
)
# 手动返回标准响应格式
return StandardResponse(code=200, data=result)
return StandardResponse(code=200, data=[Clothing.model_validate(clothing) for clothing in clothes])
@router.get("/{clothing_id}", tags=["clothing"])
async def read_clothing(
@ -160,11 +157,11 @@ async def read_clothing(
db: AsyncSession = Depends(get_db)
):
"""获取单个衣服"""
db_clothing = await clothing_service.get_clothing(db, clothing_id=clothing_id)
if db_clothing is None:
clothing = await clothing_service.get_clothing(db, clothing_id=clothing_id)
if clothing is None:
raise BusinessError("衣服不存在", code=404)
# 手动返回标准响应格式
return StandardResponse(code=200, data=db_clothing)
return StandardResponse(code=200, data=Clothing.model_validate(clothing))
@router.put("/{clothing_id}", tags=["clothing"])
async def update_clothing(
@ -184,15 +181,15 @@ async def update_clothing(
if category is None:
raise BusinessError("指定的分类不存在", code=404)
db_clothing = await clothing_service.update_clothing(
clothing = await clothing_service.update_clothing(
db=db,
clothing_id=clothing_id,
clothing_update=clothing
)
if db_clothing is None:
if clothing is None:
raise BusinessError("衣服不存在", code=404)
# 手动返回标准响应格式
return StandardResponse(code=200, data=db_clothing)
return StandardResponse(code=200, data=Clothing.model_validate(clothing))
@router.delete("/{clothing_id}", tags=["clothing"])
async def delete_clothing(
@ -205,8 +202,12 @@ async def delete_clothing(
需要JWT令牌认证
"""
db_clothing = await clothing_service.delete_clothing(db=db, clothing_id=clothing_id)
if db_clothing is None:
clothing = await clothing_service.delete_clothing(db=db, clothing_id=clothing_id)
if clothing is None:
raise BusinessError("衣服不存在", code=404)
if clothing.user_id != current_user.id:
raise BusinessError("没有权限删除此衣服", code=403)
# 手动返回标准响应格式
return StandardResponse(code=200, data=db_clothing)
return StandardResponse(code=200, data=Clothing.model_validate(clothing))

View File

@ -7,71 +7,52 @@ from app.schemas.person_image import (
PersonImageCreate,
PersonImageUpdate
)
from app.services import person_image as user_image_service
from app.services import person_image as person_image_service
from app.models.users import User
from app.schemas.response import StandardResponse
router = APIRouter()
@router.get("", response_model=List[PersonImage])
@router.get("", tags=["person_images"])
async def get_person_images(
db: AsyncSession = Depends(deps.get_db),
current_user = Depends(deps.get_current_user),
current_user: User = Depends(deps.get_current_user),
skip: int = 0,
limit: int = 100
):
"""获取当前用户的所有人物形象"""
return await user_image_service.get_person_images_by_user(
images = await person_image_service.get_person_images_by_user(
db=db,
user_id=current_user.id,
skip=skip,
limit=limit
)
return StandardResponse(code=200, message="获取人物形象成功", data=images)
@router.post("", response_model=PersonImage, status_code=status.HTTP_201_CREATED)
@router.post("", response_model=PersonImage, tags=["person_images"])
async def create_person_image(
*,
db: AsyncSession = Depends(deps.get_db),
current_user = Depends(deps.get_current_user),
image_in: PersonImageCreate
current_user: User = Depends(deps.get_current_user),
image_in: PersonImage
):
"""创建新的人物形象"""
image_in.user_id = current_user.id
return await user_image_service.create_person_image(db=db, image=image_in)
image = await person_image_service.create_person_image(db=db, image=image_in)
@router.put("/{image_id}", response_model=PersonImage)
async def update_person_image(
*,
db: AsyncSession = Depends(deps.get_db),
current_user = Depends(deps.get_current_user),
image_id: int,
image_in: PersonImageUpdate
):
"""更新人物形象"""
image = await user_image_service.get_person_image(db=db, image_id=image_id)
if not image:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="人物形象不存在"
)
if image.user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="没有权限修改此人物形象"
)
return await user_image_service.update_person_image(
db=db,
image_id=image_id,
image=image_in
)
return StandardResponse(code=200, message="创建人物形象成功", data=image)
@router.delete("/{image_id}", status_code=status.HTTP_204_NO_CONTENT)
@router.delete("/{image_id}", tags=["person_images"])
async def delete_person_image(
*,
db: AsyncSession = Depends(deps.get_db),
current_user = Depends(deps.get_current_user),
current_user: User = Depends(deps.get_current_user),
image_id: int
):
"""删除人物形象"""
image = await user_image_service.get_person_image(db=db, image_id=image_id)
image = await person_image_service.get_person_image(db=db, image_id=image_id)
if not image:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@ -82,5 +63,6 @@ async def delete_person_image(
status_code=status.HTTP_403_FORBIDDEN,
detail="没有权限删除此人物形象"
)
await user_image_service.delete_person_image(db=db, image_id=image_id)
return None
await person_image_service.delete_person_image(db=db, image_id=image_id)
return StandardResponse(code=200, message="删除人物形象成功")

View File

@ -7,6 +7,7 @@ from app.api.v1.api import api_router
from app.core.exceptions import add_exception_handlers
from app.core.middleware import add_response_middleware
from app.schemas.response import StandardResponse
from starlette.middleware.sessions import SessionMiddleware
# 配置日志
logging.basicConfig(
@ -30,7 +31,8 @@ app = FastAPI(
title=settings.PROJECT_NAME,
description=settings.PROJECT_DESCRIPTION,
version=settings.PROJECT_VERSION,
lifespan=lifespan
lifespan=lifespan,
openapi_url=f"{settings.API_V1_STR}/openapi.json"
)
logger.info("开始配置应用中间件和路由...")
@ -38,13 +40,21 @@ logger.info("开始配置应用中间件和路由...")
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins,
allow_origins=settings.BACKEND_CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
logger.info("CORS中间件已添加")
# 添加Session中间件
app.add_middleware(
SessionMiddleware,
secret_key=settings.SECRET_KEY,
session_cookie="session",
max_age=1800 # 30分钟
)
# 添加异常处理器
add_exception_handlers(app)
logger.info("异常处理器已添加")

View File

@ -23,11 +23,13 @@ class Clothing(Base):
id = Column(Integer, primary_key=True, autoincrement=True, index=True)
clothing_category_id = Column(Integer, ForeignKey("clothing_category.id"), nullable=False, index=True, comment="分类ID")
user_id = Column(Integer, ForeignKey("users.id"), nullable=False)
image_url = Column(String(500), nullable=False, comment="图片URL")
create_time = Column(DateTime, default=func.now(), comment="创建时间")
# 关系
category = relationship("ClothingCategory", back_populates="clothes")
user = relationship("User", back_populates="clothings")
def __repr__(self):
return f"<Clothing(id={self.id}, category_id={self.clothing_category_id})>"

View File

@ -2,6 +2,7 @@ from sqlalchemy import Column, Integer, String, DateTime
from sqlalchemy.sql import func
from sqlalchemy.orm import relationship
from app.db.database import Base
from pydantic import BaseModel
class User(Base):
"""用户数据模型"""
@ -16,6 +17,18 @@ class User(Base):
# 关系
person_images = relationship("PersonImage", back_populates="user", cascade="all, delete-orphan")
clothings = relationship("Clothing", back_populates="user", cascade="all, delete-orphan")
def __repr__(self):
return f"<User(id={self.id}, nickname={self.nickname})>"
def to_dict(self):
"""将模型转换为字典"""
return {
"id": self.id,
"openid": self.openid,
"unionid": self.unionid,
"avatar": self.avatar,
"nickname": self.nickname,
"create_time": self.create_time
}

View File

@ -4,6 +4,7 @@ from datetime import datetime
class PersonImageBase(BaseModel):
"""人物形象基础模型"""
user_id: int
image_url: str
is_default: bool = False

View File

@ -3,24 +3,27 @@ from datetime import datetime
from typing import Optional
class UserBase(BaseModel):
"""用户基础模型"""
openid: str
unionid: Optional[str] = None
avatar: Optional[str] = None
nickname: Optional[str] = None
class UserCreate(UserBase):
"""创建用户请求模型"""
pass
class UserUpdate(BaseModel):
class UserUpdate(UserBase):
"""更新用户请求模型"""
openid: Optional[str] = None
unionid: Optional[str] = None
avatar: Optional[str] = None
nickname: Optional[str] = None
class UserInDB(UserBase):
class User(UserBase):
"""用户响应模型"""
id: int
create_time: datetime
class Config:
from_attributes = True
class User(UserInDB):
pass

View File

@ -6,6 +6,8 @@ from typing import List, Optional
from app.models.clothing import Clothing, ClothingCategory
from app.schemas.clothing import ClothingCreate, ClothingUpdate
from app.schemas.clothing import ClothingCategoryCreate, ClothingCategoryUpdate
from app.models.users import User
from fastapi import Depends
# 衣服分类服务函数
async def get_category(db: AsyncSession, category_id: int):
@ -71,10 +73,11 @@ async def get_clothes(db: AsyncSession, skip: int = 0, limit: int = 100):
async def get_clothes_by_category(db: AsyncSession, category_id: int, skip: int = 0, limit: int = 100):
"""根据分类获取衣服"""
query = select(Clothing).order_by(Clothing.create_time.desc())
if category_id > 0:
query = query.filter(Clothing.clothing_category_id == category_id)
result = await db.execute(
select(Clothing)
.filter(Clothing.clothing_category_id == category_id)
.order_by(Clothing.create_time.desc())
query
.offset(skip)
.limit(limit)
)

View File

@ -10,3 +10,4 @@ python-jose[cryptography]==3.3.0
passlib==1.7.4
httpx==0.24.1
dashscope==1.10.0
itsdangerous==2.2.0