105 lines
3.4 KiB
Python
105 lines
3.4 KiB
Python
"""
|
|
Repair missing class_memberships data for databases that were partially migrated.
|
|
|
|
Usage:
|
|
python repair_class_memberships.py
|
|
Or inside Docker:
|
|
docker compose exec backend python repair_class_memberships.py
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections import defaultdict
|
|
|
|
from sqlalchemy import select, text
|
|
|
|
from app.db.database import async_session, engine
|
|
from app.db.base import Base
|
|
from app.db.models import (
|
|
Announcement,
|
|
Assignment,
|
|
AssignmentSubmission,
|
|
ClassMembership,
|
|
FundRecord,
|
|
Resource,
|
|
Timeline,
|
|
User,
|
|
Vote,
|
|
)
|
|
|
|
|
|
async def ensure_class_memberships_table() -> None:
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all, tables=[ClassMembership.__table__])
|
|
|
|
|
|
async def repair_memberships() -> None:
|
|
await ensure_class_memberships_table()
|
|
|
|
async with async_session() as db:
|
|
existing_result = await db.execute(
|
|
select(ClassMembership.user_id, ClassMembership.class_id)
|
|
)
|
|
existing = {(row[0], row[1]) for row in existing_result.all()}
|
|
|
|
inferred: dict[tuple[int, int], str] = {}
|
|
|
|
def remember(user_id: int | None, class_id: int | None, role: str) -> None:
|
|
if not user_id or not class_id:
|
|
return
|
|
key = (user_id, class_id)
|
|
current = inferred.get(key)
|
|
if current == "teacher":
|
|
return
|
|
inferred[key] = "teacher" if role == "teacher" else (current or "student")
|
|
|
|
user_rows = await db.execute(select(User.id, User.role))
|
|
user_roles = {user_id: role for user_id, role in user_rows.all()}
|
|
|
|
statement_specs = [
|
|
(Announcement.author_id, Announcement.class_id),
|
|
(Timeline.author_id, Timeline.class_id),
|
|
(Resource.uploader_id, Resource.class_id),
|
|
(Vote.creator_id, Vote.class_id),
|
|
(Assignment.creator_id, Assignment.class_id),
|
|
(FundRecord.recorder_id, FundRecord.class_id),
|
|
]
|
|
|
|
for user_col, class_col in statement_specs:
|
|
rows = await db.execute(select(user_col, class_col))
|
|
for user_id, class_id in rows.all():
|
|
remember(user_id, class_id, user_roles.get(user_id, "student"))
|
|
|
|
submission_rows = await db.execute(
|
|
select(AssignmentSubmission.student_id, Assignment.class_id).join(
|
|
Assignment, Assignment.id == AssignmentSubmission.assignment_id
|
|
)
|
|
)
|
|
for user_id, class_id in submission_rows.all():
|
|
remember(user_id, class_id, user_roles.get(user_id, "student"))
|
|
|
|
grouped = defaultdict(list)
|
|
for (user_id, class_id), role in inferred.items():
|
|
if (user_id, class_id) not in existing:
|
|
grouped[role].append((user_id, class_id))
|
|
|
|
created = 0
|
|
for role, pairs in grouped.items():
|
|
for user_id, class_id in pairs:
|
|
db.add(
|
|
ClassMembership(
|
|
user_id=user_id,
|
|
class_id=class_id,
|
|
membership_role="teacher" if role == "teacher" else "student",
|
|
)
|
|
)
|
|
created += 1
|
|
|
|
await db.commit()
|
|
print(f"repair_class_memberships: created {created} memberships")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(repair_memberships())
|