hku-class/backend/app/services/roster_service.py
2026-04-12 18:15:38 +08:00

126 lines
3.7 KiB
Python

import secrets
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import StudentRoster, Class_
def generate_invite_code(length: int = 8) -> str:
chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
return "".join(secrets.choice(chars) for _ in range(length))
async def ensure_invite_code(db: AsyncSession, class_id: int) -> str:
result = await db.execute(select(Class_).where(Class_.id == class_id))
class_ = result.scalar_one_or_none()
if class_ is None:
return ""
if not class_.invite_code:
class_.invite_code = generate_invite_code()
await db.commit()
await db.refresh(class_)
return class_.invite_code
async def regenerate_invite_code(db: AsyncSession, class_id: int) -> str:
result = await db.execute(select(Class_).where(Class_.id == class_id))
class_ = result.scalar_one_or_none()
if class_ is None:
return ""
class_.invite_code = generate_invite_code()
await db.commit()
await db.refresh(class_)
return class_.invite_code
async def import_roster(
db: AsyncSession, class_id: int, entries: list[dict]
) -> int:
existing_ids: set[str] = set()
result = await db.execute(
select(StudentRoster.student_id).where(StudentRoster.class_id == class_id)
)
for row in result.all():
existing_ids.add(row[0])
count = 0
for entry in entries:
sid = entry.get("student_id", "").strip()
name = entry.get("name", "").strip()
if not sid or not name or sid in existing_ids:
continue
roster = StudentRoster(class_id=class_id, student_id=sid, name=name)
db.add(roster)
existing_ids.add(sid)
count += 1
await db.commit()
return count
async def get_roster(
db: AsyncSession, class_id: int, page: int = 1, page_size: int = 50
) -> tuple[list[StudentRoster], int]:
query = select(StudentRoster).where(StudentRoster.class_id == class_id)
count_query = select(func.count(StudentRoster.id)).where(
StudentRoster.class_id == class_id
)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
result = await db.execute(
query.order_by(StudentRoster.student_id)
.offset((page - 1) * page_size)
.limit(page_size)
)
return list(result.scalars().all()), total
async def validate_registration(
db: AsyncSession, invite_code: str, student_id: str
) -> StudentRoster | None:
class_result = await db.execute(
select(Class_).where(Class_.invite_code == invite_code)
)
class_ = class_result.scalar_one_or_none()
if class_ is None:
return None
roster_result = await db.execute(
select(StudentRoster).where(
StudentRoster.class_id == class_.id,
StudentRoster.student_id == student_id,
StudentRoster.status == "unregistered",
)
)
return roster_result.scalar_one_or_none()
async def delete_roster_entry(db: AsyncSession, roster_id: int) -> bool:
result = await db.execute(
select(StudentRoster).where(StudentRoster.id == roster_id)
)
entry = result.scalar_one_or_none()
if entry is None:
return False
if entry.status == "registered":
return False
await db.delete(entry)
await db.commit()
return True
async def clear_unregistered_roster(db: AsyncSession, class_id: int) -> int:
result = await db.execute(
select(StudentRoster).where(
StudentRoster.class_id == class_id,
StudentRoster.status == "unregistered",
)
)
entries = list(result.scalars().all())
for entry in entries:
await db.delete(entry)
await db.commit()
return len(entries)