""" Repair required tables and missing class_memberships data for databases that were partially migrated. Usage: python repair_class_memberships.py Or inside Docker: docker compose exec backend python repair_class_memberships.py """ from __future__ import annotations import asyncio from collections import defaultdict from sqlalchemy import select, text from app.db.database import async_session, engine from app.db.base import Base from app.db.models import ( Announcement, Assignment, AssignmentSubmission, ClassMembership, EmailVerificationCode, FundRecord, Resource, Timeline, User, Vote, ) async def ensure_class_memberships_table() -> None: async with engine.begin() as conn: await conn.run_sync( Base.metadata.create_all, tables=[ ClassMembership.__table__, EmailVerificationCode.__table__, ], ) async def repair_memberships() -> None: await ensure_class_memberships_table() async with async_session() as db: existing_result = await db.execute( select(ClassMembership.user_id, ClassMembership.class_id) ) existing = {(row[0], row[1]) for row in existing_result.all()} inferred: dict[tuple[int, int], str] = {} def remember(user_id: int | None, class_id: int | None, role: str) -> None: if not user_id or not class_id: return key = (user_id, class_id) current = inferred.get(key) if current == "teacher": return inferred[key] = "teacher" if role == "teacher" else (current or "student") user_rows = await db.execute(select(User.id, User.role)) user_roles = {user_id: role for user_id, role in user_rows.all()} statement_specs = [ (Announcement.author_id, Announcement.class_id), (Timeline.author_id, Timeline.class_id), (Resource.uploader_id, Resource.class_id), (Vote.creator_id, Vote.class_id), (Assignment.creator_id, Assignment.class_id), (FundRecord.recorder_id, FundRecord.class_id), ] for user_col, class_col in statement_specs: rows = await db.execute(select(user_col, class_col)) for user_id, class_id in rows.all(): remember(user_id, class_id, user_roles.get(user_id, "student")) submission_rows = await db.execute( select(AssignmentSubmission.student_id, Assignment.class_id).join( Assignment, Assignment.id == AssignmentSubmission.assignment_id ) ) for user_id, class_id in submission_rows.all(): remember(user_id, class_id, user_roles.get(user_id, "student")) grouped = defaultdict(list) for (user_id, class_id), role in inferred.items(): if (user_id, class_id) not in existing: grouped[role].append((user_id, class_id)) created = 0 for role, pairs in grouped.items(): for user_id, class_id in pairs: db.add( ClassMembership( user_id=user_id, class_id=class_id, membership_role="teacher" if role == "teacher" else "student", ) ) created += 1 await db.commit() print(f"repair_class_memberships: created {created} memberships") if __name__ == "__main__": asyncio.run(repair_memberships())