108 lines
2.9 KiB
Python
108 lines
2.9 KiB
Python
from sqlalchemy import select, func, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.models import Notification, User
|
|
|
|
|
|
async def create_notification(
|
|
db: AsyncSession,
|
|
user_id: int,
|
|
type: str,
|
|
title: str,
|
|
content: str | None = None,
|
|
related_id: int | None = None,
|
|
) -> Notification:
|
|
notification = Notification(
|
|
user_id=user_id,
|
|
type=type,
|
|
title=title,
|
|
content=content,
|
|
related_id=related_id,
|
|
)
|
|
db.add(notification)
|
|
await db.commit()
|
|
await db.refresh(notification)
|
|
return notification
|
|
|
|
|
|
async def create_notifications_for_class(
|
|
db: AsyncSession,
|
|
class_id: int,
|
|
type: str,
|
|
title: str,
|
|
content: str | None = None,
|
|
related_id: int | None = None,
|
|
exclude_user_id: int | None = None,
|
|
):
|
|
"""Create notifications for all approved users in a class."""
|
|
result = await db.execute(
|
|
select(User.id).where(
|
|
User.class_id == class_id,
|
|
User.status == "approved",
|
|
)
|
|
)
|
|
user_ids = [row[0] for row in result.all()]
|
|
|
|
for uid in user_ids:
|
|
if exclude_user_id and uid == exclude_user_id:
|
|
continue
|
|
notification = Notification(
|
|
user_id=uid,
|
|
type=type,
|
|
title=title,
|
|
content=content,
|
|
related_id=related_id,
|
|
)
|
|
db.add(notification)
|
|
|
|
await db.commit()
|
|
|
|
|
|
async def list_notifications(
|
|
db: AsyncSession, user_id: int, page: int = 1, page_size: int = 20
|
|
) -> tuple[list[Notification], int]:
|
|
total_result = await db.execute(
|
|
select(func.count(Notification.id)).where(Notification.user_id == user_id)
|
|
)
|
|
total = total_result.scalar() or 0
|
|
|
|
result = await db.execute(
|
|
select(Notification)
|
|
.where(Notification.user_id == user_id)
|
|
.order_by(Notification.created_at.desc())
|
|
.offset((page - 1) * page_size)
|
|
.limit(page_size)
|
|
)
|
|
notifications = list(result.scalars().all())
|
|
return notifications, total
|
|
|
|
|
|
async def get_unread_count(db: AsyncSession, user_id: int) -> int:
|
|
result = await db.execute(
|
|
select(func.count(Notification.id)).where(
|
|
Notification.user_id == user_id,
|
|
Notification.is_read == False,
|
|
)
|
|
)
|
|
return result.scalar() or 0
|
|
|
|
|
|
async def mark_as_read(db: AsyncSession, notification_id: int, user_id: int) -> bool:
|
|
result = await db.execute(
|
|
update(Notification)
|
|
.where(Notification.id == notification_id, Notification.user_id == user_id)
|
|
.values(is_read=True)
|
|
)
|
|
await db.commit()
|
|
return result.rowcount > 0
|
|
|
|
|
|
async def mark_all_as_read(db: AsyncSession, user_id: int) -> int:
|
|
result = await db.execute(
|
|
update(Notification)
|
|
.where(Notification.user_id == user_id, Notification.is_read == False)
|
|
.values(is_read=True)
|
|
)
|
|
await db.commit()
|
|
return result.rowcount
|