This commit is contained in:
aaron 2026-04-11 23:08:50 +08:00
parent 85f6b7e42b
commit 31e7a598dc
12 changed files with 376 additions and 80 deletions

View File

@ -1,39 +1,57 @@
import json from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.auth import hash_password, verify_password, create_access_token from app.core.auth import hash_password, verify_password, create_access_token
from app.core.deps import get_current_user from app.core.deps import get_current_user
from app.db.database import get_db from app.db.database import get_db
from app.db.models import User, Class_ from app.db.models import User
from app.schemas.auth import LoginRequest, RegisterRequest, ChangePasswordRequest from app.schemas.auth import LoginRequest, RegisterRequest, ChangePasswordRequest
from app.schemas.user import TokenResponse, UserOut from app.schemas.user import TokenResponse, UserOut
from app.services.user_service import register_user from app.services.roster_service import validate_registration
router = APIRouter(prefix="/api/auth", tags=["auth"]) router = APIRouter(prefix="/api/auth", tags=["auth"])
@router.post("/register") @router.post("/register")
async def register(req: RegisterRequest, db: AsyncSession = Depends(get_db)): async def register(req: RegisterRequest, db: AsyncSession = Depends(get_db)):
# 1. Check if email is already registered
existing = await db.execute(select(User).where(User.email == req.email)) existing = await db.execute(select(User).where(User.email == req.email))
if existing.scalar_one_or_none(): if existing.scalar_one_or_none():
raise HTTPException(status_code=400, detail="Email already registered") raise HTTPException(status_code=400, detail="该邮箱已注册")
class_result = await db.execute(select(Class_).where(Class_.id == req.class_id)) # 2. Validate invite_code + student_id against roster
if class_result.scalar_one_or_none() is None: roster_entry = await validate_registration(db, req.invite_code, req.student_id)
raise HTTPException(status_code=400, detail="Class not found") if roster_entry is None:
raise HTTPException(
status_code=400, detail="邀请码或学号无效,或该学号已注册"
)
user = await register_user( # 3. Create user with approved status directly
db=db, user = User(
email=req.email, email=req.email,
password_hash=hash_password(req.password), password_hash=hash_password(req.password),
name=req.name, name=req.name,
class_id=req.class_id,
student_id=req.student_id, student_id=req.student_id,
role="student",
status="approved",
class_id=roster_entry.class_id,
) )
return {"message": "Registration submitted. Awaiting admin approval."} db.add(user)
await db.flush()
# 4. Mark roster entry as registered
roster_entry.status = "registered"
roster_entry.user_id = user.id
await db.commit()
# 5. Issue token — register and login in one step
token = create_access_token({"sub": str(user.id), "role": user.role})
return {
"message": "注册成功",
"token": token,
"user": UserOut.model_validate(user),
}
@router.post("/login", response_model=TokenResponse) @router.post("/login", response_model=TokenResponse)
@ -41,13 +59,14 @@ async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
result = await db.execute(select(User).where(User.email == req.email)) result = await db.execute(select(User).where(User.email == req.email))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if user is None or user.status != "approved": if user is None:
raise HTTPException( raise HTTPException(status_code=401, detail="邮箱或密码错误")
status_code=401, detail="Invalid credentials or account not approved"
) if user.status == "disabled":
raise HTTPException(status_code=401, detail="账号已被禁用")
if not verify_password(req.password, user.password_hash): if not verify_password(req.password, user.password_hash):
raise HTTPException(status_code=401, detail="Invalid credentials") raise HTTPException(status_code=401, detail="邮箱或密码错误")
token = create_access_token({"sub": str(user.id), "role": user.role}) token = create_access_token({"sub": str(user.id), "role": user.role})
return TokenResponse( return TokenResponse(

View File

@ -1,11 +1,15 @@
from fastapi import APIRouter, Depends, HTTPException, status import csv
import io
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.core.deps import get_current_user, require_role from app.core.deps import require_role
from app.db.database import get_db from app.db.database import get_db
from app.db.models import User from app.db.models import User
from app.schemas.class_ import ClassCreate, ClassUpdate, ClassOut from app.schemas.class_ import ClassCreate, ClassUpdate, ClassOut
from app.schemas.user import UserListItem from app.schemas.user import UserListItem
from app.schemas.roster import RosterOut, RosterImportRequest
from app.schemas.common import PageResponse from app.schemas.common import PageResponse
from app.services.class_service import ( from app.services.class_service import (
create_class, create_class,
@ -16,6 +20,14 @@ from app.services.class_service import (
get_member_count, get_member_count,
get_class_members, get_class_members,
) )
from app.services.roster_service import (
ensure_invite_code,
regenerate_invite_code,
import_roster,
get_roster,
delete_roster_entry,
clear_unregistered_roster,
)
router = APIRouter(prefix="/api/classes", tags=["classes"]) router = APIRouter(prefix="/api/classes", tags=["classes"])
@ -103,8 +115,11 @@ async def get_members(
) )
@router.get("/{class_id}/pending", response_model=PageResponse[UserListItem]) # --- Roster management ---
async def get_pending_members(
@router.get("/{class_id}/roster", response_model=PageResponse[RosterOut])
async def get_class_roster(
class_id: int, class_id: int,
page: int = 1, page: int = 1,
page_size: int = 50, page_size: int = 50,
@ -112,14 +127,144 @@ async def get_pending_members(
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
if admin.role == "class_admin" and admin.class_id != class_id: if admin.role == "class_admin" and admin.class_id != class_id:
raise HTTPException(status_code=403, detail="Access denied for this class") raise HTTPException(status_code=403, detail="Access denied")
entries, total = await get_roster(db, class_id, page, page_size)
members, total = await get_class_members(db, class_id, status="pending", page=page, page_size=page_size)
total_pages = (total + page_size - 1) // page_size total_pages = (total + page_size - 1) // page_size
return PageResponse( return PageResponse(
items=[UserListItem.model_validate(m) for m in members], items=[RosterOut.model_validate(e) for e in entries],
total=total, total=total,
page=page, page=page,
page_size=page_size, page_size=page_size,
total_pages=total_pages, total_pages=total_pages,
) )
@router.post("/{class_id}/roster/import")
async def import_class_roster(
class_id: int,
data: RosterImportRequest,
admin: User = Depends(require_role("super_admin", "class_admin")),
db: AsyncSession = Depends(get_db),
):
if admin.role == "class_admin" and admin.class_id != class_id:
raise HTTPException(status_code=403, detail="Access denied")
count = await import_roster(db, class_id, data.entries)
return {"message": f"成功导入 {count} 条记录"}
@router.post("/{class_id}/roster/upload")
async def upload_roster_file(
class_id: int,
file: UploadFile = File(...),
admin: User = Depends(require_role("super_admin", "class_admin")),
db: AsyncSession = Depends(get_db),
):
if admin.role == "class_admin" and admin.class_id != class_id:
raise HTTPException(status_code=403, detail="Access denied")
contents = await file.read()
filename = file.filename or ""
entries: list[dict] = []
if filename.endswith(".csv"):
text = contents.decode("utf-8-sig")
reader = csv.DictReader(io.StringIO(text))
for row in reader:
sid = row.get("student_id") or row.get("学号") or ""
name = row.get("name") or row.get("姓名") or ""
if sid and name:
entries.append({"student_id": sid.strip(), "name": name.strip()})
elif filename.endswith((".xlsx", ".xls")):
try:
import openpyxl
wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True)
ws = wb.active
rows = list(ws.iter_rows(values_only=True))
if len(rows) < 2:
raise HTTPException(status_code=400, detail="Excel 文件为空")
header = [str(h).strip() if h else "" for h in rows[0]]
# Find student_id and name columns
sid_col = None
name_col = None
for i, h in enumerate(header):
if h in ("student_id", "学号"):
sid_col = i
elif h in ("name", "姓名"):
name_col = i
if sid_col is None or name_col is None:
raise HTTPException(
status_code=400,
detail="Excel 需包含 '学号'(student_id) 和 '姓名'(name) 列",
)
for row in rows[1:]:
sid = str(row[sid_col]).strip() if row[sid_col] else ""
name = str(row[name_col]).strip() if row[name_col] else ""
if sid and name and sid != "None":
entries.append({"student_id": sid, "name": name})
wb.close()
except ImportError:
raise HTTPException(
status_code=400, detail="服务器未安装 openpyxl请使用 CSV 格式"
)
else:
raise HTTPException(status_code=400, detail="仅支持 CSV 或 Excel (.xlsx) 文件")
if not entries:
raise HTTPException(status_code=400, detail="未找到有效数据")
count = await import_roster(db, class_id, entries)
return {"message": f"成功导入 {count} 条记录"}
@router.delete("/{class_id}/roster/{roster_id}")
async def delete_roster_item(
class_id: int,
roster_id: int,
admin: User = Depends(require_role("super_admin", "class_admin")),
db: AsyncSession = Depends(get_db),
):
success = await delete_roster_entry(db, roster_id)
if not success:
raise HTTPException(status_code=400, detail="无法删除(已注册或不存在)")
return {"message": "已删除"}
@router.post("/{class_id}/roster/clear")
async def clear_roster(
class_id: int,
admin: User = Depends(require_role("super_admin", "class_admin")),
db: AsyncSession = Depends(get_db),
):
if admin.role == "class_admin" and admin.class_id != class_id:
raise HTTPException(status_code=403, detail="Access denied")
count = await clear_unregistered_roster(db, class_id)
return {"message": f"已清除 {count} 条未注册记录"}
# --- Invite code management ---
@router.get("/{class_id}/invite-code")
async def get_invite_code(
class_id: int,
admin: User = Depends(require_role("super_admin", "class_admin")),
db: AsyncSession = Depends(get_db),
):
code = await ensure_invite_code(db, class_id)
if not code:
raise HTTPException(status_code=404, detail="Class not found")
return {"invite_code": code}
@router.post("/{class_id}/invite-code/regenerate")
async def regenerate_invite(
class_id: int,
admin: User = Depends(require_role("super_admin", "class_admin")),
db: AsyncSession = Depends(get_db),
):
code = await regenerate_invite_code(db, class_id)
if not code:
raise HTTPException(status_code=404, detail="Class not found")
return {"invite_code": code}

View File

@ -31,7 +31,10 @@ async def update_my_profile(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),
): ):
updated = await update_profile(db, user, data) try:
updated = await update_profile(db, user, data)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return UserOut.model_validate(updated) return UserOut.model_validate(updated)

View File

@ -35,9 +35,9 @@ async def get_current_user(
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found" status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found"
) )
if user.status != "approved": if user.status == "disabled":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Account not approved" status_code=status.HTTP_403_FORBIDDEN, detail="Account disabled"
) )
return user return user

