from datetime import datetime from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.db.models import Vote, VoteOption, VoteResponse, User from app.schemas.vote import VoteCreate from app.services.notification_service import create_notifications_for_class async def create_vote( db: AsyncSession, class_id: int, creator_id: int, data: VoteCreate ) -> Vote: vote = Vote( class_id=class_id, creator_id=creator_id, title=data.title, description=data.description, vote_type=data.vote_type, is_anonymous=data.is_anonymous, max_choices=data.max_choices, deadline=data.deadline, ) db.add(vote) await db.flush() for i, opt_text in enumerate(data.options): option = VoteOption(vote_id=vote.id, content=opt_text, sort_order=i) db.add(option) await db.commit() await db.refresh(vote) # Load options for the returned object result = await db.execute( select(Vote) .options(selectinload(Vote.options)) .where(Vote.id == vote.id) ) vote = result.scalar_one() await create_notifications_for_class( db, class_id, "vote", f"新投票: {data.title}", content=data.description, related_id=vote.id, email_subject=f"HKU ICB - 新投票: {data.title}", email_body=f"
{data.description or data.title}
", email_action_path="/votes", ) return vote async def get_vote_by_id(db: AsyncSession, vote_id: int) -> Vote | None: result = await db.execute( select(Vote) .options( selectinload(Vote.creator), selectinload(Vote.options).selectinload(VoteOption.responses).selectinload(VoteResponse.voter), ) .where(Vote.id == vote_id) ) return result.scalar_one_or_none() async def list_votes( db: AsyncSession, class_id: int, page: int = 1, page_size: int = 20 ) -> tuple[list[Vote], int]: total_result = await db.execute( select(func.count(Vote.id)).where(Vote.class_id == class_id) ) total = total_result.scalar() or 0 result = await db.execute( select(Vote) .options( selectinload(Vote.creator), selectinload(Vote.options).selectinload(VoteOption.responses).selectinload(VoteResponse.voter), ) .where(Vote.class_id == class_id) .order_by(Vote.created_at.desc()) .offset((page - 1) * page_size) .limit(page_size) ) return list(result.scalars().all()), total async def submit_vote( db: AsyncSession, vote_id: int, voter_id: int, option_ids: list[int] ) -> None: vote_result = await db.execute(select(Vote).where(Vote.id == vote_id)) vote = vote_result.scalar_one_or_none() if vote is None: raise ValueError("投票不存在") if vote.status != "open": raise ValueError("投票已关闭") if vote.deadline and datetime.now() > vote.deadline: raise ValueError("投票已过截止日期") # Check if already voted existing = await db.execute( select(VoteResponse).where( VoteResponse.vote_id == vote_id, VoteResponse.voter_id == voter_id, ) ) if existing.scalar_one_or_none(): raise ValueError("你已经投过票了") # Validate option count if vote.vote_type == "single" and len(option_ids) != 1: raise ValueError("单选投票只能选择一个选项") if vote.vote_type == "multiple" and len(option_ids) > vote.max_choices: raise ValueError(f"最多选择 {vote.max_choices} 个选项") # Validate option_ids belong to this vote for oid in option_ids: opt_result = await db.execute( select(VoteOption).where(VoteOption.id == oid, VoteOption.vote_id == vote_id) ) if opt_result.scalar_one_or_none() is None: raise ValueError(f"选项 {oid} 不属于此投票") for oid in option_ids: response = VoteResponse(vote_id=vote_id, option_id=oid, voter_id=voter_id) db.add(response) await db.commit() async def close_vote(db: AsyncSession, vote: Vote) -> Vote: vote.status = "closed" await db.commit() await db.refresh(vote) return vote async def delete_vote(db: AsyncSession, vote: Vote): await db.delete(vote) await db.commit()