deliveryman-api/app/api/endpoints/merchant.py
2025-01-22 23:06:34 +08:00

229 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any
from sqlalchemy import text
from app.models.merchant import (
MerchantDB,
MerchantCreate,
MerchantUpdate,
MerchantInfo,
MerchantImageDB
)
from app.models.merchant_category import MerchantCategoryDB
from app.models.database import get_db
from app.api.deps import get_admin_user
from app.models.user import UserDB
from app.core.response import success_response, error_response, ResponseModel
from app.models.merchant_pay_order import MerchantPayOrderDB
from sqlalchemy.sql import func
router = APIRouter()
@router.post("", response_model=ResponseModel)
async def create_merchant(
merchant: MerchantCreate,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""创建商家(管理员)"""
# 创建商家基本信息
merchant_data = merchant.model_dump(exclude={'images'})
db_merchant = MerchantDB(**merchant_data)
db.add(db_merchant)
try:
db.flush() # 获取商家ID
# 创建商家图片
for image in merchant.images:
db_image = MerchantImageDB(
merchant_id=db_merchant.id,
image_url=image.image_url,
sort=image.sort
)
db.add(db_image)
db.commit()
db.refresh(db_merchant)
return success_response(data=MerchantInfo.model_validate(db_merchant))
except Exception as e:
db.rollback()
return error_response(code=500, message=f"创建失败: {str(e)}")
@router.put("/{merchant_id}", response_model=ResponseModel)
async def update_merchant(
merchant_id: int,
merchant: MerchantUpdate,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""更新商家信息(管理员)"""
db_merchant = db.query(MerchantDB).filter(
MerchantDB.id == merchant_id
).first()
if not db_merchant:
return error_response(code=404, message="商家不存在")
# 如果要更新用户ID先验证用户是否存在
if merchant.user_id is not None:
user_exists = db.query(UserDB).filter(
UserDB.userid == merchant.user_id
).first()
if not user_exists:
return error_response(code=400, message="指定的用户不存在")
# 更新基本信息
update_data = merchant.model_dump(exclude={'images'}, exclude_unset=True)
for key, value in update_data.items():
setattr(db_merchant, key, value)
# 如果更新了图片
if merchant.images is not None:
# 删除原有图片
db.query(MerchantImageDB).filter(
MerchantImageDB.merchant_id == merchant_id
).delete()
# 添加新图片
for image in merchant.images:
db_image = MerchantImageDB(
merchant_id=merchant_id,
image_url=image.image_url,
sort=image.sort
)
db.add(db_image)
try:
db.commit()
db.refresh(db_merchant)
# 获取更新后的完整信息(包括用户信息)
updated_merchant = db.query(
MerchantDB,
UserDB.phone.label('user_phone'),
UserDB.nickname.label('user_nickname')
).join(
UserDB,
MerchantDB.user_id == UserDB.userid
).filter(
MerchantDB.id == merchant_id
).first()
# 构建返回数据
merchant_info = MerchantInfo.model_validate(updated_merchant.MerchantDB)
merchant_data = merchant_info.model_dump()
merchant_data.update({
'user_phone': updated_merchant.user_phone,
'user_nickname': updated_merchant.user_nickname
})
return success_response(data=merchant_data)
except Exception as e:
db.rollback()
return error_response(code=500, message=f"更新失败: {str(e)}")
@router.get("/{merchant_id}", response_model=ResponseModel)
async def get_merchant(
merchant_id: int,
db: Session = Depends(get_db)
):
"""获取商家详情"""
merchant = db.query(
MerchantDB,
UserDB.phone.label('user_phone'),
UserDB.nickname.label('user_nickname')
).join(
UserDB,
MerchantDB.user_id == UserDB.userid
).filter(
MerchantDB.id == merchant_id
).first()
if not merchant:
return error_response(code=404, message="商家不存在")
# 构建返回数据
merchant_info = MerchantInfo.model_validate(merchant.MerchantDB)
merchant_data = merchant_info.model_dump()
merchant_data.update({
'user_phone': merchant.user_phone,
'user_nickname': merchant.user_nickname
})
return success_response(data=merchant_data)
@router.get("", response_model=ResponseModel)
async def list_merchants(
longitude: Optional[float] = None,
latitude: Optional[float] = None,
category_id: Optional[int] = None,
skip: int = 0,
limit: int = 20,
db: Session = Depends(get_db)
):
"""获取商家列表,支持经纬度排序和分类过滤"""
query = db.query(
MerchantDB,
MerchantCategoryDB.name.label('category_name'),
UserDB.phone.label('user_phone'),
UserDB.nickname.label('user_nickname')
).outerjoin(
MerchantCategoryDB,
MerchantDB.category_id == MerchantCategoryDB.id
).join(
UserDB,
MerchantDB.user_id == UserDB.userid
)
# 添加分类过滤
if category_id is not None:
query = query.filter(MerchantDB.category_id == category_id)
# 根据经纬度排序
if longitude is not None and latitude is not None:
query = query.add_columns(
text("ST_Distance_Sphere(point(longitude, latitude), point(:lon, :lat)) as distance")
).params(lon=longitude, lat=latitude)
# 默认按距离排序
query = query.order_by(text("distance"))
else:
# 如果没有经纬度,则按创建时间排序
query = query.order_by(MerchantDB.create_time.desc())
# 添加一个空的距离列,保持返回结构一致
query = query.add_columns(text("NULL as distance"))
merchants = query.offset(skip).limit(limit).all()
# 获取商家在线买单数量
merchant_ids = [m[0].id for m in merchants]
pay_order_counts = dict(
db.query(
MerchantPayOrderDB.merchant_id,
func.count(MerchantPayOrderDB.id).label('count')
).filter(
MerchantPayOrderDB.merchant_id.in_(merchant_ids)
).group_by(MerchantPayOrderDB.merchant_id).all()
)
# 处理返回结果
merchant_list = [{
**MerchantInfo.model_validate(m[0]).model_dump(),
"category_name": m[1],
"user_phone": m[2],
"user_nickname": m[3],
"online_pay_count": pay_order_counts.get(m[0].id, 0),
"distance": round(m[4]) if longitude is not None and latitude is not None else None
} for m in merchants]
# 获取总数(需要考虑分类过滤)
total_query = db.query(MerchantDB)
if category_id is not None:
total_query = total_query.filter(MerchantDB.category_id == category_id)
total = total_query.count()
return success_response(data={
"total": total,
"items": merchant_list
})