View File

@ -14,6 +14,7 @@ class Class_(Base):
name: Mapped[str] = mapped_column(String(100), nullable=False) name: Mapped[str] = mapped_column(String(100), nullable=False)
cohort_year: Mapped[int] = mapped_column(Integer, nullable=False) cohort_year: Mapped[int] = mapped_column(Integer, nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True) description: Mapped[str | None] = mapped_column(Text, nullable=True)
invite_code: Mapped[str | None] = mapped_column(String(20), unique=True, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now()) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime, server_default=func.now(), onupdate=func.now() DateTime, server_default=func.now(), onupdate=func.now()
@ -32,6 +33,9 @@ class Class_(Base):
resources: Mapped[list["Resource"]] = relationship( resources: Mapped[list["Resource"]] = relationship(
"Resource", back_populates="class_", cascade="all, delete-orphan" "Resource", back_populates="class_", cascade="all, delete-orphan"
) )
roster: Mapped[list["StudentRoster"]] = relationship(
"StudentRoster", back_populates="class_", cascade="all, delete-orphan"
)
class User(Base): class User(Base):
@ -199,3 +203,24 @@ class Notification(Base):
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now()) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
user: Mapped["User"] = relationship("User") user: Mapped["User"] = relationship("User")
class StudentRoster(Base):
__tablename__ = "student_rosters"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
class_id: Mapped[int] = mapped_column(
Integer, ForeignKey("classes.id"), nullable=False, index=True
)
student_id: Mapped[str] = mapped_column(String(50), nullable=False)
name: Mapped[str] = mapped_column(String(100), nullable=False)
status: Mapped[str] = mapped_column(
String(20), default="unregistered", nullable=False
)
user_id: Mapped[int | None] = mapped_column(
Integer, ForeignKey("users.id"), nullable=True
)
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
class_: Mapped["Class_"] = relationship("Class_", back_populates="roster")
user: Mapped["User | None"] = relationship("User")

