from fastapi import APIRouter, Depends from sqlalchemy.orm import Session from typing import List, Optional from app.models.message import MessageDB, MessageCreate, MessageInfo from app.models.database import get_db from app.api.deps import get_current_user, get_admin_user from app.models.user import UserDB from app.core.response import success_response, error_response, ResponseModel router = APIRouter() @router.post("", response_model=ResponseModel) async def create_message( message: MessageCreate, db: Session = Depends(get_db), admin: UserDB = Depends(get_admin_user) # 仅管理员可创建消息 ): """创建消息""" db_message = MessageDB(**message.model_dump()) try: db.add(db_message) db.commit() db.refresh(db_message) return success_response( message="消息创建成功", data=MessageInfo.model_validate(db_message) ) except Exception as e: db.rollback() return error_response(code=500, message=f"创建消息失败: {str(e)}") @router.post("/{message_id}/read", response_model=ResponseModel) async def mark_message_read( message_id: int, db: Session = Depends(get_db), current_user: UserDB = Depends(get_current_user) ): """标记消息为已读""" message = db.query(MessageDB).filter( MessageDB.id == message_id, MessageDB.user_id == current_user.userid ).first() if not message: return error_response(code=404, message="消息不存在") try: message.is_read = True db.commit() db.refresh(message) return success_response( message="消息已标记为已读", data=MessageInfo.model_validate(message) ) except Exception as e: db.rollback() return error_response(code=500, message=f"标记已读失败: {str(e)}") @router.get("/list", response_model=ResponseModel) async def get_user_messages( unread_only: bool = False, skip: int = 0, limit: int = 20, db: Session = Depends(get_db), current_user: UserDB = Depends(get_current_user) ): """获取用户消息列表""" # 构建基础查询 query = db.query(MessageDB).filter( MessageDB.user_id == current_user.userid ) # 如果只查询未读消息 if unread_only: query = query.filter(MessageDB.is_read == False) # 获取总数 total = query.count() # 获取分页数据 messages = query.order_by( MessageDB.create_time.desc() ).offset(skip).limit(limit).all() return success_response(data={ "total": total, "items": [MessageInfo.model_validate(m) for m in messages] })