81 lines
2.5 KiB
Python
81 lines
2.5 KiB
Python
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.db.models import User
|
|
from app.schemas.user import UserOut, UserUpdate
|
|
|
|
|
|
async def get_user_by_email(db: AsyncSession, email: str) -> User | None:
|
|
result = await db.execute(select(User).where(User.email == email))
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_user_by_id(db: AsyncSession, user_id: int) -> User | None:
|
|
result = await db.execute(select(User).where(User.id == user_id))
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def update_profile(db: AsyncSession, user: User, data: UserUpdate) -> User:
|
|
update_data = data.model_dump(exclude_unset=True)
|
|
|
|
# Handle email change with uniqueness check
|
|
if "email" in update_data and update_data["email"] != user.email:
|
|
existing = await db.execute(
|
|
select(User).where(User.email == update_data["email"])
|
|
)
|
|
if existing.scalar_one_or_none():
|
|
raise ValueError("该邮箱已被使用")
|
|
user.email = update_data.pop("email")
|
|
|
|
for field, value in update_data.items():
|
|
setattr(user, field, value)
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
return user
|
|
|
|
|
|
async def update_user_status(
|
|
db: AsyncSession, user_id: int, status: str, role: str | None = None
|
|
) -> User | None:
|
|
user = await get_user_by_id(db, user_id)
|
|
if user is None:
|
|
return None
|
|
user.status = status
|
|
if role is not None:
|
|
user.role = role
|
|
await db.commit()
|
|
await db.refresh(user)
|
|
return user
|
|
|
|
|
|
async def list_users(
|
|
db: AsyncSession,
|
|
page: int = 1,
|
|
page_size: int = 20,
|
|
class_id: int | None = None,
|
|
status: str | None = None,
|
|
role: str | None = None,
|
|
) -> tuple[list[User], int]:
|
|
query = select(User)
|
|
count_query = select(func.count(User.id))
|
|
|
|
if class_id is not None:
|
|
query = query.where(User.class_id == class_id)
|
|
count_query = count_query.where(User.class_id == class_id)
|
|
if status is not None:
|
|
query = query.where(User.status == status)
|
|
count_query = count_query.where(User.status == status)
|
|
if role is not None:
|
|
query = query.where(User.role == role)
|
|
count_query = count_query.where(User.role == role)
|
|
|
|
total_result = await db.execute(count_query)
|
|
total = total_result.scalar() or 0
|
|
|
|
query = query.order_by(User.created_at.desc())
|
|
query = query.offset((page - 1) * page_size).limit(page_size)
|
|
result = await db.execute(query)
|
|
users = list(result.scalars().all())
|
|
|
|
return users, total
|