diff --git a/app/api/endpoints/subscribe.py b/app/api/endpoints/subscribe.py new file mode 100644 index 0000000..eadb977 --- /dev/null +++ b/app/api/endpoints/subscribe.py @@ -0,0 +1,69 @@ +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session +from app.models.database import get_db +from app.api.deps import get_current_user +from app.models.user import UserDB +from app.core.response import success_response, error_response, ResponseModel +from app.models.subscribe import SubscribeDB, SubscribeCreate, SubscribeInfo +from sqlalchemy import and_, func + +router = APIRouter() + +@router.post("", response_model=ResponseModel) +async def subscribe_template( + subscribe: SubscribeCreate, + db: Session = Depends(get_db), + current_user: UserDB = Depends(get_current_user) +): + """订阅消息模板""" + results = [] + + try: + for template_info in subscribe.template_infos: + # 检查是否已存在订阅记录 + exists = db.query(SubscribeDB).filter( + and_( + SubscribeDB.user_id == current_user.userid, + SubscribeDB.template_id == template_info.template_id + ) + ).first() + + if exists: + # 更新动作 + exists.action = template_info.action + exists.update_time = func.now() + results.append(exists) + else: + # 创建新的订阅记录 + db_subscribe = SubscribeDB( + user_id=current_user.userid, + template_id=template_info.template_id, + action=template_info.action + ) + db.add(db_subscribe) + results.append(db_subscribe) + + db.commit() + for r in results: + db.refresh(r) + + return success_response(data=[ + SubscribeInfo.model_validate(r) for r in results + ]) + except Exception as e: + db.rollback() + return error_response(code=500, message=f"处理订阅失败: {str(e)}") + +@router.get("", response_model=ResponseModel) +async def get_subscribes( + db: Session = Depends(get_db), + current_user: UserDB = Depends(get_current_user) +): + """获取用户的订阅列表""" + subscribes = db.query(SubscribeDB).filter( + SubscribeDB.user_id == current_user.userid + ).all() + + return success_response(data=[ + SubscribeInfo.model_validate(s) for s in subscribes + ]) \ No newline at end of file diff --git a/app/main.py b/app/main.py index 81527f5..c09505f 100644 --- a/app/main.py +++ b/app/main.py @@ -1,6 +1,6 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from app.api.endpoints import wechat,user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_order, point, config, merchant_category, log, account,merchant_pay_order, message, bank_card, withdraw +from app.api.endpoints import wechat,user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_order, point, config, merchant_category, log, account,merchant_pay_order, message, bank_card, withdraw, subscribe from app.models.database import Base, engine from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse @@ -30,7 +30,8 @@ app.add_middleware( app.add_middleware(RequestLoggerMiddleware) # 添加用户路由 -app.include_router(wechat.router,prefix="/api/wechat",tags=["微信"]) +app.include_router(wechat.router,prefix="/api/wechat",tags=["微信"]) +app.include_router(subscribe.router, prefix="/api/subscribe", tags=["小程序订阅消息"]) app.include_router(user.router, prefix="/api/user", tags=["用户"]) app.include_router(bank_card.router, prefix="/api/bank-cards", tags=["用户银行卡"]) app.include_router(withdraw.router, prefix="/api/withdraw", tags=["提现"]) diff --git a/app/models/subscribe.py b/app/models/subscribe.py new file mode 100644 index 0000000..b2f2cbd --- /dev/null +++ b/app/models/subscribe.py @@ -0,0 +1,38 @@ +from sqlalchemy import Column, Integer, String, DateTime, ForeignKey, Boolean +from sqlalchemy.sql import func +from pydantic import BaseModel, Field +from typing import Optional, List +from datetime import datetime +from .database import Base + +class SubscribeDB(Base): + __tablename__ = "user_subscribes" + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, ForeignKey("users.userid"), nullable=False) + template_id = Column(String(64), nullable=False) # 模板ID + action = Column(String(10), nullable=False) # accept 或 reject + create_time = Column(DateTime(timezone=True), server_default=func.now()) + update_time = Column(DateTime(timezone=True), onupdate=func.now()) + + class Config: + from_attributes = True + + +class TemplateInfo(BaseModel): + template_id: str + action: str = Field(..., pattern="^(accept|reject)$") # 只允许 accept 或 reject + +class SubscribeCreate(BaseModel): + template_infos: List[TemplateInfo] = Field(..., min_items=1) # 至少一个模板ID + +class SubscribeInfo(BaseModel): + id: int + user_id: int + template_id: str + action: str + create_time: datetime + update_time: Optional[datetime] + + class Config: + from_attributes = True \ No newline at end of file