修改 bug

This commit is contained in:
aaron 2025-02-03 11:10:08 +08:00
parent 124613789c
commit 3dbd543357
3 changed files with 32 additions and 73 deletions

View File

@ -6,9 +6,7 @@ from app.models.merchant import (
MerchantDB, MerchantDB,
MerchantCreate, MerchantCreate,
MerchantUpdate, MerchantUpdate,
MerchantInfo, MerchantInfo)
MerchantImageDB
)
from app.models.merchant_category import MerchantCategoryDB from app.models.merchant_category import MerchantCategoryDB
from app.models.database import get_db from app.models.database import get_db
from app.api.deps import get_admin_user from app.api.deps import get_admin_user
@ -16,7 +14,7 @@ from app.models.user import UserDB
from app.core.response import success_response, error_response, ResponseModel from app.core.response import success_response, error_response, ResponseModel
from app.models.merchant_pay_order import MerchantPayOrderDB from app.models.merchant_pay_order import MerchantPayOrderDB
from sqlalchemy.sql import func, desc from sqlalchemy.sql import func, desc
from app.models.merchant_product import MerchantProductDB from app.models.merchant_product import MerchantProductDB, ProductStatus
router = APIRouter() router = APIRouter()
@ -28,22 +26,11 @@ async def create_merchant(
): ):
"""创建商家(管理员)""" """创建商家(管理员)"""
# 创建商家基本信息 # 创建商家基本信息
merchant_data = merchant.model_dump(exclude={'images'}) merchant_data = merchant.model_dump()
db_merchant = MerchantDB(**merchant_data) db_merchant = MerchantDB(**merchant_data)
db.add(db_merchant) db.add(db_merchant)
try: try:
db.flush() # 获取商家ID
# 创建商家图片
for image in merchant.images:
db_image = MerchantImageDB(
merchant_id=db_merchant.id,
image_url=image.image_url,
sort=image.sort
)
db.add(db_image)
db.commit() db.commit()
db.refresh(db_merchant) db.refresh(db_merchant)
return success_response(data=MerchantInfo.model_validate(db_merchant)) return success_response(data=MerchantInfo.model_validate(db_merchant))
@ -74,26 +61,11 @@ async def update_merchant(
if not user_exists: if not user_exists:
return error_response(code=400, message="指定的用户不存在") return error_response(code=400, message="指定的用户不存在")
# 更新基本信息 # 只更新传入的非空字段
update_data = merchant.model_dump(exclude={'images'}, exclude_unset=True) update_data = merchant.model_dump(exclude_unset=True)
for key, value in update_data.items(): for key, value in update_data.items():
setattr(db_merchant, key, value) if value is not None: # 只更新非空值
setattr(db_merchant, key, value)
# 如果更新了图片
if merchant.images is not None:
# 删除原有图片
db.query(MerchantImageDB).filter(
MerchantImageDB.merchant_id == merchant_id
).delete()
# 添加新图片
for image in merchant.images:
db_image = MerchantImageDB(
merchant_id=merchant_id,
image_url=image.image_url,
sort=image.sort
)
db.add(db_image)
try: try:
db.commit() db.commit()
@ -211,21 +183,30 @@ async def list_merchants(
# 获取商家最新或限购商品 # 获取商家最新或限购商品
merchant_products = {} merchant_products = {}
for merchant_id in merchant_ids: for merchant_id in merchant_ids:
# 先查询有限购的商品
product = db.query(MerchantProductDB).filter( product = db.query(MerchantProductDB).filter(
MerchantProductDB.merchant_id == merchant_id, MerchantProductDB.merchant_id == merchant_id,
MerchantProductDB.status == True # 只查询上架商品 MerchantProductDB.status == ProductStatus.LISTING,
MerchantProductDB.purchase_limit > 0
).order_by( ).order_by(
# 优先选择有限购的商品,其次按创建时间倒序 MerchantProductDB.create_time.desc()
desc(MerchantProductDB.purchase_limit > 0),
desc(MerchantProductDB.create_time)
).first() ).first()
# 如果没有限购商品,则查询最新上架的商品
if not product:
product = db.query(MerchantProductDB).filter(
MerchantProductDB.merchant_id == merchant_id,
MerchantProductDB.status == ProductStatus.LISTING
).order_by(
MerchantProductDB.create_time.desc()
).first()
if product: if product:
merchant_products[merchant_id] = { merchant_products[merchant_id] = {
"product_id": product.id, "product_id": product.id,
"product_name": product.name, "product_name": product.name,
"product_price": float(product.sale_price),
"product_image": product.image_url, "product_image": product.image_url,
"product_price": float(product.sale_price),
"purchase_limit": product.purchase_limit "purchase_limit": product.purchase_limit
} }

View File

@ -1,25 +1,11 @@
from sqlalchemy import Column, String, Integer, Float, DateTime, JSON, ForeignKey from sqlalchemy import Column, String, Integer, Float, DateTime, JSON, ForeignKey
from sqlalchemy.dialects.mysql import DECIMAL from sqlalchemy.dialects.mysql import DECIMAL
from sqlalchemy.sql import func, select from sqlalchemy.sql import func, select
from sqlalchemy.orm import relationship
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Optional, List from typing import Optional, List
from datetime import datetime from datetime import datetime
from .database import Base from .database import Base
# 商家图片表
class MerchantImageDB(Base):
__tablename__ = "merchant_images"
id = Column(Integer, primary_key=True, autoincrement=True)
merchant_id = Column(Integer, ForeignKey("merchants.id", ondelete="CASCADE"), index=True)
image_url = Column(String(500), nullable=False)
sort = Column(Integer, nullable=False, default=0)
create_time = Column(DateTime(timezone=True), server_default=func.now())
class Config:
unique_together = [("merchant_id", "sort")]
# 数据库模型 # 数据库模型
class MerchantDB(Base): class MerchantDB(Base):
__tablename__ = "merchants" __tablename__ = "merchants"
@ -36,19 +22,7 @@ class MerchantDB(Base):
create_time = Column(DateTime(timezone=True), server_default=func.now()) create_time = Column(DateTime(timezone=True), server_default=func.now())
update_time = Column(DateTime(timezone=True), onupdate=func.now()) update_time = Column(DateTime(timezone=True), onupdate=func.now())
category_id = Column(Integer, ForeignKey("merchant_categories.id"), nullable=True) category_id = Column(Integer, ForeignKey("merchant_categories.id"), nullable=True)
pay_gift_points_rate = Column(DECIMAL(4,2), nullable=False, default=0.00) # 支付赠送积分比例默认0%
# 关联图片
images = relationship("MerchantImageDB",
order_by="MerchantImageDB.sort",
cascade="all, delete-orphan")
# Pydantic 模型
class MerchantImage(BaseModel):
image_url: str = Field(..., max_length=500)
sort: int = Field(..., ge=0) # 排序序号从0开始
class Config:
from_attributes = True
class MerchantCreate(BaseModel): class MerchantCreate(BaseModel):
user_id: int user_id: int
@ -58,8 +32,8 @@ class MerchantCreate(BaseModel):
longitude: float = Field(..., ge=-180, le=180, description="经度") longitude: float = Field(..., ge=-180, le=180, description="经度")
latitude: float = Field(..., ge=-90, le=90, description="纬度") latitude: float = Field(..., ge=-90, le=90, description="纬度")
phone: str = Field(..., max_length=20, pattern=r'^\d+$') phone: str = Field(..., max_length=20, pattern=r'^\d+$')
pay_gift_points_rate: Optional[float] = Field(10.00, ge=0, le=100) # 支付赠送积分比例
brand_image_url: Optional[str] = Field(None, max_length=200) brand_image_url: Optional[str] = Field(None, max_length=200)
images: List[MerchantImage] = []
category_id: Optional[int] = None category_id: Optional[int] = None
class MerchantUpdate(BaseModel): class MerchantUpdate(BaseModel):
@ -70,8 +44,8 @@ class MerchantUpdate(BaseModel):
longitude: Optional[float] = Field(None, ge=-180, le=180, description="经度") longitude: Optional[float] = Field(None, ge=-180, le=180, description="经度")
latitude: Optional[float] = Field(None, ge=-90, le=90, description="纬度") latitude: Optional[float] = Field(None, ge=-90, le=90, description="纬度")
phone: Optional[str] = Field(None, max_length=20, pattern=r'^\d+$') phone: Optional[str] = Field(None, max_length=20, pattern=r'^\d+$')
pay_gift_points_rate: Optional[float] = Field(None, ge=0, le=100) # 支付赠送积分比例
brand_image_url: Optional[str] = Field(None, max_length=200) brand_image_url: Optional[str] = Field(None, max_length=200)
images: Optional[List[MerchantImage]] = None
category_id: Optional[int] = None category_id: Optional[int] = None
class MerchantInfo(BaseModel): class MerchantInfo(BaseModel):
@ -86,8 +60,8 @@ class MerchantInfo(BaseModel):
longitude: float longitude: float
latitude: float latitude: float
phone: str phone: str
pay_gift_points_rate: float
brand_image_url: Optional[str] = None brand_image_url: Optional[str] = None
images: List[MerchantImage]
create_time: datetime create_time: datetime
update_time: Optional[datetime] update_time: Optional[datetime]
distance: Optional[int] = None # 距离(米) distance: Optional[int] = None # 距离(米)

View File

@ -16,14 +16,15 @@ class MerchantProductDB(Base):
__tablename__ = "merchant_products" __tablename__ = "merchant_products"
id = Column(Integer, primary_key=True, autoincrement=True) id = Column(Integer, primary_key=True, autoincrement=True)
merchant_id = Column(Integer, ForeignKey("merchants.id", ondelete="CASCADE"), index=True) merchant_id = Column(Integer, ForeignKey("merchants.id"), nullable=False)
name = Column(String(100), nullable=False) name = Column(String(100), nullable=False)
image_url = Column(String(500), nullable=False) image_url = Column(String(500), nullable=False)
product_price = Column(Float, nullable=False) # 原价 product_price = Column(Float, nullable=False) # 原价
sale_price = Column(Float, nullable=False) # 售价 sale_price = Column(DECIMAL(10,2), nullable=False) # 售价
settlement_amount = Column(DECIMAL(10,2), nullable=False) # 商家结算金额 settlement_amount = Column(DECIMAL(10,2), nullable=False) # 商家结算金额
tags = Column(String(200)) # 标签,逗号分隔 tags = Column(String(200)) # 标签,逗号分隔
purchase_limit = Column(Integer, nullable=False, default=0) # 限购次数0表示不限购 purchase_limit = Column(Integer, nullable=False, default=0) # 限购次数0表示不限购
gift_points_rate = Column(DECIMAL(4,2), nullable=False, default=0.00) # 购买赠送积分比例默认0%
create_time = Column(DateTime(timezone=True), server_default=func.now()) create_time = Column(DateTime(timezone=True), server_default=func.now())
update_time = Column(DateTime(timezone=True), onupdate=func.now()) update_time = Column(DateTime(timezone=True), onupdate=func.now())
status = Column(Enum(ProductStatus), nullable=False, default=ProductStatus.UNLISTING) status = Column(Enum(ProductStatus), nullable=False, default=ProductStatus.UNLISTING)
@ -40,6 +41,7 @@ class MerchantProductCreate(BaseModel):
tags: str = Field("", max_length=200) tags: str = Field("", max_length=200)
purchase_limit: int = Field(0, ge=0) # 限购次数默认0表示不限购 purchase_limit: int = Field(0, ge=0) # 限购次数默认0表示不限购
status: ProductStatus = ProductStatus.UNLISTING status: ProductStatus = ProductStatus.UNLISTING
gift_points_rate: Optional[float] = Field(10.00, ge=0, le=100) # 购买赠送积分比例
class MerchantProductUpdate(BaseModel): class MerchantProductUpdate(BaseModel):
name: Optional[str] = Field(None, max_length=100) name: Optional[str] = Field(None, max_length=100)
@ -50,6 +52,7 @@ class MerchantProductUpdate(BaseModel):
tags: Optional[str] = Field(None, max_length=200) tags: Optional[str] = Field(None, max_length=200)
purchase_limit: Optional[int] = Field(None, ge=0) # 限购次数,可选字段 purchase_limit: Optional[int] = Field(None, ge=0) # 限购次数,可选字段
status: Optional[ProductStatus] = None status: Optional[ProductStatus] = None
gift_points_rate: Optional[float] = Field(None, ge=0, le=100) # 购买赠送积分比例
class MerchantProductInfo(BaseModel): class MerchantProductInfo(BaseModel):
id: int id: int
@ -62,6 +65,7 @@ class MerchantProductInfo(BaseModel):
settlement_amount: float settlement_amount: float
tags: str tags: str
purchase_limit: int # 限购次数 purchase_limit: int # 限购次数
gift_points_rate: float
create_time: datetime create_time: datetime
update_time: Optional[datetime] update_time: Optional[datetime]
status: ProductStatus status: ProductStatus