hku-class/backend/repair_class_memberships.py
2026-04-27 10:51:58 +08:00

105 lines
3.4 KiB
Python

"""
Repair missing class_memberships data for databases that were partially migrated.
Usage:
python repair_class_memberships.py
Or inside Docker:
docker compose exec backend python repair_class_memberships.py
"""
from __future__ import annotations
import asyncio
from collections import defaultdict
from sqlalchemy import select, text
from app.db.database import async_session, engine
from app.db.base import Base
from app.db.models import (
Announcement,
Assignment,
AssignmentSubmission,
ClassMembership,
FundRecord,
Resource,
Timeline,
User,
Vote,
)
async def ensure_class_memberships_table() -> None:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all, tables=[ClassMembership.__table__])
async def repair_memberships() -> None:
await ensure_class_memberships_table()
async with async_session() as db:
existing_result = await db.execute(
select(ClassMembership.user_id, ClassMembership.class_id)
)
existing = {(row[0], row[1]) for row in existing_result.all()}
inferred: dict[tuple[int, int], str] = {}
def remember(user_id: int | None, class_id: int | None, role: str) -> None:
if not user_id or not class_id:
return
key = (user_id, class_id)
current = inferred.get(key)
if current == "teacher":
return
inferred[key] = "teacher" if role == "teacher" else (current or "student")
user_rows = await db.execute(select(User.id, User.role))
user_roles = {user_id: role for user_id, role in user_rows.all()}
statement_specs = [
(Announcement.author_id, Announcement.class_id),
(Timeline.author_id, Timeline.class_id),
(Resource.uploader_id, Resource.class_id),
(Vote.creator_id, Vote.class_id),
(Assignment.creator_id, Assignment.class_id),
(FundRecord.recorder_id, FundRecord.class_id),
]
for user_col, class_col in statement_specs:
rows = await db.execute(select(user_col, class_col))
for user_id, class_id in rows.all():
remember(user_id, class_id, user_roles.get(user_id, "student"))
submission_rows = await db.execute(
select(AssignmentSubmission.student_id, Assignment.class_id).join(
Assignment, Assignment.id == AssignmentSubmission.assignment_id
)
)
for user_id, class_id in submission_rows.all():
remember(user_id, class_id, user_roles.get(user_id, "student"))
grouped = defaultdict(list)
for (user_id, class_id), role in inferred.items():
if (user_id, class_id) not in existing:
grouped[role].append((user_id, class_id))
created = 0
for role, pairs in grouped.items():
for user_id, class_id in pairs:
db.add(
ClassMembership(
user_id=user_id,
class_id=class_id,
membership_role="teacher" if role == "teacher" else "student",
)
)
created += 1
await db.commit()
print(f"repair_class_memberships: created {created} memberships")
if __name__ == "__main__":
asyncio.run(repair_memberships())