import csv import io from fastapi import APIRouter, Depends, HTTPException, UploadFile, File from sqlalchemy.ext.asyncio import AsyncSession from app.core.deps import require_role from app.db.database import get_db from app.db.models import User from app.schemas.class_ import ClassCreate, ClassUpdate, ClassOut from app.schemas.user import UserListItem from app.schemas.roster import RosterOut, RosterImportRequest from app.schemas.common import PageResponse from app.services.class_service import ( create_class, update_class, delete_class, get_class_by_id, list_classes, get_member_count, get_class_members, ) from app.services.roster_service import ( ensure_invite_code, regenerate_invite_code, import_roster, get_roster, delete_roster_entry, clear_unregistered_roster, ) router = APIRouter(prefix="/api/classes", tags=["classes"]) @router.get("/", response_model=PageResponse[ClassOut]) async def get_classes( page: int = 1, page_size: int = 50, db: AsyncSession = Depends(get_db), ): classes, total = await list_classes(db, page, page_size) total_pages = (total + page_size - 1) // page_size result = [] for c in classes: count = await get_member_count(db, c.id) out = ClassOut.model_validate(c) out.member_count = count result.append(out) return PageResponse( items=result, total=total, page=page, page_size=page_size, total_pages=total_pages ) @router.post("/", response_model=ClassOut) async def create_new_class( data: ClassCreate, admin: User = Depends(require_role("super_admin")), db: AsyncSession = Depends(get_db), ): class_ = await create_class(db, data) out = ClassOut.model_validate(class_) out.member_count = 0 return out @router.put("/{class_id}", response_model=ClassOut) async def update_existing_class( class_id: int, data: ClassUpdate, admin: User = Depends(require_role("super_admin")), db: AsyncSession = Depends(get_db), ): class_ = await get_class_by_id(db, class_id) if class_ is None: raise HTTPException(status_code=404, detail="Class not found") updated = await update_class(db, class_, data) out = ClassOut.model_validate(updated) out.member_count = await get_member_count(db, class_id) return out @router.delete("/{class_id}") async def delete_existing_class( class_id: int, admin: User = Depends(require_role("super_admin")), db: AsyncSession = Depends(get_db), ): class_ = await get_class_by_id(db, class_id) if class_ is None: raise HTTPException(status_code=404, detail="Class not found") await delete_class(db, class_) return {"message": "Class deleted"} @router.get("/{class_id}/members", response_model=PageResponse[UserListItem]) async def get_members( class_id: int, status: str | None = None, page: int = 1, page_size: int = 50, admin: User = Depends(require_role("super_admin", "class_admin")), db: AsyncSession = Depends(get_db), ): if admin.role == "class_admin" and admin.class_id != class_id: raise HTTPException(status_code=403, detail="Access denied for this class") members, total = await get_class_members(db, class_id, status, page, page_size) total_pages = (total + page_size - 1) // page_size return PageResponse( items=[UserListItem.model_validate(m) for m in members], total=total, page=page, page_size=page_size, total_pages=total_pages, ) # --- Roster management --- @router.get("/{class_id}/roster", response_model=PageResponse[RosterOut]) async def get_class_roster( class_id: int, page: int = 1, page_size: int = 50, admin: User = Depends(require_role("super_admin", "class_admin")), db: AsyncSession = Depends(get_db), ): if admin.role == "class_admin" and admin.class_id != class_id: raise HTTPException(status_code=403, detail="Access denied") entries, total = await get_roster(db, class_id, page, page_size) total_pages = (total + page_size - 1) // page_size return PageResponse( items=[RosterOut.model_validate(e) for e in entries], total=total, page=page, page_size=page_size, total_pages=total_pages, ) @router.post("/{class_id}/roster/import") async def import_class_roster( class_id: int, data: RosterImportRequest, admin: User = Depends(require_role("super_admin", "class_admin")), db: AsyncSession = Depends(get_db), ): if admin.role == "class_admin" and admin.class_id != class_id: raise HTTPException(status_code=403, detail="Access denied") count = await import_roster(db, class_id, data.entries) return {"message": f"成功导入 {count} 条记录"} @router.post("/{class_id}/roster/upload") async def upload_roster_file( class_id: int, file: UploadFile = File(...), admin: User = Depends(require_role("super_admin", "class_admin")), db: AsyncSession = Depends(get_db), ): if admin.role == "class_admin" and admin.class_id != class_id: raise HTTPException(status_code=403, detail="Access denied") contents = await file.read() filename = file.filename or "" entries: list[dict] = [] if filename.endswith(".csv"): text = contents.decode("utf-8-sig") reader = csv.DictReader(io.StringIO(text)) for row in reader: sid = row.get("student_id") or row.get("学号") or "" name = row.get("name") or row.get("姓名") or "" if sid and name: entries.append({"student_id": sid.strip(), "name": name.strip()}) elif filename.endswith((".xlsx", ".xls")): try: import openpyxl wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True) ws = wb.active rows = list(ws.iter_rows(values_only=True)) if len(rows) < 2: raise HTTPException(status_code=400, detail="Excel 文件为空") header = [str(h).strip() if h else "" for h in rows[0]] # Find student_id and name columns sid_col = None name_col = None for i, h in enumerate(header): if h in ("student_id", "学号"): sid_col = i elif h in ("name", "姓名"): name_col = i if sid_col is None or name_col is None: raise HTTPException( status_code=400, detail="Excel 需包含 '学号'(student_id) 和 '姓名'(name) 列", ) for row in rows[1:]: sid = str(row[sid_col]).strip() if row[sid_col] else "" name = str(row[name_col]).strip() if row[name_col] else "" if sid and name and sid != "None": entries.append({"student_id": sid, "name": name}) wb.close() except ImportError: raise HTTPException( status_code=400, detail="服务器未安装 openpyxl,请使用 CSV 格式" ) else: raise HTTPException(status_code=400, detail="仅支持 CSV 或 Excel (.xlsx) 文件") if not entries: raise HTTPException(status_code=400, detail="未找到有效数据") count = await import_roster(db, class_id, entries) return {"message": f"成功导入 {count} 条记录"} @router.delete("/{class_id}/roster/{roster_id}") async def delete_roster_item( class_id: int, roster_id: int, admin: User = Depends(require_role("super_admin", "class_admin")), db: AsyncSession = Depends(get_db), ): success = await delete_roster_entry(db, roster_id) if not success: raise HTTPException(status_code=400, detail="无法删除(已注册或不存在)") return {"message": "已删除"} @router.post("/{class_id}/roster/clear") async def clear_roster( class_id: int, admin: User = Depends(require_role("super_admin", "class_admin")), db: AsyncSession = Depends(get_db), ): if admin.role == "class_admin" and admin.class_id != class_id: raise HTTPException(status_code=403, detail="Access denied") count = await clear_unregistered_roster(db, class_id) return {"message": f"已清除 {count} 条未注册记录"} # --- Invite code management --- @router.get("/{class_id}/invite-code") async def get_invite_code( class_id: int, admin: User = Depends(require_role("super_admin", "class_admin")), db: AsyncSession = Depends(get_db), ): code = await ensure_invite_code(db, class_id) if not code: raise HTTPException(status_code=404, detail="Class not found") return {"invite_code": code} @router.post("/{class_id}/invite-code/regenerate") async def regenerate_invite( class_id: int, admin: User = Depends(require_role("super_admin", "class_admin")), db: AsyncSession = Depends(get_db), ): code = await regenerate_invite_code(db, class_id) if not code: raise HTTPException(status_code=404, detail="Class not found") return {"invite_code": code}