deliveryman-api/app/api/endpoints/station.py
2025-03-14 22:53:59 +08:00

149 lines
4.8 KiB
Python

from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List, Optional
from app.models.station import StationDB, StationCreate, StationUpdate, StationInfo
from app.models.community import CommunityDB
from app.models.database import get_db
from app.api.deps import get_admin_user
from app.models.user import UserDB
from app.core.response import success_response, error_response, ResponseModel
router = APIRouter()
@router.post("", response_model=ResponseModel)
async def create_station(
station: StationCreate,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""创建驿站"""
db_station = StationDB(**station.model_dump())
db.add(db_station)
db.commit()
db.refresh(db_station)
return success_response(data=StationInfo.model_validate(db_station))
@router.get("/group_by_community", response_model=ResponseModel)
async def get_stations_group_by_community(
community_id: Optional[int] = None,
db: Session = Depends(get_db)
):
"""获取驿站列表,按社区分组"""
stations = db.query(StationDB, CommunityDB.name.label('community_name')).join(CommunityDB, StationDB.community_id == CommunityDB.id)
if community_id:
stations = stations.filter(StationDB.community_id == community_id)
stations = stations.all()
# 按社区分组
grouped_results = {}
for station, community_name in stations:
if station.community_id not in grouped_results:
grouped_results[station.community_id] = {
"community_id": station.community_id,
"community_name": community_name,
"stations": []
}
grouped_results[station.community_id]["stations"].append({"station_id": station.id, "station_name": station.name})
return success_response(data=list(grouped_results.values()))
@router.get("", response_model=ResponseModel)
async def get_stations(
community_id: Optional[int] = None,
skip: int = 0,
limit: int = 10,
db: Session = Depends(get_db)
):
"""获取驿站列表"""
# 联表查询,获取社区名称
query = db.query(
StationDB,
CommunityDB.name.label('community_name')
).join(
CommunityDB,
StationDB.community_id == CommunityDB.id
)
if community_id:
query = query.filter(StationDB.community_id == community_id)
# 获取总数
total = query.count()
# 查询数据
results = query.offset(skip).limit(limit).all()
# 处理返回数据
station_list = []
for station, community_name in results:
station_info = StationInfo.model_validate(station)
station_info.community_name = community_name
station_list.append(station_info)
return success_response(data={
"total": total,
"items": station_list
})
@router.get("/{station_id}", response_model=ResponseModel)
async def get_station(
station_id: int,
db: Session = Depends(get_db)
):
"""获取驿站详情"""
station = db.query(StationDB).filter(StationDB.id == station_id).first()
if not station:
return error_response(code=404, message="驿站不存在")
return success_response(data=StationInfo.model_validate(station))
@router.put("/{station_id}", response_model=ResponseModel)
async def update_station(
station_id: int,
station: StationUpdate,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""更新驿站信息"""
db_station = db.query(StationDB).filter(StationDB.id == station_id).first()
if not db_station:
return error_response(code=404, message="驿站不存在")
update_data = station.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_station, key, value)
db.commit()
db.refresh(db_station)
return success_response(data=StationInfo.model_validate(db_station))
@router.delete("/{station_id}", response_model=ResponseModel)
async def delete_station(
station_id: int,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""删除驿站"""
result = db.query(StationDB).filter(StationDB.id == station_id).delete()
if not result:
return error_response(code=404, message="驿站不存在")
db.commit()
return success_response(message="驿站已删除")
@router.get("/community/{community_id}", response_model=ResponseModel)
async def get_stations_by_community(
community_id: int,
db: Session = Depends(get_db)
):
"""获取指定社区的驿站列表"""
stations = db.query(StationDB).filter(
StationDB.community_id == community_id
).all()
if not stations:
return success_response(data=[]) # 返回空列表而不是错误
return success_response(
data=[StationInfo.model_validate(s) for s in stations]
)