增加 merchant 相关的接口

This commit is contained in:
aaron 2025-01-06 00:47:03 +08:00
parent 87f3f1180c
commit d541f02022
6 changed files with 403 additions and 4 deletions

View File

@ -0,0 +1,142 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List, Optional
from sqlalchemy import text
from app.models.merchant import (
MerchantDB,
MerchantCreate,
MerchantUpdate,
MerchantInfo,
MerchantImageDB
)
from app.models.database import get_db
from app.api.deps import get_admin_user
from app.models.user import UserDB
from app.core.response import success_response, error_response, ResponseModel
router = APIRouter()
@router.post("", response_model=ResponseModel)
async def create_merchant(
merchant: MerchantCreate,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""创建商家(管理员)"""
# 创建商家基本信息
merchant_data = merchant.model_dump(exclude={'images'})
db_merchant = MerchantDB(**merchant_data)
db.add(db_merchant)
try:
db.flush() # 获取商家ID
# 创建商家图片
for image in merchant.images:
db_image = MerchantImageDB(
merchant_id=db_merchant.id,
image_url=image.image_url,
sort=image.sort
)
db.add(db_image)
db.commit()
db.refresh(db_merchant)
return success_response(data=MerchantInfo.model_validate(db_merchant))
except Exception as e:
db.rollback()
return error_response(code=500, message=f"创建失败: {str(e)}")
@router.put("/{merchant_id}", response_model=ResponseModel)
async def update_merchant(
merchant_id: int,
merchant: MerchantUpdate,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""更新商家信息(管理员)"""
db_merchant = db.query(MerchantDB).filter(
MerchantDB.id == merchant_id
).first()
if not db_merchant:
return error_response(code=404, message="商家不存在")
# 更新基本信息
update_data = merchant.model_dump(exclude={'images'}, exclude_unset=True)
for key, value in update_data.items():
setattr(db_merchant, key, value)
# 如果更新了图片
if merchant.images is not None:
# 删除原有图片
db.query(MerchantImageDB).filter(
MerchantImageDB.merchant_id == merchant_id
).delete()
# 添加新图片
for image in merchant.images:
db_image = MerchantImageDB(
merchant_id=merchant_id,
image_url=image.image_url,
sort=image.sort
)
db.add(db_image)
try:
db.commit()
db.refresh(db_merchant)
return success_response(data=MerchantInfo.model_validate(db_merchant))
except Exception as e:
db.rollback()
return error_response(code=500, message=f"更新失败: {str(e)}")
@router.get("/{merchant_id}", response_model=ResponseModel)
async def get_merchant(
merchant_id: int,
db: Session = Depends(get_db)
):
"""获取商家详情"""
merchant = db.query(MerchantDB).filter(
MerchantDB.id == merchant_id
).first()
if not merchant:
return error_response(code=404, message="商家不存在")
return success_response(data=MerchantInfo.model_validate(merchant))
@router.get("", response_model=ResponseModel)
async def list_merchants(
longitude: Optional[float] = None,
latitude: Optional[float] = None,
skip: int = 0,
limit: int = 20,
db: Session = Depends(get_db)
):
"""获取商家列表,如果提供经纬度则按距离排序,否则按创建时间排序"""
if longitude is not None and latitude is not None:
# 使用 MySQL 的 ST_Distance_Sphere 计算距离(单位:米)
merchants = db.query(
MerchantDB,
text("ST_Distance_Sphere(point(longitude, latitude), point(:lon, :lat)) as distance")
).params(
lon=longitude,
lat=latitude
).order_by(
text("distance")
).offset(skip).limit(limit).all()
return success_response(data=[{
**MerchantInfo.model_validate(m[0]).model_dump(),
"distance": round(m[1]) # 四舍五入到整数米
} for m in merchants])
else:
# 如果没有提供经纬度,按创建时间排序
merchants = db.query(MerchantDB).order_by(
MerchantDB.create_time.desc()
).offset(skip).limit(limit).all()
return success_response(data=[
MerchantInfo.model_validate(m) for m in merchants
])

View File

@ -0,0 +1,115 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List
from app.models.merchant_product import (
MerchantProductDB,
MerchantProductCreate,
MerchantProductUpdate,
MerchantProductInfo
)
from app.models.database import get_db
from app.api.deps import get_admin_user
from app.models.user import UserDB
from app.core.response import success_response, error_response, ResponseModel
router = APIRouter()
@router.post("", response_model=ResponseModel)
async def create_product(
product: MerchantProductCreate,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""创建商家产品(管理员)"""
db_product = MerchantProductDB(**product.model_dump())
db.add(db_product)
try:
db.commit()
db.refresh(db_product)
return success_response(data=MerchantProductInfo.model_validate(db_product))
except Exception as e:
db.rollback()
return error_response(code=500, message=f"创建失败: {str(e)}")
@router.put("/{product_id}", response_model=ResponseModel)
async def update_product(
product_id: int,
product: MerchantProductUpdate,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""更新商家产品(管理员)"""
db_product = db.query(MerchantProductDB).filter(
MerchantProductDB.id == product_id
).first()
if not db_product:
return error_response(code=404, message="产品不存在")
update_data = product.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_product, key, value)
try:
db.commit()
db.refresh(db_product)
return success_response(data=MerchantProductInfo.model_validate(db_product))
except Exception as e:
db.rollback()
return error_response(code=500, message=f"更新失败: {str(e)}")
@router.get("/merchant/{merchant_id}", response_model=ResponseModel)
async def list_merchant_products(
merchant_id: int,
skip: int = 0,
limit: int = 20,
db: Session = Depends(get_db)
):
"""获取商家的产品列表"""
products = db.query(MerchantProductDB).filter(
MerchantProductDB.merchant_id == merchant_id
).order_by(
MerchantProductDB.create_time.desc()
).offset(skip).limit(limit).all()
return success_response(data=[
MerchantProductInfo.model_validate(p) for p in products
])
@router.get("/{product_id}", response_model=ResponseModel)
async def get_product(
product_id: int,
db: Session = Depends(get_db)
):
"""获取产品详情"""
product = db.query(MerchantProductDB).filter(
MerchantProductDB.id == product_id
).first()
if not product:
return error_response(code=404, message="产品不存在")
return success_response(data=MerchantProductInfo.model_validate(product))
@router.delete("/{product_id}", response_model=ResponseModel)
async def delete_product(
product_id: int,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""删除商家产品(管理员)"""
db_product = db.query(MerchantProductDB).filter(
MerchantProductDB.id == product_id
).first()
if not db_product:
return error_response(code=404, message="产品不存在")
try:
db.delete(db_product)
db.commit()
return success_response(message="删除成功")
except Exception as e:
db.rollback()
return error_response(code=500, message=f"删除失败: {str(e)}")

View File

@ -1,6 +1,6 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.endpoints import user, address, community, station, order, coupon, community_building, upload
from app.api.endpoints import user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product
from app.models.database import Base, engine
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
@ -32,8 +32,10 @@ app.include_router(community.router, prefix="/api/community", tags=["社区"])
app.include_router(community_building.router, prefix="/api/community/building", tags=["社区楼栋"])
app.include_router(station.router, prefix="/api/station", tags=["驿站"])
app.include_router(order.router, prefix="/api/order", tags=["订单"])
app.include_router(coupon.router, prefix="/api/coupon", tags=["优惠"])
app.include_router(coupon.router, prefix="/api/coupon", tags=["跑腿"])
app.include_router(upload.router, prefix="/api/upload", tags=["文件上传"])
app.include_router(merchant.router, prefix="/api/merchant", tags=["商家"])
app.include_router(merchant_product.router, prefix="/api/merchant/product", tags=["商家产品"])
@app.get("/")
async def root():

View File

@ -21,13 +21,13 @@ class CommunityCreate(BaseModel):
name: str = Field(..., max_length=100)
address: str = Field(..., max_length=200)
longitude: float = Field(..., ge=-180, le=180)
latitude: float = Field(..., ge=-180, le=180)
latitude: float = Field(..., ge=-90, le=90)
class CommunityUpdate(BaseModel):
name: Optional[str] = Field(None, max_length=100)
address: Optional[str] = Field(None, max_length=200)
longitude: Optional[float] = Field(None, ge=-180, le=180)
latitude: Optional[float] = Field(None, ge=-180, le=180)
latitude: Optional[float] = Field(None, ge=-90, le=90)
class CommunityInfo(BaseModel):
id: int

82
app/models/merchant.py Normal file
View File

@ -0,0 +1,82 @@
from sqlalchemy import Column, String, Integer, Float, DateTime, JSON, ForeignKey
from sqlalchemy.dialects.mysql import DECIMAL
from sqlalchemy.sql import func, select
from sqlalchemy.orm import relationship
from pydantic import BaseModel, Field
from typing import Optional, List
from datetime import datetime
from .database import Base
# 商家图片表
class MerchantImageDB(Base):
__tablename__ = "merchant_images"
id = Column(Integer, primary_key=True, autoincrement=True)
merchant_id = Column(Integer, ForeignKey("merchants.id", ondelete="CASCADE"), index=True)
image_url = Column(String(500), nullable=False)
sort = Column(Integer, nullable=False, default=0)
create_time = Column(DateTime(timezone=True), server_default=func.now())
class Config:
unique_together = [("merchant_id", "sort")]
# 数据库模型
class MerchantDB(Base):
__tablename__ = "merchants"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(100), nullable=False)
business_hours = Column(String(100), nullable=False) # 营业时间,如 "09:00-22:00"
address = Column(String(200), nullable=False)
longitude = Column(DECIMAL(9, 6), nullable=False) # 经度精确到小数点后6位
latitude = Column(DECIMAL(9, 6), nullable=False) # 纬度精确到小数点后6位
phone = Column(String(20), nullable=False)
create_time = Column(DateTime(timezone=True), server_default=func.now())
update_time = Column(DateTime(timezone=True), onupdate=func.now())
# 关联图片
images = relationship("MerchantImageDB",
order_by="MerchantImageDB.sort",
cascade="all, delete-orphan")
# Pydantic 模型
class MerchantImage(BaseModel):
image_url: str = Field(..., max_length=500)
sort: int = Field(..., ge=0) # 排序序号从0开始
class Config:
from_attributes = True
class MerchantCreate(BaseModel):
name: str = Field(..., max_length=100)
business_hours: str = Field(..., max_length=100)
address: str = Field(..., max_length=200)
longitude: float = Field(..., ge=-180, le=180, description="经度")
latitude: float = Field(..., ge=-180, le=180, description="纬度")
phone: str = Field(..., max_length=20, pattern=r'^\d+$')
images: List[MerchantImage] = []
class MerchantUpdate(BaseModel):
name: Optional[str] = Field(None, max_length=100)
business_hours: Optional[str] = Field(None, max_length=100)
address: Optional[str] = Field(None, max_length=200)
longitude: Optional[float] = Field(None, ge=-180, le=180, description="经度")
latitude: Optional[float] = Field(None, ge=-180, le=180, description="纬度")
phone: Optional[str] = Field(None, max_length=20, pattern=r'^\d+$')
images: Optional[List[MerchantImage]] = None
class MerchantInfo(BaseModel):
id: int
name: str
business_hours: str
address: str
longitude: float
latitude: float
phone: str
images: List[MerchantImage]
create_time: datetime
update_time: Optional[datetime]
distance: Optional[int] = None # 距离(米)
class Config:
from_attributes = True

View File

@ -0,0 +1,58 @@
from sqlalchemy import Column, String, Integer, Float, DateTime, ForeignKey
from sqlalchemy.dialects.mysql import DECIMAL
from sqlalchemy.sql import func
from pydantic import BaseModel, Field
from typing import Optional, List
from datetime import datetime
from .database import Base
class MerchantProductDB(Base):
__tablename__ = "merchant_products"
id = Column(Integer, primary_key=True, autoincrement=True)
merchant_id = Column(Integer, ForeignKey("merchants.id", ondelete="CASCADE"), index=True)
name = Column(String(100), nullable=False)
image_url = Column(String(500), nullable=False)
product_price = Column(Float, nullable=False) # 原价
sale_price = Column(Float, nullable=False) # 售价
settlement_amount = Column(DECIMAL(10,2), nullable=False) # 商家结算金额
tags = Column(String(200)) # 标签,逗号分隔
max_deduct_points = Column(DECIMAL(10,2), default=0) # 最高可抵扣积分
create_time = Column(DateTime(timezone=True), server_default=func.now())
update_time = Column(DateTime(timezone=True), onupdate=func.now())
# Pydantic 模型
class MerchantProductCreate(BaseModel):
merchant_id: int
name: str = Field(..., max_length=100)
image_url: str = Field(..., max_length=500)
product_price: float = Field(..., gt=0)
sale_price: float = Field(..., gt=0)
settlement_amount: float = Field(..., gt=0)
tags: str = Field("", max_length=200)
max_deduct_points: float = Field(0.0, ge=0)
class MerchantProductUpdate(BaseModel):
name: Optional[str] = Field(None, max_length=100)
image_url: Optional[str] = Field(None, max_length=500)
product_price: Optional[float] = Field(None, gt=0)
sale_price: Optional[float] = Field(None, gt=0)
settlement_amount: Optional[float] = Field(None, gt=0)
tags: Optional[str] = Field(None, max_length=200)
max_deduct_points: Optional[float] = Field(None, ge=0)
class MerchantProductInfo(BaseModel):
id: int
merchant_id: int
name: str
image_url: str
product_price: float
sale_price: float
settlement_amount: float
tags: str
max_deduct_points: float
create_time: datetime
update_time: Optional[datetime]
class Config:
from_attributes = True