deliveryman-api/app/api/endpoints/community.py
2025-03-18 09:30:09 +08:00

242 lines
9.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session, joinedload
from typing import List, Optional
from app.models.community import (
CommunityDB, CommunityCreate, CommunityUpdate,
CommunityInfo, CommunityStatus
)
from app.models.database import get_db
from app.api.deps import get_admin_user
from app.models.user import UserDB
from app.core.response import success_response, error_response, ResponseModel
from app.models.community_profit_sharing import CommunityProfitSharing
from app.api.endpoints.community_profit_sharing import CommunityProfitSharingResponse
from app.core import utils
from sqlalchemy import text
from app.models.community_timeperiod import CommunityTimePeriodDB
router = APIRouter()
@router.post("", response_model=ResponseModel)
async def create_community(
community: CommunityCreate,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""创建社区"""
db_community = CommunityDB(**community.model_dump())
db.add(db_community)
db.commit()
db.refresh(db_community)
return success_response(data=CommunityInfo.model_validate(db_community))
@router.get("/{community_id}/qrcode", response_model=ResponseModel)
async def get_community_qrcode(
community_id: int,
latitude: Optional[float] = None,
longitude: Optional[float] = None,
db: Session = Depends(get_db)
):
"""获取社区二维码"""
community = db.query(CommunityDB).filter(CommunityDB.id == community_id).first()
if not community:
return error_response(code=404, message="社区不存在")
show_qrcode = False
# 如果提供了经纬度,则计算距离
if latitude is not None and longitude is not None:
distance = calculate_distance(
latitude, longitude,
community.latitude, community.longitude
)
community.distance = distance
# 如果距离小于1 公里,则显示二维码
if distance < 1000:
show_qrcode = True
return success_response(data={
"community": CommunityInfo.model_validate(community),
"show_qrcode": show_qrcode
})
@router.get("", response_model=ResponseModel)
async def get_communities(
latitude: Optional[float] = None,
longitude: Optional[float] = None,
status: Optional[CommunityStatus] = None,
skip: int = 0,
limit: int = 10,
db: Session = Depends(get_db)
):
"""获取社区列表
参数:
latitude: 纬度
longitude: 经度
status: 社区状态
skip: 跳过记录数
limit: 返回记录数
sort_by_distance: 是否按照距离排序默认为True
"""
# 构建查询, 关联社区分润
# 使用一次查询获取所有需要的数据,减少数据库连接使用时间,按照经纬度距离进行排序
query = db.query(CommunityDB).options(joinedload(CommunityDB.community_profit_sharing), joinedload(CommunityDB.admin))
# 如果指定了状态,添加状态过滤
if status:
query = query.filter(CommunityDB.status == status)
if longitude is not None and latitude is not None:
# 添加距离计算列
query = query.add_columns(
text("ST_Distance_Sphere(point(longitude, latitude), point(:lon, :lat)) as distance")
).params(lon=longitude, lat=latitude)
query = query.order_by(text("distance"))
else:
# 如果没有经纬度,则按创建时间排序
query = query.order_by(CommunityDB.create_time.desc())
query = query.add_columns(text("NULL as distance"))
# 获取总数
total = query.count()
# 查询数据
communities = query.offset(skip).limit(limit).all()
community_list = []
for result in communities:
# 当使用add_columns时结果是一个元组第一个元素是CommunityDB对象
if hasattr(result, '_fields'): # 检查是否是Row对象带有额外列的结果
community = result[0] # CommunityDB对象
distance_value = result[1] # 距离值
else:
community = result # 直接是CommunityDB对象
distance_value = None
community_info = {
"id": community.id,
"name": community.name,
"address": community.address,
"latitude": float(community.latitude),
"longitude": float(community.longitude),
"status": community.status,
"qy_group_qrcode": community.qy_group_qrcode,
"webot_webhook": community.webot_webhook,
"base_price": float(community.base_price),
"extra_package_price": float(community.extra_package_price),
"extra_package_threshold": community.extra_package_threshold,
"more_station_price": float(community.more_station_price),
"weekdays": community.weekdays,
"distance": float(distance_value) if distance_value is not None else None,
"admin": None if community.admin is None else {
"id": community.admin.userid,
"nickname": community.admin.nickname,
"phone": utils.CommonUtils.desensitize_phone(community.admin.phone),
"avatar": community.admin.avatar
},
"profit_sharing": None if community.community_profit_sharing is None else {
"id": community.community_profit_sharing.id,
"platform_rate": float(community.community_profit_sharing.platform_rate),
"partner_rate": float(community.community_profit_sharing.partner_rate),
"admin_rate": float(community.community_profit_sharing.admin_rate),
"delivery_rate": float(community.community_profit_sharing.delivery_rate)
}
}
community_list.append(community_info)
return success_response(data={
"total": total,
"items": community_list
})
def calculate_distance(lat1: float, lon1: float, lat2: float, lon2: float) -> float:
"""计算两点之间的距离(米)"""
from math import radians, sin, cos, sqrt, atan2
R = 6371000 # 地球半径(米)
# 转换为弧度
lat1, lon1, lat2, lon2 = map(radians, [lat1, lon1, lat2, lon2])
# Haversine 公式
dlat = lat2 - lat1
dlon = lon2 - lon1
a = sin(dlat/2)**2 + cos(lat1) * cos(lat2) * sin(dlon/2)**2
c = 2 * atan2(sqrt(a), sqrt(1-a))
distance = R * c
return round(distance, 2) # 保留2位小数
@router.get("/{community_id}", response_model=ResponseModel)
async def get_community(
community_id: int,
db: Session = Depends(get_db)
):
"""获取社区详情"""
community = db.query(CommunityDB).options(joinedload(CommunityDB.community_profit_sharing), joinedload(CommunityDB.admin)).filter(CommunityDB.id == community_id).first()
if not community:
return error_response(code=404, message="社区不存在")
return success_response(data=CommunityInfo.model_validate(community))
@router.put("/{community_id}", response_model=ResponseModel)
async def update_community(
community_id: int,
community: CommunityUpdate,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""更新社区信息"""
db_community = db.query(CommunityDB).filter(CommunityDB.id == community_id).first()
if not db_community:
return error_response(code=404, message="社区不存在")
if community.status is not None and community.status == CommunityStatus.OPENING:
# 检查是否设置分润
existing_profit_sharing = db.query(CommunityProfitSharing).filter(CommunityProfitSharing.community_id == community_id).first()
if not existing_profit_sharing:
return error_response(code=400, message="请先设置分润")
# 检查是否存在配送时段
existing_time_periods = db.query(CommunityTimePeriodDB).filter(CommunityTimePeriodDB.community_id == community_id).all()
if not existing_time_periods:
return error_response(code=400, message="请先设置配送时段")
update_data = community.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_community, key, value)
db.commit()
db.refresh(db_community)
return success_response(data=CommunityInfo.model_validate(db_community))
@router.delete("/{community_id}", response_model=ResponseModel)
async def delete_community(
community_id: int,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""删除社区"""
result = db.query(CommunityDB).filter(CommunityDB.id == community_id).delete()
if not result:
return error_response(code=404, message="社区不存在")
db.commit()
return success_response(message="社区已删除")
#查询小区
@router.get("/search_by_name/{name}", response_model=ResponseModel)
async def search_community_by_name(
name: str,
db: Session = Depends(get_db)
):
"""通过小区名称搜索小区"""
communities = db.query(CommunityDB).filter(CommunityDB.name.ilike(f"%{name}%")).all()
results=[]
for community in communities:
results.append(CommunityInfo.model_validate(community))
return success_response(data=results)