from sqlalchemy import select, func, update from sqlalchemy.ext.asyncio import AsyncSession from app.db.models import Notification, User async def create_notification( db: AsyncSession, user_id: int, type: str, title: str, content: str | None = None, related_id: int | None = None, ) -> Notification: notification = Notification( user_id=user_id, type=type, title=title, content=content, related_id=related_id, ) db.add(notification) await db.commit() await db.refresh(notification) return notification async def create_notifications_for_class( db: AsyncSession, class_id: int, type: str, title: str, content: str | None = None, related_id: int | None = None, exclude_user_id: int | None = None, ): """Create notifications for all approved users in a class.""" result = await db.execute( select(User.id).where( User.class_id == class_id, User.status == "approved", ) ) user_ids = [row[0] for row in result.all()] for uid in user_ids: if exclude_user_id and uid == exclude_user_id: continue notification = Notification( user_id=uid, type=type, title=title, content=content, related_id=related_id, ) db.add(notification) await db.commit() async def list_notifications( db: AsyncSession, user_id: int, page: int = 1, page_size: int = 20 ) -> tuple[list[Notification], int]: total_result = await db.execute( select(func.count(Notification.id)).where(Notification.user_id == user_id) ) total = total_result.scalar() or 0 result = await db.execute( select(Notification) .where(Notification.user_id == user_id) .order_by(Notification.created_at.desc()) .offset((page - 1) * page_size) .limit(page_size) ) notifications = list(result.scalars().all()) return notifications, total async def get_unread_count(db: AsyncSession, user_id: int) -> int: result = await db.execute( select(func.count(Notification.id)).where( Notification.user_id == user_id, Notification.is_read == False, ) ) return result.scalar() or 0 async def mark_as_read(db: AsyncSession, notification_id: int, user_id: int) -> bool: result = await db.execute( update(Notification) .where(Notification.id == notification_id, Notification.user_id == user_id) .values(is_read=True) ) await db.commit() return result.rowcount > 0 async def mark_all_as_read(db: AsyncSession, user_id: int) -> int: result = await db.execute( update(Notification) .where(Notification.user_id == user_id, Notification.is_read == False) .values(is_read=True) ) await db.commit() return result.rowcount