hku-class/backend/app/services/class_service.py
2026-04-27 09:21:20 +08:00

94 lines
2.7 KiB
Python

from sqlalchemy import select, func
from sqlalchemy.orm import selectinload
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import Class_, ClassMembership, User
from app.schemas.class_ import ClassCreate, ClassUpdate
async def create_class(db: AsyncSession, data: ClassCreate) -> Class_:
class_ = Class_(**data.model_dump())
db.add(class_)
await db.commit()
await db.refresh(class_)
return class_
async def update_class(db: AsyncSession, class_: Class_, data: ClassUpdate) -> Class_:
for field, value in data.model_dump(exclude_unset=True).items():
setattr(class_, field, value)
await db.commit()
await db.refresh(class_)
return class_
async def delete_class(db: AsyncSession, class_: Class_):
await db.delete(class_)
await db.commit()
async def get_class_by_id(db: AsyncSession, class_id: int) -> Class_ | None:
result = await db.execute(select(Class_).where(Class_.id == class_id))
return result.scalar_one_or_none()
async def list_classes(
db: AsyncSession, page: int = 1, page_size: int = 50
) -> tuple[list[Class_], int]:
total_result = await db.execute(select(func.count(Class_.id)))
total = total_result.scalar() or 0
result = await db.execute(
select(Class_)
.order_by(Class_.cohort_year.desc())
.offset((page - 1) * page_size)
.limit(page_size)
)
classes = list(result.scalars().all())
return classes, total
async def get_member_count(db: AsyncSession, class_id: int) -> int:
result = await db.execute(
select(func.count(ClassMembership.id))
.join(User, User.id == ClassMembership.user_id)
.where(
ClassMembership.class_id == class_id,
User.status == "approved",
)
)
return result.scalar() or 0
async def get_class_members(
db: AsyncSession,
class_id: int,
status: str | None = None,
page: int = 1,
page_size: int = 50,
) -> tuple[list[User], int]:
query = (
select(User)
.options(
selectinload(User.memberships),
selectinload(User.memberships).selectinload(ClassMembership.class_),
)
.join(ClassMembership)
.where(ClassMembership.class_id == class_id)
)
count_query = select(func.count(User.id)).join(ClassMembership).where(
ClassMembership.class_id == class_id
)
if status:
query = query.where(User.status == status)
count_query = count_query.where(User.status == status)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
result = await db.execute(
query.order_by(User.name).offset((page - 1) * page_size).limit(page_size)
)
return list(result.scalars().unique().all()), total