This commit is contained in:
aaron 2026-04-27 10:51:58 +08:00
parent 8c345805dd
commit 541a1c5311
5 changed files with 351 additions and 2 deletions

View File

@ -15,4 +15,4 @@ RUN mkdir -p /app/data
EXPOSE 8000
CMD ["sh", "-c", "alembic upgrade head && uvicorn app.main:app --host 0.0.0.0 --port 8000"]
CMD ["sh", "-c", "alembic upgrade head && python repair_class_memberships.py && uvicorn app.main:app --host 0.0.0.0 --port 8000"]

View File

@ -9,6 +9,8 @@ from __future__ import annotations
from alembic import op
import sqlalchemy as sa
from app.db.base import Base
from app.db import models as _models # noqa: F401 ensure metadata registered
revision = "20260426_remove_legacy"
@ -17,15 +19,137 @@ branch_labels = None
depends_on = None
def _ensure_class_memberships_table(inspector: sa.Inspector) -> None:
tables = set(inspector.get_table_names())
if "class_memberships" in tables:
return
op.create_table(
"class_memberships",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("class_id", sa.Integer(), nullable=False),
sa.Column("membership_role", sa.String(length=20), nullable=False, server_default="student"),
sa.Column("committee_role", sa.String(length=50), nullable=True),
sa.Column("class_permissions", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.now(), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["users.id"]),
sa.ForeignKeyConstraint(["class_id"], ["classes.id"]),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_id", "class_id", name="uq_class_membership_user_class"),
)
op.create_index("ix_class_memberships_user_id", "class_memberships", ["user_id"])
op.create_index("ix_class_memberships_class_id", "class_memberships", ["class_id"])
def _existing_memberships(bind) -> set[tuple[int, int]]:
rows = bind.execute(
sa.text("SELECT user_id, class_id FROM class_memberships")
).fetchall()
return {(int(user_id), int(class_id)) for user_id, class_id in rows}
def _insert_membership(
bind,
existing: set[tuple[int, int]],
*,
user_id: int,
class_id: int,
membership_role: str,
committee_role: str | None,
class_permissions: str | None,
created_at,
updated_at,
) -> None:
key = (user_id, class_id)
if key in existing:
return
bind.execute(
sa.text(
"""
INSERT INTO class_memberships (
user_id, class_id, membership_role, committee_role, class_permissions, created_at, updated_at
) VALUES (
:user_id, :class_id, :membership_role, :committee_role, :class_permissions, :created_at, :updated_at
)
"""
),
{
"user_id": user_id,
"class_id": class_id,
"membership_role": membership_role,
"committee_role": committee_role,
"class_permissions": class_permissions,
"created_at": created_at,
"updated_at": updated_at,
},
)
existing.add(key)
def upgrade() -> None:
bind = op.get_bind()
Base.metadata.create_all(bind=bind, checkfirst=True)
inspector = sa.inspect(bind)
_ensure_class_memberships_table(inspector)
tables = set(inspector.get_table_names())
existing = _existing_memberships(bind)
user_columns = {column["name"] for column in inspector.get_columns("users")}
if "class_id" in user_columns:
legacy_rows = bind.execute(
sa.text(
"""
SELECT id, class_id, role, committee_role, class_permissions, created_at, updated_at
FROM users
WHERE class_id IS NOT NULL
"""
)
).mappings()
for row in legacy_rows:
if row["role"] == "super_admin":
continue
_insert_membership(
bind,
existing,
user_id=int(row["id"]),
class_id=int(row["class_id"]),
membership_role="teacher" if row["role"] == "teacher" else "student",
committee_role=row["committee_role"],
class_permissions=row["class_permissions"],
created_at=row["created_at"],
updated_at=row["updated_at"],
)
if "student_rosters" in tables:
roster_rows = bind.execute(
sa.text(
"""
SELECT user_id, class_id, created_at
FROM student_rosters
WHERE user_id IS NOT NULL
"""
)
).mappings()
for row in roster_rows:
_insert_membership(
bind,
existing,
user_id=int(row["user_id"]),
class_id=int(row["class_id"]),
membership_role="student",
committee_role=None,
class_permissions=None,
created_at=row["created_at"],
updated_at=row["created_at"],
)
if "student_rosters" in tables:
op.drop_table("student_rosters")
user_columns = {column["name"] for column in inspector.get_columns("users")}
legacy_columns = {"class_id", "committee_role", "class_permissions"}
if legacy_columns.intersection(user_columns):
with op.batch_alter_table("users") as batch_op:
@ -38,6 +162,17 @@ def upgrade() -> None:
def downgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
tables = set(inspector.get_table_names())
if "class_memberships" in tables:
indexes = {index["name"] for index in inspector.get_indexes("class_memberships")}
if "ix_class_memberships_class_id" in indexes:
op.drop_index("ix_class_memberships_class_id", table_name="class_memberships")
if "ix_class_memberships_user_id" in indexes:
op.drop_index("ix_class_memberships_user_id", table_name="class_memberships")
op.drop_table("class_memberships")
with op.batch_alter_table("users") as batch_op:
batch_op.add_column(sa.Column("class_id", sa.Integer(), nullable=True))
batch_op.add_column(sa.Column("committee_role", sa.String(length=50), nullable=True))

View File

