281 lines
10 KiB
Python
281 lines
10 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
from typing import Dict, Any, List, Optional
|
||
from datetime import datetime
|
||
from fastapi import HTTPException, status
|
||
from sqlalchemy import Column, Integer, String, DateTime, Index
|
||
from sqlalchemy.orm import Session
|
||
from cryptoai.models.base import Base, logger
|
||
|
||
# 定义用户订阅数据模型
|
||
class UserSubscription(Base):
|
||
"""用户订阅数据表模型"""
|
||
__tablename__ = 'user_subscriptions'
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
user_id = Column(Integer, nullable=False, comment='用户ID')
|
||
member_type = Column(Integer, nullable=False, comment='会员类型(1=VIP,2=SVIP)')
|
||
time_type = Column(Integer, nullable=False, comment='时间类型(1=包月,2=包年)')
|
||
sub_order_id = Column(String(100), nullable=False, comment='订阅订单ID')
|
||
expire_time = Column(DateTime, nullable=False, comment='过期时间')
|
||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||
|
||
# 索引和表属性
|
||
__table_args__ = (
|
||
Index('idx_user_id', 'user_id'),
|
||
Index('idx_member_type', 'member_type'),
|
||
Index('idx_time_type', 'time_type'),
|
||
Index('idx_sub_order_id', 'sub_order_id'),
|
||
Index('idx_expire_time', 'expire_time'),
|
||
Index('idx_create_time', 'create_time'),
|
||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||
)
|
||
|
||
class UserSubscriptionManager:
|
||
"""用户订阅管理类"""
|
||
|
||
def __init__(self, session: Session = None):
|
||
self.session = session
|
||
|
||
def create_subscription(self, user_id: int, member_type: int, time_type: int,
|
||
sub_order_id: str, expire_time: datetime) -> bool:
|
||
"""
|
||
创建用户订阅
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
member_type: 会员类型(1=VIP,2=SVIP)
|
||
time_type: 时间类型(1=包月,2=包年)
|
||
sub_order_id: 订阅订单ID
|
||
expire_time: 过期时间
|
||
|
||
Returns:
|
||
创建是否成功
|
||
"""
|
||
try:
|
||
# 验证参数
|
||
if member_type not in [1, 2]:
|
||
logger.warning(f"无效的会员类型: {member_type}")
|
||
return False
|
||
|
||
if time_type not in [1, 2]:
|
||
logger.warning(f"无效的时间类型: {time_type}")
|
||
return False
|
||
|
||
# 检查订单ID是否已存在
|
||
existing_subscription = self.session.query(UserSubscription).filter(
|
||
UserSubscription.sub_order_id == sub_order_id
|
||
).first()
|
||
if existing_subscription:
|
||
logger.warning(f"订阅订单ID {sub_order_id} 已存在")
|
||
return False
|
||
|
||
# 创建新订阅
|
||
new_subscription = UserSubscription(
|
||
user_id=user_id,
|
||
member_type=member_type,
|
||
time_type=time_type,
|
||
sub_order_id=sub_order_id,
|
||
expire_time=expire_time,
|
||
create_time=datetime.now()
|
||
)
|
||
|
||
# 添加并提交
|
||
self.session.add(new_subscription)
|
||
self.session.commit()
|
||
|
||
logger.info(f"成功创建用户订阅: 用户 {user_id},订单 {sub_order_id}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.session.rollback()
|
||
logger.error(f"创建用户订阅失败: {e}")
|
||
return False
|
||
|
||
def get_subscription_by_user_id(self, user_id: int) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
通过用户ID获取用户当前有效订阅
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
订阅信息,如果没有有效订阅则返回None
|
||
"""
|
||
try:
|
||
# 查询用户当前有效的订阅(未过期)
|
||
subscription = self.session.query(UserSubscription).filter(
|
||
UserSubscription.user_id == user_id,
|
||
UserSubscription.expire_time > datetime.now()
|
||
).order_by(UserSubscription.expire_time.desc()).first()
|
||
|
||
if subscription:
|
||
# 转换为字典
|
||
return {
|
||
'id': subscription.id,
|
||
'user_id': subscription.user_id,
|
||
'member_type': subscription.member_type,
|
||
'time_type': subscription.time_type,
|
||
'sub_order_id': subscription.sub_order_id,
|
||
'expire_time': subscription.expire_time,
|
||
'create_time': subscription.create_time
|
||
}
|
||
else:
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取用户订阅失败: {e}")
|
||
return None
|
||
|
||
def get_all_subscriptions_by_user_id(self, user_id: int) -> List[Dict[str, Any]]:
|
||
"""
|
||
通过用户ID获取用户的所有订阅记录
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
订阅列表
|
||
"""
|
||
try:
|
||
# 查询用户的所有订阅记录
|
||
subscriptions = self.session.query(UserSubscription).filter(
|
||
UserSubscription.user_id == user_id
|
||
).order_by(UserSubscription.create_time.desc()).all()
|
||
|
||
# 转换为字典列表
|
||
subscription_list = []
|
||
for subscription in subscriptions:
|
||
subscription_list.append({
|
||
'id': subscription.id,
|
||
'user_id': subscription.user_id,
|
||
'member_type': subscription.member_type,
|
||
'time_type': subscription.time_type,
|
||
'sub_order_id': subscription.sub_order_id,
|
||
'expire_time': subscription.expire_time,
|
||
'create_time': subscription.create_time
|
||
})
|
||
|
||
return subscription_list
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取用户订阅列表失败: {e}")
|
||
return []
|
||
|
||
def get_subscription_by_order_id(self, sub_order_id: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
通过订单ID获取订阅信息
|
||
|
||
Args:
|
||
sub_order_id: 订阅订单ID
|
||
|
||
Returns:
|
||
订阅信息,如果订阅不存在则返回None
|
||
"""
|
||
try:
|
||
# 查询订阅
|
||
subscription = self.session.query(UserSubscription).filter(
|
||
UserSubscription.sub_order_id == sub_order_id
|
||
).first()
|
||
|
||
if subscription:
|
||
# 转换为字典
|
||
return {
|
||
'id': subscription.id,
|
||
'user_id': subscription.user_id,
|
||
'member_type': subscription.member_type,
|
||
'time_type': subscription.time_type,
|
||
'sub_order_id': subscription.sub_order_id,
|
||
'expire_time': subscription.expire_time,
|
||
'create_time': subscription.create_time
|
||
}
|
||
else:
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取订阅信息失败: {e}")
|
||
return None
|
||
|
||
def is_user_subscribed(self, user_id: int) -> bool:
|
||
"""
|
||
检查用户是否有有效订阅
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
是否有有效订阅
|
||
"""
|
||
try:
|
||
# 查询用户是否有未过期的订阅
|
||
subscription = self.session.query(UserSubscription).filter(
|
||
UserSubscription.user_id == user_id,
|
||
UserSubscription.expire_time > datetime.now()
|
||
).first()
|
||
|
||
return subscription is not None
|
||
|
||
except Exception as e:
|
||
logger.error(f"检查用户订阅状态失败: {e}")
|
||
return False
|
||
|
||
def get_expired_subscriptions(self) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取已过期的订阅列表
|
||
|
||
Returns:
|
||
已过期的订阅列表
|
||
"""
|
||
try:
|
||
# 查询已过期的订阅
|
||
expired_subscriptions = self.session.query(UserSubscription).filter(
|
||
UserSubscription.expire_time <= datetime.now()
|
||
).order_by(UserSubscription.expire_time.desc()).all()
|
||
|
||
# 转换为字典列表
|
||
subscription_list = []
|
||
for subscription in expired_subscriptions:
|
||
subscription_list.append({
|
||
'id': subscription.id,
|
||
'user_id': subscription.user_id,
|
||
'member_type': subscription.member_type,
|
||
'time_type': subscription.time_type,
|
||
'sub_order_id': subscription.sub_order_id,
|
||
'expire_time': subscription.expire_time,
|
||
'create_time': subscription.create_time
|
||
})
|
||
|
||
return subscription_list
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取过期订阅列表失败: {e}")
|
||
return []
|
||
|
||
def get_subscription_count(self) -> int:
|
||
"""
|
||
获取订阅总数
|
||
"""
|
||
try:
|
||
# 查询订阅数量
|
||
subscription_count = self.session.query(UserSubscription).count()
|
||
|
||
return subscription_count
|
||
except Exception as e:
|
||
logger.error(f"获取订阅数量失败: {e}")
|
||
return 0
|
||
|
||
def get_active_subscription_count(self) -> int:
|
||
"""
|
||
获取有效订阅数量
|
||
"""
|
||
try:
|
||
# 查询有效订阅数量
|
||
active_count = self.session.query(UserSubscription).filter(
|
||
UserSubscription.expire_time > datetime.now()
|
||
).count()
|
||
|
||
return active_count
|
||
except Exception as e:
|
||
logger.error(f"获取有效订阅数量失败: {e}")
|
||
return 0 |