diff --git a/backend/Dockerfile b/backend/Dockerfile index d0080e5..2c9718c 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -15,4 +15,4 @@ RUN mkdir -p /app/data EXPOSE 8000 -CMD ["sh", "-c", "alembic upgrade head && uvicorn app.main:app --host 0.0.0.0 --port 8000"] +CMD ["sh", "-c", "alembic upgrade head && python repair_class_memberships.py && uvicorn app.main:app --host 0.0.0.0 --port 8000"] diff --git a/backend/alembic/versions/20260426_remove_legacy_roster_and_user_columns.py b/backend/alembic/versions/20260426_remove_legacy_roster_and_user_columns.py index 2db9249..7dbfd9f 100644 --- a/backend/alembic/versions/20260426_remove_legacy_roster_and_user_columns.py +++ b/backend/alembic/versions/20260426_remove_legacy_roster_and_user_columns.py @@ -9,6 +9,8 @@ from __future__ import annotations from alembic import op import sqlalchemy as sa +from app.db.base import Base +from app.db import models as _models # noqa: F401 ensure metadata registered revision = "20260426_remove_legacy" @@ -17,15 +19,137 @@ branch_labels = None depends_on = None +def _ensure_class_memberships_table(inspector: sa.Inspector) -> None: + tables = set(inspector.get_table_names()) + if "class_memberships" in tables: + return + + op.create_table( + "class_memberships", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("class_id", sa.Integer(), nullable=False), + sa.Column("membership_role", sa.String(length=20), nullable=False, server_default="student"), + sa.Column("committee_role", sa.String(length=50), nullable=True), + sa.Column("class_permissions", sa.Text(), nullable=True), + sa.Column("created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.func.now(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.ForeignKeyConstraint(["class_id"], ["classes.id"]), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("user_id", "class_id", name="uq_class_membership_user_class"), + ) + op.create_index("ix_class_memberships_user_id", "class_memberships", ["user_id"]) + op.create_index("ix_class_memberships_class_id", "class_memberships", ["class_id"]) + + +def _existing_memberships(bind) -> set[tuple[int, int]]: + rows = bind.execute( + sa.text("SELECT user_id, class_id FROM class_memberships") + ).fetchall() + return {(int(user_id), int(class_id)) for user_id, class_id in rows} + + +def _insert_membership( + bind, + existing: set[tuple[int, int]], + *, + user_id: int, + class_id: int, + membership_role: str, + committee_role: str | None, + class_permissions: str | None, + created_at, + updated_at, +) -> None: + key = (user_id, class_id) + if key in existing: + return + + bind.execute( + sa.text( + """ + INSERT INTO class_memberships ( + user_id, class_id, membership_role, committee_role, class_permissions, created_at, updated_at + ) VALUES ( + :user_id, :class_id, :membership_role, :committee_role, :class_permissions, :created_at, :updated_at + ) + """ + ), + { + "user_id": user_id, + "class_id": class_id, + "membership_role": membership_role, + "committee_role": committee_role, + "class_permissions": class_permissions, + "created_at": created_at, + "updated_at": updated_at, + }, + ) + existing.add(key) + + def upgrade() -> None: bind = op.get_bind() + Base.metadata.create_all(bind=bind, checkfirst=True) inspector = sa.inspect(bind) + _ensure_class_memberships_table(inspector) tables = set(inspector.get_table_names()) + existing = _existing_memberships(bind) + + user_columns = {column["name"] for column in inspector.get_columns("users")} + if "class_id" in user_columns: + legacy_rows = bind.execute( + sa.text( + """ + SELECT id, class_id, role, committee_role, class_permissions, created_at, updated_at + FROM users + WHERE class_id IS NOT NULL + """ + ) + ).mappings() + for row in legacy_rows: + if row["role"] == "super_admin": + continue + _insert_membership( + bind, + existing, + user_id=int(row["id"]), + class_id=int(row["class_id"]), + membership_role="teacher" if row["role"] == "teacher" else "student", + committee_role=row["committee_role"], + class_permissions=row["class_permissions"], + created_at=row["created_at"], + updated_at=row["updated_at"], + ) + + if "student_rosters" in tables: + roster_rows = bind.execute( + sa.text( + """ + SELECT user_id, class_id, created_at + FROM student_rosters + WHERE user_id IS NOT NULL + """ + ) + ).mappings() + for row in roster_rows: + _insert_membership( + bind, + existing, + user_id=int(row["user_id"]), + class_id=int(row["class_id"]), + membership_role="student", + committee_role=None, + class_permissions=None, + created_at=row["created_at"], + updated_at=row["created_at"], + ) + if "student_rosters" in tables: op.drop_table("student_rosters") - user_columns = {column["name"] for column in inspector.get_columns("users")} legacy_columns = {"class_id", "committee_role", "class_permissions"} if legacy_columns.intersection(user_columns): with op.batch_alter_table("users") as batch_op: @@ -38,6 +162,17 @@ def upgrade() -> None: def downgrade() -> None: + bind = op.get_bind() + inspector = sa.inspect(bind) + tables = set(inspector.get_table_names()) + if "class_memberships" in tables: + indexes = {index["name"] for index in inspector.get_indexes("class_memberships")} + if "ix_class_memberships_class_id" in indexes: + op.drop_index("ix_class_memberships_class_id", table_name="class_memberships") + if "ix_class_memberships_user_id" in indexes: + op.drop_index("ix_class_memberships_user_id", table_name="class_memberships") + op.drop_table("class_memberships") + with op.batch_alter_table("users") as batch_op: batch_op.add_column(sa.Column("class_id", sa.Integer(), nullable=True)) batch_op.add_column(sa.Column("committee_role", sa.String(length=50), nullable=True)) diff --git a/backend/alembic/versions/20260427_create_class_memberships_if_missing.py b/backend/alembic/versions/20260427_create_class_memberships_if_missing.py new file mode 100644 index 0000000..3ac98a2 --- /dev/null +++ b/backend/alembic/versions/20260427_create_class_memberships_if_missing.py @@ -0,0 +1,54 @@ +"""create class memberships table if missing + +Revision ID: 20260427_create_memberships +Revises: 20260426_remove_legacy +Create Date: 2026-04-27 10:45:00 +""" + +from __future__ import annotations + +from alembic import op +import sqlalchemy as sa + + +revision = "20260427_create_memberships" +down_revision = "20260426_remove_legacy" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + bind = op.get_bind() + inspector = sa.inspect(bind) + + if "class_memberships" not in set(inspector.get_table_names()): + op.create_table( + "class_memberships", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("class_id", sa.Integer(), nullable=False), + sa.Column("membership_role", sa.String(length=20), nullable=False, server_default="student"), + sa.Column("committee_role", sa.String(length=50), nullable=True), + sa.Column("class_permissions", sa.Text(), nullable=True), + sa.Column("created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.func.now(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.ForeignKeyConstraint(["class_id"], ["classes.id"]), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("user_id", "class_id", name="uq_class_membership_user_class"), + ) + op.create_index("ix_class_memberships_user_id", "class_memberships", ["user_id"]) + op.create_index("ix_class_memberships_class_id", "class_memberships", ["class_id"]) + + +def downgrade() -> None: + bind = op.get_bind() + inspector = sa.inspect(bind) + tables = set(inspector.get_table_names()) + if "class_memberships" in tables: + indexes = {index["name"] for index in inspector.get_indexes("class_memberships")} + if "ix_class_memberships_class_id" in indexes: + op.drop_index("ix_class_memberships_class_id", table_name="class_memberships") + if "ix_class_memberships_user_id" in indexes: + op.drop_index("ix_class_memberships_user_id", table_name="class_memberships") + op.drop_table("class_memberships") diff --git a/backend/check_membership_health.py b/backend/check_membership_health.py new file mode 100644 index 0000000..d44af0e --- /dev/null +++ b/backend/check_membership_health.py @@ -0,0 +1,56 @@ +""" +Check class membership health after migration / repair. + +Usage: + python check_membership_health.py +Or inside Docker: + docker compose exec backend python check_membership_health.py +""" + +from __future__ import annotations + +import asyncio + +from sqlalchemy import func, select + +from app.db.database import async_session +from app.db.models import ClassMembership, Class_, User + + +async def main() -> None: + async with async_session() as db: + total_users = (await db.execute(select(func.count(User.id)))).scalar() or 0 + total_classes = (await db.execute(select(func.count(Class_.id)))).scalar() or 0 + total_memberships = (await db.execute(select(func.count(ClassMembership.id)))).scalar() or 0 + + users_without_membership = ( + await db.execute( + select(User.id, User.name, User.email, User.role) + .outerjoin(ClassMembership, ClassMembership.user_id == User.id) + .where(User.role != "super_admin") + .group_by(User.id) + .having(func.count(ClassMembership.id) == 0) + ) + ).all() + + classes_without_members = ( + await db.execute( + select(Class_.id, Class_.name) + .outerjoin(ClassMembership, ClassMembership.class_id == Class_.id) + .group_by(Class_.id) + .having(func.count(ClassMembership.id) == 0) + ) + ).all() + + print(f"users={total_users} classes={total_classes} memberships={total_memberships}") + print(f"users_without_membership={len(users_without_membership)}") + for row in users_without_membership[:20]: + print(" user_without_membership", row) + + print(f"classes_without_members={len(classes_without_members)}") + for row in classes_without_members[:20]: + print(" class_without_members", row) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/backend/repair_class_memberships.py b/backend/repair_class_memberships.py new file mode 100644 index 0000000..8e0322b --- /dev/null +++ b/backend/repair_class_memberships.py @@ -0,0 +1,104 @@ +""" +Repair missing class_memberships data for databases that were partially migrated. + +Usage: + python repair_class_memberships.py +Or inside Docker: + docker compose exec backend python repair_class_memberships.py +""" + +from __future__ import annotations + +import asyncio +from collections import defaultdict + +from sqlalchemy import select, text + +from app.db.database import async_session, engine +from app.db.base import Base +from app.db.models import ( + Announcement, + Assignment, + AssignmentSubmission, + ClassMembership, + FundRecord, + Resource, + Timeline, + User, + Vote, +) + + +async def ensure_class_memberships_table() -> None: + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all, tables=[ClassMembership.__table__]) + + +async def repair_memberships() -> None: + await ensure_class_memberships_table() + + async with async_session() as db: + existing_result = await db.execute( + select(ClassMembership.user_id, ClassMembership.class_id) + ) + existing = {(row[0], row[1]) for row in existing_result.all()} + + inferred: dict[tuple[int, int], str] = {} + + def remember(user_id: int | None, class_id: int | None, role: str) -> None: + if not user_id or not class_id: + return + key = (user_id, class_id) + current = inferred.get(key) + if current == "teacher": + return + inferred[key] = "teacher" if role == "teacher" else (current or "student") + + user_rows = await db.execute(select(User.id, User.role)) + user_roles = {user_id: role for user_id, role in user_rows.all()} + + statement_specs = [ + (Announcement.author_id, Announcement.class_id), + (Timeline.author_id, Timeline.class_id), + (Resource.uploader_id, Resource.class_id), + (Vote.creator_id, Vote.class_id), + (Assignment.creator_id, Assignment.class_id), + (FundRecord.recorder_id, FundRecord.class_id), + ] + + for user_col, class_col in statement_specs: + rows = await db.execute(select(user_col, class_col)) + for user_id, class_id in rows.all(): + remember(user_id, class_id, user_roles.get(user_id, "student")) + + submission_rows = await db.execute( + select(AssignmentSubmission.student_id, Assignment.class_id).join( + Assignment, Assignment.id == AssignmentSubmission.assignment_id + ) + ) + for user_id, class_id in submission_rows.all(): + remember(user_id, class_id, user_roles.get(user_id, "student")) + + grouped = defaultdict(list) + for (user_id, class_id), role in inferred.items(): + if (user_id, class_id) not in existing: + grouped[role].append((user_id, class_id)) + + created = 0 + for role, pairs in grouped.items(): + for user_id, class_id in pairs: + db.add( + ClassMembership( + user_id=user_id, + class_id=class_id, + membership_role="teacher" if role == "teacher" else "student", + ) + ) + created += 1 + + await db.commit() + print(f"repair_class_memberships: created {created} memberships") + + +if __name__ == "__main__": + asyncio.run(repair_memberships())