This commit is contained in:
aaron 2026-04-11 23:08:50 +08:00
parent 85f6b7e42b
commit 31e7a598dc
12 changed files with 376 additions and 80 deletions

View File

@ -1,39 +1,57 @@
import json
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import hash_password, verify_password, create_access_token
from app.core.deps import get_current_user
from app.db.database import get_db
from app.db.models import User, Class_
from app.db.models import User
from app.schemas.auth import LoginRequest, RegisterRequest, ChangePasswordRequest
from app.schemas.user import TokenResponse, UserOut
from app.services.user_service import register_user
from app.services.roster_service import validate_registration
router = APIRouter(prefix="/api/auth", tags=["auth"])
@router.post("/register")
async def register(req: RegisterRequest, db: AsyncSession = Depends(get_db)):
# 1. Check if email is already registered
existing = await db.execute(select(User).where(User.email == req.email))
if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Email already registered")
raise HTTPException(status_code=400, detail="该邮箱已注册")
class_result = await db.execute(select(Class_).where(Class_.id == req.class_id))
if class_result.scalar_one_or_none() is None:
raise HTTPException(status_code=400, detail="Class not found")
# 2. Validate invite_code + student_id against roster
roster_entry = await validate_registration(db, req.invite_code, req.student_id)
if roster_entry is None:
raise HTTPException(
status_code=400, detail="邀请码或学号无效,或该学号已注册"
)
user = await register_user(
db=db,
# 3. Create user with approved status directly
user = User(
email=req.email,
password_hash=hash_password(req.password),
name=req.name,
class_id=req.class_id,
student_id=req.student_id,
role="student",
status="approved",
class_id=roster_entry.class_id,
)
return {"message": "Registration submitted. Awaiting admin approval."}
db.add(user)
await db.flush()
# 4. Mark roster entry as registered
roster_entry.status = "registered"
roster_entry.user_id = user.id
await db.commit()
# 5. Issue token — register and login in one step
token = create_access_token({"sub": str(user.id), "role": user.role})
return {
"message": "注册成功",
"token": token,
"user": UserOut.model_validate(user),
}
@router.post("/login", response_model=TokenResponse)
@ -41,13 +59,14 @@ async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(User).where(User.email == req.email))
user = result.scalar_one_or_none()
if user is None or user.status != "approved":
raise HTTPException(
status_code=401, detail="Invalid credentials or account not approved"
)
if user is None:
raise HTTPException(status_code=401, detail="邮箱或密码错误")
if user.status == "disabled":
raise HTTPException(status_code=401, detail="账号已被禁用")
if not verify_password(req.password, user.password_hash):
raise HTTPException(status_code=401, detail="Invalid credentials")
raise HTTPException(status_code=401, detail="邮箱或密码错误")
token = create_access_token({"sub": str(user.id), "role": user.role})
return TokenResponse(

View File

@ -1,11 +1,15 @@
from fastapi import APIRouter, Depends, HTTPException, status
import csv
import io
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_current_user, require_role
from app.core.deps import require_role
from app.db.database import get_db
from app.db.models import User
from app.schemas.class_ import ClassCreate, ClassUpdate, ClassOut
from app.schemas.user import UserListItem
from app.schemas.roster import RosterOut, RosterImportRequest
from app.schemas.common import PageResponse
from app.services.class_service import (
create_class,
@ -16,6 +20,14 @@ from app.services.class_service import (
get_member_count,
get_class_members,
)
from app.services.roster_service import (
ensure_invite_code,
regenerate_invite_code,
import_roster,
get_roster,
delete_roster_entry,
clear_unregistered_roster,
)
router = APIRouter(prefix="/api/classes", tags=["classes"])
@ -103,8 +115,11 @@ async def get_members(
)
@router.get("/{class_id}/pending", response_model=PageResponse[UserListItem])
async def get_pending_members(
# --- Roster management ---
@router.get("/{class_id}/roster", response_model=PageResponse[RosterOut])
async def get_class_roster(
class_id: int,
page: int = 1,
page_size: int = 50,
@ -112,14 +127,144 @@ async def get_pending_members(
db: AsyncSession = Depends(get_db),
):
if admin.role == "class_admin" and admin.class_id != class_id:
raise HTTPException(status_code=403, detail="Access denied for this class")
members, total = await get_class_members(db, class_id, status="pending", page=page, page_size=page_size)
raise HTTPException(status_code=403, detail="Access denied")
entries, total = await get_roster(db, class_id, page, page_size)
total_pages = (total + page_size - 1) // page_size
return PageResponse(
items=[UserListItem.model_validate(m) for m in members],
items=[RosterOut.model_validate(e) for e in entries],
total=total,
page=page,
page_size=page_size,
total_pages=total_pages,
)
@router.post("/{class_id}/roster/import")
async def import_class_roster(
class_id: int,
data: RosterImportRequest,
admin: User = Depends(require_role("super_admin", "class_admin")),
db: AsyncSession = Depends(get_db),
):
if admin.role == "class_admin" and admin.class_id != class_id:
raise HTTPException(status_code=403, detail="Access denied")
count = await import_roster(db, class_id, data.entries)
return {"message": f"成功导入 {count} 条记录"}
@router.post("/{class_id}/roster/upload")
async def upload_roster_file(
class_id: int,
file: UploadFile = File(...),
admin: User = Depends(require_role("super_admin", "class_admin")),
db: AsyncSession = Depends(get_db),
):
if admin.role == "class_admin" and admin.class_id != class_id:
raise HTTPException(status_code=403, detail="Access denied")
contents = await file.read()
filename = file.filename or ""
entries: list[dict] = []
if filename.endswith(".csv"):
text = contents.decode("utf-8-sig")
reader = csv.DictReader(io.StringIO(text))
for row in reader:
sid = row.get("student_id") or row.get("学号") or ""
name = row.get("name") or row.get("姓名") or ""
if sid and name:
entries.append({"student_id": sid.strip(), "name": name.strip()})
elif filename.endswith((".xlsx", ".xls")):
try:
import openpyxl
wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True)
ws = wb.active
rows = list(ws.iter_rows(values_only=True))
if len(rows) < 2:
raise HTTPException(status_code=400, detail="Excel 文件为空")
header = [str(h).strip() if h else "" for h in rows[0]]
# Find student_id and name columns
sid_col = None
name_col = None
for i, h in enumerate(header):
if h in ("student_id", "学号"):
sid_col = i
elif h in ("name", "姓名"):
name_col = i
if sid_col is None or name_col is None:
raise HTTPException(
status_code=400,
detail="Excel 需包含 '学号'(student_id) 和 '姓名'(name) 列",
)
for row in rows[1:]:
sid = str(row[sid_col]).strip() if row[sid_col] else ""
name = str(row[name_col]).strip() if row[name_col] else ""
if sid and name and sid != "None":
entries.append({"student_id": sid, "name": name})
wb.close()
except ImportError:
raise HTTPException(
status_code=400, detail="服务器未安装 openpyxl请使用 CSV 格式"
)
else:
raise HTTPException(status_code=400, detail="仅支持 CSV 或 Excel (.xlsx) 文件")
if not entries:
raise HTTPException(status_code=400, detail="未找到有效数据")
count = await import_roster(db, class_id, entries)
return {"message": f"成功导入 {count} 条记录"}
@router.delete("/{class_id}/roster/{roster_id}")
async def delete_roster_item(
class_id: int,
roster_id: int,
admin: User = Depends(require_role("super_admin", "class_admin")),
db: AsyncSession = Depends(get_db),
):
success = await delete_roster_entry(db, roster_id)
if not success:
raise HTTPException(status_code=400, detail="无法删除(已注册或不存在)")
return {"message": "已删除"}
@router.post("/{class_id}/roster/clear")
async def clear_roster(
class_id: int,
admin: User = Depends(require_role("super_admin", "class_admin")),
db: AsyncSession = Depends(get_db),
):
if admin.role == "class_admin" and admin.class_id != class_id:
raise HTTPException(status_code=403, detail="Access denied")
count = await clear_unregistered_roster(db, class_id)
return {"message": f"已清除 {count} 条未注册记录"}
# --- Invite code management ---
@router.get("/{class_id}/invite-code")
async def get_invite_code(
class_id: int,
admin: User = Depends(require_role("super_admin", "class_admin")),
db: AsyncSession = Depends(get_db),
):
code = await ensure_invite_code(db, class_id)
if not code:
raise HTTPException(status_code=404, detail="Class not found")
return {"invite_code": code}
@router.post("/{class_id}/invite-code/regenerate")
async def regenerate_invite(
class_id: int,
admin: User = Depends(require_role("super_admin", "class_admin")),
db: AsyncSession = Depends(get_db),
):
code = await regenerate_invite_code(db, class_id)
if not code:
raise HTTPException(status_code=404, detail="Class not found")
return {"invite_code": code}

View File

@ -31,7 +31,10 @@ async def update_my_profile(
user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
updated = await update_profile(db, user, data)
try:
updated = await update_profile(db, user, data)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return UserOut.model_validate(updated)

View File

@ -35,9 +35,9 @@ async def get_current_user(
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found"
)
if user.status != "approved":
if user.status == "disabled":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Account not approved"
status_code=status.HTTP_403_FORBIDDEN, detail="Account disabled"
)
return user

View File

@ -14,6 +14,7 @@ class Class_(Base):
name: Mapped[str] = mapped_column(String(100), nullable=False)
cohort_year: Mapped[int] = mapped_column(Integer, nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
invite_code: Mapped[str | None] = mapped_column(String(20), unique=True, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now(), onupdate=func.now()
@ -32,6 +33,9 @@ class Class_(Base):
resources: Mapped[list["Resource"]] = relationship(
"Resource", back_populates="class_", cascade="all, delete-orphan"
)
roster: Mapped[list["StudentRoster"]] = relationship(
"StudentRoster", back_populates="class_", cascade="all, delete-orphan"
)
class User(Base):
@ -199,3 +203,24 @@ class Notification(Base):
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
user: Mapped["User"] = relationship("User")
class StudentRoster(Base):
__tablename__ = "student_rosters"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
class_id: Mapped[int] = mapped_column(
Integer, ForeignKey("classes.id"), nullable=False, index=True
)
student_id: Mapped[str] = mapped_column(String(50), nullable=False)
name: Mapped[str] = mapped_column(String(100), nullable=False)
status: Mapped[str] = mapped_column(
String(20), default="unregistered", nullable=False
)
user_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("users.id"), nullable=True
)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
class_: Mapped["Class_"] = relationship("Class_", back_populates="roster")
user: Mapped["User | None"] = relationship("User")

View File

@ -7,11 +7,11 @@ class LoginRequest(BaseModel):
class RegisterRequest(BaseModel):
invite_code: str
student_id: str
name: str
email: EmailStr
password: str
name: str
class_id: int
student_id: str
class ChangePasswordRequest(BaseModel):

View File

@ -20,6 +20,7 @@ class ClassOut(BaseModel):
name: str
cohort_year: int
description: str | None
invite_code: str | None = None
member_count: int = 0
created_at: datetime

View File

@ -0,0 +1,15 @@
from pydantic import BaseModel
class RosterOut(BaseModel):
id: int
student_id: str
name: str
status: str # "unregistered" | "registered"
user_id: int | None
model_config = {"from_attributes": True}
class RosterImportRequest(BaseModel):
entries: list[dict] # [{"student_id": "...", "name": "..."}, ...]

View File

@ -43,7 +43,6 @@ class UserPublic(BaseModel):
industry: str | None
company: str | None
position: str | None
skills_tags: list[str] | None
wechat_id: str | None
phone: str | None
avatar_url: str | None
@ -67,11 +66,11 @@ class UserListItem(BaseModel):
class UserUpdate(BaseModel):
email: EmailStr | None = None
name: str | None = None
industry: str | None = None
company: str | None = None
position: str | None = None
skills_tags: list[str] | None = None
wechat_id: str | None = None
phone: str | None = None
bio: str | None = None

View File

@ -70,7 +70,6 @@ def user_to_public(user: User, include_contact: bool = True) -> UserPublic:
industry=user.industry,
company=user.company,
position=user.position,
skills_tags=user.get_skills_list(),
wechat_id=user.wechat_id if include_contact else None,
phone=user.phone if include_contact else None,
avatar_url=user.avatar_url,

View File

@ -0,0 +1,125 @@
import secrets
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import StudentRoster, Class_
def generate_invite_code(length: int = 8) -> str:
chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
return "".join(secrets.choice(chars) for _ in range(length))
async def ensure_invite_code(db: AsyncSession, class_id: int) -> str:
result = await db.execute(select(Class_).where(Class_.id == class_id))
class_ = result.scalar_one_or_none()
if class_ is None:
return ""
if not class_.invite_code:
class_.invite_code = generate_invite_code()
await db.commit()
await db.refresh(class_)
return class_.invite_code
async def regenerate_invite_code(db: AsyncSession, class_id: int) -> str:
result = await db.execute(select(Class_).where(Class_.id == class_id))
class_ = result.scalar_one_or_none()
if class_ is None:
return ""
class_.invite_code = generate_invite_code()
await db.commit()
await db.refresh(class_)
return class_.invite_code
async def import_roster(
db: AsyncSession, class_id: int, entries: list[dict]
) -> int:
existing_ids: set[str] = set()
result = await db.execute(
select(StudentRoster.student_id).where(StudentRoster.class_id == class_id)
)
for row in result.all():
existing_ids.add(row[0])
count = 0
for entry in entries:
sid = entry.get("student_id", "").strip()
name = entry.get("name", "").strip()
if not sid or not name or sid in existing_ids:
continue
roster = StudentRoster(class_id=class_id, student_id=sid, name=name)
db.add(roster)
existing_ids.add(sid)
count += 1
await db.commit()
return count
async def get_roster(
db: AsyncSession, class_id: int, page: int = 1, page_size: int = 50
) -> tuple[list[StudentRoster], int]:
query = select(StudentRoster).where(StudentRoster.class_id == class_id)
count_query = select(func.count(StudentRoster.id)).where(
StudentRoster.class_id == class_id
)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
result = await db.execute(
query.order_by(StudentRoster.student_id)
.offset((page - 1) * page_size)
.limit(page_size)
)
return list(result.scalars().all()), total
async def validate_registration(
db: AsyncSession, invite_code: str, student_id: str
) -> StudentRoster | None:
class_result = await db.execute(
select(Class_).where(Class_.invite_code == invite_code)
)
class_ = class_result.scalar_one_or_none()
if class_ is None:
return None
roster_result = await db.execute(
select(StudentRoster).where(
StudentRoster.class_id == class_.id,
StudentRoster.student_id == student_id,
StudentRoster.status == "unregistered",
)
)
return roster_result.scalar_one_or_none()
async def delete_roster_entry(db: AsyncSession, roster_id: int) -> bool:
result = await db.execute(
select(StudentRoster).where(StudentRoster.id == roster_id)
)
entry = result.scalar_one_or_none()
if entry is None:
return False
if entry.status == "registered":
return False
await db.delete(entry)
await db.commit()
return True
async def clear_unregistered_roster(db: AsyncSession, class_id: int) -> int:
result = await db.execute(
select(StudentRoster).where(
StudentRoster.class_id == class_id,
StudentRoster.status == "unregistered",
)
)
entries = list(result.scalars().all())
for entry in entries:
await db.delete(entry)
await db.commit()
return len(entries)

View File

@ -1,9 +1,8 @@
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import User, Class_
from app.db.models import User
from app.schemas.user import UserOut, UserUpdate
from app.services.email_service import send_registration_notification
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
@ -16,52 +15,18 @@ async def get_user_by_id(db: AsyncSession, user_id: int) -> User | None:
return result.scalar_one_or_none()
async def register_user(
db: AsyncSession,
email: str,
password_hash: str,
name: str,
class_id: int,
student_id: str | None = None,
) -> User:
user = User(
email=email,
password_hash=password_hash,
name=name,
student_id=student_id,
role="student",
status="pending",
class_id=class_id,
)
db.add(user)
await db.commit()
await db.refresh(user)
# Notify class admins
admins_result = await db.execute(
select(User).where(
User.class_id == class_id,
User.role.in_(["class_admin", "super_admin"]),
User.status == "approved",
)
)
class_result = await db.execute(select(Class_).where(Class_.id == class_id))
class_ = class_result.scalar_one_or_none()
class_name = class_.name if class_ else "Unknown"
for admin in admins_result.scalars():
await send_registration_notification(admin.email, name, class_name)
return user
async def update_profile(db: AsyncSession, user: User, data: UserUpdate) -> User:
update_data = data.model_dump(exclude_unset=True)
if "skills_tags" in update_data and update_data["skills_tags"] is not None:
import json
user.skills_tags = json.dumps(
update_data.pop("skills_tags"), ensure_ascii=False
# Handle email change with uniqueness check
if "email" in update_data and update_data["email"] != user.email:
existing = await db.execute(
select(User).where(User.email == update_data["email"])
)
if existing.scalar_one_or_none():
raise ValueError("该邮箱已被使用")
user.email = update_data.pop("email")
for field, value in update_data.items():
setattr(user, field, value)
await db.commit()