diff --git a/app/api/endpoints/merchant.py b/app/api/endpoints/merchant.py index 9959a06..2b2fb07 100644 --- a/app/api/endpoints/merchant.py +++ b/app/api/endpoints/merchant.py @@ -9,6 +9,7 @@ from app.models.merchant import ( MerchantInfo, MerchantImageDB ) +from app.models.merchant_category import MerchantCategoryDB from app.models.database import get_db from app.api.deps import get_admin_user from app.models.user import UserDB @@ -110,38 +111,47 @@ async def get_merchant( async def list_merchants( longitude: Optional[float] = None, latitude: Optional[float] = None, + category_id: Optional[int] = None, skip: int = 0, limit: int = 20, db: Session = Depends(get_db) ): - """获取商家列表,如果提供经纬度则按距离排序,否则按创建时间排序""" + """获取商家列表,支持经纬度排序和分类过滤""" + query = db.query( + MerchantDB, + MerchantCategoryDB.name.label('category_name') + ).outerjoin( + MerchantCategoryDB, + MerchantDB.category_id == MerchantCategoryDB.id + ) + + # 添加分类过滤 + if category_id is not None: + query = query.filter(MerchantDB.category_id == category_id) + + # 根据经纬度排序 if longitude is not None and latitude is not None: - # 使用 MySQL 的 ST_Distance_Sphere 计算距离(单位:米) - merchants = db.query( - MerchantDB, + query = query.add_columns( text("ST_Distance_Sphere(point(longitude, latitude), point(:lon, :lat)) as distance") - ).params( - lon=longitude, - lat=latitude - ).order_by( - text("distance") - ).offset(skip).limit(limit).all() - - merchant_list = [{ - **MerchantInfo.model_validate(m[0]).model_dump(), - "distance": round(m[1]) # 四舍五入到整数米 - } for m in merchants] + ).params(lon=longitude, lat=latitude).order_by(text("distance")) else: - # 如果没有提供经纬度,按创建时间排序 - merchants = db.query(MerchantDB).order_by( - MerchantDB.create_time.desc() - ).offset(skip).limit(limit).all() - - merchant_list = [MerchantInfo.model_validate(m) for m in merchants] - - # 获取总数 - total = db.query(MerchantDB).count() - + query = query.order_by(MerchantDB.create_time.desc()) + + merchants = query.offset(skip).limit(limit).all() + + # 处理返回结果 + merchant_list = [{ + **MerchantInfo.model_validate(m[0]).model_dump(), + "category_name": m[1], + "distance": round(m[2]) if longitude is not None and latitude is not None else None + } for m in merchants] + + # 获取总数(需要考虑分类过滤) + total_query = db.query(MerchantDB) + if category_id is not None: + total_query = total_query.filter(MerchantDB.category_id == category_id) + total = total_query.count() + return success_response(data={ "total": total, "items": merchant_list diff --git a/app/api/endpoints/merchant_category.py b/app/api/endpoints/merchant_category.py new file mode 100644 index 0000000..338770f --- /dev/null +++ b/app/api/endpoints/merchant_category.py @@ -0,0 +1,71 @@ +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from app.models.database import get_db +from app.models.merchant_category import MerchantCategoryDB, MerchantCategoryCreate, MerchantCategoryUpdate, MerchantCategory +from app.core.response import success_response, error_response, ResponseModel +from app.api.deps import get_admin_user +from typing import List + +router = APIRouter() + +@router.post("", response_model=ResponseModel) +async def create_category( + category: MerchantCategoryCreate, + db: Session = Depends(get_db), + admin = Depends(get_admin_user) +): + """创建商家分类(管理员)""" + db_category = MerchantCategoryDB(**category.model_dump()) + db.add(db_category) + try: + db.commit() + db.refresh(db_category) + return success_response(data=MerchantCategory.model_validate(db_category)) + except Exception as e: + db.rollback() + return error_response(code=500, message=f"创建失败: {str(e)}") + +@router.put("/{category_id}", response_model=ResponseModel) +async def update_category( + category_id: int, + category: MerchantCategoryUpdate, + db: Session = Depends(get_db), + admin = Depends(get_admin_user) +): + """更新商家分类(管理员)""" + db_category = db.query(MerchantCategoryDB).filter( + MerchantCategoryDB.id == category_id + ).first() + + if not db_category: + return error_response(code=404, message="分类不存在") + + update_data = category.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(db_category, key, value) + + try: + db.commit() + db.refresh(db_category) + return success_response(data=MerchantCategory.model_validate(db_category)) + except Exception as e: + db.rollback() + return error_response(code=500, message=f"更新失败: {str(e)}") + +@router.get("", response_model=ResponseModel) +async def list_categories( + db: Session = Depends(get_db), + skip: int = 0, + limit: int = 100 +): + """获取商家分类列表""" + categories = db.query(MerchantCategoryDB).order_by( + MerchantCategoryDB.sort.desc() + ).offset(skip).limit(limit).all() + + total = db.query(MerchantCategoryDB).count() + + return success_response(data={ + "total": total, + "items": [MerchantCategory.model_validate(c) for c in categories] + }) \ No newline at end of file diff --git a/app/api/endpoints/merchant_product_category.py b/app/api/endpoints/merchant_product_category.py deleted file mode 100644 index 6da8592..0000000 --- a/app/api/endpoints/merchant_product_category.py +++ /dev/null @@ -1,98 +0,0 @@ -from fastapi import APIRouter, Depends -from sqlalchemy.orm import Session -from typing import List -from app.models.merchant_product import ( - MerchantProductCategoryDB, - ProductCategoryCreate, - ProductCategoryUpdate, - ProductCategoryInfo -) -from app.models.database import get_db -from app.api.deps import get_admin_user -from app.models.user import UserDB -from app.core.response import success_response, error_response, ResponseModel - -router = APIRouter() - -@router.post("", response_model=ResponseModel) -async def create_category( - category: ProductCategoryCreate, - db: Session = Depends(get_db), - admin: UserDB = Depends(get_admin_user) -): - """创建商品分类""" - db_category = MerchantProductCategoryDB(**category.model_dump()) - db.add(db_category) - - try: - db.commit() - db.refresh(db_category) - return success_response(data=ProductCategoryInfo.model_validate(db_category)) - except Exception as e: - db.rollback() - return error_response(code=500, message=f"创建失败: {str(e)}") - -@router.get("/merchant/{merchant_id}", response_model=ResponseModel) -async def list_categories( - merchant_id: int, - db: Session = Depends(get_db) -): - """获取商家的所有分类""" - categories = db.query(MerchantProductCategoryDB).filter( - MerchantProductCategoryDB.merchant_id == merchant_id - ).order_by( - MerchantProductCategoryDB.sort - ).all() - - return success_response(data=[ - ProductCategoryInfo.model_validate(c) for c in categories - ]) - -@router.put("/{category_id}", response_model=ResponseModel) -async def update_category( - category_id: int, - category: ProductCategoryUpdate, - db: Session = Depends(get_db), - admin: UserDB = Depends(get_admin_user) -): - """更新分类""" - db_category = db.query(MerchantProductCategoryDB).filter( - MerchantProductCategoryDB.id == category_id - ).first() - - if not db_category: - return error_response(code=404, message="分类不存在") - - update_data = category.model_dump(exclude_unset=True) - for key, value in update_data.items(): - setattr(db_category, key, value) - - try: - db.commit() - db.refresh(db_category) - return success_response(data=ProductCategoryInfo.model_validate(db_category)) - except Exception as e: - db.rollback() - return error_response(code=500, message=f"更新失败: {str(e)}") - -@router.delete("/{category_id}", response_model=ResponseModel) -async def delete_category( - category_id: int, - db: Session = Depends(get_db), - admin: UserDB = Depends(get_admin_user) -): - """删除分类""" - db_category = db.query(MerchantProductCategoryDB).filter( - MerchantProductCategoryDB.id == category_id - ).first() - - if not db_category: - return error_response(code=404, message="分类不存在") - - try: - db.delete(db_category) - db.commit() - return success_response(message="删除成功") - except Exception as e: - db.rollback() - return error_response(code=500, message=f"删除失败: {str(e)}") \ No newline at end of file diff --git a/app/main.py b/app/main.py index 66b34c2..98c7c3d 100644 --- a/app/main.py +++ b/app/main.py @@ -1,6 +1,6 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from app.api.endpoints import user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_product_category, merchant_order, point, config +from app.api.endpoints import user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_order, point, config, merchant_category from app.models.database import Base, engine from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse @@ -35,12 +35,13 @@ app.include_router(order.router, prefix="/api/order", tags=["订单"]) app.include_router(coupon.router, prefix="/api/coupon", tags=["跑腿券"]) app.include_router(upload.router, prefix="/api/upload", tags=["文件上传"]) app.include_router(merchant.router, prefix="/api/merchant", tags=["商家"]) +app.include_router(merchant_category.router, prefix="/api/merchant-categories", tags=["商家分类"]) app.include_router(merchant_product.router, prefix="/api/merchant/product", tags=["商家产品"]) -app.include_router(merchant_product_category.router, prefix="/api/merchant/category", tags=["商品分类"]) app.include_router(merchant_order.router, prefix="/api/merchant/order", tags=["商家订单"]) app.include_router(point.router, prefix="/api/point", tags=["用户积分"]) app.include_router(config.router, prefix="/api/config", tags=["系统配置"]) + @app.get("/") async def root(): return {"message": "欢迎使用 FastAPI!"} diff --git a/app/models/merchant.py b/app/models/merchant.py index 5f2ca23..4a32977 100644 --- a/app/models/merchant.py +++ b/app/models/merchant.py @@ -33,6 +33,7 @@ class MerchantDB(Base): phone = Column(String(20), nullable=False) create_time = Column(DateTime(timezone=True), server_default=func.now()) update_time = Column(DateTime(timezone=True), onupdate=func.now()) + category_id = Column(Integer, ForeignKey("merchant_categories.id"), nullable=True) # 关联图片 images = relationship("MerchantImageDB", @@ -55,6 +56,7 @@ class MerchantCreate(BaseModel): latitude: float = Field(..., ge=-90, le=90, description="纬度") phone: str = Field(..., max_length=20, pattern=r'^\d+$') images: List[MerchantImage] = [] + category_id: Optional[int] = None class MerchantUpdate(BaseModel): name: Optional[str] = Field(None, max_length=100) @@ -64,6 +66,7 @@ class MerchantUpdate(BaseModel): latitude: Optional[float] = Field(None, ge=-90, le=90, description="纬度") phone: Optional[str] = Field(None, max_length=20, pattern=r'^\d+$') images: Optional[List[MerchantImage]] = None + category_id: Optional[int] = None class MerchantInfo(BaseModel): id: int @@ -77,6 +80,8 @@ class MerchantInfo(BaseModel): create_time: datetime update_time: Optional[datetime] distance: Optional[int] = None # 距离(米) + category_id: Optional[int] = None + category_name: Optional[str] = None # 用于关联查询显示分类名称 class Config: from_attributes = True \ No newline at end of file diff --git a/app/models/merchant_category.py b/app/models/merchant_category.py new file mode 100644 index 0000000..d74d0fc --- /dev/null +++ b/app/models/merchant_category.py @@ -0,0 +1,28 @@ +from sqlalchemy import Column, Integer, String +from app.models.database import Base +from pydantic import BaseModel +from typing import Optional + +class MerchantCategoryDB(Base): + __tablename__ = "merchant_categories" + + id = Column(Integer, primary_key=True, index=True, autoincrement=True) + name = Column(String(50), nullable=False, comment="分类名称") + sort = Column(Integer, default=0, comment="排序") + + +class MerchantCategoryCreate(BaseModel): + name: str + sort: int = 0 + +class MerchantCategoryUpdate(BaseModel): + name: Optional[str] = None + sort: Optional[int] = None + +class MerchantCategory(BaseModel): + id: int + name: str + sort: int + + class Config: + from_attributes = True \ No newline at end of file diff --git a/app/models/merchant_product.py b/app/models/merchant_product.py index 0b86b75..286835a 100644 --- a/app/models/merchant_product.py +++ b/app/models/merchant_product.py @@ -6,16 +6,6 @@ from typing import Optional, List from datetime import datetime from .database import Base -# 商品分类表 -class MerchantProductCategoryDB(Base): - __tablename__ = "merchant_product_categories" - - id = Column(Integer, primary_key=True, autoincrement=True) - merchant_id = Column(Integer, ForeignKey("merchants.id", ondelete="CASCADE"), index=True) - name = Column(String(50), nullable=False) - sort = Column(Integer, nullable=False, default=0) - create_time = Column(DateTime(timezone=True), server_default=func.now()) - update_time = Column(DateTime(timezone=True), onupdate=func.now()) class MerchantProductDB(Base): __tablename__ = "merchant_products" @@ -33,26 +23,6 @@ class MerchantProductDB(Base): create_time = Column(DateTime(timezone=True), server_default=func.now()) update_time = Column(DateTime(timezone=True), onupdate=func.now()) -# Pydantic 模型 - 分类 -class ProductCategoryCreate(BaseModel): - merchant_id: int - name: str = Field(..., max_length=50) - sort: int = Field(0, ge=0) - -class ProductCategoryUpdate(BaseModel): - name: Optional[str] = Field(None, max_length=50) - sort: Optional[int] = Field(None, ge=0) - -class ProductCategoryInfo(BaseModel): - id: int - merchant_id: int - name: str - sort: int - create_time: datetime - update_time: Optional[datetime] - - class Config: - from_attributes = True # Pydantic 模型 class MerchantProductCreate(BaseModel):