diff --git a/app/api/endpoints/user.py b/app/api/endpoints/user.py index 8e83f61..8988a28 100644 --- a/app/api/endpoints/user.py +++ b/app/api/endpoints/user.py @@ -302,17 +302,18 @@ async def password_login( if not verify_password(login_data.password, user.password): return error_response(code=401, message="密码错误") - - if user.roles not in [UserRole.DELIVERYMAN, UserRole.MERCHANT, UserRole.ADMIN]: - return error_response(code=401, message="只有配送员、商家和管理员可以登录") - if user.roles == UserRole.MERCHANT: + if login_data.is_admin: + if UserRole.ADMIN not in user.roles: + return error_response(code=401, message="管理员账户,请先设置管理员角色") + + if UserRole.MERCHANT in user.roles: # 检查是否有商家设置了当前用户 id merchant = db.query(MerchantDB).filter(MerchantDB.user_id == user.userid).first() if not merchant: - return error_response(code=401, message="商家账户,请先关联商家") + return error_response(code=401, message="商家账户,请先关联商 家") - if user.roles == UserRole.DELIVERYMAN and not user.community_id: + if UserRole.DELIVERYMAN in user.roles and not user.community_id: return error_response(code=401, message="配送员账户,请先设置归属小区") # 生成访问令牌 diff --git a/app/models/user.py b/app/models/user.py index 9fa67b3..8b41b66 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -92,7 +92,8 @@ class UserUpdate(BaseModel): class UserPasswordLogin(BaseModel): phone: str = Field(..., pattern="^1[3-9]\d{9}$") - password: str = Field(..., min_length=6, max_length=20) + password: str = Field(..., min_length=6, max_length=20) + is_admin: bool = Field(default=False) class ChangePasswordRequest(BaseModel): verify_code: str = Field(..., min_length=6, max_length=6)