1
This commit is contained in:
parent
85f6b7e42b
commit
31e7a598dc
@ -1,39 +1,57 @@
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.auth import hash_password, verify_password, create_access_token
|
||||
from app.core.deps import get_current_user
|
||||
from app.db.database import get_db
|
||||
from app.db.models import User, Class_
|
||||
from app.db.models import User
|
||||
from app.schemas.auth import LoginRequest, RegisterRequest, ChangePasswordRequest
|
||||
from app.schemas.user import TokenResponse, UserOut
|
||||
from app.services.user_service import register_user
|
||||
from app.services.roster_service import validate_registration
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register")
|
||||
async def register(req: RegisterRequest, db: AsyncSession = Depends(get_db)):
|
||||
# 1. Check if email is already registered
|
||||
existing = await db.execute(select(User).where(User.email == req.email))
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=400, detail="Email already registered")
|
||||
raise HTTPException(status_code=400, detail="该邮箱已注册")
|
||||
|
||||
class_result = await db.execute(select(Class_).where(Class_.id == req.class_id))
|
||||
if class_result.scalar_one_or_none() is None:
|
||||
raise HTTPException(status_code=400, detail="Class not found")
|
||||
# 2. Validate invite_code + student_id against roster
|
||||
roster_entry = await validate_registration(db, req.invite_code, req.student_id)
|
||||
if roster_entry is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="邀请码或学号无效,或该学号已注册"
|
||||
)
|
||||
|
||||
user = await register_user(
|
||||
db=db,
|
||||
# 3. Create user with approved status directly
|
||||
user = User(
|
||||
email=req.email,
|
||||
password_hash=hash_password(req.password),
|
||||
name=req.name,
|
||||
class_id=req.class_id,
|
||||
student_id=req.student_id,
|
||||
role="student",
|
||||
status="approved",
|
||||
class_id=roster_entry.class_id,
|
||||
)
|
||||
return {"message": "Registration submitted. Awaiting admin approval."}
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
|
||||
# 4. Mark roster entry as registered
|
||||
roster_entry.status = "registered"
|
||||
roster_entry.user_id = user.id
|
||||
await db.commit()
|
||||
|
||||
# 5. Issue token — register and login in one step
|
||||
token = create_access_token({"sub": str(user.id), "role": user.role})
|
||||
return {
|
||||
"message": "注册成功",
|
||||
"token": token,
|
||||
"user": UserOut.model_validate(user),
|
||||
}
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
@ -41,13 +59,14 @@ async def login(req: LoginRequest, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(User).where(User.email == req.email))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None or user.status != "approved":
|
||||
raise HTTPException(
|
||||
status_code=401, detail="Invalid credentials or account not approved"
|
||||
)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="邮箱或密码错误")
|
||||
|
||||
if user.status == "disabled":
|
||||
raise HTTPException(status_code=401, detail="账号已被禁用")
|
||||
|
||||
if not verify_password(req.password, user.password_hash):
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
raise HTTPException(status_code=401, detail="邮箱或密码错误")
|
||||
|
||||
token = create_access_token({"sub": str(user.id), "role": user.role})
|
||||
return TokenResponse(
|
||||
|
||||
@ -1,11 +1,15 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
import csv
|
||||
import io
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.deps import get_current_user, require_role
|
||||
from app.core.deps import require_role
|
||||
from app.db.database import get_db
|
||||
from app.db.models import User
|
||||
from app.schemas.class_ import ClassCreate, ClassUpdate, ClassOut
|
||||
from app.schemas.user import UserListItem
|
||||
from app.schemas.roster import RosterOut, RosterImportRequest
|
||||
from app.schemas.common import PageResponse
|
||||
from app.services.class_service import (
|
||||
create_class,
|
||||
@ -16,6 +20,14 @@ from app.services.class_service import (
|
||||
get_member_count,
|
||||
get_class_members,
|
||||
)
|
||||
from app.services.roster_service import (
|
||||
ensure_invite_code,
|
||||
regenerate_invite_code,
|
||||
import_roster,
|
||||
get_roster,
|
||||
delete_roster_entry,
|
||||
clear_unregistered_roster,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/classes", tags=["classes"])
|
||||
|
||||
@ -103,8 +115,11 @@ async def get_members(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{class_id}/pending", response_model=PageResponse[UserListItem])
|
||||
async def get_pending_members(
|
||||
# --- Roster management ---
|
||||
|
||||
|
||||
@router.get("/{class_id}/roster", response_model=PageResponse[RosterOut])
|
||||
async def get_class_roster(
|
||||
class_id: int,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
@ -112,14 +127,144 @@ async def get_pending_members(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if admin.role == "class_admin" and admin.class_id != class_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied for this class")
|
||||
|
||||
members, total = await get_class_members(db, class_id, status="pending", page=page, page_size=page_size)
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
entries, total = await get_roster(db, class_id, page, page_size)
|
||||
total_pages = (total + page_size - 1) // page_size
|
||||
return PageResponse(
|
||||
items=[UserListItem.model_validate(m) for m in members],
|
||||
items=[RosterOut.model_validate(e) for e in entries],
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{class_id}/roster/import")
|
||||
async def import_class_roster(
|
||||
class_id: int,
|
||||
data: RosterImportRequest,
|
||||
admin: User = Depends(require_role("super_admin", "class_admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if admin.role == "class_admin" and admin.class_id != class_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
count = await import_roster(db, class_id, data.entries)
|
||||
return {"message": f"成功导入 {count} 条记录"}
|
||||
|
||||
|
||||
@router.post("/{class_id}/roster/upload")
|
||||
async def upload_roster_file(
|
||||
class_id: int,
|
||||
file: UploadFile = File(...),
|
||||
admin: User = Depends(require_role("super_admin", "class_admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if admin.role == "class_admin" and admin.class_id != class_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
contents = await file.read()
|
||||
filename = file.filename or ""
|
||||
|
||||
entries: list[dict] = []
|
||||
|
||||
if filename.endswith(".csv"):
|
||||
text = contents.decode("utf-8-sig")
|
||||
reader = csv.DictReader(io.StringIO(text))
|
||||
for row in reader:
|
||||
sid = row.get("student_id") or row.get("学号") or ""
|
||||
name = row.get("name") or row.get("姓名") or ""
|
||||
if sid and name:
|
||||
entries.append({"student_id": sid.strip(), "name": name.strip()})
|
||||
elif filename.endswith((".xlsx", ".xls")):
|
||||
try:
|
||||
import openpyxl
|
||||
|
||||
wb = openpyxl.load_workbook(io.BytesIO(contents), read_only=True)
|
||||
ws = wb.active
|
||||
rows = list(ws.iter_rows(values_only=True))
|
||||
if len(rows) < 2:
|
||||
raise HTTPException(status_code=400, detail="Excel 文件为空")
|
||||
header = [str(h).strip() if h else "" for h in rows[0]]
|
||||
# Find student_id and name columns
|
||||
sid_col = None
|
||||
name_col = None
|
||||
for i, h in enumerate(header):
|
||||
if h in ("student_id", "学号"):
|
||||
sid_col = i
|
||||
elif h in ("name", "姓名"):
|
||||
name_col = i
|
||||
if sid_col is None or name_col is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Excel 需包含 '学号'(student_id) 和 '姓名'(name) 列",
|
||||
)
|
||||
for row in rows[1:]:
|
||||
sid = str(row[sid_col]).strip() if row[sid_col] else ""
|
||||
name = str(row[name_col]).strip() if row[name_col] else ""
|
||||
if sid and name and sid != "None":
|
||||
entries.append({"student_id": sid, "name": name})
|
||||
wb.close()
|
||||
except ImportError:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="服务器未安装 openpyxl,请使用 CSV 格式"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="仅支持 CSV 或 Excel (.xlsx) 文件")
|
||||
|
||||
if not entries:
|
||||
raise HTTPException(status_code=400, detail="未找到有效数据")
|
||||
|
||||
count = await import_roster(db, class_id, entries)
|
||||
return {"message": f"成功导入 {count} 条记录"}
|
||||
|
||||
|
||||
@router.delete("/{class_id}/roster/{roster_id}")
|
||||
async def delete_roster_item(
|
||||
class_id: int,
|
||||
roster_id: int,
|
||||
admin: User = Depends(require_role("super_admin", "class_admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
success = await delete_roster_entry(db, roster_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=400, detail="无法删除(已注册或不存在)")
|
||||
return {"message": "已删除"}
|
||||
|
||||
|
||||
@router.post("/{class_id}/roster/clear")
|
||||
async def clear_roster(
|
||||
class_id: int,
|
||||
admin: User = Depends(require_role("super_admin", "class_admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
if admin.role == "class_admin" and admin.class_id != class_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
count = await clear_unregistered_roster(db, class_id)
|
||||
return {"message": f"已清除 {count} 条未注册记录"}
|
||||
|
||||
|
||||
# --- Invite code management ---
|
||||
|
||||
|
||||
@router.get("/{class_id}/invite-code")
|
||||
async def get_invite_code(
|
||||
class_id: int,
|
||||
admin: User = Depends(require_role("super_admin", "class_admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
code = await ensure_invite_code(db, class_id)
|
||||
if not code:
|
||||
raise HTTPException(status_code=404, detail="Class not found")
|
||||
return {"invite_code": code}
|
||||
|
||||
|
||||
@router.post("/{class_id}/invite-code/regenerate")
|
||||
async def regenerate_invite(
|
||||
class_id: int,
|
||||
admin: User = Depends(require_role("super_admin", "class_admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
code = await regenerate_invite_code(db, class_id)
|
||||
if not code:
|
||||
raise HTTPException(status_code=404, detail="Class not found")
|
||||
return {"invite_code": code}
|
||||
|
||||
@ -31,7 +31,10 @@ async def update_my_profile(
|
||||
user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
updated = await update_profile(db, user, data)
|
||||
try:
|
||||
updated = await update_profile(db, user, data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return UserOut.model_validate(updated)
|
||||
|
||||
|
||||
|
||||
@ -35,9 +35,9 @@ async def get_current_user(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found"
|
||||
)
|
||||
if user.status != "approved":
|
||||
if user.status == "disabled":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Account not approved"
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="Account disabled"
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@ -14,6 +14,7 @@ class Class_(Base):
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
cohort_year: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
invite_code: Mapped[str | None] = mapped_column(String(20), unique=True, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, server_default=func.now(), onupdate=func.now()
|
||||
@ -32,6 +33,9 @@ class Class_(Base):
|
||||
resources: Mapped[list["Resource"]] = relationship(
|
||||
"Resource", back_populates="class_", cascade="all, delete-orphan"
|
||||
)
|
||||
roster: Mapped[list["StudentRoster"]] = relationship(
|
||||
"StudentRoster", back_populates="class_", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class User(Base):
|
||||
@ -199,3 +203,24 @@ class Notification(Base):
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
|
||||
user: Mapped["User"] = relationship("User")
|
||||
|
||||
|
||||
class StudentRoster(Base):
|
||||
__tablename__ = "student_rosters"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
class_id: Mapped[int] = mapped_column(
|
||||
Integer, ForeignKey("classes.id"), nullable=False, index=True
|
||||
)
|
||||
student_id: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20), default="unregistered", nullable=False
|
||||
)
|
||||
user_id: Mapped[int | None] = mapped_column(
|
||||
Integer, ForeignKey("users.id"), nullable=True
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now())
|
||||
|
||||
class_: Mapped["Class_"] = relationship("Class_", back_populates="roster")
|
||||
user: Mapped["User | None"] = relationship("User")
|
||||
|
||||
@ -7,11 +7,11 @@ class LoginRequest(BaseModel):
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
invite_code: str
|
||||
student_id: str
|
||||
name: str
|
||||
email: EmailStr
|
||||
password: str
|
||||
name: str
|
||||
class_id: int
|
||||
student_id: str
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
|
||||
@ -20,6 +20,7 @@ class ClassOut(BaseModel):
|
||||
name: str
|
||||
cohort_year: int
|
||||
description: str | None
|
||||
invite_code: str | None = None
|
||||
member_count: int = 0
|
||||
created_at: datetime
|
||||
|
||||
|
||||
15
backend/app/schemas/roster.py
Normal file
15
backend/app/schemas/roster.py
Normal file
@ -0,0 +1,15 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RosterOut(BaseModel):
|
||||
id: int
|
||||
student_id: str
|
||||
name: str
|
||||
status: str # "unregistered" | "registered"
|
||||
user_id: int | None
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class RosterImportRequest(BaseModel):
|
||||
entries: list[dict] # [{"student_id": "...", "name": "..."}, ...]
|
||||
@ -43,7 +43,6 @@ class UserPublic(BaseModel):
|
||||
industry: str | None
|
||||
company: str | None
|
||||
position: str | None
|
||||
skills_tags: list[str] | None
|
||||
wechat_id: str | None
|
||||
phone: str | None
|
||||
avatar_url: str | None
|
||||
@ -67,11 +66,11 @@ class UserListItem(BaseModel):
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
email: EmailStr | None = None
|
||||
name: str | None = None
|
||||
industry: str | None = None
|
||||
company: str | None = None
|
||||
position: str | None = None
|
||||
skills_tags: list[str] | None = None
|
||||
wechat_id: str | None = None
|
||||
phone: str | None = None
|
||||
bio: str | None = None
|
||||
|
||||
@ -70,7 +70,6 @@ def user_to_public(user: User, include_contact: bool = True) -> UserPublic:
|
||||
industry=user.industry,
|
||||
company=user.company,
|
||||
position=user.position,
|
||||
skills_tags=user.get_skills_list(),
|
||||
wechat_id=user.wechat_id if include_contact else None,
|
||||
phone=user.phone if include_contact else None,
|
||||
avatar_url=user.avatar_url,
|
||||
|
||||
125
backend/app/services/roster_service.py
Normal file
125
backend/app/services/roster_service.py
Normal file
@ -0,0 +1,125 @@
|
||||
import secrets
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import StudentRoster, Class_
|
||||
|
||||
|
||||
def generate_invite_code(length: int = 8) -> str:
|
||||
chars = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
||||
return "".join(secrets.choice(chars) for _ in range(length))
|
||||
|
||||
|
||||
async def ensure_invite_code(db: AsyncSession, class_id: int) -> str:
|
||||
result = await db.execute(select(Class_).where(Class_.id == class_id))
|
||||
class_ = result.scalar_one_or_none()
|
||||
if class_ is None:
|
||||
return ""
|
||||
if not class_.invite_code:
|
||||
class_.invite_code = generate_invite_code()
|
||||
await db.commit()
|
||||
await db.refresh(class_)
|
||||
return class_.invite_code
|
||||
|
||||
|
||||
async def regenerate_invite_code(db: AsyncSession, class_id: int) -> str:
|
||||
result = await db.execute(select(Class_).where(Class_.id == class_id))
|
||||
class_ = result.scalar_one_or_none()
|
||||
if class_ is None:
|
||||
return ""
|
||||
class_.invite_code = generate_invite_code()
|
||||
await db.commit()
|
||||
await db.refresh(class_)
|
||||
return class_.invite_code
|
||||
|
||||
|
||||
async def import_roster(
|
||||
db: AsyncSession, class_id: int, entries: list[dict]
|
||||
) -> int:
|
||||
existing_ids: set[str] = set()
|
||||
result = await db.execute(
|
||||
select(StudentRoster.student_id).where(StudentRoster.class_id == class_id)
|
||||
)
|
||||
for row in result.all():
|
||||
existing_ids.add(row[0])
|
||||
|
||||
count = 0
|
||||
for entry in entries:
|
||||
sid = entry.get("student_id", "").strip()
|
||||
name = entry.get("name", "").strip()
|
||||
if not sid or not name or sid in existing_ids:
|
||||
continue
|
||||
roster = StudentRoster(class_id=class_id, student_id=sid, name=name)
|
||||
db.add(roster)
|
||||
existing_ids.add(sid)
|
||||
count += 1
|
||||
await db.commit()
|
||||
return count
|
||||
|
||||
|
||||
async def get_roster(
|
||||
db: AsyncSession, class_id: int, page: int = 1, page_size: int = 50
|
||||
) -> tuple[list[StudentRoster], int]:
|
||||
query = select(StudentRoster).where(StudentRoster.class_id == class_id)
|
||||
count_query = select(func.count(StudentRoster.id)).where(
|
||||
StudentRoster.class_id == class_id
|
||||
)
|
||||
|
||||
total_result = await db.execute(count_query)
|
||||
total = total_result.scalar() or 0
|
||||
|
||||
result = await db.execute(
|
||||
query.order_by(StudentRoster.student_id)
|
||||
.offset((page - 1) * page_size)
|
||||
.limit(page_size)
|
||||
)
|
||||
return list(result.scalars().all()), total
|
||||
|
||||
|
||||
async def validate_registration(
|
||||
db: AsyncSession, invite_code: str, student_id: str
|
||||
) -> StudentRoster | None:
|
||||
class_result = await db.execute(
|
||||
select(Class_).where(Class_.invite_code == invite_code)
|
||||
)
|
||||
class_ = class_result.scalar_one_or_none()
|
||||
if class_ is None:
|
||||
return None
|
||||
|
||||
roster_result = await db.execute(
|
||||
select(StudentRoster).where(
|
||||
StudentRoster.class_id == class_.id,
|
||||
StudentRoster.student_id == student_id,
|
||||
StudentRoster.status == "unregistered",
|
||||
)
|
||||
)
|
||||
return roster_result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def delete_roster_entry(db: AsyncSession, roster_id: int) -> bool:
|
||||
result = await db.execute(
|
||||
select(StudentRoster).where(StudentRoster.id == roster_id)
|
||||
)
|
||||
entry = result.scalar_one_or_none()
|
||||
if entry is None:
|
||||
return False
|
||||
if entry.status == "registered":
|
||||
return False
|
||||
await db.delete(entry)
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
|
||||
async def clear_unregistered_roster(db: AsyncSession, class_id: int) -> int:
|
||||
result = await db.execute(
|
||||
select(StudentRoster).where(
|
||||
StudentRoster.class_id == class_id,
|
||||
StudentRoster.status == "unregistered",
|
||||
)
|
||||
)
|
||||
entries = list(result.scalars().all())
|
||||
for entry in entries:
|
||||
await db.delete(entry)
|
||||
await db.commit()
|
||||
return len(entries)
|
||||
@ -1,9 +1,8 @@
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db.models import User, Class_
|
||||
from app.db.models import User
|
||||
from app.schemas.user import UserOut, UserUpdate
|
||||
from app.services.email_service import send_registration_notification
|
||||
|
||||
|
||||
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
|
||||
@ -16,52 +15,18 @@ async def get_user_by_id(db: AsyncSession, user_id: int) -> User | None:
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def register_user(
|
||||
db: AsyncSession,
|
||||
email: str,
|
||||
password_hash: str,
|
||||
name: str,
|
||||
class_id: int,
|
||||
student_id: str | None = None,
|
||||
) -> User:
|
||||
user = User(
|
||||
email=email,
|
||||
password_hash=password_hash,
|
||||
name=name,
|
||||
student_id=student_id,
|
||||
role="student",
|
||||
status="pending",
|
||||
class_id=class_id,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
# Notify class admins
|
||||
admins_result = await db.execute(
|
||||
select(User).where(
|
||||
User.class_id == class_id,
|
||||
User.role.in_(["class_admin", "super_admin"]),
|
||||
User.status == "approved",
|
||||
)
|
||||
)
|
||||
class_result = await db.execute(select(Class_).where(Class_.id == class_id))
|
||||
class_ = class_result.scalar_one_or_none()
|
||||
class_name = class_.name if class_ else "Unknown"
|
||||
|
||||
for admin in admins_result.scalars():
|
||||
await send_registration_notification(admin.email, name, class_name)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def update_profile(db: AsyncSession, user: User, data: UserUpdate) -> User:
|
||||
update_data = data.model_dump(exclude_unset=True)
|
||||
if "skills_tags" in update_data and update_data["skills_tags"] is not None:
|
||||
import json
|
||||
user.skills_tags = json.dumps(
|
||||
update_data.pop("skills_tags"), ensure_ascii=False
|
||||
|
||||
# Handle email change with uniqueness check
|
||||
if "email" in update_data and update_data["email"] != user.email:
|
||||
existing = await db.execute(
|
||||
select(User).where(User.email == update_data["email"])
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
raise ValueError("该邮箱已被使用")
|
||||
user.email = update_data.pop("email")
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(user, field, value)
|
||||
await db.commit()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user