支持商户用户归属的功能。

This commit is contained in:
aaron 2025-01-19 15:39:37 +08:00
parent 91e56c9eb3
commit 60011e3d9a
2 changed files with 59 additions and 5 deletions

View File

@ -63,6 +63,14 @@ async def update_merchant(
if not db_merchant:
return error_response(code=404, message="商家不存在")
# 如果要更新用户ID先验证用户是否存在
if merchant.user_id is not None:
user_exists = db.query(UserDB).filter(
UserDB.userid == merchant.user_id
).first()
if not user_exists:
return error_response(code=400, message="指定的用户不存在")
# 更新基本信息
update_data = merchant.model_dump(exclude={'images'}, exclude_unset=True)
for key, value in update_data.items():
@ -87,7 +95,28 @@ async def update_merchant(
try:
db.commit()
db.refresh(db_merchant)
return success_response(data=MerchantInfo.model_validate(db_merchant))
# 获取更新后的完整信息(包括用户信息)
updated_merchant = db.query(
MerchantDB,
UserDB.phone.label('user_phone'),
UserDB.username.label('user_username')
).join(
UserDB,
MerchantDB.user_id == UserDB.userid
).filter(
MerchantDB.id == merchant_id
).first()
# 构建返回数据
merchant_info = MerchantInfo.model_validate(updated_merchant.MerchantDB)
merchant_data = merchant_info.model_dump()
merchant_data.update({
'user_phone': updated_merchant.user_phone,
'user_username': updated_merchant.user_username
})
return success_response(data=merchant_data)
except Exception as e:
db.rollback()
return error_response(code=500, message=f"更新失败: {str(e)}")
@ -98,14 +127,29 @@ async def get_merchant(
db: Session = Depends(get_db)
):
"""获取商家详情"""
merchant = db.query(MerchantDB).filter(
merchant = db.query(
MerchantDB,
UserDB.phone.label('user_phone'),
UserDB.username.label('user_username')
).join(
UserDB,
MerchantDB.user_id == UserDB.userid
).filter(
MerchantDB.id == merchant_id
).first()
if not merchant:
return error_response(code=404, message="商家不存在")
return success_response(data=MerchantInfo.model_validate(merchant))
# 构建返回数据
merchant_info = MerchantInfo.model_validate(merchant.MerchantDB)
merchant_data = merchant_info.model_dump()
merchant_data.update({
'user_phone': merchant.user_phone,
'user_username': merchant.user_username
})
return success_response(data=merchant_data)
@router.get("", response_model=ResponseModel)
async def list_merchants(
@ -119,10 +163,15 @@ async def list_merchants(
"""获取商家列表,支持经纬度排序和分类过滤"""
query = db.query(
MerchantDB,
MerchantCategoryDB.name.label('category_name')
MerchantCategoryDB.name.label('category_name'),
UserDB.phone.label('user_phone'),
UserDB.username.label('user_username')
).outerjoin(
MerchantCategoryDB,
MerchantDB.category_id == MerchantCategoryDB.id
).join(
UserDB,
MerchantDB.user_id == UserDB.userid
)
# 添加分类过滤
@ -143,7 +192,9 @@ async def list_merchants(
merchant_list = [{
**MerchantInfo.model_validate(m[0]).model_dump(),
"category_name": m[1],
"distance": round(m[2]) if longitude is not None and latitude is not None else None
"user_phone": m[2],
"user_username": m[3],
"distance": round(m[4]) if longitude is not None and latitude is not None else None
} for m in merchants]
# 获取总数(需要考虑分类过滤)

View File

@ -61,6 +61,7 @@ class MerchantCreate(BaseModel):
category_id: Optional[int] = None
class MerchantUpdate(BaseModel):
user_id: Optional[int] = None
name: Optional[str] = Field(None, max_length=100)
business_hours: Optional[str] = Field(None, max_length=100)
address: Optional[str] = Field(None, max_length=200)
@ -73,6 +74,8 @@ class MerchantUpdate(BaseModel):
class MerchantInfo(BaseModel):
id: int
user_id: int
user_phone: Optional[str] = None
user_username: Optional[str] = None
name: str
business_hours: str
address: str