314 lines
10 KiB
Python
314 lines
10 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
from typing import Dict, Any, List, Optional
|
||
from datetime import datetime
|
||
from fastapi import HTTPException, status
|
||
from sqlalchemy import Column, Integer, String, DateTime, Index
|
||
from sqlalchemy.orm import relationship
|
||
from sqlalchemy.orm import Session
|
||
from cryptoai.models.base import Base, logger
|
||
from cryptoai.models.user_subscription import UserSubscriptionManager
|
||
from datetime import datetime, timedelta
|
||
import uuid
|
||
|
||
# 定义用户数据模型
|
||
class User(Base):
|
||
"""用户数据表模型"""
|
||
__tablename__ = 'users'
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
mail = Column(String(100), nullable=False, unique=True, comment='邮箱')
|
||
nickname = Column(String(50), nullable=False, comment='昵称')
|
||
password = Column(String(100), nullable=False, comment='密码')
|
||
level = Column(Integer, nullable=False, default=0, comment='用户级别(0=普通用户,1=VIP,2=SVIP)')
|
||
points = Column(Integer, nullable=False, default=0, comment='用户积分')
|
||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||
|
||
# 关系
|
||
questions = relationship("UserQuestion", back_populates="user")
|
||
analysis_histories = relationship("AnalysisHistory", back_populates="user")
|
||
|
||
# 索引和表属性
|
||
__table_args__ = (
|
||
Index('idx_mail', 'mail'),
|
||
Index('idx_level', 'level'),
|
||
Index('idx_create_time', 'create_time'),
|
||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||
)
|
||
|
||
class UserManager:
|
||
"""用户管理类"""
|
||
|
||
def __init__(self, session: Session = None):
|
||
self.session = session
|
||
|
||
def register_user(self, mail: str, nickname: str, password: str, level: int = 0, points: int = 0) -> bool:
|
||
"""
|
||
注册新用户
|
||
|
||
Args:
|
||
mail: 邮箱
|
||
nickname: 昵称
|
||
password: 密码
|
||
level: 用户级别,默认为0(普通用户)
|
||
points: 初始积分,默认为0
|
||
|
||
Returns:
|
||
注册是否成功
|
||
"""
|
||
try:
|
||
# 检查邮箱是否已存在
|
||
existing_user = self.session.query(User).filter(User.mail == mail).first()
|
||
if existing_user:
|
||
logger.warning(f"邮箱 {mail} 已被注册")
|
||
return False
|
||
|
||
# 创建新用户
|
||
new_user = User(
|
||
mail=mail,
|
||
nickname=nickname,
|
||
password=password, # 实际应用中应该对密码进行哈希处理
|
||
level=level,
|
||
points=points,
|
||
create_time=datetime.now()
|
||
)
|
||
|
||
# 添加并提交
|
||
self.session.add(new_user)
|
||
self.session.commit()
|
||
self.session.refresh(new_user)
|
||
|
||
# 增加初始订阅信息
|
||
user_subscription_manager = UserSubscriptionManager(self.session)
|
||
|
||
random_order_id = str(uuid.uuid4())
|
||
expire_time = datetime.now().replace(hour=23, minute=59, second=59) + timedelta(days=7)
|
||
|
||
user_subscription_manager.create_subscription(new_user.id, 1, 1, random_order_id, expire_time)
|
||
|
||
logger.info(f"成功注册用户: {mail},初始积分: {points}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.session.rollback()
|
||
logger.error(f"注册用户失败: {e}")
|
||
return False
|
||
|
||
def get_user_count(self) -> int:
|
||
"""
|
||
获取用户数量
|
||
"""
|
||
try:
|
||
# 查询用户数量
|
||
user_count = self.session.query(User).count()
|
||
|
||
return user_count
|
||
except Exception as e:
|
||
logger.error(f"获取用户数量失败: {e}")
|
||
return 0
|
||
|
||
def get_user_by_id(self, user_id: int) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
通过ID获取用户信息
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
用户信息,如果用户不存在则返回None
|
||
"""
|
||
try:
|
||
# 查询用户
|
||
user = self.session.query(User).filter(User.id == user_id).first()
|
||
|
||
if user:
|
||
# 转换为字典
|
||
return {
|
||
'id': user.id,
|
||
'mail': user.mail,
|
||
'nickname': user.nickname,
|
||
'level': user.level,
|
||
'points': user.points,
|
||
'create_time': user.create_time
|
||
}
|
||
else:
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取用户信息失败: {e}")
|
||
return None
|
||
|
||
def get_user_by_mail_and_password(self, mail: str, password: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
登录
|
||
"""
|
||
|
||
user = self.session.query(User).filter(User.mail == mail).first()
|
||
|
||
if not user:
|
||
return None
|
||
|
||
if user.password != password:
|
||
return None
|
||
|
||
return {'id': user.id, 'mail': user.mail, 'nickname': user.nickname, 'level': user.level, 'points': user.points, 'create_time': user.create_time}
|
||
|
||
def get_user_by_mail(self, mail: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
通过邮箱获取用户信息
|
||
|
||
Args:
|
||
mail: 邮箱
|
||
|
||
Returns:
|
||
用户信息,如果用户不存在则返回None
|
||
"""
|
||
try:
|
||
# 查询用户
|
||
user = self.session.query(User).filter(User.mail == mail).first()
|
||
|
||
if user:
|
||
# 转换为字典
|
||
return {
|
||
'id': user.id,
|
||
'mail': user.mail,
|
||
'nickname': user.nickname,
|
||
'level': user.level,
|
||
'points': user.points,
|
||
'create_time': user.create_time
|
||
}
|
||
else:
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取用户信息失败: {e}")
|
||
return None
|
||
|
||
def update_password(self, mail: str, password: str) -> bool:
|
||
"""
|
||
更新用户密码
|
||
"""
|
||
try:
|
||
user = self.session.query(User).filter(User.mail == mail).first()
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="用户不存在"
|
||
)
|
||
|
||
user.password = password
|
||
self.session.commit()
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"更新用户密码失败: {e}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"更新用户密码失败: {e}"
|
||
)
|
||
|
||
def update_user_level(self, user_id: int, level: int) -> bool:
|
||
"""
|
||
更新用户级别
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
level: 新的用户级别
|
||
|
||
Returns:
|
||
更新是否成功
|
||
"""
|
||
try:
|
||
# 查询用户
|
||
user = self.session.query(User).filter(User.id == user_id).first()
|
||
|
||
if not user:
|
||
logger.warning(f"用户ID {user_id} 不存在")
|
||
return False
|
||
|
||
# 更新级别
|
||
user.level = level
|
||
self.session.commit()
|
||
|
||
logger.info(f"成功更新用户 {user.mail} 的级别为 {level}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.session.rollback()
|
||
logger.error(f"更新用户级别失败: {e}")
|
||
return False
|
||
|
||
def add_user_points(self, user_id: int, points: int) -> bool:
|
||
"""
|
||
为用户增加积分
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
points: 增加的积分数量(正数)
|
||
|
||
Returns:
|
||
操作是否成功
|
||
"""
|
||
if points <= 0:
|
||
logger.warning(f"增加的积分必须是正数: {points}")
|
||
return False
|
||
|
||
try:
|
||
# 查询用户
|
||
user = self.session.query(User).filter(User.id == user_id).first()
|
||
|
||
if not user:
|
||
logger.warning(f"用户ID {user_id} 不存在")
|
||
return False
|
||
|
||
# 增加积分
|
||
user.points += points
|
||
self.session.commit()
|
||
|
||
logger.info(f"成功为用户 {user.mail} 增加 {points} 积分,当前积分: {user.points}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.session.rollback()
|
||
logger.error(f"增加用户积分失败: {e}")
|
||
return False
|
||
|
||
def consume_user_points(self, user_id: int, points: int) -> bool:
|
||
"""
|
||
用户消费积分
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
points: 消费的积分数量(正数)
|
||
|
||
Returns:
|
||
操作是否成功
|
||
"""
|
||
if points <= 0:
|
||
logger.warning(f"消费的积分必须是正数: {points}")
|
||
return False
|
||
|
||
try:
|
||
# 查询用户
|
||
user = self.session.query(User).filter(User.id == user_id).first()
|
||
|
||
if not user:
|
||
logger.warning(f"用户ID {user_id} 不存在")
|
||
return False
|
||
|
||
# 检查积分是否足够
|
||
if user.points < points:
|
||
logger.warning(f"用户 {user.mail} 积分不足,当前积分: {user.points},需要消费: {points}")
|
||
return False
|
||
|
||
# 消费积分
|
||
user.points -= points
|
||
self.session.commit()
|
||
|
||
logger.info(f"成功从用户 {user.mail} 消费 {points} 积分,剩余积分: {user.points}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.session.rollback()
|
||
logger.error(f"消费用户积分失败: {e}")
|
||
return False |