View File

@ -7,11 +7,11 @@ class LoginRequest(BaseModel):
class RegisterRequest(BaseModel): class RegisterRequest(BaseModel):
invite_code: str
student_id: str
name: str
email: EmailStr email: EmailStr
password: str password: str
name: str
class_id: int
student_id: str
class ChangePasswordRequest(BaseModel): class ChangePasswordRequest(BaseModel):

View File

@ -20,6 +20,7 @@ class ClassOut(BaseModel):
name: str name: str
cohort_year: int cohort_year: int
description: str | None description: str | None
invite_code: str | None = None
member_count: int = 0 member_count: int = 0
created_at: datetime created_at: datetime

View File

@ -0,0 +1,15 @@
from pydantic import BaseModel
class RosterOut(BaseModel):
id: int
student_id: str
name: str
status: str # "unregistered" | "registered"
user_id: int | None
model_config = {"from_attributes": True}
class RosterImportRequest(BaseModel):
entries: list[dict] # [{"student_id": "...", "name": "..."}, ...]

View File

@ -43,7 +43,6 @@ class UserPublic(BaseModel):
industry: str | None industry: str | None
company: str | None company: str | None
position: str | None position: str | None
skills_tags: list[str] | None
wechat_id: str | None wechat_id: str | None
phone: str | None phone: str | None
avatar_url: str | None avatar_url: str | None
@ -67,11 +66,11 @@ class UserListItem(BaseModel):
class UserUpdate(BaseModel): class UserUpdate(BaseModel):
email: EmailStr | None = None
name: str | None = None name: str | None = None
industry: str | None = None industry: str | None = None
company: str | None = None company: str | None = None
position: str | None = None position: str | None = None
skills_tags: list[str] | None = None
wechat_id: str | None = None wechat_id: str | None = None
phone: str | None = None phone: str | None = None
bio: str | None = None bio: str | None = None

