126 lines
3.7 KiB
Python
126 lines
3.7 KiB
Python
import secrets
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.models import StudentRoster, Class_
|
|
|
|
|
|
def generate_invite_code(length: int = 8) -> str:
|
|
chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
|
return "".join(secrets.choice(chars) for _ in range(length))
|
|
|
|
|
|
async def ensure_invite_code(db: AsyncSession, class_id: int) -> str:
|
|
result = await db.execute(select(Class_).where(Class_.id == class_id))
|
|
class_ = result.scalar_one_or_none()
|
|
if class_ is None:
|
|
return ""
|
|
if not class_.invite_code:
|
|
class_.invite_code = generate_invite_code()
|
|
await db.commit()
|
|
await db.refresh(class_)
|
|
return class_.invite_code
|
|
|
|
|
|
async def regenerate_invite_code(db: AsyncSession, class_id: int) -> str:
|
|
result = await db.execute(select(Class_).where(Class_.id == class_id))
|
|
class_ = result.scalar_one_or_none()
|
|
if class_ is None:
|
|
return ""
|
|
class_.invite_code = generate_invite_code()
|
|
await db.commit()
|
|
await db.refresh(class_)
|
|
return class_.invite_code
|
|
|
|
|
|
async def import_roster(
|
|
db: AsyncSession, class_id: int, entries: list[dict]
|
|
) -> int:
|
|
existing_ids: set[str] = set()
|
|
result = await db.execute(
|
|
select(StudentRoster.student_id).where(StudentRoster.class_id == class_id)
|
|
)
|
|
for row in result.all():
|
|
existing_ids.add(row[0])
|
|
|
|
count = 0
|
|
for entry in entries:
|
|
sid = entry.get("student_id", "").strip()
|
|
name = entry.get("name", "").strip()
|
|
if not sid or not name or sid in existing_ids:
|
|
continue
|
|
roster = StudentRoster(class_id=class_id, student_id=sid, name=name)
|
|
db.add(roster)
|
|
existing_ids.add(sid)
|
|
count += 1
|
|
await db.commit()
|
|
return count
|
|
|
|
|
|
async def get_roster(
|
|
db: AsyncSession, class_id: int, page: int = 1, page_size: int = 50
|
|
) -> tuple[list[StudentRoster], int]:
|
|
query = select(StudentRoster).where(StudentRoster.class_id == class_id)
|
|
count_query = select(func.count(StudentRoster.id)).where(
|
|
StudentRoster.class_id == class_id
|
|
)
|
|
|
|
total_result = await db.execute(count_query)
|
|
total = total_result.scalar() or 0
|
|
|
|
result = await db.execute(
|
|
query.order_by(StudentRoster.student_id)
|
|
.offset((page - 1) * page_size)
|
|
.limit(page_size)
|
|
)
|
|
return list(result.scalars().all()), total
|
|
|
|
|
|
async def validate_registration(
|
|
db: AsyncSession, invite_code: str, student_id: str
|
|
) -> StudentRoster | None:
|
|
class_result = await db.execute(
|
|
select(Class_).where(Class_.invite_code == invite_code)
|
|
)
|
|
class_ = class_result.scalar_one_or_none()
|
|
if class_ is None:
|
|
return None
|
|
|
|
roster_result = await db.execute(
|
|
select(StudentRoster).where(
|
|
StudentRoster.class_id == class_.id,
|
|
StudentRoster.student_id == student_id,
|
|
StudentRoster.status == "unregistered",
|
|
)
|
|
)
|
|
return roster_result.scalar_one_or_none()
|
|
|
|
|
|
async def delete_roster_entry(db: AsyncSession, roster_id: int) -> bool:
|
|
result = await db.execute(
|
|
select(StudentRoster).where(StudentRoster.id == roster_id)
|
|
)
|
|
entry = result.scalar_one_or_none()
|
|
if entry is None:
|
|
return False
|
|
if entry.status == "registered":
|
|
return False
|
|
await db.delete(entry)
|
|
await db.commit()
|
|
return True
|
|
|
|
|
|
async def clear_unregistered_roster(db: AsyncSession, class_id: int) -> int:
|
|
result = await db.execute(
|
|
select(StudentRoster).where(
|
|
StudentRoster.class_id == class_id,
|
|
StudentRoster.status == "unregistered",
|
|
)
|
|
)
|
|
entries = list(result.scalars().all())
|
|
for entry in entries:
|
|
await db.delete(entry)
|
|
await db.commit()
|
|
return len(entries)
|