diff --git a/app/api/endpoints/merchant.py b/app/api/endpoints/merchant.py new file mode 100644 index 0000000..7f95215 --- /dev/null +++ b/app/api/endpoints/merchant.py @@ -0,0 +1,142 @@ +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from typing import List, Optional +from sqlalchemy import text +from app.models.merchant import ( + MerchantDB, + MerchantCreate, + MerchantUpdate, + MerchantInfo, + MerchantImageDB +) +from app.models.database import get_db +from app.api.deps import get_admin_user +from app.models.user import UserDB +from app.core.response import success_response, error_response, ResponseModel + +router = APIRouter() + +@router.post("", response_model=ResponseModel) +async def create_merchant( + merchant: MerchantCreate, + db: Session = Depends(get_db), + admin: UserDB = Depends(get_admin_user) +): + """创建商家(管理员)""" + # 创建商家基本信息 + merchant_data = merchant.model_dump(exclude={'images'}) + db_merchant = MerchantDB(**merchant_data) + db.add(db_merchant) + + try: + db.flush() # 获取商家ID + + # 创建商家图片 + for image in merchant.images: + db_image = MerchantImageDB( + merchant_id=db_merchant.id, + image_url=image.image_url, + sort=image.sort + ) + db.add(db_image) + + db.commit() + db.refresh(db_merchant) + return success_response(data=MerchantInfo.model_validate(db_merchant)) + except Exception as e: + db.rollback() + return error_response(code=500, message=f"创建失败: {str(e)}") + +@router.put("/{merchant_id}", response_model=ResponseModel) +async def update_merchant( + merchant_id: int, + merchant: MerchantUpdate, + db: Session = Depends(get_db), + admin: UserDB = Depends(get_admin_user) +): + """更新商家信息(管理员)""" + db_merchant = db.query(MerchantDB).filter( + MerchantDB.id == merchant_id + ).first() + + if not db_merchant: + return error_response(code=404, message="商家不存在") + + # 更新基本信息 + update_data = merchant.model_dump(exclude={'images'}, exclude_unset=True) + for key, value in update_data.items(): + setattr(db_merchant, key, value) + + # 如果更新了图片 + if merchant.images is not None: + # 删除原有图片 + db.query(MerchantImageDB).filter( + MerchantImageDB.merchant_id == merchant_id + ).delete() + + # 添加新图片 + for image in merchant.images: + db_image = MerchantImageDB( + merchant_id=merchant_id, + image_url=image.image_url, + sort=image.sort + ) + db.add(db_image) + + try: + db.commit() + db.refresh(db_merchant) + return success_response(data=MerchantInfo.model_validate(db_merchant)) + except Exception as e: + db.rollback() + return error_response(code=500, message=f"更新失败: {str(e)}") + +@router.get("/{merchant_id}", response_model=ResponseModel) +async def get_merchant( + merchant_id: int, + db: Session = Depends(get_db) +): + """获取商家详情""" + merchant = db.query(MerchantDB).filter( + MerchantDB.id == merchant_id + ).first() + + if not merchant: + return error_response(code=404, message="商家不存在") + + return success_response(data=MerchantInfo.model_validate(merchant)) + +@router.get("", response_model=ResponseModel) +async def list_merchants( + longitude: Optional[float] = None, + latitude: Optional[float] = None, + skip: int = 0, + limit: int = 20, + db: Session = Depends(get_db) +): + """获取商家列表,如果提供经纬度则按距离排序,否则按创建时间排序""" + if longitude is not None and latitude is not None: + # 使用 MySQL 的 ST_Distance_Sphere 计算距离(单位:米) + merchants = db.query( + MerchantDB, + text("ST_Distance_Sphere(point(longitude, latitude), point(:lon, :lat)) as distance") + ).params( + lon=longitude, + lat=latitude + ).order_by( + text("distance") + ).offset(skip).limit(limit).all() + + return success_response(data=[{ + **MerchantInfo.model_validate(m[0]).model_dump(), + "distance": round(m[1]) # 四舍五入到整数米 + } for m in merchants]) + else: + # 如果没有提供经纬度,按创建时间排序 + merchants = db.query(MerchantDB).order_by( + MerchantDB.create_time.desc() + ).offset(skip).limit(limit).all() + + return success_response(data=[ + MerchantInfo.model_validate(m) for m in merchants + ]) \ No newline at end of file diff --git a/app/api/endpoints/merchant_product.py b/app/api/endpoints/merchant_product.py new file mode 100644 index 0000000..3af4927 --- /dev/null +++ b/app/api/endpoints/merchant_product.py @@ -0,0 +1,115 @@ +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from typing import List +from app.models.merchant_product import ( + MerchantProductDB, + MerchantProductCreate, + MerchantProductUpdate, + MerchantProductInfo +) +from app.models.database import get_db +from app.api.deps import get_admin_user +from app.models.user import UserDB +from app.core.response import success_response, error_response, ResponseModel + +router = APIRouter() + +@router.post("", response_model=ResponseModel) +async def create_product( + product: MerchantProductCreate, + db: Session = Depends(get_db), + admin: UserDB = Depends(get_admin_user) +): + """创建商家产品(管理员)""" + db_product = MerchantProductDB(**product.model_dump()) + db.add(db_product) + + try: + db.commit() + db.refresh(db_product) + return success_response(data=MerchantProductInfo.model_validate(db_product)) + except Exception as e: + db.rollback() + return error_response(code=500, message=f"创建失败: {str(e)}") + +@router.put("/{product_id}", response_model=ResponseModel) +async def update_product( + product_id: int, + product: MerchantProductUpdate, + db: Session = Depends(get_db), + admin: UserDB = Depends(get_admin_user) +): + """更新商家产品(管理员)""" + db_product = db.query(MerchantProductDB).filter( + MerchantProductDB.id == product_id + ).first() + + if not db_product: + return error_response(code=404, message="产品不存在") + + update_data = product.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(db_product, key, value) + + try: + db.commit() + db.refresh(db_product) + return success_response(data=MerchantProductInfo.model_validate(db_product)) + except Exception as e: + db.rollback() + return error_response(code=500, message=f"更新失败: {str(e)}") + +@router.get("/merchant/{merchant_id}", response_model=ResponseModel) +async def list_merchant_products( + merchant_id: int, + skip: int = 0, + limit: int = 20, + db: Session = Depends(get_db) +): + """获取商家的产品列表""" + products = db.query(MerchantProductDB).filter( + MerchantProductDB.merchant_id == merchant_id + ).order_by( + MerchantProductDB.create_time.desc() + ).offset(skip).limit(limit).all() + + return success_response(data=[ + MerchantProductInfo.model_validate(p) for p in products + ]) + +@router.get("/{product_id}", response_model=ResponseModel) +async def get_product( + product_id: int, + db: Session = Depends(get_db) +): + """获取产品详情""" + product = db.query(MerchantProductDB).filter( + MerchantProductDB.id == product_id + ).first() + + if not product: + return error_response(code=404, message="产品不存在") + + return success_response(data=MerchantProductInfo.model_validate(product)) + +@router.delete("/{product_id}", response_model=ResponseModel) +async def delete_product( + product_id: int, + db: Session = Depends(get_db), + admin: UserDB = Depends(get_admin_user) +): + """删除商家产品(管理员)""" + db_product = db.query(MerchantProductDB).filter( + MerchantProductDB.id == product_id + ).first() + + if not db_product: + return error_response(code=404, message="产品不存在") + + try: + db.delete(db_product) + db.commit() + return success_response(message="删除成功") + except Exception as e: + db.rollback() + return error_response(code=500, message=f"删除失败: {str(e)}") \ No newline at end of file diff --git a/app/main.py b/app/main.py index 753e4dc..be78b48 100644 --- a/app/main.py +++ b/app/main.py @@ -1,6 +1,6 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from app.api.endpoints import user, address, community, station, order, coupon, community_building, upload +from app.api.endpoints import user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product from app.models.database import Base, engine from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse @@ -32,8 +32,10 @@ app.include_router(community.router, prefix="/api/community", tags=["社区"]) app.include_router(community_building.router, prefix="/api/community/building", tags=["社区楼栋"]) app.include_router(station.router, prefix="/api/station", tags=["驿站"]) app.include_router(order.router, prefix="/api/order", tags=["订单"]) -app.include_router(coupon.router, prefix="/api/coupon", tags=["优惠券"]) +app.include_router(coupon.router, prefix="/api/coupon", tags=["跑腿券"]) app.include_router(upload.router, prefix="/api/upload", tags=["文件上传"]) +app.include_router(merchant.router, prefix="/api/merchant", tags=["商家"]) +app.include_router(merchant_product.router, prefix="/api/merchant/product", tags=["商家产品"]) @app.get("/") async def root(): diff --git a/app/models/community.py b/app/models/community.py index e97dc5e..e286336 100644 --- a/app/models/community.py +++ b/app/models/community.py @@ -21,13 +21,13 @@ class CommunityCreate(BaseModel): name: str = Field(..., max_length=100) address: str = Field(..., max_length=200) longitude: float = Field(..., ge=-180, le=180) - latitude: float = Field(..., ge=-180, le=180) + latitude: float = Field(..., ge=-90, le=90) class CommunityUpdate(BaseModel): name: Optional[str] = Field(None, max_length=100) address: Optional[str] = Field(None, max_length=200) longitude: Optional[float] = Field(None, ge=-180, le=180) - latitude: Optional[float] = Field(None, ge=-180, le=180) + latitude: Optional[float] = Field(None, ge=-90, le=90) class CommunityInfo(BaseModel): id: int diff --git a/app/models/merchant.py b/app/models/merchant.py new file mode 100644 index 0000000..eb29341 --- /dev/null +++ b/app/models/merchant.py @@ -0,0 +1,82 @@ +from sqlalchemy import Column, String, Integer, Float, DateTime, JSON, ForeignKey +from sqlalchemy.dialects.mysql import DECIMAL +from sqlalchemy.sql import func, select +from sqlalchemy.orm import relationship +from pydantic import BaseModel, Field +from typing import Optional, List +from datetime import datetime +from .database import Base + +# 商家图片表 +class MerchantImageDB(Base): + __tablename__ = "merchant_images" + + id = Column(Integer, primary_key=True, autoincrement=True) + merchant_id = Column(Integer, ForeignKey("merchants.id", ondelete="CASCADE"), index=True) + image_url = Column(String(500), nullable=False) + sort = Column(Integer, nullable=False, default=0) + create_time = Column(DateTime(timezone=True), server_default=func.now()) + + class Config: + unique_together = [("merchant_id", "sort")] + +# 数据库模型 +class MerchantDB(Base): + __tablename__ = "merchants" + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(100), nullable=False) + business_hours = Column(String(100), nullable=False) # 营业时间,如 "09:00-22:00" + address = Column(String(200), nullable=False) + longitude = Column(DECIMAL(9, 6), nullable=False) # 经度,精确到小数点后6位 + latitude = Column(DECIMAL(9, 6), nullable=False) # 纬度,精确到小数点后6位 + phone = Column(String(20), nullable=False) + create_time = Column(DateTime(timezone=True), server_default=func.now()) + update_time = Column(DateTime(timezone=True), onupdate=func.now()) + + # 关联图片 + images = relationship("MerchantImageDB", + order_by="MerchantImageDB.sort", + cascade="all, delete-orphan") + +# Pydantic 模型 +class MerchantImage(BaseModel): + image_url: str = Field(..., max_length=500) + sort: int = Field(..., ge=0) # 排序序号,从0开始 + + class Config: + from_attributes = True + +class MerchantCreate(BaseModel): + name: str = Field(..., max_length=100) + business_hours: str = Field(..., max_length=100) + address: str = Field(..., max_length=200) + longitude: float = Field(..., ge=-180, le=180, description="经度") + latitude: float = Field(..., ge=-180, le=180, description="纬度") + phone: str = Field(..., max_length=20, pattern=r'^\d+$') + images: List[MerchantImage] = [] + +class MerchantUpdate(BaseModel): + name: Optional[str] = Field(None, max_length=100) + business_hours: Optional[str] = Field(None, max_length=100) + address: Optional[str] = Field(None, max_length=200) + longitude: Optional[float] = Field(None, ge=-180, le=180, description="经度") + latitude: Optional[float] = Field(None, ge=-180, le=180, description="纬度") + phone: Optional[str] = Field(None, max_length=20, pattern=r'^\d+$') + images: Optional[List[MerchantImage]] = None + +class MerchantInfo(BaseModel): + id: int + name: str + business_hours: str + address: str + longitude: float + latitude: float + phone: str + images: List[MerchantImage] + create_time: datetime + update_time: Optional[datetime] + distance: Optional[int] = None # 距离(米) + + class Config: + from_attributes = True \ No newline at end of file diff --git a/app/models/merchant_product.py b/app/models/merchant_product.py new file mode 100644 index 0000000..d73cd41 --- /dev/null +++ b/app/models/merchant_product.py @@ -0,0 +1,58 @@ +from sqlalchemy import Column, String, Integer, Float, DateTime, ForeignKey +from sqlalchemy.dialects.mysql import DECIMAL +from sqlalchemy.sql import func +from pydantic import BaseModel, Field +from typing import Optional, List +from datetime import datetime +from .database import Base + +class MerchantProductDB(Base): + __tablename__ = "merchant_products" + + id = Column(Integer, primary_key=True, autoincrement=True) + merchant_id = Column(Integer, ForeignKey("merchants.id", ondelete="CASCADE"), index=True) + name = Column(String(100), nullable=False) + image_url = Column(String(500), nullable=False) + product_price = Column(Float, nullable=False) # 原价 + sale_price = Column(Float, nullable=False) # 售价 + settlement_amount = Column(DECIMAL(10,2), nullable=False) # 商家结算金额 + tags = Column(String(200)) # 标签,逗号分隔 + max_deduct_points = Column(DECIMAL(10,2), default=0) # 最高可抵扣积分 + create_time = Column(DateTime(timezone=True), server_default=func.now()) + update_time = Column(DateTime(timezone=True), onupdate=func.now()) + +# Pydantic 模型 +class MerchantProductCreate(BaseModel): + merchant_id: int + name: str = Field(..., max_length=100) + image_url: str = Field(..., max_length=500) + product_price: float = Field(..., gt=0) + sale_price: float = Field(..., gt=0) + settlement_amount: float = Field(..., gt=0) + tags: str = Field("", max_length=200) + max_deduct_points: float = Field(0.0, ge=0) + +class MerchantProductUpdate(BaseModel): + name: Optional[str] = Field(None, max_length=100) + image_url: Optional[str] = Field(None, max_length=500) + product_price: Optional[float] = Field(None, gt=0) + sale_price: Optional[float] = Field(None, gt=0) + settlement_amount: Optional[float] = Field(None, gt=0) + tags: Optional[str] = Field(None, max_length=200) + max_deduct_points: Optional[float] = Field(None, ge=0) + +class MerchantProductInfo(BaseModel): + id: int + merchant_id: int + name: str + image_url: str + product_price: float + sale_price: float + settlement_amount: float + tags: str + max_deduct_points: float + create_time: datetime + update_time: Optional[datetime] + + class Config: + from_attributes = True \ No newline at end of file