142 lines
4.3 KiB
Python
142 lines
4.3 KiB
Python
from datetime import datetime
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from app.db.models import Vote, VoteOption, VoteResponse, User
|
|
from app.schemas.vote import VoteCreate
|
|
from app.services.notification_service import create_notifications_for_class
|
|
|
|
|
|
async def create_vote(
|
|
db: AsyncSession, class_id: int, creator_id: int, data: VoteCreate
|
|
) -> Vote:
|
|
vote = Vote(
|
|
class_id=class_id,
|
|
creator_id=creator_id,
|
|
title=data.title,
|
|
description=data.description,
|
|
vote_type=data.vote_type,
|
|
is_anonymous=data.is_anonymous,
|
|
max_choices=data.max_choices,
|
|
deadline=data.deadline,
|
|
)
|
|
db.add(vote)
|
|
await db.flush()
|
|
|
|
for i, opt_text in enumerate(data.options):
|
|
option = VoteOption(vote_id=vote.id, content=opt_text, sort_order=i)
|
|
db.add(option)
|
|
|
|
await db.commit()
|
|
await db.refresh(vote)
|
|
|
|
# Load options for the returned object
|
|
result = await db.execute(
|
|
select(Vote)
|
|
.options(selectinload(Vote.options))
|
|
.where(Vote.id == vote.id)
|
|
)
|
|
vote = result.scalar_one()
|
|
|
|
await create_notifications_for_class(
|
|
db, class_id, "vote", f"新投票: {data.title}",
|
|
content=data.description,
|
|
related_id=vote.id,
|
|
email_subject=f"HKU ICB - 新投票: {data.title}",
|
|
email_body=f"<p>{data.description or data.title}</p>",
|
|
email_action_path="/votes",
|
|
)
|
|
|
|
return vote
|
|
|
|
|
|
async def get_vote_by_id(db: AsyncSession, vote_id: int) -> Vote | None:
|
|
result = await db.execute(
|
|
select(Vote)
|
|
.options(
|
|
selectinload(Vote.creator),
|
|
selectinload(Vote.options).selectinload(VoteOption.responses).selectinload(VoteResponse.voter),
|
|
)
|
|
.where(Vote.id == vote_id)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def list_votes(
|
|
db: AsyncSession, class_id: int, page: int = 1, page_size: int = 20
|
|
) -> tuple[list[Vote], int]:
|
|
total_result = await db.execute(
|
|
select(func.count(Vote.id)).where(Vote.class_id == class_id)
|
|
)
|
|
total = total_result.scalar() or 0
|
|
|
|
result = await db.execute(
|
|
select(Vote)
|
|
.options(
|
|
selectinload(Vote.creator),
|
|
selectinload(Vote.options).selectinload(VoteOption.responses),
|
|
)
|
|
.where(Vote.class_id == class_id)
|
|
.order_by(Vote.created_at.desc())
|
|
.offset((page - 1) * page_size)
|
|
.limit(page_size)
|
|
)
|
|
return list(result.scalars().all()), total
|
|
|
|
|
|
async def submit_vote(
|
|
db: AsyncSession, vote_id: int, voter_id: int, option_ids: list[int]
|
|
) -> None:
|
|
vote_result = await db.execute(select(Vote).where(Vote.id == vote_id))
|
|
vote = vote_result.scalar_one_or_none()
|
|
if vote is None:
|
|
raise ValueError("投票不存在")
|
|
if vote.status != "open":
|
|
raise ValueError("投票已关闭")
|
|
if vote.deadline and datetime.now() > vote.deadline:
|
|
raise ValueError("投票已过截止日期")
|
|
|
|
# Check if already voted
|
|
existing = await db.execute(
|
|
select(VoteResponse).where(
|
|
VoteResponse.vote_id == vote_id,
|
|
VoteResponse.voter_id == voter_id,
|
|
)
|
|
)
|
|
if existing.scalar_one_or_none():
|
|
raise ValueError("你已经投过票了")
|
|
|
|
# Validate option count
|
|
if vote.vote_type == "single" and len(option_ids) != 1:
|
|
raise ValueError("单选投票只能选择一个选项")
|
|
if vote.vote_type == "multiple" and len(option_ids) > vote.max_choices:
|
|
raise ValueError(f"最多选择 {vote.max_choices} 个选项")
|
|
|
|
# Validate option_ids belong to this vote
|
|
for oid in option_ids:
|
|
opt_result = await db.execute(
|
|
select(VoteOption).where(VoteOption.id == oid, VoteOption.vote_id == vote_id)
|
|
)
|
|
if opt_result.scalar_one_or_none() is None:
|
|
raise ValueError(f"选项 {oid} 不属于此投票")
|
|
|
|
for oid in option_ids:
|
|
response = VoteResponse(vote_id=vote_id, option_id=oid, voter_id=voter_id)
|
|
db.add(response)
|
|
|
|
await db.commit()
|
|
|
|
|
|
async def close_vote(db: AsyncSession, vote: Vote) -> Vote:
|
|
vote.status = "closed"
|
|
await db.commit()
|
|
await db.refresh(vote)
|
|
return vote
|
|
|
|
|
|
async def delete_vote(db: AsyncSession, vote: Vote):
|
|
await db.delete(vote)
|
|
await db.commit()
|