1
This commit is contained in:
parent
85f6b7e42b
commit
31e7a598dc
@ -1,39 +1,57 @@
|
|||||||
import json
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.auth import hash_password, verify_password, create_access_token
|
from app.core.auth import hash_password, verify_password, create_access_token
|
||||||
from app.core.deps import get_current_user
|
from app.core.deps import get_current_user
|
||||||
from app.db.database import get_db
|
from app.db.database import get_db
|
||||||
from app.db.models import User, Class_
|
from app.db.models import User
|
||||||
from app.schemas.auth import LoginRequest, RegisterRequest, ChangePasswordRequest
|
from app.schemas.auth import LoginRequest, RegisterRequest, ChangePasswordRequest
|
||||||
from app.schemas.user import TokenResponse, UserOut
|
from app.schemas.user import TokenResponse, UserOut
|
||||||
from app.services.user_service import register_user
|
from app.services.roster_service import validate_registration
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register")
|
@router.post("/register")
|
||||||
async def register(req: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
async def register(req: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||||
|
# 1. Check if email is already registered
|
||||||
existing = await db.execute(select(User).where(User.email == req.email))
|
existing = await db.execute(select(User).where(User.email == req.email))
|
||||||
if existing.scalar_one_or_none():
|
if existing.scalar_one_or_none():
|
||||||
raise HTTPException(status_code=400, detail="Email already registered")
|
raise HTTPException(status_code=400, detail="该邮箱已注册")
|
||||||
|
|
||||||
class_result = await db.execute(select(Class_).where(Class_.id == req.class_id))
|
# 2. Validate invite_code + student_id against roster
|
||||||
if class_result.scalar_one_or_none() is None:
|
roster_entry = await validate_registration(db, req.invite_code, req.student_id)
|
||||||
raise HTTPException(status_code=400, detail="Class not found")
|
if roster_entry is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="邀请码或学号无效,或该学号已注册"
|
||||||
|
)
|
||||||
|
|
||||||
user = await register_user(
|
# 3. Create user with approved status directly
|
||||||
db=db,
|
user = User(
|
||||||
email=req.email,
|
email=req.email,
|
||||||
password_hash=hash_password(req.password),
|
password_hash=hash_password(req.password),
|
||||||
name=req.name,
|
name=req.name,
|
||||||
class_id=req.class_id,
|
|
||||||
student_id=req.student_id,
|
student_id=req.student_id,
|
||||||
|
role="student",
|
||||||
|
status="approved",
|
||||||
|
class_id=roster_entry.class_id,
|
||||||
)
|
)
|
||||||
return {"message": "Registration submitted. Awaiting admin approval."}
|
db.add(user)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
# 4. Mark roster entry as registered
|
||||||
|
roster_entry.status = "registered"
|
||||||
|
roster_entry.user_id = user.id
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# 5. Issue token — register and login in one step
|
||||||
|
token = create_access_token({"sub": str(user.id), "role": user.role})
|
||||||
|
return {
|
||||||
|
"message": "注册成功",
|
||||||
|
"token": token,
|
||||||
|
"user": UserOut.model_validate(user),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=TokenResponse)
|
@router.post("/login", response_model=TokenResponse)
|
||||||
@ -41,13 +59,14 @@ async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
|
|||||||
result = await db.execute(select(User).where(User.email == req.email))
|
result = await db.execute(select(User).where(User.email == req.email))
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
if user is None or user.status != "approved":
|
if user is None:
|
||||||
raise HTTPException(
|
raise HTTPException(status_code=401, detail="邮箱或密码错误")
|
||||||
status_code=401, detail="Invalid credentials or account not approved"
|
|
||||||
)
|
if user.status == "disabled":
|
||||||
|
raise HTTPException(status_code=401, detail="账号已被禁用")
|
||||||
|
|
||||||
if not verify_password(req.password, user.password_hash):
|
if not verify_password(req.password, user.password_hash):
|
||||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
raise HTTPException(status_code=401, detail="邮箱或密码错误")
|
||||||
|
|
||||||
token = create_access_token({"sub": str(user.id), "role": user.role})
|
token = create_access_token({"sub": str(user.id), "role": user.role})
|
||||||
return TokenResponse(
|
return TokenResponse(
|
||||||
|
|||||||
@ -1,11 +1,15 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
import csv
|
||||||
|
import io
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.core.deps import get_current_user, require_role
|
from app.core.deps import require_role
|
||||||
from app.db.database import get_db
|
from app.db.database import get_db
|
||||||
from app.db.models import User
|
from app.db.models import User
|
||||||
from app.schemas.class_ import ClassCreate, ClassUpdate, ClassOut
|
from app.schemas.class_ import ClassCreate, ClassUpdate, ClassOut
|
||||||
from app.schemas.user import UserListItem
|
from app.schemas.user import UserListItem
|
||||||
|
from app.schemas.roster import RosterOut, RosterImportRequest
|
||||||
from app.schemas.common import PageResponse
|
from app.schemas.common import PageResponse
|
||||||
from app.services.class_service import (
|
from app.services.class_service import (
|
||||||
create_class,
|
create_class,
|
||||||
@ -16,6 +20,14 @@ from app.services.class_service import (
|
|||||||
get_member_count,
|
get_member_count,
|
||||||
get_class_members,
|
get_class_members,
|
||||||
)
|
)
|
||||||
|
from app.services.roster_service import (
|
||||||
|
ensure_invite_code,
|
||||||
|
regenerate_invite_code,
|
||||||
|
import_roster,
|
||||||
|
get_roster,
|
||||||
|
delete_roster_entry,
|
||||||
|
clear_unregistered_roster,
|
||||||
|
)
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/classes", tags=["classes"])
|
router = APIRouter(prefix="/api/classes", tags=["classes"])
|
||||||
|
|
||||||
@ -103,8 +115,11 @@ async def get_members(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{class_id}/pending", response_model=PageResponse[UserListItem])
|
# --- Roster management ---
|
||||||
async def get_pending_members(
|
|
||||||
|
|
||||||
|
@router.get("/{class_id}/roster", response_model=PageResponse[RosterOut])
|
||||||
|
async def get_class_roster(
|
||||||
class_id: int,
|
class_id: int,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
page_size: int = 50,
|
page_size: int = 50,
|
||||||
@ -112,14 +127,144 @@ async def get_pending_members(
|
|||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
if admin.role == "class_admin" and admin.class_id != class_id:
|
if admin.role == "class_admin" and admin.class_id != class_id:
|
||||||
raise HTTPException(status_code=403, detail="Access denied for this class")
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
entries, total = await get_roster(db, class_id, page, page_size)
|
||||||
members, total = await get_class_members(db, class_id, status="pending", page=page, page_size=page_size)
|
|
||||||
total_pages = (total + page_size - 1) // page_size
|
total_pages = (total + page_size - 1) // page_size
|
||||||
return PageResponse(
|
return PageResponse(
|
||||||
items=[UserListItem.model_validate(m) for m in members],
|
items=[RosterOut.model_validate(e) for e in entries],
|
||||||
total=total,
|
total=total,
|
||||||
page=page,
|
page=page,
|
||||||
page_size=page_size,
|
page_size=page_size,
|
||||||
total_pages=total_pages,
|
total_pages=total_pages,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{class_id}/roster/import")
|
||||||
|
async def import_class_roster(
|
||||||
|
class_id: int,
|
||||||
|
data: RosterImportRequest,
|
||||||
|
admin: User = Depends(require_role("super_admin", "class_admin")),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
if admin.role == "class_admin" and admin.class_id != class_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
count = await import_roster(db, class_id, data.entries)
|
||||||
|
return {"message": f"成功导入 {count} 条记录"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{class_id}/roster/upload")
|
||||||
|
async def upload_roster_file(
|
||||||
|
class_id: int,
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
admin: User = Depends(require_role("super_admin", "class_admin")),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
if admin.role == "class_admin" and admin.class_id != class_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
|
contents = await file.read()
|
||||||
|
filename = file.filename or ""
|
||||||
|
|
||||||
|
entries: list[dict] = []
|
||||||
|
|
||||||
|
if filename.endswith(".csv"):
|
||||||
|
text = contents.decode("utf-8-sig")
|
||||||
|
reader = csv.DictReader(io.StringIO(text))
|
||||||
|
for row in reader:
|
||||||
|
sid = row.get("student_id") or row.get("学号") or ""
|
||||||
|
name = row.get("name") or row.get("姓名") or ""
|
||||||
|
if sid and name:
|
||||||
|
entries.append({"student_id": sid.strip(), "name": name.strip()})
|
||||||
|
elif filename.endswith((".xlsx", ".xls")):
|
||||||
|
try:
|
||||||
|
import openpyxl
|
||||||
|
|
||||||
|
wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True)
|
||||||
|
ws = wb.active
|
||||||
|
rows = list(ws.iter_rows(values_only=True))
|
||||||
|
if len(rows) < 2:
|
||||||
|
raise HTTPException(status_code=400, detail="Excel 文件为空")
|
||||||
|
header = [str(h).strip() if h else "" for h in rows[0]]
|
||||||
|
# Find student_id and name columns
|
||||||
|
sid_col = None
|
||||||
|
name_col = None
|
||||||
|
for i, h in enumerate(header):
|
||||||
|
if h in ("student_id", "学号"):
|
||||||
|
sid_col = i
|
||||||
|
elif h in ("name", "姓名"):
|
||||||
|
name_col = i
|
||||||
|
if sid_col is None or name_col is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Excel 需包含 '学号'(student_id) 和 '姓名'(name) 列",
|
||||||
|
)
|
||||||
|
for row in rows[1:]:
|
||||||
|
sid = str(row[sid_col]).strip() if row[sid_col] else ""
|
||||||
|
name = str(row[name_col]).strip() if row[name_col] else ""
|
||||||
|
if sid and name and sid != "None":
|
||||||
|
entries.append({"student_id": sid, "name": name})
|
||||||
|
wb.close()
|
||||||
|
except ImportError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="服务器未安装 openpyxl,请使用 CSV 格式"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=400, detail="仅支持 CSV 或 Excel (.xlsx) 文件")
|
||||||
|
|
||||||
|
if not entries:
|
||||||
|
raise HTTPException(status_code=400, detail="未找到有效数据")
|
||||||
|
|
||||||
|
count = await import_roster(db, class_id, entries)
|
||||||
|
return {"message": f"成功导入 {count} 条记录"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{class_id}/roster/{roster_id}")
|
||||||
|
async def delete_roster_item(
|
||||||
|
class_id: int,
|
||||||
|
roster_id: int,
|
||||||
|
admin: User = Depends(require_role("super_admin", "class_admin")),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
success = await delete_roster_entry(db, roster_id)
|
||||||
|
if not success:
|
||||||
|
raise HTTPException(status_code=400, detail="无法删除(已注册或不存在)")
|
||||||
|
return {"message": "已删除"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{class_id}/roster/clear")
|
||||||
|
async def clear_roster(
|
||||||
|
class_id: int,
|
||||||
|
admin: User = Depends(require_role("super_admin", "class_admin")),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
if admin.role == "class_admin" and admin.class_id != class_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
count = await clear_unregistered_roster(db, class_id)
|
||||||
|
return {"message": f"已清除 {count} 条未注册记录"}
|
||||||
|
|
||||||
|
|
||||||
|
# --- Invite code management ---
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{class_id}/invite-code")
|
||||||
|
async def get_invite_code(
|
||||||
|
class_id: int,
|
||||||
|
admin: User = Depends(require_role("super_admin", "class_admin")),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
code = await ensure_invite_code(db, class_id)
|
||||||
|
if not code:
|
||||||
|
raise HTTPException(status_code=404, detail="Class not found")
|
||||||
|
return {"invite_code": code}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{class_id}/invite-code/regenerate")
|
||||||
|
async def regenerate_invite(
|
||||||
|
class_id: int,
|
||||||
|
admin: User = Depends(require_role("super_admin", "class_admin")),
|
||||||
|
db: AsyncSession = Depends(get_db),
|
||||||
|
):
|
||||||
|
code = await regenerate_invite_code(db, class_id)
|
||||||
|
if not code:
|
||||||
|
raise HTTPException(status_code=404, detail="Class not found")
|
||||||
|
return {"invite_code": code}
|
||||||
|
|||||||
@ -31,7 +31,10 @@ async def update_my_profile(
|
|||||||
user: User = Depends(get_current_user),
|
user: User = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
):
|
):
|
||||||
updated = await update_profile(db, user, data)
|
try:
|
||||||
|
updated = await update_profile(db, user, data)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
return UserOut.model_validate(updated)
|
return UserOut.model_validate(updated)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -35,9 +35,9 @@ async def get_current_user(
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found"
|
status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found"
|
||||||
)
|
)
|
||||||
if user.status != "approved":
|
if user.status == "disabled":
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail="Account not approved"
|
status_code=status.HTTP_403_FORBIDDEN, detail="Account disabled"
|
||||||
)
|
)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|||||||
@ -14,6 +14,7 @@ class Class_(Base):
|
|||||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
cohort_year: Mapped[int] = mapped_column(Integer, nullable=False)
|
cohort_year: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
invite_code: Mapped[str | None] = mapped_column(String(20), unique=True, nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime, server_default=func.now(), onupdate=func.now()
|
DateTime, server_default=func.now(), onupdate=func.now()
|
||||||
@ -32,6 +33,9 @@ class Class_(Base):
|
|||||||
resources: Mapped[list["Resource"]] = relationship(
|
resources: Mapped[list["Resource"]] = relationship(
|
||||||
"Resource", back_populates="class_", cascade="all, delete-orphan"
|
"Resource", back_populates="class_", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
roster: Mapped[list["StudentRoster"]] = relationship(
|
||||||
|
"StudentRoster", back_populates="class_", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
@ -199,3 +203,24 @@ class Notification(Base):
|
|||||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||||
|
|
||||||
user: Mapped["User"] = relationship("User")
|
user: Mapped["User"] = relationship("User")
|
||||||
|
|
||||||
|
|
||||||
|
class StudentRoster(Base):
|
||||||
|
__tablename__ = "student_rosters"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
class_id: Mapped[int] = mapped_column(
|
||||||
|
Integer, ForeignKey("classes.id"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
student_id: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
|
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
status: Mapped[str] = mapped_column(
|
||||||
|
String(20), default="unregistered", nullable=False
|
||||||
|
)
|
||||||
|
user_id: Mapped[int | None] = mapped_column(
|
||||||
|
Integer, ForeignKey("users.id"), nullable=True
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||||
|
|
||||||
|
class_: Mapped["Class_"] = relationship("Class_", back_populates="roster")
|
||||||
|
user: Mapped["User | None"] = relationship("User")
|
||||||
|
|||||||
@ -7,11 +7,11 @@ class LoginRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class RegisterRequest(BaseModel):
|
class RegisterRequest(BaseModel):
|
||||||
|
invite_code: str
|
||||||
|
student_id: str
|
||||||
|
name: str
|
||||||
email: EmailStr
|
email: EmailStr
|
||||||
password: str
|
password: str
|
||||||
name: str
|
|
||||||
class_id: int
|
|
||||||
student_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class ChangePasswordRequest(BaseModel):
|
class ChangePasswordRequest(BaseModel):
|
||||||
|
|||||||
@ -20,6 +20,7 @@ class ClassOut(BaseModel):
|
|||||||
name: str
|
name: str
|
||||||
cohort_year: int
|
cohort_year: int
|
||||||
description: str | None
|
description: str | None
|
||||||
|
invite_code: str | None = None
|
||||||
member_count: int = 0
|
member_count: int = 0
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
|
||||||
|
|||||||
15
backend/app/schemas/roster.py
Normal file
15
backend/app/schemas/roster.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class RosterOut(BaseModel):
|
||||||
|
id: int
|
||||||
|
student_id: str
|
||||||
|
name: str
|
||||||
|
status: str # "unregistered" | "registered"
|
||||||
|
user_id: int | None
|
||||||
|
|
||||||
|
model_config = {"from_attributes": True}
|
||||||
|
|
||||||
|
|
||||||
|
class RosterImportRequest(BaseModel):
|
||||||
|
entries: list[dict] # [{"student_id": "...", "name": "..."}, ...]
|
||||||
@ -43,7 +43,6 @@ class UserPublic(BaseModel):
|
|||||||
industry: str | None
|
industry: str | None
|
||||||
company: str | None
|
company: str | None
|
||||||
position: str | None
|
position: str | None
|
||||||
skills_tags: list[str] | None
|
|
||||||
wechat_id: str | None
|
wechat_id: str | None
|
||||||
phone: str | None
|
phone: str | None
|
||||||
avatar_url: str | None
|
avatar_url: str | None
|
||||||
@ -67,11 +66,11 @@ class UserListItem(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class UserUpdate(BaseModel):
|
class UserUpdate(BaseModel):
|
||||||
|
email: EmailStr | None = None
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
industry: str | None = None
|
industry: str | None = None
|
||||||
company: str | None = None
|
company: str | None = None
|
||||||
position: str | None = None
|
position: str | None = None
|
||||||
skills_tags: list[str] | None = None
|
|
||||||
wechat_id: str | None = None
|
wechat_id: str | None = None
|
||||||
phone: str | None = None
|
phone: str | None = None
|
||||||
bio: str | None = None
|
bio: str | None = None
|
||||||
|
|||||||
@ -70,7 +70,6 @@ def user_to_public(user: User, include_contact: bool = True) -> UserPublic:
|
|||||||
industry=user.industry,
|
industry=user.industry,
|
||||||
company=user.company,
|
company=user.company,
|
||||||
position=user.position,
|
position=user.position,
|
||||||
skills_tags=user.get_skills_list(),
|
|
||||||
wechat_id=user.wechat_id if include_contact else None,
|
wechat_id=user.wechat_id if include_contact else None,
|
||||||
phone=user.phone if include_contact else None,
|
phone=user.phone if include_contact else None,
|
||||||
avatar_url=user.avatar_url,
|
avatar_url=user.avatar_url,
|
||||||
|
|||||||
125
backend/app/services/roster_service.py
Normal file
125
backend/app/services/roster_service.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
import secrets
|
||||||
|
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.db.models import StudentRoster, Class_
|
||||||
|
|
||||||
|
|
||||||
|
def generate_invite_code(length: int = 8) -> str:
|
||||||
|
chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
||||||
|
return "".join(secrets.choice(chars) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_invite_code(db: AsyncSession, class_id: int) -> str:
|
||||||
|
result = await db.execute(select(Class_).where(Class_.id == class_id))
|
||||||
|
class_ = result.scalar_one_or_none()
|
||||||
|
if class_ is None:
|
||||||
|
return ""
|
||||||
|
if not class_.invite_code:
|
||||||
|
class_.invite_code = generate_invite_code()
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(class_)
|
||||||
|
return class_.invite_code
|
||||||
|
|
||||||
|
|
||||||
|
async def regenerate_invite_code(db: AsyncSession, class_id: int) -> str:
|
||||||
|
result = await db.execute(select(Class_).where(Class_.id == class_id))
|
||||||
|
class_ = result.scalar_one_or_none()
|
||||||
|
if class_ is None:
|
||||||
|
return ""
|
||||||
|
class_.invite_code = generate_invite_code()
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(class_)
|
||||||
|
return class_.invite_code
|
||||||
|
|
||||||
|
|
||||||
|
async def import_roster(
|
||||||
|
db: AsyncSession, class_id: int, entries: list[dict]
|
||||||
|
) -> int:
|
||||||
|
existing_ids: set[str] = set()
|
||||||
|
result = await db.execute(
|
||||||
|
select(StudentRoster.student_id).where(StudentRoster.class_id == class_id)
|
||||||
|
)
|
||||||
|
for row in result.all():
|
||||||
|
existing_ids.add(row[0])
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for entry in entries:
|
||||||
|
sid = entry.get("student_id", "").strip()
|
||||||
|
name = entry.get("name", "").strip()
|
||||||
|
if not sid or not name or sid in existing_ids:
|
||||||
|
continue
|
||||||
|
roster = StudentRoster(class_id=class_id, student_id=sid, name=name)
|
||||||
|
db.add(roster)
|
||||||
|
existing_ids.add(sid)
|
||||||
|
count += 1
|
||||||
|
await db.commit()
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
async def get_roster(
|
||||||
|
db: AsyncSession, class_id: int, page: int = 1, page_size: int = 50
|
||||||
|
) -> tuple[list[StudentRoster], int]:
|
||||||
|
query = select(StudentRoster).where(StudentRoster.class_id == class_id)
|
||||||
|
count_query = select(func.count(StudentRoster.id)).where(
|
||||||
|
StudentRoster.class_id == class_id
|
||||||
|
)
|
||||||
|
|
||||||
|
total_result = await db.execute(count_query)
|
||||||
|
total = total_result.scalar() or 0
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
query.order_by(StudentRoster.student_id)
|
||||||
|
.offset((page - 1) * page_size)
|
||||||
|
.limit(page_size)
|
||||||
|
)
|
||||||
|
return list(result.scalars().all()), total
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_registration(
|
||||||
|
db: AsyncSession, invite_code: str, student_id: str
|
||||||
|
) -> StudentRoster | None:
|
||||||
|
class_result = await db.execute(
|
||||||
|
select(Class_).where(Class_.invite_code == invite_code)
|
||||||
|
)
|
||||||
|
class_ = class_result.scalar_one_or_none()
|
||||||
|
if class_ is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
roster_result = await db.execute(
|
||||||
|
select(StudentRoster).where(
|
||||||
|
StudentRoster.class_id == class_.id,
|
||||||
|
StudentRoster.student_id == student_id,
|
||||||
|
StudentRoster.status == "unregistered",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return roster_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_roster_entry(db: AsyncSession, roster_id: int) -> bool:
|
||||||
|
result = await db.execute(
|
||||||
|
select(StudentRoster).where(StudentRoster.id == roster_id)
|
||||||
|
)
|
||||||
|
entry = result.scalar_one_or_none()
|
||||||
|
if entry is None:
|
||||||
|
return False
|
||||||
|
if entry.status == "registered":
|
||||||
|
return False
|
||||||
|
await db.delete(entry)
|
||||||
|
await db.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def clear_unregistered_roster(db: AsyncSession, class_id: int) -> int:
|
||||||
|
result = await db.execute(
|
||||||
|
select(StudentRoster).where(
|
||||||
|
StudentRoster.class_id == class_id,
|
||||||
|
StudentRoster.status == "unregistered",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
entries = list(result.scalars().all())
|
||||||
|
for entry in entries:
|
||||||
|
await db.delete(entry)
|
||||||
|
await db.commit()
|
||||||
|
return len(entries)
|
||||||
@ -1,9 +1,8 @@
|
|||||||
from sqlalchemy import select, func
|
from sqlalchemy import select, func
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.db.models import User, Class_
|
from app.db.models import User
|
||||||
from app.schemas.user import UserOut, UserUpdate
|
from app.schemas.user import UserOut, UserUpdate
|
||||||
from app.services.email_service import send_registration_notification
|
|
||||||
|
|
||||||
|
|
||||||
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
|
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
|
||||||
@ -16,52 +15,18 @@ async def get_user_by_id(db: AsyncSession, user_id: int) -> User | None:
|
|||||||
return result.scalar_one_or_none()
|
return result.scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
async def register_user(
|
|
||||||
db: AsyncSession,
|
|
||||||
email: str,
|
|
||||||
password_hash: str,
|
|
||||||
name: str,
|
|
||||||
class_id: int,
|
|
||||||
student_id: str | None = None,
|
|
||||||
) -> User:
|
|
||||||
user = User(
|
|
||||||
email=email,
|
|
||||||
password_hash=password_hash,
|
|
||||||
name=name,
|
|
||||||
student_id=student_id,
|
|
||||||
role="student",
|
|
||||||
status="pending",
|
|
||||||
class_id=class_id,
|
|
||||||
)
|
|
||||||
db.add(user)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(user)
|
|
||||||
|
|
||||||
# Notify class admins
|
|
||||||
admins_result = await db.execute(
|
|
||||||
select(User).where(
|
|
||||||
User.class_id == class_id,
|
|
||||||
User.role.in_(["class_admin", "super_admin"]),
|
|
||||||
User.status == "approved",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
class_result = await db.execute(select(Class_).where(Class_.id == class_id))
|
|
||||||
class_ = class_result.scalar_one_or_none()
|
|
||||||
class_name = class_.name if class_ else "Unknown"
|
|
||||||
|
|
||||||
for admin in admins_result.scalars():
|
|
||||||
await send_registration_notification(admin.email, name, class_name)
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
async def update_profile(db: AsyncSession, user: User, data: UserUpdate) -> User:
|
async def update_profile(db: AsyncSession, user: User, data: UserUpdate) -> User:
|
||||||
update_data = data.model_dump(exclude_unset=True)
|
update_data = data.model_dump(exclude_unset=True)
|
||||||
if "skills_tags" in update_data and update_data["skills_tags"] is not None:
|
|
||||||
import json
|
# Handle email change with uniqueness check
|
||||||
user.skills_tags = json.dumps(
|
if "email" in update_data and update_data["email"] != user.email:
|
||||||
update_data.pop("skills_tags"), ensure_ascii=False
|
existing = await db.execute(
|
||||||
|
select(User).where(User.email == update_data["email"])
|
||||||
)
|
)
|
||||||
|
if existing.scalar_one_or_none():
|
||||||
|
raise ValueError("该邮箱已被使用")
|
||||||
|
user.email = update_data.pop("email")
|
||||||
|
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(user, field, value)
|
setattr(user, field, value)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user