diff --git a/app/api/deps.py b/app/api/deps.py index c9e63ac..7f10177 100644 --- a/app/api/deps.py +++ b/app/api/deps.py @@ -2,7 +2,7 @@ from fastapi import Depends, HTTPException, Header, Cookie from typing import Optional from sqlalchemy.orm import Session from app.models.database import get_db -from app.models.user import UserDB +from app.models.user import UserDB, UserRole from app.core.security import verify_token async def get_current_user( @@ -32,6 +32,14 @@ async def get_current_user( async def get_admin_user( current_user: UserDB = Depends(get_current_user) ) -> UserDB: - if not current_user.is_admin: + if UserRole.ADMIN not in current_user.roles: raise HTTPException(status_code=403, detail="需要管理员权限") + return current_user + +async def get_deliveryman_user( + current_user: UserDB = Depends(get_current_user) +) -> UserDB: + """验证配送员权限""" + if UserRole.DELIVERYMAN not in current_user.roles: + raise HTTPException(status_code=403, detail="需要配送员权限") return current_user \ No newline at end of file diff --git a/app/api/endpoints/user.py b/app/api/endpoints/user.py index 4a1bf0d..c0a4574 100644 --- a/app/api/endpoints/user.py +++ b/app/api/endpoints/user.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, HTTPException, Depends, Response from sqlalchemy.orm import Session -from app.models.user import UserLogin, UserInfo, VerifyCodeRequest, UserDB, UserUpdate -from app.api.deps import get_current_user +from app.models.user import UserLogin, UserInfo, VerifyCodeRequest, UserDB, UserUpdate, UserRole +from app.api.deps import get_current_user, get_admin_user from app.models.database import get_db import random import string @@ -13,6 +13,7 @@ from datetime import timedelta from app.core.security import create_access_token, set_jwt_cookie, clear_jwt_cookie from app.core.response import success_response, error_response, ResponseModel from pydantic import BaseModel, Field +from typing import List router = APIRouter() @@ -86,7 +87,8 @@ async def login( if not user: user = UserDB( username=f"user_{phone[-4:]}", - phone=phone + phone=phone, + roles=[UserRole.USER] ) db.add(user) db.commit() @@ -189,6 +191,40 @@ async def update_user_info( message="用户信息更新成功", data=UserInfo.model_validate(current_user) ) + except Exception as e: + db.rollback() + return error_response(code=500, message=f"更新失败: {str(e)}") + +@router.put("/roles", response_model=ResponseModel) +async def update_user_roles( + user_id: int, + roles: List[UserRole], + db: Session = Depends(get_db), + admin: UserDB = Depends(get_admin_user) +): + """更新用户角色(管理员)""" + user = db.query(UserDB).filter(UserDB.userid == user_id).first() + if not user: + return error_response(code=404, message="用户不存在") + + # 确保至少有一个角色 + if not roles: + return error_response(code=400, message="用户必须至少有一个角色") + + # 确保普通用户角色始终存在 + if UserRole.USER not in roles: + roles.append(UserRole.USER) + + # 更新角色 + user.roles = list(set(roles)) # 去重 + + try: + db.commit() + db.refresh(user) + return success_response( + message="用户角色更新成功", + data=UserInfo.model_validate(user) + ) except Exception as e: db.rollback() return error_response(code=500, message=f"更新失败: {str(e)}") \ No newline at end of file diff --git a/app/models/user.py b/app/models/user.py index 08d7899..6a68cb7 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -1,8 +1,15 @@ from sqlalchemy import Column, String, DateTime,Integer, Boolean from sqlalchemy.sql import func +from sqlalchemy.dialects.mysql import JSON from pydantic import BaseModel, Field from .database import Base -from typing import Optional +from typing import Optional, List +import enum + +class UserRole(str, enum.Enum): + USER = "user" + DELIVERYMAN = "deliveryman" + ADMIN = "admin" # 数据库模型 class UserDB(Base): @@ -11,7 +18,8 @@ class UserDB(Base): userid = Column(Integer, primary_key=True,autoincrement=True, index=True) username = Column(String(50)) phone = Column(String(11), unique=True, index=True) - is_admin = Column(Boolean, default=False) + avatar = Column(String(200), nullable=True) # 头像URL地址 + roles = Column(JSON, default=lambda: [UserRole.USER]) # 存储角色列表 create_time = Column(DateTime(timezone=True), server_default=func.now()) update_time = Column(DateTime(timezone=True), onupdate=func.now()) @@ -24,7 +32,8 @@ class UserInfo(BaseModel): userid: int username: str phone: str - is_admin: bool + avatar: Optional[str] = None + roles: List[UserRole] class Config: from_attributes = True