from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.db.models import Class_, User from app.schemas.class_ import ClassCreate, ClassUpdate async def create_class(db: AsyncSession, data: ClassCreate) -> Class_: class_ = Class_(**data.model_dump()) db.add(class_) await db.commit() await db.refresh(class_) return class_ async def update_class(db: AsyncSession, class_: Class_, data: ClassUpdate) -> Class_: for field, value in data.model_dump(exclude_unset=True).items(): setattr(class_, field, value) await db.commit() await db.refresh(class_) return class_ async def delete_class(db: AsyncSession, class_: Class_): await db.delete(class_) await db.commit() async def get_class_by_id(db: AsyncSession, class_id: int) -> Class_ | None: result = await db.execute(select(Class_).where(Class_.id == class_id)) return result.scalar_one_or_none() async def list_classes( db: AsyncSession, page: int = 1, page_size: int = 50 ) -> tuple[list[Class_], int]: total_result = await db.execute(select(func.count(Class_.id))) total = total_result.scalar() or 0 result = await db.execute( select(Class_) .order_by(Class_.cohort_year.desc()) .offset((page - 1) * page_size) .limit(page_size) ) classes = list(result.scalars().all()) return classes, total async def get_member_count(db: AsyncSession, class_id: int) -> int: result = await db.execute( select(func.count(User.id)).where( User.class_id == class_id, User.status == "approved" ) ) return result.scalar() or 0 async def get_class_members( db: AsyncSession, class_id: int, status: str | None = None, page: int = 1, page_size: int = 50, ) -> tuple[list[User], int]: query = select(User).where(User.class_id == class_id) count_query = select(func.count(User.id)).where(User.class_id == class_id) if status: query = query.where(User.status == status) count_query = count_query.where(User.status == status) total_result = await db.execute(count_query) total = total_result.scalar() or 0 result = await db.execute( query.order_by(User.name).offset((page - 1) * page_size).limit(page_size) ) return list(result.scalars().all()), total