diff --git a/app/api/deps.py b/app/api/deps.py index f200d30..9966cbb 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -1,35 +1,61 @@ -from fastapi import Depends, HTTPException, status +from typing import Generator, Optional +from fastapi import Depends, HTTPException, status, Request from fastapi.security import OAuth2PasswordBearer -from sqlalchemy.ext.asyncio import AsyncSession from jose import jwt +from pydantic import ValidationError +from sqlalchemy.ext.asyncio import AsyncSession +from app.core import security from app.core.config import settings +from app.models.users import User +from app.schemas.user import User as UserSchema from app.core.security import verify_token -from app.db.database import get_db -from app.services import user as user_service - +from app.db.database import AsyncSessionLocal +from sqlalchemy import select +import logging # OAuth2密码Bearer - JWT token的位置 -oauth2_scheme = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/auth/login") +reusable_oauth2 = OAuth2PasswordBearer( + tokenUrl=f"{settings.API_V1_STR}/login/access-token" +) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +async def get_db(): + """获取数据库会话""" + async with AsyncSessionLocal() as session: + try: + yield session + finally: + await session.close() -# 获取当前用户的依赖项 async def get_current_user( - db: AsyncSession = Depends(get_db), - token: str = Depends(oauth2_scheme) -): - credentials_exception = HTTPException( + request: Request, + db: AsyncSession = Depends(get_db) +) -> UserSchema: + """获取当前登录用户""" + + UNAUTHORIZED = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="无效的认证凭证", + detail="未提供认证信息", headers={"WWW-Authenticate": "Bearer"}, ) + + auth_header = request.headers.get("Authorization") + if auth_header: + access_token = auth_header.split(" ")[1] + else: + access_token = request.session.get("access_token") + if not access_token: + raise UNAUTHORIZED - # 验证token - openid = verify_token(token) - if not openid: - raise credentials_exception - - # 根据openid获取用户 - user = await user_service.get_user_by_openid(db, openid=openid) - if user is None: - raise credentials_exception - - return user \ No newline at end of file + try: + sub = verify_token(access_token) + if sub: + user = (await db.execute(select(User).filter(User.openid == sub))).scalars().first() + if user: + return user + raise UNAUTHORIZED + except Exception as e: + logger.error(f"获取当前登录用户失败: {e}") + raise UNAUTHORIZED diff --git a/app/api/v1/auth.py b/app/api/v1/auth.py index 08d67f8..84975bb 100644 --- a/app/api/v1/auth.py +++ b/app/api/v1/auth.py @@ -1,25 +1,51 @@ from fastapi import APIRouter, Depends, HTTPException, status, Response from sqlalchemy.ext.asyncio import AsyncSession import logging +from fastapi.security import OAuth2PasswordRequestForm +from app.api import deps +from app.core import security +from app.core.config import settings +from app.schemas.response import StandardResponse from app.schemas.auth import WechatLogin, LoginResponse from app.schemas.user import UserCreate from app.services import wechat as wechat_service from app.services import user as user_service -from app.core.security import create_access_token -from app.db.database import get_db from app.core.exceptions import BusinessError -from app.schemas.response import StandardResponse - +from app.schemas.user import User +from fastapi import Request # 创建日志记录器 logger = logging.getLogger(__name__) - +logger.setLevel(logging.DEBUG) router = APIRouter() +@router.get("/login/{user_id}", tags=["auth"]) +async def login( + user_id: int, + request: Request, + db: AsyncSession = Depends(deps.get_db) +): + """ + 登录接口 + """ + user = await user_service.get_user(db, user_id=user_id) + logger.info(f"登录用户: {user.openid}") + if not user: + raise BusinessError("用户不存在", code=404) + + # 创建访问令牌 - 使用openid作为标识 + access_token = security.create_access_token(subject=user.openid) + + # 将access_token存储在session中 + request.session["access_token"] = access_token + + return StandardResponse(code=200, message="登录成功", data=User.model_validate(user)) + @router.post("/login/wechat", tags=["auth"]) async def wechat_login( login_data: WechatLogin, - db: AsyncSession = Depends(get_db) + request: Request, + db: AsyncSession = Depends(deps.get_db) ): """ 微信登录接口 @@ -53,7 +79,10 @@ async def wechat_login( logger.info(f"现有用户登录: id={user.id}") # 创建访问令牌 - 使用openid作为标识 - access_token = create_access_token(subject=openid) + access_token = security.create_access_token(subject=openid) + + # 将access_token存储在session中 + request.session["access_token"] = access_token # 创建响应对象 response_data = LoginResponse( @@ -68,8 +97,4 @@ async def wechat_login( except Exception as e: # 记录异常 logger.error(f"登录异常: {str(e)}", exc_info=True) - - # 重新抛出业务异常或转换为业务异常 - if isinstance(e, BusinessError): - raise - raise BusinessError(f"登录失败: {str(e)}", code=500) \ No newline at end of file + raise BusinessError(f"登录失败: {str(e)}", code=500) \ No newline at end of file diff --git a/app/api/v1/clothing.py b/app/api/v1/clothing.py index c57c5b5..7638b52 100644 --- a/app/api/v1/clothing.py +++ b/app/api/v1/clothing.py @@ -12,6 +12,7 @@ from app.api.deps import get_current_user from app.models.users import User as UserModel from app.core.exceptions import BusinessError from app.schemas.response import StandardResponse +from app.schemas.clothing import ClothingCategory import logging # 创建日志记录器 @@ -31,10 +32,8 @@ async def create_category( 需要JWT令牌认证 """ - result = await clothing_service.create_category(db=db, category=category) - logger.info(f"创建分类成功: id={result.id}, name={result.name}") - # 手动返回标准响应格式 - return StandardResponse(code=200, data=result) + category = await clothing_service.create_category(db=db, category=category) + return StandardResponse(code=200, data=ClothingCategory.model_validate(category)) @router.get("/categories", tags=["clothing-categories"]) async def read_categories( @@ -43,10 +42,8 @@ async def read_categories( db: AsyncSession = Depends(get_db) ): """获取所有衣服分类""" - result = await clothing_service.get_categories(db=db, skip=skip, limit=limit) - logger.info(f"获取分类列表: 共{len(result)}条") - # 手动返回标准响应格式 - return StandardResponse(code=200, data=result) + categories = await clothing_service.get_categories(db=db, skip=skip, limit=limit) + return StandardResponse(code=200, data=[ClothingCategory.model_validate(category) for category in categories]) @router.get("/categories/{category_id}", tags=["clothing-categories"]) async def read_category( @@ -54,11 +51,11 @@ async def read_category( db: AsyncSession = Depends(get_db) ): """获取单个衣服分类""" - db_category = await clothing_service.get_category(db, category_id=category_id) - if db_category is None: + category = await clothing_service.get_category(db, category_id=category_id) + if category is None: raise BusinessError("分类不存在", code=404) # 手动返回标准响应格式 - return StandardResponse(code=200, data=db_category) + return StandardResponse(code=200, data=ClothingCategory.model_validate(category)) @router.put("/categories/{category_id}", tags=["clothing-categories"]) async def update_category( @@ -72,15 +69,15 @@ async def update_category( 需要JWT令牌认证 """ - db_category = await clothing_service.update_category( + category = await clothing_service.update_category( db=db, category_id=category_id, category_update=category ) - if db_category is None: + if category is None: raise BusinessError("分类不存在", code=404) # 手动返回标准响应格式 - return StandardResponse(code=200, data=db_category) + return StandardResponse(code=200, data=ClothingCategory.model_validate(category)) @router.delete("/categories/{category_id}", tags=["clothing-categories"]) async def delete_category( @@ -93,11 +90,11 @@ async def delete_category( 需要JWT令牌认证 """ - db_category = await clothing_service.delete_category(db=db, category_id=category_id) - if db_category is None: + category = await clothing_service.delete_category(db=db, category_id=category_id) + if category is None: raise BusinessError("分类不存在", code=404) # 手动返回标准响应格式 - return StandardResponse(code=200, data=db_category) + return StandardResponse(code=200, data=ClothingCategory.model_validate(category)) # 衣服API @router.post("", tags=["clothing"]) @@ -116,10 +113,10 @@ async def create_clothing( if category is None: raise BusinessError("指定的分类不存在", code=404) - result = await clothing_service.create_clothing(db=db, clothing=clothing) - logger.info(f"创建衣服成功: id={result.id}") + clothing = await clothing_service.create_clothing(db=db, clothing=clothing) + logger.info(f"创建衣服成功: id={clothing.id}") # 手动返回标准响应格式 - return StandardResponse(code=200, data=result) + return StandardResponse(code=200, data=Clothing.model_validate(clothing)) @router.get("", tags=["clothing"]) async def read_clothes( @@ -128,9 +125,9 @@ async def read_clothes( db: AsyncSession = Depends(get_db) ): """获取所有衣服""" - result = await clothing_service.get_clothes(db=db, skip=skip, limit=limit) + clothes = await clothing_service.get_clothes(db=db, skip=skip, limit=limit) # 手动返回标准响应格式 - return StandardResponse(code=200, data=result) + return StandardResponse(code=200, data=[Clothing.model_validate(clothing) for clothing in clothes]) @router.get("/by-category/{category_id}", tags=["clothing"]) async def read_clothes_by_category( @@ -145,14 +142,14 @@ async def read_clothes_by_category( if category is None: raise BusinessError("分类不存在", code=404) - result = await clothing_service.get_clothes_by_category( + clothes = await clothing_service.get_clothes_by_category( db=db, category_id=category_id, skip=skip, limit=limit ) # 手动返回标准响应格式 - return StandardResponse(code=200, data=result) + return StandardResponse(code=200, data=[Clothing.model_validate(clothing) for clothing in clothes]) @router.get("/{clothing_id}", tags=["clothing"]) async def read_clothing( @@ -160,11 +157,11 @@ async def read_clothing( db: AsyncSession = Depends(get_db) ): """获取单个衣服""" - db_clothing = await clothing_service.get_clothing(db, clothing_id=clothing_id) - if db_clothing is None: + clothing = await clothing_service.get_clothing(db, clothing_id=clothing_id) + if clothing is None: raise BusinessError("衣服不存在", code=404) # 手动返回标准响应格式 - return StandardResponse(code=200, data=db_clothing) + return StandardResponse(code=200, data=Clothing.model_validate(clothing)) @router.put("/{clothing_id}", tags=["clothing"]) async def update_clothing( @@ -184,15 +181,15 @@ async def update_clothing( if category is None: raise BusinessError("指定的分类不存在", code=404) - db_clothing = await clothing_service.update_clothing( + clothing = await clothing_service.update_clothing( db=db, clothing_id=clothing_id, clothing_update=clothing ) - if db_clothing is None: + if clothing is None: raise BusinessError("衣服不存在", code=404) # 手动返回标准响应格式 - return StandardResponse(code=200, data=db_clothing) + return StandardResponse(code=200, data=Clothing.model_validate(clothing)) @router.delete("/{clothing_id}", tags=["clothing"]) async def delete_clothing( @@ -205,8 +202,12 @@ async def delete_clothing( 需要JWT令牌认证 """ - db_clothing = await clothing_service.delete_clothing(db=db, clothing_id=clothing_id) - if db_clothing is None: + clothing = await clothing_service.delete_clothing(db=db, clothing_id=clothing_id) + if clothing is None: raise BusinessError("衣服不存在", code=404) + + if clothing.user_id != current_user.id: + raise BusinessError("没有权限删除此衣服", code=403) + # 手动返回标准响应格式 - return StandardResponse(code=200, data=db_clothing) \ No newline at end of file + return StandardResponse(code=200, data=Clothing.model_validate(clothing)) \ No newline at end of file diff --git a/app/api/v1/person_images.py b/app/api/v1/person_images.py index 20a1b5a..89cfc6e 100644 --- a/app/api/v1/person_images.py +++ b/app/api/v1/person_images.py @@ -7,71 +7,52 @@ from app.schemas.person_image import ( PersonImageCreate, PersonImageUpdate ) -from app.services import person_image as user_image_service - +from app.services import person_image as person_image_service +from app.models.users import User +from app.schemas.response import StandardResponse router = APIRouter() -@router.get("", response_model=List[PersonImage]) +@router.get("", tags=["person_images"]) async def get_person_images( db: AsyncSession = Depends(deps.get_db), - current_user = Depends(deps.get_current_user), + current_user: User = Depends(deps.get_current_user), skip: int = 0, limit: int = 100 ): """获取当前用户的所有人物形象""" - return await user_image_service.get_person_images_by_user( + + images = await person_image_service.get_person_images_by_user( db=db, user_id=current_user.id, skip=skip, limit=limit ) + return StandardResponse(code=200, message="获取人物形象成功", data=images) + -@router.post("", response_model=PersonImage, status_code=status.HTTP_201_CREATED) +@router.post("", response_model=PersonImage, tags=["person_images"]) async def create_person_image( *, db: AsyncSession = Depends(deps.get_db), - current_user = Depends(deps.get_current_user), - image_in: PersonImageCreate + current_user: User = Depends(deps.get_current_user), + image_in: PersonImage ): """创建新的人物形象""" image_in.user_id = current_user.id - return await user_image_service.create_person_image(db=db, image=image_in) + image = await person_image_service.create_person_image(db=db, image=image_in) -@router.put("/{image_id}", response_model=PersonImage) -async def update_person_image( - *, - db: AsyncSession = Depends(deps.get_db), - current_user = Depends(deps.get_current_user), - image_id: int, - image_in: PersonImageUpdate -): - """更新人物形象""" - image = await user_image_service.get_person_image(db=db, image_id=image_id) - if not image: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="人物形象不存在" - ) - if image.user_id != current_user.id: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="没有权限修改此人物形象" - ) - return await user_image_service.update_person_image( - db=db, - image_id=image_id, - image=image_in - ) + return StandardResponse(code=200, message="创建人物形象成功", data=image) -@router.delete("/{image_id}", status_code=status.HTTP_204_NO_CONTENT) + +@router.delete("/{image_id}", tags=["person_images"]) async def delete_person_image( *, db: AsyncSession = Depends(deps.get_db), - current_user = Depends(deps.get_current_user), + current_user: User = Depends(deps.get_current_user), image_id: int ): """删除人物形象""" - image = await user_image_service.get_person_image(db=db, image_id=image_id) + image = await person_image_service.get_person_image(db=db, image_id=image_id) if not image: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -82,5 +63,6 @@ async def delete_person_image( status_code=status.HTTP_403_FORBIDDEN, detail="没有权限删除此人物形象" ) - await user_image_service.delete_person_image(db=db, image_id=image_id) - return None \ No newline at end of file + await person_image_service.delete_person_image(db=db, image_id=image_id) + + return StandardResponse(code=200, message="删除人物形象成功") \ No newline at end of file diff --git a/app/main.py b/app/main.py index 490b77e..3ca5c34 100644 --- a/app/main.py +++ b/app/main.py @@ -7,6 +7,7 @@ from app.api.v1.api import api_router from app.core.exceptions import add_exception_handlers from app.core.middleware import add_response_middleware from app.schemas.response import StandardResponse +from starlette.middleware.sessions import SessionMiddleware # 配置日志 logging.basicConfig( @@ -30,7 +31,8 @@ app = FastAPI( title=settings.PROJECT_NAME, description=settings.PROJECT_DESCRIPTION, version=settings.PROJECT_VERSION, - lifespan=lifespan + lifespan=lifespan, + openapi_url=f"{settings.API_V1_STR}/openapi.json" ) logger.info("开始配置应用中间件和路由...") @@ -38,13 +40,21 @@ logger.info("开始配置应用中间件和路由...") # 配置CORS app.add_middleware( CORSMiddleware, - allow_origins=settings.cors_origins, + allow_origins=settings.BACKEND_CORS_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) logger.info("CORS中间件已添加") +# 添加Session中间件 +app.add_middleware( + SessionMiddleware, + secret_key=settings.SECRET_KEY, + session_cookie="session", + max_age=1800 # 30分钟 +) + # 添加异常处理器 add_exception_handlers(app) logger.info("异常处理器已添加") diff --git a/app/models/clothing.py b/app/models/clothing.py index 06f3b5c..eacd882 100644 --- a/app/models/clothing.py +++ b/app/models/clothing.py @@ -23,11 +23,13 @@ class Clothing(Base): id = Column(Integer, primary_key=True, autoincrement=True, index=True) clothing_category_id = Column(Integer, ForeignKey("clothing_category.id"), nullable=False, index=True, comment="分类ID") + user_id = Column(Integer, ForeignKey("users.id"), nullable=False) image_url = Column(String(500), nullable=False, comment="图片URL") create_time = Column(DateTime, default=func.now(), comment="创建时间") # 关系 category = relationship("ClothingCategory", back_populates="clothes") + user = relationship("User", back_populates="clothings") def __repr__(self): return f"" \ No newline at end of file diff --git a/app/models/users.py b/app/models/users.py index 72ea699..d79bbdf 100644 --- a/app/models/users.py +++ b/app/models/users.py @@ -2,6 +2,7 @@ from sqlalchemy import Column, Integer, String, DateTime from sqlalchemy.sql import func from sqlalchemy.orm import relationship from app.db.database import Base +from pydantic import BaseModel class User(Base): """用户数据模型""" @@ -16,6 +17,18 @@ class User(Base): # 关系 person_images = relationship("PersonImage", back_populates="user", cascade="all, delete-orphan") + clothings = relationship("Clothing", back_populates="user", cascade="all, delete-orphan") def __repr__(self): - return f"" \ No newline at end of file + return f"" + + def to_dict(self): + """将模型转换为字典""" + return { + "id": self.id, + "openid": self.openid, + "unionid": self.unionid, + "avatar": self.avatar, + "nickname": self.nickname, + "create_time": self.create_time + } \ No newline at end of file diff --git a/app/schemas/person_image.py b/app/schemas/person_image.py index 5880e14..3d13944 100644 --- a/app/schemas/person_image.py +++ b/app/schemas/person_image.py @@ -4,6 +4,7 @@ from datetime import datetime class PersonImageBase(BaseModel): """人物形象基础模型""" + user_id: int image_url: str is_default: bool = False diff --git a/app/schemas/user.py b/app/schemas/user.py index e7d6120..91a8de9 100644 --- a/app/schemas/user.py +++ b/app/schemas/user.py @@ -3,24 +3,27 @@ from datetime import datetime from typing import Optional class UserBase(BaseModel): + """用户基础模型""" openid: str unionid: Optional[str] = None avatar: Optional[str] = None nickname: Optional[str] = None class UserCreate(UserBase): + """创建用户请求模型""" pass -class UserUpdate(BaseModel): +class UserUpdate(UserBase): + """更新用户请求模型""" + openid: Optional[str] = None + unionid: Optional[str] = None avatar: Optional[str] = None nickname: Optional[str] = None -class UserInDB(UserBase): +class User(UserBase): + """用户响应模型""" id: int create_time: datetime - - class Config: - from_attributes = True -class User(UserInDB): - pass \ No newline at end of file + class Config: + from_attributes = True \ No newline at end of file diff --git a/app/services/clothing.py b/app/services/clothing.py index 9da3df1..5435bdc 100644 --- a/app/services/clothing.py +++ b/app/services/clothing.py @@ -6,6 +6,8 @@ from typing import List, Optional from app.models.clothing import Clothing, ClothingCategory from app.schemas.clothing import ClothingCreate, ClothingUpdate from app.schemas.clothing import ClothingCategoryCreate, ClothingCategoryUpdate +from app.models.users import User +from fastapi import Depends # 衣服分类服务函数 async def get_category(db: AsyncSession, category_id: int): @@ -71,10 +73,11 @@ async def get_clothes(db: AsyncSession, skip: int = 0, limit: int = 100): async def get_clothes_by_category(db: AsyncSession, category_id: int, skip: int = 0, limit: int = 100): """根据分类获取衣服""" + query = select(Clothing).order_by(Clothing.create_time.desc()) + if category_id > 0: + query = query.filter(Clothing.clothing_category_id == category_id) result = await db.execute( - select(Clothing) - .filter(Clothing.clothing_category_id == category_id) - .order_by(Clothing.create_time.desc()) + query .offset(skip) .limit(limit) ) diff --git a/requirements.txt b/requirements.txt index 1a5221a..0ca6872 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ greenlet==2.0.2 python-jose[cryptography]==3.3.0 passlib==1.7.4 httpx==0.24.1 -dashscope==1.10.0 \ No newline at end of file +dashscope==1.10.0 +itsdangerous==2.2.0 \ No newline at end of file