@ -0,0 +1,54 @@
"""create class memberships table if missing
Revision ID: 20260427_create_memberships
Revises: 20260426_remove_legacy
Create Date: 2026-04-27 10:45:00
"""
from __future__ import annotations
from alembic import op
import sqlalchemy as sa
revision = "20260427_create_memberships"
down_revision = "20260426_remove_legacy"
branch_labels = None
depends_on = None
def upgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
if "class_memberships" not in set(inspector.get_table_names()):
op.create_table(
"class_memberships",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("class_id", sa.Integer(), nullable=False),
sa.Column("membership_role", sa.String(length=20), nullable=False, server_default="student"),
sa.Column("committee_role", sa.String(length=50), nullable=True),
sa.Column("class_permissions", sa.Text(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.now(), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["users.id"]),
sa.ForeignKeyConstraint(["class_id"], ["classes.id"]),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("user_id", "class_id", name="uq_class_membership_user_class"),
)
op.create_index("ix_class_memberships_user_id", "class_memberships", ["user_id"])
op.create_index("ix_class_memberships_class_id", "class_memberships", ["class_id"])
def downgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
tables = set(inspector.get_table_names())
if "class_memberships" in tables:
indexes = {index["name"] for index in inspector.get_indexes("class_memberships")}
if "ix_class_memberships_class_id" in indexes:
op.drop_index("ix_class_memberships_class_id", table_name="class_memberships")
if "ix_class_memberships_user_id" in indexes:
op.drop_index("ix_class_memberships_user_id", table_name="class_memberships")
op.drop_table("class_memberships")

View File

@ -0,0 +1,56 @@
"""
Check class membership health after migration / repair.
Usage:
python check_membership_health.py
Or inside Docker:
docker compose exec backend python check_membership_health.py
"""
from __future__ import annotations
import asyncio
from sqlalchemy import func, select
from app.db.database import async_session
from app.db.models import ClassMembership, Class_, User
async def main() -> None:
async with async_session() as db:
total_users = (await db.execute(select(func.count(User.id)))).scalar() or 0
total_classes = (await db.execute(select(func.count(Class_.id)))).scalar() or 0
total_memberships = (await db.execute(select(func.count(ClassMembership.id)))).scalar() or 0
users_without_membership = (
await db.execute(
select(User.id, User.name, User.email, User.role)
.outerjoin(ClassMembership, ClassMembership.user_id == User.id)
.where(User.role != "super_admin")
.group_by(User.id)
.having(func.count(ClassMembership.id) == 0)
)
).all()
classes_without_members = (
await db.execute(
select(Class_.id, Class_.name)
.outerjoin(ClassMembership, ClassMembership.class_id == Class_.id)
.group_by(Class_.id)
.having(func.count(ClassMembership.id) == 0)
)
).all()
print(f"users={total_users} classes={total_classes} memberships={total_memberships}")
print(f"users_without_membership={len(users_without_membership)}")
for row in users_without_membership[:20]:
print(" user_without_membership", row)
print(f"classes_without_members={len(classes_without_members)}")
for row in classes_without_members[:20]:
print(" class_without_members", row)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,104 @@
"""
Repair missing class_memberships data for databases that were partially migrated.
Usage:
python repair_class_memberships.py
Or inside Docker:
docker compose exec backend python repair_class_memberships.py
"""
from __future__ import annotations
import asyncio
from collections import defaultdict
from sqlalchemy import select, text
from app.db.database import async_session, engine
from app.db.base import Base
from app.db.models import (
Announcement,
Assignment,
AssignmentSubmission,
ClassMembership,
FundRecord,
Resource,
Timeline,
User,
Vote,
)
async def ensure_class_memberships_table() -> None:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all, tables=[ClassMembership.__table__])
async def repair_memberships() -> None:
await ensure_class_memberships_table()
async with async_session() as db:
existing_result = await db.execute(
select(ClassMembership.user_id, ClassMembership.class_id)
)
existing = {(row[0], row[1]) for row in existing_result.all()}
inferred: dict[tuple[int, int], str] = {}
def remember(user_id: int | None, class_id: int | None, role: str) -> None:
if not user_id or not class_id:
return
key = (user_id, class_id)
current = inferred.get(key)
if current == "teacher":
return
inferred[key] = "teacher" if role == "teacher" else (current or "student")
user_rows = await db.execute(select(User.id, User.role))
user_roles = {user_id: role for user_id, role in user_rows.all()}
statement_specs = [
(Announcement.author_id, Announcement.class_id),
(Timeline.author_id, Timeline.class_id),
(Resource.uploader_id, Resource.class_id),
(Vote.creator_id, Vote.class_id),
(Assignment.creator_id, Assignment.class_id),
(FundRecord.recorder_id, FundRecord.class_id),
]
for user_col, class_col in statement_specs:
rows = await db.execute(select(user_col, class_col))
for user_id, class_id in rows.all():
remember(user_id, class_id, user_roles.get(user_id, "student"))
submission_rows = await db.execute(
select(AssignmentSubmission.student_id, Assignment.class_id).join(
Assignment, Assignment.id == AssignmentSubmission.assignment_id
)
)
for user_id, class_id in submission_rows.all():
remember(user_id, class_id, user_roles.get(user_id, "student"))
grouped = defaultdict(list)
for (user_id, class_id), role in inferred.items():
if (user_id, class_id) not in existing:
grouped[role].append((user_id, class_id))
created = 0
for role, pairs in grouped.items():
for user_id, class_id in pairs:
db.add(
ClassMembership(
user_id=user_id,
class_id=class_id,
membership_role="teacher" if role == "teacher" else "student",
)
)
created += 1
await db.commit()
print(f"repair_class_memberships: created {created} memberships")
if __name__ == "__main__":
asyncio.run(repair_memberships())