80 lines
2.2 KiB
Python
80 lines
2.2 KiB
Python
from datetime import datetime, timezone
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.models import Schedule
|
|
from app.schemas.schedule import ScheduleCreate, ScheduleUpdate
|
|
|
|
|
|
async def create_schedule(
|
|
db: AsyncSession, class_id: int, data: ScheduleCreate
|
|
) -> Schedule:
|
|
item = Schedule(
|
|
class_id=class_id,
|
|
**data.model_dump(),
|
|
)
|
|
db.add(item)
|
|
await db.commit()
|
|
await db.refresh(item)
|
|
return item
|
|
|
|
|
|
async def update_schedule(
|
|
db: AsyncSession, item: Schedule, data: ScheduleUpdate
|
|
) -> Schedule:
|
|
for field, value in data.model_dump(exclude_unset=True).items():
|
|
setattr(item, field, value)
|
|
await db.commit()
|
|
await db.refresh(item)
|
|
return item
|
|
|
|
|
|
async def delete_schedule(db: AsyncSession, item: Schedule):
|
|
await db.delete(item)
|
|
await db.commit()
|
|
|
|
|
|
async def get_schedule_by_id(db: AsyncSession, schedule_id: int) -> Schedule | None:
|
|
result = await db.execute(select(Schedule).where(Schedule.id == schedule_id))
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def list_schedules(
|
|
db: AsyncSession,
|
|
class_id: int,
|
|
schedule_type: str | None = None,
|
|
page: int = 1,
|
|
page_size: int = 50,
|
|
) -> tuple[list[Schedule], int]:
|
|
query = select(Schedule).where(Schedule.class_id == class_id)
|
|
count_query = select(func.count(Schedule.id)).where(Schedule.class_id == class_id)
|
|
|
|
if schedule_type:
|
|
query = query.where(Schedule.type == schedule_type)
|
|
count_query = count_query.where(Schedule.type == schedule_type)
|
|
|
|
total_result = await db.execute(count_query)
|
|
total = total_result.scalar() or 0
|
|
|
|
result = await db.execute(
|
|
query.order_by(Schedule.start_time.desc())
|
|
.offset((page - 1) * page_size)
|
|
.limit(page_size)
|
|
)
|
|
items = list(result.scalars().all())
|
|
return items, total
|
|
|
|
|
|
async def get_upcoming_schedules(
|
|
db: AsyncSession, class_id: int, limit: int = 10
|
|
) -> list[Schedule]:
|
|
now = datetime.now(timezone.utc)
|
|
result = await db.execute(
|
|
select(Schedule)
|
|
.where(Schedule.class_id == class_id, Schedule.start_time >= now)
|
|
.order_by(Schedule.start_time.asc())
|
|
.limit(limit)
|
|
)
|
|
return list(result.scalars().all())
|