View File

@ -70,7 +70,6 @@ def user_to_public(user: User, include_contact: bool = True) -> UserPublic:
industry=user.industry, industry=user.industry,
company=user.company, company=user.company,
position=user.position, position=user.position,
skills_tags=user.get_skills_list(),
wechat_id=user.wechat_id if include_contact else None, wechat_id=user.wechat_id if include_contact else None,
phone=user.phone if include_contact else None, phone=user.phone if include_contact else None,
avatar_url=user.avatar_url, avatar_url=user.avatar_url,

View File

@ -0,0 +1,125 @@
import secrets
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import StudentRoster, Class_
def generate_invite_code(length: int = 8) -> str:
chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
return "".join(secrets.choice(chars) for _ in range(length))
async def ensure_invite_code(db: AsyncSession, class_id: int) -> str:
result = await db.execute(select(Class_).where(Class_.id == class_id))
class_ = result.scalar_one_or_none()
if class_ is None:
return ""
if not class_.invite_code:
class_.invite_code = generate_invite_code()
await db.commit()
await db.refresh(class_)
return class_.invite_code
async def regenerate_invite_code(db: AsyncSession, class_id: int) -> str:
result = await db.execute(select(Class_).where(Class_.id == class_id))
class_ = result.scalar_one_or_none()
if class_ is None:
return ""
class_.invite_code = generate_invite_code()
await db.commit()
await db.refresh(class_)
return class_.invite_code
async def import_roster(
db: AsyncSession, class_id: int, entries: list[dict]
) -> int:
existing_ids: set[str] = set()
result = await db.execute(
select(StudentRoster.student_id).where(StudentRoster.class_id == class_id)
)
for row in result.all():
existing_ids.add(row[0])
count = 0
for entry in entries:
sid = entry.get("student_id", "").strip()
name = entry.get("name", "").strip()
if not sid or not name or sid in existing_ids:
continue
roster = StudentRoster(class_id=class_id, student_id=sid, name=name)
db.add(roster)
existing_ids.add(sid)
count += 1
await db.commit()
return count
async def get_roster(
db: AsyncSession, class_id: int, page: int = 1, page_size: int = 50
) -> tuple[list[StudentRoster], int]:
query = select(StudentRoster).where(StudentRoster.class_id == class_id)
count_query = select(func.count(StudentRoster.id)).where(
StudentRoster.class_id == class_id
)
total_result = await db.execute(count_query)
total = total_result.scalar() or 0
result = await db.execute(
query.order_by(StudentRoster.student_id)
.offset((page - 1) * page_size)
.limit(page_size)
)
return list(result.scalars().all()), total
async def validate_registration(
db: AsyncSession, invite_code: str, student_id: str
) -> StudentRoster | None:
class_result = await db.execute(
select(Class_).where(Class_.invite_code == invite_code)
)
class_ = class_result.scalar_one_or_none()
if class_ is None:
return None
roster_result = await db.execute(
select(StudentRoster).where(
StudentRoster.class_id == class_.id,
StudentRoster.student_id == student_id,
StudentRoster.status == "unregistered",
)
)
return roster_result.scalar_one_or_none()
async def delete_roster_entry(db: AsyncSession, roster_id: int) -> bool:
result = await db.execute(
select(StudentRoster).where(StudentRoster.id == roster_id)
)
entry = result.scalar_one_or_none()
if entry is None:
return False
if entry.status == "registered":
return False
await db.delete(entry)
await db.commit()
return True
async def clear_unregistered_roster(db: AsyncSession, class_id: int) -> int:
result = await db.execute(
select(StudentRoster).where(
StudentRoster.class_id == class_id,
StudentRoster.status == "unregistered",
)
)
entries = list(result.scalars().all())
for entry in entries:
await db.delete(entry)
await db.commit()
return len(entries)

