删除商品分类,增加商家分类。

This commit is contained in:
aaron 2025-01-10 09:56:57 +08:00
parent be75d8d3fa
commit e8ca3c6684
7 changed files with 142 additions and 155 deletions

View File

@ -9,6 +9,7 @@ from app.models.merchant import (
MerchantInfo, MerchantInfo,
MerchantImageDB MerchantImageDB
) )
from app.models.merchant_category import MerchantCategoryDB
from app.models.database import get_db from app.models.database import get_db
from app.api.deps import get_admin_user from app.api.deps import get_admin_user
from app.models.user import UserDB from app.models.user import UserDB
@ -110,37 +111,46 @@ async def get_merchant(
async def list_merchants( async def list_merchants(
longitude: Optional[float] = None, longitude: Optional[float] = None,
latitude: Optional[float] = None, latitude: Optional[float] = None,
category_id: Optional[int] = None,
skip: int = 0, skip: int = 0,
limit: int = 20, limit: int = 20,
db: Session = Depends(get_db) db: Session = Depends(get_db)
): ):
"""获取商家列表,如果提供经纬度则按距离排序,否则按创建时间排序""" """获取商家列表,支持经纬度排序和分类过滤"""
if longitude is not None and latitude is not None: query = db.query(
# 使用 MySQL 的 ST_Distance_Sphere 计算距离(单位:米)
merchants = db.query(
MerchantDB, MerchantDB,
text("ST_Distance_Sphere(point(longitude, latitude), point(:lon, :lat)) as distance") MerchantCategoryDB.name.label('category_name')
).params( ).outerjoin(
lon=longitude, MerchantCategoryDB,
lat=latitude MerchantDB.category_id == MerchantCategoryDB.id
).order_by( )
text("distance")
).offset(skip).limit(limit).all()
# 添加分类过滤
if category_id is not None:
query = query.filter(MerchantDB.category_id == category_id)
# 根据经纬度排序
if longitude is not None and latitude is not None:
query = query.add_columns(
text("ST_Distance_Sphere(point(longitude, latitude), point(:lon, :lat)) as distance")
).params(lon=longitude, lat=latitude).order_by(text("distance"))
else:
query = query.order_by(MerchantDB.create_time.desc())
merchants = query.offset(skip).limit(limit).all()
# 处理返回结果
merchant_list = [{ merchant_list = [{
**MerchantInfo.model_validate(m[0]).model_dump(), **MerchantInfo.model_validate(m[0]).model_dump(),
"distance": round(m[1]) # 四舍五入到整数米 "category_name": m[1],
"distance": round(m[2]) if longitude is not None and latitude is not None else None
} for m in merchants] } for m in merchants]
else:
# 如果没有提供经纬度,按创建时间排序
merchants = db.query(MerchantDB).order_by(
MerchantDB.create_time.desc()
).offset(skip).limit(limit).all()
merchant_list = [MerchantInfo.model_validate(m) for m in merchants] # 获取总数(需要考虑分类过滤)
total_query = db.query(MerchantDB)
# 获取总数 if category_id is not None:
total = db.query(MerchantDB).count() total_query = total_query.filter(MerchantDB.category_id == category_id)
total = total_query.count()
return success_response(data={ return success_response(data={
"total": total, "total": total,

View File

@ -0,0 +1,71 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.models.database import get_db
from app.models.merchant_category import MerchantCategoryDB, MerchantCategoryCreate, MerchantCategoryUpdate, MerchantCategory
from app.core.response import success_response, error_response, ResponseModel
from app.api.deps import get_admin_user
from typing import List
router = APIRouter()
@router.post("", response_model=ResponseModel)
async def create_category(
category: MerchantCategoryCreate,
db: Session = Depends(get_db),
admin = Depends(get_admin_user)
):
"""创建商家分类(管理员)"""
db_category = MerchantCategoryDB(**category.model_dump())
db.add(db_category)
try:
db.commit()
db.refresh(db_category)
return success_response(data=MerchantCategory.model_validate(db_category))
except Exception as e:
db.rollback()
return error_response(code=500, message=f"创建失败: {str(e)}")
@router.put("/{category_id}", response_model=ResponseModel)
async def update_category(
category_id: int,
category: MerchantCategoryUpdate,
db: Session = Depends(get_db),
admin = Depends(get_admin_user)
):
"""更新商家分类(管理员)"""
db_category = db.query(MerchantCategoryDB).filter(
MerchantCategoryDB.id == category_id
).first()
if not db_category:
return error_response(code=404, message="分类不存在")
update_data = category.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_category, key, value)
try:
db.commit()
db.refresh(db_category)
return success_response(data=MerchantCategory.model_validate(db_category))
except Exception as e:
db.rollback()
return error_response(code=500, message=f"更新失败: {str(e)}")
@router.get("", response_model=ResponseModel)
async def list_categories(
db: Session = Depends(get_db),
skip: int = 0,
limit: int = 100
):
"""获取商家分类列表"""
categories = db.query(MerchantCategoryDB).order_by(
MerchantCategoryDB.sort.desc()
).offset(skip).limit(limit).all()
total = db.query(MerchantCategoryDB).count()
return success_response(data={
"total": total,
"items": [MerchantCategory.model_validate(c) for c in categories]
})

View File

@ -1,98 +0,0 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List
from app.models.merchant_product import (
MerchantProductCategoryDB,
ProductCategoryCreate,
ProductCategoryUpdate,
ProductCategoryInfo
)
from app.models.database import get_db
from app.api.deps import get_admin_user
from app.models.user import UserDB
from app.core.response import success_response, error_response, ResponseModel
router = APIRouter()
@router.post("", response_model=ResponseModel)
async def create_category(
category: ProductCategoryCreate,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""创建商品分类"""
db_category = MerchantProductCategoryDB(**category.model_dump())
db.add(db_category)
try:
db.commit()
db.refresh(db_category)
return success_response(data=ProductCategoryInfo.model_validate(db_category))
except Exception as e:
db.rollback()
return error_response(code=500, message=f"创建失败: {str(e)}")
@router.get("/merchant/{merchant_id}", response_model=ResponseModel)
async def list_categories(
merchant_id: int,
db: Session = Depends(get_db)
):
"""获取商家的所有分类"""
categories = db.query(MerchantProductCategoryDB).filter(
MerchantProductCategoryDB.merchant_id == merchant_id
).order_by(
MerchantProductCategoryDB.sort
).all()
return success_response(data=[
ProductCategoryInfo.model_validate(c) for c in categories
])
@router.put("/{category_id}", response_model=ResponseModel)
async def update_category(
category_id: int,
category: ProductCategoryUpdate,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""更新分类"""
db_category = db.query(MerchantProductCategoryDB).filter(
MerchantProductCategoryDB.id == category_id
).first()
if not db_category:
return error_response(code=404, message="分类不存在")
update_data = category.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_category, key, value)
try:
db.commit()
db.refresh(db_category)
return success_response(data=ProductCategoryInfo.model_validate(db_category))
except Exception as e:
db.rollback()
return error_response(code=500, message=f"更新失败: {str(e)}")
@router.delete("/{category_id}", response_model=ResponseModel)
async def delete_category(
category_id: int,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""删除分类"""
db_category = db.query(MerchantProductCategoryDB).filter(
MerchantProductCategoryDB.id == category_id
).first()
if not db_category:
return error_response(code=404, message="分类不存在")
try:
db.delete(db_category)
db.commit()
return success_response(message="删除成功")
except Exception as e:
db.rollback()
return error_response(code=500, message=f"删除失败: {str(e)}")

View File

@ -1,6 +1,6 @@
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from app.api.endpoints import user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_product_category, merchant_order, point, config from app.api.endpoints import user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_order, point, config, merchant_category
from app.models.database import Base, engine from app.models.database import Base, engine
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@ -35,12 +35,13 @@ app.include_router(order.router, prefix="/api/order", tags=["订单"])
app.include_router(coupon.router, prefix="/api/coupon", tags=["跑腿券"]) app.include_router(coupon.router, prefix="/api/coupon", tags=["跑腿券"])
app.include_router(upload.router, prefix="/api/upload", tags=["文件上传"]) app.include_router(upload.router, prefix="/api/upload", tags=["文件上传"])
app.include_router(merchant.router, prefix="/api/merchant", tags=["商家"]) app.include_router(merchant.router, prefix="/api/merchant", tags=["商家"])
app.include_router(merchant_category.router, prefix="/api/merchant-categories", tags=["商家分类"])
app.include_router(merchant_product.router, prefix="/api/merchant/product", tags=["商家产品"]) app.include_router(merchant_product.router, prefix="/api/merchant/product", tags=["商家产品"])
app.include_router(merchant_product_category.router, prefix="/api/merchant/category", tags=["商品分类"])
app.include_router(merchant_order.router, prefix="/api/merchant/order", tags=["商家订单"]) app.include_router(merchant_order.router, prefix="/api/merchant/order", tags=["商家订单"])
app.include_router(point.router, prefix="/api/point", tags=["用户积分"]) app.include_router(point.router, prefix="/api/point", tags=["用户积分"])
app.include_router(config.router, prefix="/api/config", tags=["系统配置"]) app.include_router(config.router, prefix="/api/config", tags=["系统配置"])
@app.get("/") @app.get("/")
async def root(): async def root():
return {"message": "欢迎使用 FastAPI!"} return {"message": "欢迎使用 FastAPI!"}

View File

@ -33,6 +33,7 @@ class MerchantDB(Base):
phone = Column(String(20), nullable=False) phone = Column(String(20), nullable=False)
create_time = Column(DateTime(timezone=True), server_default=func.now()) create_time = Column(DateTime(timezone=True), server_default=func.now())
update_time = Column(DateTime(timezone=True), onupdate=func.now()) update_time = Column(DateTime(timezone=True), onupdate=func.now())
category_id = Column(Integer, ForeignKey("merchant_categories.id"), nullable=True)
# 关联图片 # 关联图片
images = relationship("MerchantImageDB", images = relationship("MerchantImageDB",
@ -55,6 +56,7 @@ class MerchantCreate(BaseModel):
latitude: float = Field(..., ge=-90, le=90, description="纬度") latitude: float = Field(..., ge=-90, le=90, description="纬度")
phone: str = Field(..., max_length=20, pattern=r'^\d+$') phone: str = Field(..., max_length=20, pattern=r'^\d+$')
images: List[MerchantImage] = [] images: List[MerchantImage] = []
category_id: Optional[int] = None
class MerchantUpdate(BaseModel): class MerchantUpdate(BaseModel):
name: Optional[str] = Field(None, max_length=100) name: Optional[str] = Field(None, max_length=100)
@ -64,6 +66,7 @@ class MerchantUpdate(BaseModel):
latitude: Optional[float] = Field(None, ge=-90, le=90, description="纬度") latitude: Optional[float] = Field(None, ge=-90, le=90, description="纬度")
phone: Optional[str] = Field(None, max_length=20, pattern=r'^\d+$') phone: Optional[str] = Field(None, max_length=20, pattern=r'^\d+$')
images: Optional[List[MerchantImage]] = None images: Optional[List[MerchantImage]] = None
category_id: Optional[int] = None
class MerchantInfo(BaseModel): class MerchantInfo(BaseModel):
id: int id: int
@ -77,6 +80,8 @@ class MerchantInfo(BaseModel):
create_time: datetime create_time: datetime
update_time: Optional[datetime] update_time: Optional[datetime]
distance: Optional[int] = None # 距离(米) distance: Optional[int] = None # 距离(米)
category_id: Optional[int] = None
category_name: Optional[str] = None # 用于关联查询显示分类名称
class Config: class Config:
from_attributes = True from_attributes = True

View File

@ -0,0 +1,28 @@
from sqlalchemy import Column, Integer, String
from app.models.database import Base
from pydantic import BaseModel
from typing import Optional
class MerchantCategoryDB(Base):
__tablename__ = "merchant_categories"
id = Column(Integer, primary_key=True, index=True, autoincrement=True)
name = Column(String(50), nullable=False, comment="分类名称")
sort = Column(Integer, default=0, comment="排序")
class MerchantCategoryCreate(BaseModel):
name: str
sort: int = 0
class MerchantCategoryUpdate(BaseModel):
name: Optional[str] = None
sort: Optional[int] = None
class MerchantCategory(BaseModel):
id: int
name: str
sort: int
class Config:
from_attributes = True

View File

@ -6,16 +6,6 @@ from typing import Optional, List
from datetime import datetime from datetime import datetime
from .database import Base from .database import Base
# 商品分类表
class MerchantProductCategoryDB(Base):
__tablename__ = "merchant_product_categories"
id = Column(Integer, primary_key=True, autoincrement=True)
merchant_id = Column(Integer, ForeignKey("merchants.id", ondelete="CASCADE"), index=True)
name = Column(String(50), nullable=False)
sort = Column(Integer, nullable=False, default=0)
create_time = Column(DateTime(timezone=True), server_default=func.now())
update_time = Column(DateTime(timezone=True), onupdate=func.now())
class MerchantProductDB(Base): class MerchantProductDB(Base):
__tablename__ = "merchant_products" __tablename__ = "merchant_products"
@ -33,26 +23,6 @@ class MerchantProductDB(Base):
create_time = Column(DateTime(timezone=True), server_default=func.now()) create_time = Column(DateTime(timezone=True), server_default=func.now())
update_time = Column(DateTime(timezone=True), onupdate=func.now()) update_time = Column(DateTime(timezone=True), onupdate=func.now())
# Pydantic 模型 - 分类
class ProductCategoryCreate(BaseModel):
merchant_id: int
name: str = Field(..., max_length=50)
sort: int = Field(0, ge=0)
class ProductCategoryUpdate(BaseModel):
name: Optional[str] = Field(None, max_length=50)
sort: Optional[int] = Field(None, ge=0)
class ProductCategoryInfo(BaseModel):
id: int
merchant_id: int
name: str
sort: int
create_time: datetime
update_time: Optional[datetime]
class Config:
from_attributes = True
# Pydantic 模型 # Pydantic 模型
class MerchantProductCreate(BaseModel): class MerchantProductCreate(BaseModel):