211 lines
6.2 KiB
Python
211 lines
6.2 KiB
Python
import secrets
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.orm import selectinload
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.auth import hash_password
|
|
from app.db.models import ClassMembership, Class_, User
|
|
|
|
|
|
def generate_invite_code(length: int = 8) -> str:
|
|
chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
|
return "".join(secrets.choice(chars) for _ in range(length))
|
|
|
|
|
|
async def ensure_invite_code(db: AsyncSession, class_id: int) -> str:
|
|
result = await db.execute(select(Class_).where(Class_.id == class_id))
|
|
class_ = result.scalar_one_or_none()
|
|
if class_ is None:
|
|
return ""
|
|
if not class_.invite_code:
|
|
class_.invite_code = generate_invite_code()
|
|
await db.commit()
|
|
await db.refresh(class_)
|
|
return class_.invite_code
|
|
|
|
|
|
async def regenerate_invite_code(db: AsyncSession, class_id: int) -> str:
|
|
result = await db.execute(select(Class_).where(Class_.id == class_id))
|
|
class_ = result.scalar_one_or_none()
|
|
if class_ is None:
|
|
return ""
|
|
class_.invite_code = generate_invite_code()
|
|
await db.commit()
|
|
await db.refresh(class_)
|
|
return class_.invite_code
|
|
|
|
|
|
async def import_members(
|
|
db: AsyncSession, class_id: int, entries: list[dict]
|
|
) -> int:
|
|
incoming_entries: list[tuple[str, str]] = []
|
|
for entry in entries:
|
|
sid = entry.get("student_id", "").strip()
|
|
name = entry.get("name", "").strip()
|
|
if sid and name:
|
|
incoming_entries.append((sid, name))
|
|
|
|
if not incoming_entries:
|
|
return 0
|
|
|
|
student_ids = {sid for sid, _ in incoming_entries}
|
|
result = await db.execute(
|
|
select(User)
|
|
.options(
|
|
selectinload(User.memberships),
|
|
selectinload(User.memberships).selectinload(ClassMembership.class_),
|
|
)
|
|
.where(User.student_id.in_(student_ids))
|
|
)
|
|
existing_users = {user.student_id: user for user in result.scalars().unique().all() if user.student_id}
|
|
|
|
count = 0
|
|
seen_in_batch: set[str] = set()
|
|
for sid, name in incoming_entries:
|
|
if sid in seen_in_batch:
|
|
continue
|
|
seen_in_batch.add(sid)
|
|
|
|
user = existing_users.get(sid)
|
|
if user is not None:
|
|
if user.get_membership(class_id) is not None:
|
|
continue
|
|
if user.status == "inactive":
|
|
user.name = name
|
|
db.add(
|
|
ClassMembership(
|
|
user_id=user.id,
|
|
class_id=class_id,
|
|
membership_role="student",
|
|
)
|
|
)
|
|
count += 1
|
|
continue
|
|
|
|
placeholder_email = f"inactive+{class_id}.{sid}@member.local"
|
|
new_user = User(
|
|
email=placeholder_email,
|
|
# Keep inactive imports compatible with older schemas where password_hash is NOT NULL.
|
|
password_hash=hash_password(secrets.token_urlsafe(24)),
|
|
name=name,
|
|
student_id=sid,
|
|
role="student",
|
|
status="inactive",
|
|
)
|
|
db.add(new_user)
|
|
await db.flush()
|
|
db.add(
|
|
ClassMembership(
|
|
user_id=new_user.id,
|
|
class_id=class_id,
|
|
membership_role="student",
|
|
)
|
|
)
|
|
existing_users[sid] = new_user
|
|
count += 1
|
|
|
|
await db.commit()
|
|
return count
|
|
|
|
|
|
async def get_inactive_members(
|
|
db: AsyncSession, class_id: int, page: int = 1, page_size: int = 50
|
|
) -> tuple[list[User], int]:
|
|
query = (
|
|
select(User)
|
|
.options(
|
|
selectinload(User.memberships),
|
|
selectinload(User.memberships).selectinload(ClassMembership.class_),
|
|
)
|
|
.join(ClassMembership)
|
|
.where(
|
|
ClassMembership.class_id == class_id,
|
|
User.status == "inactive",
|
|
)
|
|
)
|
|
count_query = (
|
|
select(func.count(User.id))
|
|
.join(ClassMembership)
|
|
.where(
|
|
ClassMembership.class_id == class_id,
|
|
User.status == "inactive",
|
|
)
|
|
)
|
|
|
|
total_result = await db.execute(count_query)
|
|
total = total_result.scalar() or 0
|
|
|
|
result = await db.execute(
|
|
query.order_by(User.student_id)
|
|
.offset((page - 1) * page_size)
|
|
.limit(page_size)
|
|
)
|
|
return list(result.scalars().unique().all()), total
|
|
|
|
|
|
async def validate_registration(
|
|
db: AsyncSession, invite_code: str, student_id: str
|
|
) -> tuple[User, int] | None:
|
|
class_result = await db.execute(
|
|
select(Class_).where(Class_.invite_code == invite_code)
|
|
)
|
|
class_ = class_result.scalar_one_or_none()
|
|
if class_ is None:
|
|
return None
|
|
user_result = await db.execute(
|
|
select(User)
|
|
.options(
|
|
selectinload(User.memberships),
|
|
selectinload(User.memberships).selectinload(ClassMembership.class_),
|
|
)
|
|
.join(ClassMembership)
|
|
.where(
|
|
ClassMembership.class_id == class_.id,
|
|
User.student_id == student_id,
|
|
User.status == "inactive",
|
|
)
|
|
)
|
|
user = user_result.scalar_one_or_none()
|
|
if user is None:
|
|
return None
|
|
user.set_active_membership(class_.id)
|
|
return user, class_.id
|
|
|
|
|
|
async def delete_inactive_member(db: AsyncSession, class_id: int, user_id: int) -> bool:
|
|
result = await db.execute(
|
|
select(User)
|
|
.options(selectinload(User.memberships))
|
|
.where(User.id == user_id)
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
if user is None or user.status != "inactive":
|
|
return False
|
|
membership = user.get_membership(class_id)
|
|
if membership is None:
|
|
return False
|
|
other_memberships = [item for item in user.memberships if item.class_id != class_id]
|
|
await db.delete(membership)
|
|
await db.flush()
|
|
if not other_memberships:
|
|
await db.delete(user)
|
|
await db.commit()
|
|
return True
|
|
|
|
|
|
async def clear_inactive_members(db: AsyncSession, class_id: int) -> int:
|
|
result = await db.execute(
|
|
select(User.id)
|
|
.join(ClassMembership)
|
|
.where(
|
|
ClassMembership.class_id == class_id,
|
|
User.status == "inactive",
|
|
)
|
|
)
|
|
user_ids = list(result.scalars().all())
|
|
removed = 0
|
|
for user_id in user_ids:
|
|
removed += 1 if await delete_inactive_member(db, class_id, user_id) else 0
|
|
return removed
|