167 lines
5.0 KiB
Python
167 lines
5.0 KiB
Python
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import selectinload
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.auth import decode_access_token
|
|
from app.db.database import get_db
|
|
from app.db.models import ClassMembership, User
|
|
|
|
security = HTTPBearer()
|
|
|
|
CLASS_PERMISSIONS = {
|
|
"member_view",
|
|
"member_manage",
|
|
"committee_manage",
|
|
"announcement_manage",
|
|
"timeline_manage",
|
|
"vote_manage",
|
|
"schedule_manage",
|
|
"resource_manage",
|
|
"assignment_manage",
|
|
"fund_manage",
|
|
"module_manage",
|
|
}
|
|
|
|
TEACHER_DEFAULT_PERMISSIONS = {
|
|
"member_view",
|
|
"member_manage",
|
|
"committee_manage",
|
|
"announcement_manage",
|
|
"timeline_manage",
|
|
"vote_manage",
|
|
"schedule_manage",
|
|
"resource_manage",
|
|
"assignment_manage",
|
|
"module_manage",
|
|
}
|
|
|
|
|
|
async def get_current_user(
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
db: AsyncSession = Depends(get_db),
|
|
) -> User:
|
|
payload = decode_access_token(credentials.credentials)
|
|
if payload is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Token invalid or expired",
|
|
)
|
|
|
|
user_id = payload.get("sub")
|
|
if user_id is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token format",
|
|
)
|
|
|
|
result = await db.execute(
|
|
select(User)
|
|
.options(
|
|
selectinload(User.memberships),
|
|
selectinload(User.memberships).selectinload(ClassMembership.class_),
|
|
)
|
|
.where(User.id == int(user_id))
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
|
|
if user is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found"
|
|
)
|
|
if user.status == "inactive":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN, detail="Account inactive"
|
|
)
|
|
if user.status == "disabled":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN, detail="Account disabled"
|
|
)
|
|
|
|
return user
|
|
|
|
|
|
def require_role(*roles: str):
|
|
"""Factory: returns a dependency that checks user role."""
|
|
|
|
async def _check(user: User = Depends(get_current_user)) -> User:
|
|
if user.role not in roles:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Insufficient permissions",
|
|
)
|
|
return user
|
|
|
|
return _check
|
|
|
|
|
|
def get_membership_for_class(user: User, class_id: int | None) -> ClassMembership | None:
|
|
return user.get_membership(class_id)
|
|
|
|
|
|
def get_active_membership(
|
|
user: User, class_id: int | None = None
|
|
) -> ClassMembership | None:
|
|
membership = get_membership_for_class(user, class_id)
|
|
if membership is not None:
|
|
return membership
|
|
return user.get_default_membership()
|
|
|
|
|
|
def get_effective_class_permissions(user: User, class_id: int | None = None) -> set[str]:
|
|
if user.role == "super_admin":
|
|
return set(CLASS_PERMISSIONS)
|
|
membership = get_active_membership(user, class_id)
|
|
scoped_permissions = membership.get_class_permissions() if membership else []
|
|
if user.role == "teacher":
|
|
return set(TEACHER_DEFAULT_PERMISSIONS) | set(scoped_permissions)
|
|
scoped_permissions = [
|
|
permission
|
|
for permission in scoped_permissions
|
|
if permission not in {"member_view", "member_manage", "committee_manage"}
|
|
]
|
|
return set(scoped_permissions)
|
|
|
|
|
|
def can_access_class(user: User, class_id: int | None) -> bool:
|
|
if class_id is None:
|
|
return False
|
|
if user.role in {"super_admin", "teacher"}:
|
|
return True
|
|
return get_membership_for_class(user, class_id) is not None
|
|
|
|
|
|
def resolve_class_id_for_user(user: User, requested_class_id: int | None) -> int | None:
|
|
if user.role in {"super_admin", "teacher"}:
|
|
return requested_class_id
|
|
if requested_class_id is not None and can_access_class(user, requested_class_id):
|
|
return requested_class_id
|
|
membership = user.get_default_membership()
|
|
return membership.class_id if membership else None
|
|
|
|
|
|
def ensure_class_access(user: User, class_id: int | None):
|
|
if not can_access_class(user, class_id):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Access denied for this class",
|
|
)
|
|
|
|
|
|
def ensure_class_permission(user: User, permission: str, class_id: int | None = None):
|
|
if permission not in CLASS_PERMISSIONS:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Unknown permission: {permission}",
|
|
)
|
|
if class_id is not None:
|
|
ensure_class_access(user, class_id)
|
|
if user.role == "super_admin":
|
|
return
|
|
if permission not in get_effective_class_permissions(user, class_id):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Insufficient permissions",
|
|
)
|