from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.db.models import User from app.schemas.user import UserOut, UserUpdate async def get_user_by_email(db: AsyncSession, email: str) -> User | None: result = await db.execute(select(User).where(User.email == email)) return result.scalar_one_or_none() async def get_user_by_id(db: AsyncSession, user_id: int) -> User | None: result = await db.execute(select(User).where(User.id == user_id)) return result.scalar_one_or_none() async def update_profile(db: AsyncSession, user: User, data: UserUpdate) -> User: update_data = data.model_dump(exclude_unset=True) # Handle email change with uniqueness check if "email" in update_data and update_data["email"] != user.email: existing = await db.execute( select(User).where(User.email == update_data["email"]) ) if existing.scalar_one_or_none(): raise ValueError("该邮箱已被使用") user.email = update_data.pop("email") for field, value in update_data.items(): setattr(user, field, value) await db.commit() await db.refresh(user) return user async def update_user_status( db: AsyncSession, user_id: int, status: str, role: str | None = None ) -> User | None: user = await get_user_by_id(db, user_id) if user is None: return None user.status = status if role is not None: user.role = role await db.commit() await db.refresh(user) return user async def list_users( db: AsyncSession, page: int = 1, page_size: int = 20, class_id: int | None = None, status: str | None = None, role: str | None = None, ) -> tuple[list[User], int]: query = select(User) count_query = select(func.count(User.id)) if class_id is not None: query = query.where(User.class_id == class_id) count_query = count_query.where(User.class_id == class_id) if status is not None: query = query.where(User.status == status) count_query = count_query.where(User.status == status) if role is not None: query = query.where(User.role == role) count_query = count_query.where(User.role == role) total_result = await db.execute(count_query) total = total_result.scalar() or 0 query = query.order_by(User.created_at.desc()) query = query.offset((page - 1) * page_size).limit(page_size) result = await db.execute(query) users = list(result.scalars().all()) return users, total