From 3dbd54335784279ac5ad10934f21331a7c5df2c8 Mon Sep 17 00:00:00 2001 From: aaron <> Date: Mon, 3 Feb 2025 11:10:08 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/endpoints/merchant.py | 63 ++++++++++++---------------------- app/models/merchant.py | 34 +++--------------- app/models/merchant_product.py | 8 +++-- 3 files changed, 32 insertions(+), 73 deletions(-) diff --git a/app/api/endpoints/merchant.py b/app/api/endpoints/merchant.py index 957cdfb..79274f9 100644 --- a/app/api/endpoints/merchant.py +++ b/app/api/endpoints/merchant.py @@ -6,9 +6,7 @@ from app.models.merchant import ( MerchantDB, MerchantCreate, MerchantUpdate, - MerchantInfo, - MerchantImageDB -) + MerchantInfo) from app.models.merchant_category import MerchantCategoryDB from app.models.database import get_db from app.api.deps import get_admin_user @@ -16,7 +14,7 @@ from app.models.user import UserDB from app.core.response import success_response, error_response, ResponseModel from app.models.merchant_pay_order import MerchantPayOrderDB from sqlalchemy.sql import func, desc -from app.models.merchant_product import MerchantProductDB +from app.models.merchant_product import MerchantProductDB, ProductStatus router = APIRouter() @@ -28,22 +26,11 @@ async def create_merchant( ): """创建商家(管理员)""" # 创建商家基本信息 - merchant_data = merchant.model_dump(exclude={'images'}) + merchant_data = merchant.model_dump() db_merchant = MerchantDB(**merchant_data) db.add(db_merchant) try: - db.flush() # 获取商家ID - - # 创建商家图片 - for image in merchant.images: - db_image = MerchantImageDB( - merchant_id=db_merchant.id, - image_url=image.image_url, - sort=image.sort - ) - db.add(db_image) - db.commit() db.refresh(db_merchant) return success_response(data=MerchantInfo.model_validate(db_merchant)) @@ -74,27 +61,12 @@ async def update_merchant( if not user_exists: return error_response(code=400, message="指定的用户不存在") - # 更新基本信息 - update_data = merchant.model_dump(exclude={'images'}, exclude_unset=True) + # 只更新传入的非空字段 + update_data = merchant.model_dump(exclude_unset=True) for key, value in update_data.items(): - setattr(db_merchant, key, value) - - # 如果更新了图片 - if merchant.images is not None: - # 删除原有图片 - db.query(MerchantImageDB).filter( - MerchantImageDB.merchant_id == merchant_id - ).delete() - - # 添加新图片 - for image in merchant.images: - db_image = MerchantImageDB( - merchant_id=merchant_id, - image_url=image.image_url, - sort=image.sort - ) - db.add(db_image) - + if value is not None: # 只更新非空值 + setattr(db_merchant, key, value) + try: db.commit() db.refresh(db_merchant) @@ -211,21 +183,30 @@ async def list_merchants( # 获取商家最新或限购商品 merchant_products = {} for merchant_id in merchant_ids: + # 先查询有限购的商品 product = db.query(MerchantProductDB).filter( MerchantProductDB.merchant_id == merchant_id, - MerchantProductDB.status == True # 只查询上架商品 + MerchantProductDB.status == ProductStatus.LISTING, + MerchantProductDB.purchase_limit > 0 ).order_by( - # 优先选择有限购的商品,其次按创建时间倒序 - desc(MerchantProductDB.purchase_limit > 0), - desc(MerchantProductDB.create_time) + MerchantProductDB.create_time.desc() ).first() + # 如果没有限购商品,则查询最新上架的商品 + if not product: + product = db.query(MerchantProductDB).filter( + MerchantProductDB.merchant_id == merchant_id, + MerchantProductDB.status == ProductStatus.LISTING + ).order_by( + MerchantProductDB.create_time.desc() + ).first() + if product: merchant_products[merchant_id] = { "product_id": product.id, "product_name": product.name, - "product_price": float(product.sale_price), "product_image": product.image_url, + "product_price": float(product.sale_price), "purchase_limit": product.purchase_limit } diff --git a/app/models/merchant.py b/app/models/merchant.py index 3a48d12..23da523 100644 --- a/app/models/merchant.py +++ b/app/models/merchant.py @@ -1,25 +1,11 @@ from sqlalchemy import Column, String, Integer, Float, DateTime, JSON, ForeignKey from sqlalchemy.dialects.mysql import DECIMAL from sqlalchemy.sql import func, select -from sqlalchemy.orm import relationship from pydantic import BaseModel, Field from typing import Optional, List from datetime import datetime from .database import Base -# 商家图片表 -class MerchantImageDB(Base): - __tablename__ = "merchant_images" - - id = Column(Integer, primary_key=True, autoincrement=True) - merchant_id = Column(Integer, ForeignKey("merchants.id", ondelete="CASCADE"), index=True) - image_url = Column(String(500), nullable=False) - sort = Column(Integer, nullable=False, default=0) - create_time = Column(DateTime(timezone=True), server_default=func.now()) - - class Config: - unique_together = [("merchant_id", "sort")] - # 数据库模型 class MerchantDB(Base): __tablename__ = "merchants" @@ -36,19 +22,7 @@ class MerchantDB(Base): create_time = Column(DateTime(timezone=True), server_default=func.now()) update_time = Column(DateTime(timezone=True), onupdate=func.now()) category_id = Column(Integer, ForeignKey("merchant_categories.id"), nullable=True) - - # 关联图片 - images = relationship("MerchantImageDB", - order_by="MerchantImageDB.sort", - cascade="all, delete-orphan") - -# Pydantic 模型 -class MerchantImage(BaseModel): - image_url: str = Field(..., max_length=500) - sort: int = Field(..., ge=0) # 排序序号,从0开始 - - class Config: - from_attributes = True + pay_gift_points_rate = Column(DECIMAL(4,2), nullable=False, default=0.00) # 支付赠送积分比例,默认0% class MerchantCreate(BaseModel): user_id: int @@ -58,8 +32,8 @@ class MerchantCreate(BaseModel): longitude: float = Field(..., ge=-180, le=180, description="经度") latitude: float = Field(..., ge=-90, le=90, description="纬度") phone: str = Field(..., max_length=20, pattern=r'^\d+$') + pay_gift_points_rate: Optional[float] = Field(10.00, ge=0, le=100) # 支付赠送积分比例 brand_image_url: Optional[str] = Field(None, max_length=200) - images: List[MerchantImage] = [] category_id: Optional[int] = None class MerchantUpdate(BaseModel): @@ -70,8 +44,8 @@ class MerchantUpdate(BaseModel): longitude: Optional[float] = Field(None, ge=-180, le=180, description="经度") latitude: Optional[float] = Field(None, ge=-90, le=90, description="纬度") phone: Optional[str] = Field(None, max_length=20, pattern=r'^\d+$') + pay_gift_points_rate: Optional[float] = Field(None, ge=0, le=100) # 支付赠送积分比例 brand_image_url: Optional[str] = Field(None, max_length=200) - images: Optional[List[MerchantImage]] = None category_id: Optional[int] = None class MerchantInfo(BaseModel): @@ -86,8 +60,8 @@ class MerchantInfo(BaseModel): longitude: float latitude: float phone: str + pay_gift_points_rate: float brand_image_url: Optional[str] = None - images: List[MerchantImage] create_time: datetime update_time: Optional[datetime] distance: Optional[int] = None # 距离(米) diff --git a/app/models/merchant_product.py b/app/models/merchant_product.py index 4a3b0bd..62792b9 100644 --- a/app/models/merchant_product.py +++ b/app/models/merchant_product.py @@ -16,14 +16,15 @@ class MerchantProductDB(Base): __tablename__ = "merchant_products" id = Column(Integer, primary_key=True, autoincrement=True) - merchant_id = Column(Integer, ForeignKey("merchants.id", ondelete="CASCADE"), index=True) + merchant_id = Column(Integer, ForeignKey("merchants.id"), nullable=False) name = Column(String(100), nullable=False) image_url = Column(String(500), nullable=False) product_price = Column(Float, nullable=False) # 原价 - sale_price = Column(Float, nullable=False) # 售价 + sale_price = Column(DECIMAL(10,2), nullable=False) # 售价 settlement_amount = Column(DECIMAL(10,2), nullable=False) # 商家结算金额 tags = Column(String(200)) # 标签,逗号分隔 purchase_limit = Column(Integer, nullable=False, default=0) # 限购次数,0表示不限购 + gift_points_rate = Column(DECIMAL(4,2), nullable=False, default=0.00) # 购买赠送积分比例,默认0% create_time = Column(DateTime(timezone=True), server_default=func.now()) update_time = Column(DateTime(timezone=True), onupdate=func.now()) status = Column(Enum(ProductStatus), nullable=False, default=ProductStatus.UNLISTING) @@ -40,6 +41,7 @@ class MerchantProductCreate(BaseModel): tags: str = Field("", max_length=200) purchase_limit: int = Field(0, ge=0) # 限购次数,默认0表示不限购 status: ProductStatus = ProductStatus.UNLISTING + gift_points_rate: Optional[float] = Field(10.00, ge=0, le=100) # 购买赠送积分比例 class MerchantProductUpdate(BaseModel): name: Optional[str] = Field(None, max_length=100) @@ -50,6 +52,7 @@ class MerchantProductUpdate(BaseModel): tags: Optional[str] = Field(None, max_length=200) purchase_limit: Optional[int] = Field(None, ge=0) # 限购次数,可选字段 status: Optional[ProductStatus] = None + gift_points_rate: Optional[float] = Field(None, ge=0, le=100) # 购买赠送积分比例 class MerchantProductInfo(BaseModel): id: int @@ -62,6 +65,7 @@ class MerchantProductInfo(BaseModel): settlement_amount: float tags: str purchase_limit: int # 限购次数 + gift_points_rate: float create_time: datetime update_time: Optional[datetime] status: ProductStatus