hku-class-hub/backend/app/services/notification_service.py
2026-04-11 17:08:59 +08:00

108 lines
2.9 KiB
Python

from sqlalchemy import select, func, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import Notification, User
async def create_notification(
db: AsyncSession,
user_id: int,
type: str,
title: str,
content: str | None = None,
related_id: int | None = None,
) -> Notification:
notification = Notification(
user_id=user_id,
type=type,
title=title,
content=content,
related_id=related_id,
)
db.add(notification)
await db.commit()
await db.refresh(notification)
return notification
async def create_notifications_for_class(
db: AsyncSession,
class_id: int,
type: str,
title: str,
content: str | None = None,
related_id: int | None = None,
exclude_user_id: int | None = None,
):
"""Create notifications for all approved users in a class."""
result = await db.execute(
select(User.id).where(
User.class_id == class_id,
User.status == "approved",
)
)
user_ids = [row[0] for row in result.all()]
for uid in user_ids:
if exclude_user_id and uid == exclude_user_id:
continue
notification = Notification(
user_id=uid,
type=type,
title=title,
content=content,
related_id=related_id,
)
db.add(notification)
await db.commit()
async def list_notifications(
db: AsyncSession, user_id: int, page: int = 1, page_size: int = 20
) -> tuple[list[Notification], int]:
total_result = await db.execute(
select(func.count(Notification.id)).where(Notification.user_id == user_id)
)
total = total_result.scalar() or 0
result = await db.execute(
select(Notification)
.where(Notification.user_id == user_id)
.order_by(Notification.created_at.desc())
.offset((page - 1) * page_size)
.limit(page_size)
)
notifications = list(result.scalars().all())
return notifications, total
async def get_unread_count(db: AsyncSession, user_id: int) -> int:
result = await db.execute(
select(func.count(Notification.id)).where(
Notification.user_id == user_id,
Notification.is_read == False,
)
)
return result.scalar() or 0
async def mark_as_read(db: AsyncSession, notification_id: int, user_id: int) -> bool:
result = await db.execute(
update(Notification)
.where(Notification.id == notification_id, Notification.user_id == user_id)
.values(is_read=True)
)
await db.commit()
return result.rowcount > 0
async def mark_all_as_read(db: AsyncSession, user_id: int) -> int:
result = await db.execute(
update(Notification)
.where(Notification.user_id == user_id, Notification.is_read == False)
.values(is_read=True)
)
await db.commit()
return result.rowcount