新增商家订单等相关接口

This commit is contained in:
aaron 2025-01-06 11:56:11 +08:00
parent 79d898ea97
commit 7712456ac2
8 changed files with 451 additions and 7 deletions

View File

@ -0,0 +1,199 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import List
from app.models.merchant_order import (
MerchantOrderDB,
MerchantOrderCreate,
MerchantOrderInfo,
generate_order_id,
generate_verify_code,
OrderStatus
)
from app.models.merchant_product import MerchantProductDB
from app.models.database import get_db
from app.api.deps import get_current_user
from app.models.user import UserDB
from app.core.response import success_response, error_response, ResponseModel
from datetime import datetime, timezone
router = APIRouter()
@router.post("", response_model=ResponseModel)
async def create_order(
order: MerchantOrderCreate,
db: Session = Depends(get_db),
current_user: UserDB = Depends(get_current_user)
):
"""创建商家订单"""
# 检查商品是否存在
product = db.query(MerchantProductDB).filter(
MerchantProductDB.id == order.merchant_product_id
).first()
if not product:
return error_response(code=404, message="商品不存在")
# 生成订单号和核销码
while True:
order_id = generate_order_id()
verify_code = generate_verify_code()
# 检查是否已存在
exists = db.query(MerchantOrderDB).filter(
(MerchantOrderDB.order_id == order_id) |
(MerchantOrderDB.order_verify_code == verify_code)
).first()
if not exists:
break
# 创建订单
db_order = MerchantOrderDB(
order_id=order_id,
user_id=current_user.userid,
merchant_product_id=order.merchant_product_id,
order_amount=order.order_amount,
status=OrderStatus.CREATED, # 创建时状态为已下单
order_verify_code=verify_code
)
db.add(db_order)
try:
db.commit()
db.refresh(db_order)
return success_response(data=MerchantOrderInfo.model_validate(db_order))
except Exception as e:
db.rollback()
return error_response(code=500, message=f"创建订单失败: {str(e)}")
@router.get("/user", response_model=ResponseModel)
async def get_user_orders(
skip: int = 0,
limit: int = 20,
db: Session = Depends(get_db),
current_user: UserDB = Depends(get_current_user)
):
"""获取用户的订单列表"""
orders = db.query(MerchantOrderDB).filter(
MerchantOrderDB.user_id == current_user.userid
).order_by(
MerchantOrderDB.create_time.desc()
).offset(skip).limit(limit).all()
return success_response(data=[
MerchantOrderInfo.model_validate(o) for o in orders
])
@router.post("/{order_id}/verify", response_model=ResponseModel)
async def verify_order(
order_id: str,
verify_code: str,
db: Session = Depends(get_db),
current_user: UserDB = Depends(get_current_user)
):
"""核销订单"""
order = db.query(MerchantOrderDB).filter(
MerchantOrderDB.order_id == order_id,
MerchantOrderDB.order_verify_code == verify_code,
MerchantOrderDB.verify_time.is_(None) # 未核销
).first()
if not order:
return error_response(code=404, message="订单不存在或已核销")
# 更新核销时间和核销用户
order.verify_time = datetime.now(timezone.utc)
order.verify_user_id = current_user.userid
order.status = OrderStatus.VERIFIED # 更新为已核销状态
try:
db.commit()
return success_response(
message="核销成功",
data=MerchantOrderInfo.model_validate(order)
)
except Exception as e:
db.rollback()
return error_response(code=500, message=f"核销失败: {str(e)}")
@router.post("/{order_id}/unverify", response_model=ResponseModel)
async def set_order_unverified(
order_id: str,
db: Session = Depends(get_db),
current_user: UserDB = Depends(get_current_user)
):
"""设置订单为未核销状态"""
order = db.query(MerchantOrderDB).filter(
MerchantOrderDB.order_id == order_id,
MerchantOrderDB.status == OrderStatus.CREATED # 只有已下单状态可以设为未核销
).first()
if not order:
return error_response(code=404, message="订单不存在或状态不正确")
# 更新状态为未核销
order.status = OrderStatus.UNVERIFIED
try:
db.commit()
return success_response(
message="状态更新成功",
data=MerchantOrderInfo.model_validate(order)
)
except Exception as e:
db.rollback()
return error_response(code=500, message=f"状态更新失败: {str(e)}")
@router.post("/{order_id}/refund/apply", response_model=ResponseModel)
async def apply_refund(
order_id: str,
db: Session = Depends(get_db),
current_user: UserDB = Depends(get_current_user)
):
"""申请退款"""
order = db.query(MerchantOrderDB).filter(
MerchantOrderDB.order_id == order_id,
MerchantOrderDB.user_id == current_user.userid, # 只能申请自己的订单
MerchantOrderDB.status.in_([OrderStatus.CREATED, OrderStatus.UNVERIFIED]) # 只有未核销的订单可以退款
).first()
if not order:
return error_response(code=404, message="订单不存在或状态不允许退款")
# 更新状态为退款中
order.status = OrderStatus.REFUNDING
try:
db.commit()
return success_response(
message="退款申请成功",
data=MerchantOrderInfo.model_validate(order)
)
except Exception as e:
db.rollback()
return error_response(code=500, message=f"申请退款失败: {str(e)}")
@router.post("/{order_id}/refund/confirm", response_model=ResponseModel)
async def confirm_refund(
order_id: str,
db: Session = Depends(get_db),
current_user: UserDB = Depends(get_current_user)
):
"""确认退款(管理员)"""
order = db.query(MerchantOrderDB).filter(
MerchantOrderDB.order_id == order_id,
MerchantOrderDB.status == OrderStatus.REFUNDING # 只能确认退款中的订单
).first()
if not order:
return error_response(code=404, message="订单不存在或状态不正确")
# 更新状态为已退款
order.status = OrderStatus.REFUNDED
try:
db.commit()
return success_response(
message="退款确认成功",
data=MerchantOrderInfo.model_validate(order)
)
except Exception as e:
db.rollback()
return error_response(code=500, message=f"确认退款失败: {str(e)}")