View File

@ -1,9 +1,8 @@
from sqlalchemy import select, func from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.db.models import User, Class_ from app.db.models import User
from app.schemas.user import UserOut, UserUpdate from app.schemas.user import UserOut, UserUpdate
from app.services.email_service import send_registration_notification
async def get_user_by_email(db: AsyncSession, email: str) -> User | None: async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
@ -16,52 +15,18 @@ async def get_user_by_id(db: AsyncSession, user_id: int) -> User | None:
return result.scalar_one_or_none() return result.scalar_one_or_none()
async def register_user(
db: AsyncSession,
email: str,
password_hash: str,
name: str,
class_id: int,
student_id: str | None = None,
) -> User:
user = User(
email=email,
password_hash=password_hash,
name=name,
student_id=student_id,
role="student",
status="pending",
class_id=class_id,
)
db.add(user)
await db.commit()
await db.refresh(user)
# Notify class admins
admins_result = await db.execute(
select(User).where(
User.class_id == class_id,
User.role.in_(["class_admin", "super_admin"]),
User.status == "approved",
)
)
class_result = await db.execute(select(Class_).where(Class_.id == class_id))
class_ = class_result.scalar_one_or_none()
class_name = class_.name if class_ else "Unknown"
for admin in admins_result.scalars():
await send_registration_notification(admin.email, name, class_name)
return user
async def update_profile(db: AsyncSession, user: User, data: UserUpdate) -> User: async def update_profile(db: AsyncSession, user: User, data: UserUpdate) -> User:
update_data = data.model_dump(exclude_unset=True) update_data = data.model_dump(exclude_unset=True)
if "skills_tags" in update_data and update_data["skills_tags"] is not None:
import json # Handle email change with uniqueness check
user.skills_tags = json.dumps( if "email" in update_data and update_data["email"] != user.email:
update_data.pop("skills_tags"), ensure_ascii=False existing = await db.execute(
select(User).where(User.email == update_data["email"])
) )
if existing.scalar_one_or_none():
raise ValueError("该邮箱已被使用")
user.email = update_data.pop("email")
for field, value in update_data.items(): for field, value in update_data.items():
setattr(user, field, value) setattr(user, field, value)
await db.commit() await db.commit()