80 lines
2.3 KiB
Python
80 lines
2.3 KiB
Python
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.models import Class_, User
|
|
from app.schemas.class_ import ClassCreate, ClassUpdate
|
|
|
|
|
|
async def create_class(db: AsyncSession, data: ClassCreate) -> Class_:
|
|
class_ = Class_(**data.model_dump())
|
|
db.add(class_)
|
|
await db.commit()
|
|
await db.refresh(class_)
|
|
return class_
|
|
|
|
|
|
async def update_class(db: AsyncSession, class_: Class_, data: ClassUpdate) -> Class_:
|
|
for field, value in data.model_dump(exclude_unset=True).items():
|
|
setattr(class_, field, value)
|
|
await db.commit()
|
|
await db.refresh(class_)
|
|
return class_
|
|
|
|
|
|
async def delete_class(db: AsyncSession, class_: Class_):
|
|
await db.delete(class_)
|
|
await db.commit()
|
|
|
|
|
|
async def get_class_by_id(db: AsyncSession, class_id: int) -> Class_ | None:
|
|
result = await db.execute(select(Class_).where(Class_.id == class_id))
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def list_classes(
|
|
db: AsyncSession, page: int = 1, page_size: int = 50
|
|
) -> tuple[list[Class_], int]:
|
|
total_result = await db.execute(select(func.count(Class_.id)))
|
|
total = total_result.scalar() or 0
|
|
|
|
result = await db.execute(
|
|
select(Class_)
|
|
.order_by(Class_.cohort_year.desc())
|
|
.offset((page - 1) * page_size)
|
|
.limit(page_size)
|
|
)
|
|
classes = list(result.scalars().all())
|
|
return classes, total
|
|
|
|
|
|
async def get_member_count(db: AsyncSession, class_id: int) -> int:
|
|
result = await db.execute(
|
|
select(func.count(User.id)).where(
|
|
User.class_id == class_id, User.status == "approved"
|
|
)
|
|
)
|
|
return result.scalar() or 0
|
|
|
|
|
|
async def get_class_members(
|
|
db: AsyncSession,
|
|
class_id: int,
|
|
status: str | None = None,
|
|
page: int = 1,
|
|
page_size: int = 50,
|
|
) -> tuple[list[User], int]:
|
|
query = select(User).where(User.class_id == class_id)
|
|
count_query = select(func.count(User.id)).where(User.class_id == class_id)
|
|
|
|
if status:
|
|
query = query.where(User.status == status)
|
|
count_query = count_query.where(User.status == status)
|
|
|
|
total_result = await db.execute(count_query)
|
|
total = total_result.scalar() or 0
|
|
|
|
result = await db.execute(
|
|
query.order_by(User.name).offset((page - 1) * page_size).limit(page_size)
|
|
)
|
|
return list(result.scalars().all()), total
|