diff --git a/app/api/endpoints/community.py b/app/api/endpoints/community.py index 8691bd5..742faa3 100644 --- a/app/api/endpoints/community.py +++ b/app/api/endpoints/community.py @@ -12,6 +12,7 @@ from app.core.response import success_response, error_response, ResponseModel from app.models.community_profit_sharing import CommunityProfitSharing from app.api.endpoints.community_profit_sharing import CommunityProfitSharingResponse from app.core import utils +from sqlalchemy import text router = APIRouter() @@ -69,22 +70,52 @@ async def get_communities( limit: int = 10, db: Session = Depends(get_db) ): - """获取社区列表""" + """获取社区列表 + + 参数: + latitude: 纬度 + longitude: 经度 + status: 社区状态 + skip: 跳过记录数 + limit: 返回记录数 + sort_by_distance: 是否按照距离排序,默认为True + """ # 构建查询, 关联社区分润 - # 使用一次查询获取所有需要的数据,减少数据库连接使用时间 + # 使用一次查询获取所有需要的数据,减少数据库连接使用时间,按照经纬度距离进行排序 query = db.query(CommunityDB).options(joinedload(CommunityDB.community_profit_sharing), joinedload(CommunityDB.admin)) - # 状态过滤 + # 如果指定了状态,添加状态过滤 if status: query = query.filter(CommunityDB.status == status) + if longitude is not None and latitude is not None: + # 添加距离计算列 + query = query.add_columns( + text("ST_Distance_Sphere(point(longitude, latitude), point(:lon, :lat)) as distance") + ).params(lon=longitude, lat=latitude) + + query = query.order_by(text("distance")) + else: + # 如果没有经纬度,则按创建时间排序 + query = query.order_by(CommunityDB.create_time.desc()) + query = query.add_columns(text("NULL as distance")) + + # 获取总数 total = query.count() # 查询数据 communities = query.offset(skip).limit(limit).all() - + community_list = [] - for community in communities: + for result in communities: + # 当使用add_columns时,结果是一个元组,第一个元素是CommunityDB对象 + if hasattr(result, '_fields'): # 检查是否是Row对象(带有额外列的结果) + community = result[0] # CommunityDB对象 + distance_value = result[1] # 距离值 + else: + community = result # 直接是CommunityDB对象 + distance_value = None + community_info = { "id": community.id, "name": community.name, @@ -97,6 +128,7 @@ async def get_communities( "base_price": float(community.base_price), "extra_package_price": float(community.extra_package_price), "extra_package_threshold": community.extra_package_threshold, + "distance": float(distance_value) if distance_value is not None else None, "admin": None if community.admin is None else { "id": community.admin.userid, "nickname": community.admin.nickname, @@ -112,20 +144,8 @@ async def get_communities( } } - # 如果提供了经纬度,则计算距离 - if latitude is not None and longitude is not None: - distance = calculate_distance( - latitude, longitude, - float(community.latitude), float(community.longitude) - ) - community_info["distance"] = distance - community_list.append(community_info) - # 如果计算了距离,则按距离排序 - if latitude is not None and longitude is not None: - community_list.sort(key=lambda x: x["distance"]) - return success_response(data={ "total": total, "items": community_list diff --git a/jobs.sqlite b/jobs.sqlite index 39a1123..0fb2f09 100644 Binary files a/jobs.sqlite and b/jobs.sqlite differ