diff --git a/backend/app/api/auth.py b/backend/app/api/auth.py index 0f5caa6..72547c0 100644 --- a/backend/app/api/auth.py +++ b/backend/app/api/auth.py @@ -1,39 +1,57 @@ -import json - -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.core.auth import hash_password, verify_password, create_access_token from app.core.deps import get_current_user from app.db.database import get_db -from app.db.models import User, Class_ +from app.db.models import User from app.schemas.auth import LoginRequest, RegisterRequest, ChangePasswordRequest from app.schemas.user import TokenResponse, UserOut -from app.services.user_service import register_user +from app.services.roster_service import validate_registration router = APIRouter(prefix="/api/auth", tags=["auth"]) @router.post("/register") async def register(req: RegisterRequest, db: AsyncSession = Depends(get_db)): + # 1. Check if email is already registered existing = await db.execute(select(User).where(User.email == req.email)) if existing.scalar_one_or_none(): - raise HTTPException(status_code=400, detail="Email already registered") + raise HTTPException(status_code=400, detail="该邮箱已注册") - class_result = await db.execute(select(Class_).where(Class_.id == req.class_id)) - if class_result.scalar_one_or_none() is None: - raise HTTPException(status_code=400, detail="Class not found") + # 2. Validate invite_code + student_id against roster + roster_entry = await validate_registration(db, req.invite_code, req.student_id) + if roster_entry is None: + raise HTTPException( + status_code=400, detail="邀请码或学号无效,或该学号已注册" + ) - user = await register_user( - db=db, + # 3. Create user with approved status directly + user = User( email=req.email, password_hash=hash_password(req.password), name=req.name, - class_id=req.class_id, student_id=req.student_id, + role="student", + status="approved", + class_id=roster_entry.class_id, ) - return {"message": "Registration submitted. Awaiting admin approval."} + db.add(user) + await db.flush() + + # 4. Mark roster entry as registered + roster_entry.status = "registered" + roster_entry.user_id = user.id + await db.commit() + + # 5. Issue token — register and login in one step + token = create_access_token({"sub": str(user.id), "role": user.role}) + return { + "message": "注册成功", + "token": token, + "user": UserOut.model_validate(user), + } @router.post("/login", response_model=TokenResponse) @@ -41,13 +59,14 @@ async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)): result = await db.execute(select(User).where(User.email == req.email)) user = result.scalar_one_or_none() - if user is None or user.status != "approved": - raise HTTPException( - status_code=401, detail="Invalid credentials or account not approved" - ) + if user is None: + raise HTTPException(status_code=401, detail="邮箱或密码错误") + + if user.status == "disabled": + raise HTTPException(status_code=401, detail="账号已被禁用") if not verify_password(req.password, user.password_hash): - raise HTTPException(status_code=401, detail="Invalid credentials") + raise HTTPException(status_code=401, detail="邮箱或密码错误") token = create_access_token({"sub": str(user.id), "role": user.role}) return TokenResponse( diff --git a/backend/app/api/classes.py b/backend/app/api/classes.py index 112c365..cbdca3e 100644 --- a/backend/app/api/classes.py +++ b/backend/app/api/classes.py @@ -1,11 +1,15 @@ -from fastapi import APIRouter, Depends, HTTPException, status +import csv +import io + +from fastapi import APIRouter, Depends, HTTPException, UploadFile, File from sqlalchemy.ext.asyncio import AsyncSession -from app.core.deps import get_current_user, require_role +from app.core.deps import require_role from app.db.database import get_db from app.db.models import User from app.schemas.class_ import ClassCreate, ClassUpdate, ClassOut from app.schemas.user import UserListItem +from app.schemas.roster import RosterOut, RosterImportRequest from app.schemas.common import PageResponse from app.services.class_service import ( create_class, @@ -16,6 +20,14 @@ from app.services.class_service import ( get_member_count, get_class_members, ) +from app.services.roster_service import ( + ensure_invite_code, + regenerate_invite_code, + import_roster, + get_roster, + delete_roster_entry, + clear_unregistered_roster, +) router = APIRouter(prefix="/api/classes", tags=["classes"]) @@ -103,8 +115,11 @@ async def get_members( ) -@router.get("/{class_id}/pending", response_model=PageResponse[UserListItem]) -async def get_pending_members( +# --- Roster management --- + + +@router.get("/{class_id}/roster", response_model=PageResponse[RosterOut]) +async def get_class_roster( class_id: int, page: int = 1, page_size: int = 50, @@ -112,14 +127,144 @@ async def get_pending_members( db: AsyncSession = Depends(get_db), ): if admin.role == "class_admin" and admin.class_id != class_id: - raise HTTPException(status_code=403, detail="Access denied for this class") - - members, total = await get_class_members(db, class_id, status="pending", page=page, page_size=page_size) + raise HTTPException(status_code=403, detail="Access denied") + entries, total = await get_roster(db, class_id, page, page_size) total_pages = (total + page_size - 1) // page_size return PageResponse( - items=[UserListItem.model_validate(m) for m in members], + items=[RosterOut.model_validate(e) for e in entries], total=total, page=page, page_size=page_size, total_pages=total_pages, ) + + +@router.post("/{class_id}/roster/import") +async def import_class_roster( + class_id: int, + data: RosterImportRequest, + admin: User = Depends(require_role("super_admin", "class_admin")), + db: AsyncSession = Depends(get_db), +): + if admin.role == "class_admin" and admin.class_id != class_id: + raise HTTPException(status_code=403, detail="Access denied") + count = await import_roster(db, class_id, data.entries) + return {"message": f"成功导入 {count} 条记录"} + + +@router.post("/{class_id}/roster/upload") +async def upload_roster_file( + class_id: int, + file: UploadFile = File(...), + admin: User = Depends(require_role("super_admin", "class_admin")), + db: AsyncSession = Depends(get_db), +): + if admin.role == "class_admin" and admin.class_id != class_id: + raise HTTPException(status_code=403, detail="Access denied") + + contents = await file.read() + filename = file.filename or "" + + entries: list[dict] = [] + + if filename.endswith(".csv"): + text = contents.decode("utf-8-sig") + reader = csv.DictReader(io.StringIO(text)) + for row in reader: + sid = row.get("student_id") or row.get("学号") or "" + name = row.get("name") or row.get("姓名") or "" + if sid and name: + entries.append({"student_id": sid.strip(), "name": name.strip()}) + elif filename.endswith((".xlsx", ".xls")): + try: + import openpyxl + + wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True) + ws = wb.active + rows = list(ws.iter_rows(values_only=True)) + if len(rows) < 2: + raise HTTPException(status_code=400, detail="Excel 文件为空") + header = [str(h).strip() if h else "" for h in rows[0]] + # Find student_id and name columns + sid_col = None + name_col = None + for i, h in enumerate(header): + if h in ("student_id", "学号"): + sid_col = i + elif h in ("name", "姓名"): + name_col = i + if sid_col is None or name_col is None: + raise HTTPException( + status_code=400, + detail="Excel 需包含 '学号'(student_id) 和 '姓名'(name) 列", + ) + for row in rows[1:]: + sid = str(row[sid_col]).strip() if row[sid_col] else "" + name = str(row[name_col]).strip() if row[name_col] else "" + if sid and name and sid != "None": + entries.append({"student_id": sid, "name": name}) + wb.close() + except ImportError: + raise HTTPException( + status_code=400, detail="服务器未安装 openpyxl,请使用 CSV 格式" + ) + else: + raise HTTPException(status_code=400, detail="仅支持 CSV 或 Excel (.xlsx) 文件") + + if not entries: + raise HTTPException(status_code=400, detail="未找到有效数据") + + count = await import_roster(db, class_id, entries) + return {"message": f"成功导入 {count} 条记录"} + + +@router.delete("/{class_id}/roster/{roster_id}") +async def delete_roster_item( + class_id: int, + roster_id: int, + admin: User = Depends(require_role("super_admin", "class_admin")), + db: AsyncSession = Depends(get_db), +): + success = await delete_roster_entry(db, roster_id) + if not success: + raise HTTPException(status_code=400, detail="无法删除(已注册或不存在)") + return {"message": "已删除"} + + +@router.post("/{class_id}/roster/clear") +async def clear_roster( + class_id: int, + admin: User = Depends(require_role("super_admin", "class_admin")), + db: AsyncSession = Depends(get_db), +): + if admin.role == "class_admin" and admin.class_id != class_id: + raise HTTPException(status_code=403, detail="Access denied") + count = await clear_unregistered_roster(db, class_id) + return {"message": f"已清除 {count} 条未注册记录"} + + +# --- Invite code management --- + + +@router.get("/{class_id}/invite-code") +async def get_invite_code( + class_id: int, + admin: User = Depends(require_role("super_admin", "class_admin")), + db: AsyncSession = Depends(get_db), +): + code = await ensure_invite_code(db, class_id) + if not code: + raise HTTPException(status_code=404, detail="Class not found") + return {"invite_code": code} + + +@router.post("/{class_id}/invite-code/regenerate") +async def regenerate_invite( + class_id: int, + admin: User = Depends(require_role("super_admin", "class_admin")), + db: AsyncSession = Depends(get_db), +): + code = await regenerate_invite_code(db, class_id) + if not code: + raise HTTPException(status_code=404, detail="Class not found") + return {"invite_code": code} diff --git a/backend/app/api/users.py b/backend/app/api/users.py index e008035..57f594d 100644 --- a/backend/app/api/users.py +++ b/backend/app/api/users.py @@ -31,7 +31,10 @@ async def update_my_profile( user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): - updated = await update_profile(db, user, data) + try: + updated = await update_profile(db, user, data) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) return UserOut.model_validate(updated) diff --git a/backend/app/core/deps.py b/backend/app/core/deps.py index ee713f2..e1a8bba 100644 --- a/backend/app/core/deps.py +++ b/backend/app/core/deps.py @@ -35,9 +35,9 @@ async def get_current_user( raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found" ) - if user.status != "approved": + if user.status == "disabled": raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail="Account not approved" + status_code=status.HTTP_403_FORBIDDEN, detail="Account disabled" ) return user diff --git a/backend/app/db/models.py b/backend/app/db/models.py index 59f8187..47fc1b2 100644 --- a/backend/app/db/models.py +++ b/backend/app/db/models.py @@ -14,6 +14,7 @@ class Class_(Base): name: Mapped[str] = mapped_column(String(100), nullable=False) cohort_year: Mapped[int] = mapped_column(Integer, nullable=False) description: Mapped[str | None] = mapped_column(Text, nullable=True) + invite_code: Mapped[str | None] = mapped_column(String(20), unique=True, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now()) updated_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.now(), onupdate=func.now() @@ -32,6 +33,9 @@ class Class_(Base): resources: Mapped[list["Resource"]] = relationship( "Resource", back_populates="class_", cascade="all, delete-orphan" ) + roster: Mapped[list["StudentRoster"]] = relationship( + "StudentRoster", back_populates="class_", cascade="all, delete-orphan" + ) class User(Base): @@ -199,3 +203,24 @@ class Notification(Base): created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now()) user: Mapped["User"] = relationship("User") + + +class StudentRoster(Base): + __tablename__ = "student_rosters" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + class_id: Mapped[int] = mapped_column( + Integer, ForeignKey("classes.id"), nullable=False, index=True + ) + student_id: Mapped[str] = mapped_column(String(50), nullable=False) + name: Mapped[str] = mapped_column(String(100), nullable=False) + status: Mapped[str] = mapped_column( + String(20), default="unregistered", nullable=False + ) + user_id: Mapped[int | None] = mapped_column( + Integer, ForeignKey("users.id"), nullable=True + ) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now()) + + class_: Mapped["Class_"] = relationship("Class_", back_populates="roster") + user: Mapped["User | None"] = relationship("User") diff --git a/backend/app/schemas/auth.py b/backend/app/schemas/auth.py index c245337..851f3f7 100644 --- a/backend/app/schemas/auth.py +++ b/backend/app/schemas/auth.py @@ -7,11 +7,11 @@ class LoginRequest(BaseModel): class RegisterRequest(BaseModel): + invite_code: str + student_id: str + name: str email: EmailStr password: str - name: str - class_id: int - student_id: str class ChangePasswordRequest(BaseModel): diff --git a/backend/app/schemas/class_.py b/backend/app/schemas/class_.py index 3f78265..5615b24 100644 --- a/backend/app/schemas/class_.py +++ b/backend/app/schemas/class_.py @@ -20,6 +20,7 @@ class ClassOut(BaseModel): name: str cohort_year: int description: str | None + invite_code: str | None = None member_count: int = 0 created_at: datetime diff --git a/backend/app/schemas/roster.py b/backend/app/schemas/roster.py new file mode 100644 index 0000000..75640e7 --- /dev/null +++ b/backend/app/schemas/roster.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel + + +class RosterOut(BaseModel): + id: int + student_id: str + name: str + status: str # "unregistered" | "registered" + user_id: int | None + + model_config = {"from_attributes": True} + + +class RosterImportRequest(BaseModel): + entries: list[dict] # [{"student_id": "...", "name": "..."}, ...] diff --git a/backend/app/schemas/user.py b/backend/app/schemas/user.py index 1e13277..eaf4184 100644 --- a/backend/app/schemas/user.py +++ b/backend/app/schemas/user.py @@ -43,7 +43,6 @@ class UserPublic(BaseModel): industry: str | None company: str | None position: str | None - skills_tags: list[str] | None wechat_id: str | None phone: str | None avatar_url: str | None @@ -67,11 +66,11 @@ class UserListItem(BaseModel): class UserUpdate(BaseModel): + email: EmailStr | None = None name: str | None = None industry: str | None = None company: str | None = None position: str | None = None - skills_tags: list[str] | None = None wechat_id: str | None = None phone: str | None = None bio: str | None = None diff --git a/backend/app/services/directory_service.py b/backend/app/services/directory_service.py index 6ef826a..72b56a5 100644 --- a/backend/app/services/directory_service.py +++ b/backend/app/services/directory_service.py @@ -70,7 +70,6 @@ def user_to_public(user: User, include_contact: bool = True) -> UserPublic: industry=user.industry, company=user.company, position=user.position, - skills_tags=user.get_skills_list(), wechat_id=user.wechat_id if include_contact else None, phone=user.phone if include_contact else None, avatar_url=user.avatar_url, diff --git a/backend/app/services/roster_service.py b/backend/app/services/roster_service.py new file mode 100644 index 0000000..b0b2fc0 --- /dev/null +++ b/backend/app/services/roster_service.py @@ -0,0 +1,125 @@ +import secrets + +from sqlalchemy import select, func +from sqlalchemy.ext.asyncio import AsyncSession + +from app.db.models import StudentRoster, Class_ + + +def generate_invite_code(length: int = 8) -> str: + chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" + return "".join(secrets.choice(chars) for _ in range(length)) + + +async def ensure_invite_code(db: AsyncSession, class_id: int) -> str: + result = await db.execute(select(Class_).where(Class_.id == class_id)) + class_ = result.scalar_one_or_none() + if class_ is None: + return "" + if not class_.invite_code: + class_.invite_code = generate_invite_code() + await db.commit() + await db.refresh(class_) + return class_.invite_code + + +async def regenerate_invite_code(db: AsyncSession, class_id: int) -> str: + result = await db.execute(select(Class_).where(Class_.id == class_id)) + class_ = result.scalar_one_or_none() + if class_ is None: + return "" + class_.invite_code = generate_invite_code() + await db.commit() + await db.refresh(class_) + return class_.invite_code + + +async def import_roster( + db: AsyncSession, class_id: int, entries: list[dict] +) -> int: + existing_ids: set[str] = set() + result = await db.execute( + select(StudentRoster.student_id).where(StudentRoster.class_id == class_id) + ) + for row in result.all(): + existing_ids.add(row[0]) + + count = 0 + for entry in entries: + sid = entry.get("student_id", "").strip() + name = entry.get("name", "").strip() + if not sid or not name or sid in existing_ids: + continue + roster = StudentRoster(class_id=class_id, student_id=sid, name=name) + db.add(roster) + existing_ids.add(sid) + count += 1 + await db.commit() + return count + + +async def get_roster( + db: AsyncSession, class_id: int, page: int = 1, page_size: int = 50 +) -> tuple[list[StudentRoster], int]: + query = select(StudentRoster).where(StudentRoster.class_id == class_id) + count_query = select(func.count(StudentRoster.id)).where( + StudentRoster.class_id == class_id + ) + + total_result = await db.execute(count_query) + total = total_result.scalar() or 0 + + result = await db.execute( + query.order_by(StudentRoster.student_id) + .offset((page - 1) * page_size) + .limit(page_size) + ) + return list(result.scalars().all()), total + + +async def validate_registration( + db: AsyncSession, invite_code: str, student_id: str +) -> StudentRoster | None: + class_result = await db.execute( + select(Class_).where(Class_.invite_code == invite_code) + ) + class_ = class_result.scalar_one_or_none() + if class_ is None: + return None + + roster_result = await db.execute( + select(StudentRoster).where( + StudentRoster.class_id == class_.id, + StudentRoster.student_id == student_id, + StudentRoster.status == "unregistered", + ) + ) + return roster_result.scalar_one_or_none() + + +async def delete_roster_entry(db: AsyncSession, roster_id: int) -> bool: + result = await db.execute( + select(StudentRoster).where(StudentRoster.id == roster_id) + ) + entry = result.scalar_one_or_none() + if entry is None: + return False + if entry.status == "registered": + return False + await db.delete(entry) + await db.commit() + return True + + +async def clear_unregistered_roster(db: AsyncSession, class_id: int) -> int: + result = await db.execute( + select(StudentRoster).where( + StudentRoster.class_id == class_id, + StudentRoster.status == "unregistered", + ) + ) + entries = list(result.scalars().all()) + for entry in entries: + await db.delete(entry) + await db.commit() + return len(entries) diff --git a/backend/app/services/user_service.py b/backend/app/services/user_service.py index c8787de..a951866 100644 --- a/backend/app/services/user_service.py +++ b/backend/app/services/user_service.py @@ -1,9 +1,8 @@ from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession -from app.db.models import User, Class_ +from app.db.models import User from app.schemas.user import UserOut, UserUpdate -from app.services.email_service import send_registration_notification async def get_user_by_email(db: AsyncSession, email: str) -> User | None: @@ -16,52 +15,18 @@ async def get_user_by_id(db: AsyncSession, user_id: int) -> User | None: return result.scalar_one_or_none() -async def register_user( - db: AsyncSession, - email: str, - password_hash: str, - name: str, - class_id: int, - student_id: str | None = None, -) -> User: - user = User( - email=email, - password_hash=password_hash, - name=name, - student_id=student_id, - role="student", - status="pending", - class_id=class_id, - ) - db.add(user) - await db.commit() - await db.refresh(user) - - # Notify class admins - admins_result = await db.execute( - select(User).where( - User.class_id == class_id, - User.role.in_(["class_admin", "super_admin"]), - User.status == "approved", - ) - ) - class_result = await db.execute(select(Class_).where(Class_.id == class_id)) - class_ = class_result.scalar_one_or_none() - class_name = class_.name if class_ else "Unknown" - - for admin in admins_result.scalars(): - await send_registration_notification(admin.email, name, class_name) - - return user - - async def update_profile(db: AsyncSession, user: User, data: UserUpdate) -> User: update_data = data.model_dump(exclude_unset=True) - if "skills_tags" in update_data and update_data["skills_tags"] is not None: - import json - user.skills_tags = json.dumps( - update_data.pop("skills_tags"), ensure_ascii=False + + # Handle email change with uniqueness check + if "email" in update_data and update_data["email"] != user.email: + existing = await db.execute( + select(User).where(User.email == update_data["email"]) ) + if existing.scalar_one_or_none(): + raise ValueError("该邮箱已被使用") + user.email = update_data.pop("email") + for field, value in update_data.items(): setattr(user, field, value) await db.commit()