hku-class/backend/app/services/notification_service.py
2026-04-27 22:36:48 +08:00

217 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import logging
from sqlalchemy import select, func, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import ClassMembership, Notification, User
from app.services.email_service import send_class_notification_email
logger = logging.getLogger(__name__)
EMAIL_VISUALS: dict[str, dict[str, str | None]] = {
"announcement": {
"eyebrow": "Announcement",
"action_label": "查看公告",
"summary_prefix": "公告摘要",
},
"timeline": {
"eyebrow": "Class Feed",
"action_label": "查看动态",
"summary_prefix": "动态摘要",
},
"vote": {
"eyebrow": "Voting Open",
"action_label": "立即投票",
"summary_prefix": "投票说明",
},
"schedule": {
"eyebrow": "Schedule Update",
"action_label": "查看排期",
"summary_prefix": "时间信息",
},
"assignment": {
"eyebrow": "Assignment Posted",
"action_label": "查看作业",
"summary_prefix": "作业说明",
},
"resource": {
"eyebrow": "Resource Added",
"action_label": "查看资源",
"summary_prefix": "资源说明",
},
"fund": {
"eyebrow": "Fund Update",
"action_label": "查看班费",
"summary_prefix": "账务说明",
},
}
def _build_email_meta(type_: str, email_body: str | None) -> tuple[str, str | None, str | None]:
visual = EMAIL_VISUALS.get(type_, {})
eyebrow = str(visual.get("eyebrow") or "Class Update")
action_label = visual.get("action_label")
summary_prefix = visual.get("summary_prefix")
if not email_body:
return eyebrow, action_label, None
summary_html = email_body
if summary_prefix:
summary_html = f"<strong>{summary_prefix}</strong>{email_body}"
return eyebrow, action_label, summary_html
async def create_notification(
db: AsyncSession,
user_id: int,
type: str,
title: str,
content: str | None = None,
related_id: int | None = None,
) -> Notification:
notification = Notification(
user_id=user_id,
type=type,
title=title,
content=content,
related_id=related_id,
)
db.add(notification)
await db.commit()
await db.refresh(notification)
return notification
async def create_notifications_for_class(
db: AsyncSession,
class_id: int,
type: str,
title: str,
content: str | None = None,
related_id: int | None = None,
exclude_user_id: int | None = None,
email_subject: str | None = None,
email_body: str | None = None,
email_action_path: str | None = None,
):
"""Create in-app notifications + send email for all active users in a class."""
result = await db.execute(
select(User.id, User.email)
.join(ClassMembership, ClassMembership.user_id == User.id)
.where(
ClassMembership.class_id == class_id,
User.status == "approved",
)
)
rows = result.all()
emails: list[str] = []
for uid, email in rows:
notification = Notification(
user_id=uid,
type=type,
title=title,
content=content,
related_id=related_id,
)
db.add(notification)
emails.append(email)
await db.commit()
# Send email notification in background (fire-and-forget)
if email_subject and emails:
from app.config import settings
action_url = f"{settings.frontend_url}{email_action_path}" if email_action_path else None
eyebrow, action_label, summary_html = _build_email_meta(type, email_body or content or "")
asyncio.create_task(
_safe_send_emails(
emails,
email_subject,
title,
email_body or content or "",
action_url,
eyebrow,
action_label,
summary_html,
)
)
async def _safe_send_emails(
emails: list[str],
subject: str,
title: str,
body: str,
action_url: str | None,
eyebrow: str,
action_label: str | None,
summary_html: str | None,
):
"""Fire-and-forget email sending with error logging."""
try:
await send_class_notification_email(
emails,
subject,
title,
body,
action_url,
eyebrow=eyebrow,
action_label=action_label,
summary_html=summary_html,
)
except Exception as e:
logger.error(f"Failed to send class notification emails: {e}")
async def list_notifications(
db: AsyncSession, user_id: int, page: int = 1, page_size: int = 20
) -> tuple[list[Notification], int]:
total_result = await db.execute(
select(func.count(Notification.id)).where(Notification.user_id == user_id)
)
total = total_result.scalar() or 0
result = await db.execute(
select(Notification)
.where(Notification.user_id == user_id)
.order_by(Notification.created_at.desc())
.offset((page - 1) * page_size)
.limit(page_size)
)
notifications = list(result.scalars().all())
return notifications, total
async def get_unread_count(db: AsyncSession, user_id: int) -> int:
result = await db.execute(
select(func.count(Notification.id)).where(
Notification.user_id == user_id,
Notification.is_read == False,
)
)
return result.scalar() or 0
async def mark_as_read(db: AsyncSession, notification_id: int, user_id: int) -> bool:
result = await db.execute(
update(Notification)
.where(Notification.id == notification_id, Notification.user_id == user_id)
.values(is_read=True)
)
await db.commit()
return result.rowcount > 0
async def mark_all_as_read(db: AsyncSession, user_id: int) -> int:
result = await db.execute(
update(Notification)
.where(Notification.user_id == user_id, Notification.is_read == False)
.values(is_read=True)
)
await db.commit()
return result.rowcount