View File

@ -0,0 +1,98 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from typing import List
from app.models.merchant_product import (
MerchantProductCategoryDB,
ProductCategoryCreate,
ProductCategoryUpdate,
ProductCategoryInfo
)
from app.models.database import get_db
from app.api.deps import get_admin_user
from app.models.user import UserDB
from app.core.response import success_response, error_response, ResponseModel
router = APIRouter()
@router.post("", response_model=ResponseModel)
async def create_category(
category: ProductCategoryCreate,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""创建商品分类"""
db_category = MerchantProductCategoryDB(**category.model_dump())
db.add(db_category)
try:
db.commit()
db.refresh(db_category)
return success_response(data=ProductCategoryInfo.model_validate(db_category))
except Exception as e:
db.rollback()
return error_response(code=500, message=f"创建失败: {str(e)}")
@router.get("/merchant/{merchant_id}", response_model=ResponseModel)
async def list_categories(
merchant_id: int,
db: Session = Depends(get_db)
):
"""获取商家的所有分类"""
categories = db.query(MerchantProductCategoryDB).filter(
MerchantProductCategoryDB.merchant_id == merchant_id
).order_by(
MerchantProductCategoryDB.sort
).all()
return success_response(data=[
ProductCategoryInfo.model_validate(c) for c in categories
])
@router.put("/{category_id}", response_model=ResponseModel)
async def update_category(
category_id: int,
category: ProductCategoryUpdate,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""更新分类"""
db_category = db.query(MerchantProductCategoryDB).filter(
MerchantProductCategoryDB.id == category_id
).first()
if not db_category:
return error_response(code=404, message="分类不存在")
update_data = category.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(db_category, key, value)
try:
db.commit()
db.refresh(db_category)
return success_response(data=ProductCategoryInfo.model_validate(db_category))
except Exception as e:
db.rollback()
return error_response(code=500, message=f"更新失败: {str(e)}")
@router.delete("/{category_id}", response_model=ResponseModel)
async def delete_category(
category_id: int,
db: Session = Depends(get_db),
admin: UserDB = Depends(get_admin_user)
):
"""删除分类"""
db_category = db.query(MerchantProductCategoryDB).filter(
MerchantProductCategoryDB.id == category_id
).first()
if not db_category:
return error_response(code=404, message="分类不存在")
try:
db.delete(db_category)
db.commit()
return success_response(message="删除成功")
except Exception as e:
db.rollback()
return error_response(code=500, message=f"删除失败: {str(e)}")

View File

@ -1,6 +1,6 @@
from fastapi import APIRouter, HTTPException, Depends, Response
from fastapi import APIRouter, HTTPException, Depends, Response, Body
from sqlalchemy.orm import Session
from app.models.user import UserLogin, UserInfo, VerifyCodeRequest, UserDB, UserUpdate, UserRole
from app.models.user import UserLogin, UserInfo, VerifyCodeRequest, UserDB, UserUpdate, UserRole, UserPasswordLogin
from app.api.deps import get_current_user, get_admin_user
from app.models.database import get_db
import random
@ -10,7 +10,7 @@ from app.core.config import settings
from unisdk.sms import UniSMS
from unisdk.exception import UniException
from datetime import timedelta
from app.core.security import create_access_token, set_jwt_cookie, clear_jwt_cookie
from app.core.security import create_access_token, set_jwt_cookie, clear_jwt_cookie, get_password_hash, verify_password
from app.core.response import success_response, error_response, ResponseModel
from pydantic import BaseModel, Field
from typing import List
@ -227,4 +227,31 @@ async def update_user_roles(
)
except Exception as e:
db.rollback()
return error_response(code=500, message=f"更新失败: {str(e)}")
return error_response(code=500, message=f"更新失败: {str(e)}")
@router.post("/password_login", response_model=ResponseModel)
async def password_login(
login_data: UserPasswordLogin,
db: Session = Depends(get_db)
):
"""密码登录"""
user = db.query(UserDB).filter(UserDB.phone == login_data.phone).first()
if not user:
return error_response(code=401, message="用户不存在")
if not user.password:
return error_response(code=401, message="请先设置密码")
if not verify_password(login_data.password, user.password):
return error_response(code=401, message="密码错误")
# 生成访问令牌
access_token = create_access_token(user.phone)
return success_response(
data={
"access_token": f"Bearer {access_token}",
"user": UserInfo.model_validate(user)
}
)

View File

@ -3,6 +3,10 @@ from typing import Optional
from jose import JWTError, jwt
from app.core.config import settings
from fastapi import Response
from passlib.context import CryptContext
# 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
to_encode = data.copy()
@ -42,4 +46,12 @@ def verify_token(token: str) -> Optional[str]:
return None
return phone
except JWTError:
return None
return None
def get_password_hash(password: str) -> str:
"""获取密码哈希值"""
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
return pwd_context.verify(plain_password, hashed_password)

View File

@ -1,6 +1,6 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.endpoints import user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product
from app.api.endpoints import user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_product_category, merchant_order
from app.models.database import Base, engine
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
@ -36,6 +36,8 @@ app.include_router(coupon.router, prefix="/api/coupon", tags=["跑腿券"])
app.include_router(upload.router, prefix="/api/upload", tags=["文件上传"])
app.include_router(merchant.router, prefix="/api/merchant", tags=["商家"])
app.include_router(merchant_product.router, prefix="/api/merchant/product", tags=["商家产品"])
app.include_router(merchant_product_category.router, prefix="/api/merchant/category", tags=["商品分类"])
app.include_router(merchant_order.router, prefix="/api/merchant/order", tags=["商家订单"])
@app.get("/")
async def root():

View File

@ -0,0 +1,64 @@
from sqlalchemy import Column, String, Integer, DateTime, ForeignKey
from sqlalchemy.dialects.mysql import DECIMAL
from sqlalchemy.sql import func
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
from .database import Base
import random
import time
import enum
class OrderStatus(str, enum.Enum):
CREATED = "created" # 已下单
UNVERIFIED = "unverified" # 未核销
VERIFIED = "verified" # 已核销
REFUNDING = "refunding" # 退款中
REFUNDED = "refunded" # 已退款
class MerchantOrderDB(Base):
__tablename__ = "merchant_orders"
id = Column(Integer, primary_key=True, autoincrement=True)
order_id = Column(String(15), unique=True, nullable=False)
user_id = Column(Integer, ForeignKey("users.userid"), nullable=False)
merchant_product_id = Column(Integer, ForeignKey("merchant_products.id"), nullable=False)
order_amount = Column(DECIMAL(10,2), nullable=False)
status = Column(Enum(OrderStatus), nullable=False, default=OrderStatus.CREATED)
order_verify_code = Column(String(21), unique=True, nullable=False)
verify_time = Column(DateTime(timezone=True), nullable=True)
verify_user_id = Column(Integer, ForeignKey("users.userid"), nullable=True)
create_time = Column(DateTime(timezone=True), server_default=func.now())
update_time = Column(DateTime(timezone=True), onupdate=func.now())
class MerchantOrderCreate(BaseModel):
merchant_product_id: int
order_amount: float = Field(..., gt=0)
class MerchantOrderInfo(BaseModel):
id: int
order_id: str
user_id: int
merchant_product_id: int
order_amount: float
status: OrderStatus
order_verify_code: str
verify_time: Optional[datetime]
verify_user_id: Optional[int]
create_time: datetime
update_time: Optional[datetime]
class Config:
from_attributes = True
def generate_order_id() -> str:
"""生成订单号8位日期 + 7位时间戳"""
now = datetime.now()
date_str = now.strftime('%Y%m%d')
# 取时间戳后7位
timestamp = str(int(time.time() * 1000))[-7:]
return f"{date_str}{timestamp}"
def generate_verify_code() -> str:
"""生成21位数字核销码"""
return ''.join(random.choices('0123456789', k=21))

View File

@ -6,12 +6,24 @@ from typing import Optional, List
from datetime import datetime
from .database import Base
# 商品分类表
class MerchantProductCategoryDB(Base):
__tablename__ = "merchant_product_categories"
id = Column(Integer, primary_key=True, autoincrement=True)
merchant_id = Column(Integer, ForeignKey("merchants.id", ondelete="CASCADE"), index=True)
name = Column(String(50), nullable=False)
sort = Column(Integer, nullable=False, default=0)
create_time = Column(DateTime(timezone=True), server_default=func.now())
update_time = Column(DateTime(timezone=True), onupdate=func.now())
class MerchantProductDB(Base):
__tablename__ = "merchant_products"
id = Column(Integer, primary_key=True, autoincrement=True)
merchant_id = Column(Integer, ForeignKey("merchants.id", ondelete="CASCADE"), index=True)
name = Column(String(100), nullable=False)
category_id = Column(Integer, ForeignKey("merchant_product_categories.id"), nullable=False)
image_url = Column(String(500), nullable=False)
product_price = Column(Float, nullable=False) # 原价
sale_price = Column(Float, nullable=False) # 售价
@ -21,10 +33,32 @@ class MerchantProductDB(Base):
create_time = Column(DateTime(timezone=True), server_default=func.now())
update_time = Column(DateTime(timezone=True), onupdate=func.now())
# Pydantic 模型 - 分类
class ProductCategoryCreate(BaseModel):
merchant_id: int
name: str = Field(..., max_length=50)
sort: int = Field(0, ge=0)
class ProductCategoryUpdate(BaseModel):
name: Optional[str] = Field(None, max_length=50)
sort: Optional[int] = Field(None, ge=0)
class ProductCategoryInfo(BaseModel):
id: int
merchant_id: int
name: str
sort: int
create_time: datetime
update_time: Optional[datetime]
class Config:
from_attributes = True
# Pydantic 模型
class MerchantProductCreate(BaseModel):
merchant_id: int
name: str = Field(..., max_length=100)
category_id: int
image_url: str = Field(..., max_length=500)
product_price: float = Field(..., gt=0)
sale_price: float = Field(..., gt=0)
@ -34,6 +68,7 @@ class MerchantProductCreate(BaseModel):
class MerchantProductUpdate(BaseModel):
name: Optional[str] = Field(None, max_length=100)
category_id: Optional[int] = None
image_url: Optional[str] = Field(None, max_length=500)
product_price: Optional[float] = Field(None, gt=0)
sale_price: Optional[float] = Field(None, gt=0)
@ -45,6 +80,8 @@ class MerchantProductInfo(BaseModel):
id: int
merchant_id: int
name: str
category_id: int
category_name: str # 通过关联查询获取
image_url: str
product_price: float
sale_price: float

View File

@ -18,6 +18,7 @@ class UserDB(Base):
userid = Column(Integer, primary_key=True,autoincrement=True, index=True)
username = Column(String(50))
phone = Column(String(11), unique=True, index=True)
password = Column(String(128), nullable=True) # 加密后的密码
avatar = Column(String(200), nullable=True) # 头像URL地址
roles = Column(JSON, default=lambda: [UserRole.USER]) # 存储角色列表
create_time = Column(DateTime(timezone=True), server_default=func.now())
@ -46,4 +47,8 @@ class UserUpdate(BaseModel):
avatar: Optional[str] = Field(None, max_length=200)
class Config:
extra = "forbid" # 禁止额外字段
extra = "forbid" # 禁止额外字段
class UserPasswordLogin(BaseModel):
phone: str = Field(..., pattern="^1[3-9]\d{9}$")
password: str = Field(..., min_length=6, max_length=20)