diff --git a/app/api/endpoints/subscribe.py b/app/api/endpoints/subscribe.py index eadb977..a807fad 100644 --- a/app/api/endpoints/subscribe.py +++ b/app/api/endpoints/subscribe.py @@ -66,4 +66,4 @@ async def get_subscribes( return success_response(data=[ SubscribeInfo.model_validate(s) for s in subscribes - ]) \ No newline at end of file + ]) \ No newline at end of file diff --git a/app/api/endpoints/user.py b/app/api/endpoints/user.py index 6852ba6..03b6398 100644 --- a/app/api/endpoints/user.py +++ b/app/api/endpoints/user.py @@ -23,7 +23,7 @@ from app.models.user_auth import UserAuthDB, UserAuthCreate, UserAuthInfo from app.core.qcloud import qcloud_manager from app.models.merchant import MerchantDB from app.models.address import AddressDB, AddressInfo - +from app.models.subscribe import SubscribeDB router = APIRouter() # Redis 连接 @@ -120,11 +120,7 @@ async def get_user_info( current_user: UserDB = Depends(get_current_user) ): """获取用户信息""" - # 查询用户未使用的优惠券数量 - coupon_count = db.query(UserCouponDB).filter( - UserCouponDB.user_id == current_user.userid, - UserCouponDB.status == CouponStatus.UNUSED - ).count() + # 获取用户默认地址 default_address = db.query(AddressDB, CommunityDB.name.label('community_name')).join( @@ -155,7 +151,19 @@ async def get_user_info( } user_data['default_address'] = AddressInfo(**address_data) + # 查询用户未使用的优惠券数量 + coupon_count = db.query(UserCouponDB).filter( + UserCouponDB.user_id == current_user.userid, + UserCouponDB.status == CouponStatus.UNUSED + ).count() user_data['coupon_count'] = coupon_count + + # 查询当前用户是否订阅的模板 + subscribe_count = db.query(SubscribeDB).filter( + SubscribeDB.user_id == current_user.userid + ).count() + + user_data['is_subscribe'] = subscribe_count > 0 return success_response(data=user_data)