deliveryman-api/app/api/endpoints/message.py
2025-01-23 12:38:13 +08:00

107 lines
3.2 KiB
Python

from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List, Optional
from app.models.message import MessageDB, MessageCreate, MessageInfo
from app.models.database import get_db
from app.api.deps import get_current_user, get_admin_user
from app.models.user import UserDB
from app.core.response import success_response, error_response, ResponseModel
router = APIRouter()
@router.post("", response_model=ResponseModel)
async def create_message(
message: MessageCreate,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user) # 仅管理员可创建消息
):
"""创建消息 (仅管理员)"""
db_message = MessageDB(**message.model_dump())
try:
db.add(db_message)
db.commit()
db.refresh(db_message)
return success_response(
message="消息创建成功",
data=MessageInfo.model_validate(db_message)
)
except Exception as e:
db.rollback()
return error_response(code=500, message=f"创建消息失败: {str(e)}")
@router.post("/{message_id}/read", response_model=ResponseModel)
async def mark_message_read(
message_id: int,
db: Session = Depends(get_db),
current_user: UserDB = Depends(get_current_user)
):
"""标记消息为已读"""
message = db.query(MessageDB).filter(
MessageDB.id == message_id,
MessageDB.user_id == current_user.userid
).first()
if not message:
return error_response(code=404, message="消息不存在")
try:
message.is_read = True
db.commit()
db.refresh(message)
return success_response(
message="消息已标记为已读",
data=MessageInfo.model_validate(message)
)
except Exception as e:
db.rollback()
return error_response(code=500, message=f"标记已读失败: {str(e)}")
@router.get("/list", response_model=ResponseModel)
async def get_user_messages(
unread_only: bool = False,
skip: int = 0,
limit: int = 20,
db: Session = Depends(get_db),
current_user: UserDB = Depends(get_current_user)
):
"""获取用户消息列表"""
# 构建基础查询
query = db.query(MessageDB).filter(
MessageDB.user_id == current_user.userid
)
# 如果只查询未读消息
if unread_only:
query = query.filter(MessageDB.is_read == False)
# 获取总数
total = query.count()
# 获取分页数据
messages = query.order_by(
MessageDB.create_time.desc()
).offset(skip).limit(limit).all()
return success_response(data={
"total": total,
"items": [MessageInfo.model_validate(m) for m in messages]
})
@router.get("/latest-unread", response_model=ResponseModel)
async def get_latest_unread_message(
db: Session = Depends(get_db),
current_user: UserDB = Depends(get_current_user)
):
"""获取最新一条未读消息"""
message = db.query(MessageDB).filter(
MessageDB.user_id == current_user.userid,
MessageDB.is_read == False
).order_by(
MessageDB.create_time.desc()
).first()
if not message:
return success_response(data=None)
return success_response(data=MessageInfo.model_validate(message))