diff --git a/cryptoai/routes/user.py b/cryptoai/routes/user.py index cd5b627..9ace636 100644 --- a/cryptoai/routes/user.py +++ b/cryptoai/routes/user.py @@ -55,6 +55,7 @@ class UserResponse(BaseModel): mail: str nickname: str level: int + points: int create_time: datetime class TokenResponse(BaseModel): @@ -190,7 +191,8 @@ async def register_user(user: UserRegister) -> Dict[str, Any]: mail=user.mail, nickname=user.nickname, password=hashed_password, - level=0 # 默认为普通用户 + level=0, # 默认为普通用户 + points=100 # 默认初始积分为100 ) if not success: @@ -270,6 +272,7 @@ async def login(loginData: UserLogin) -> TokenResponse: mail=user["mail"], nickname=user["nickname"], level=user["level"], + points=user["points"], create_time=user["create_time"] ) ) @@ -299,6 +302,7 @@ async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user) mail=current_user["mail"], nickname=current_user["nickname"], level=current_user["level"], + points=current_user["points"], create_time=current_user["create_time"] ) @@ -314,41 +318,162 @@ async def update_user_level( Args: user_id: 用户ID level: 新的用户级别 - current_user: 当前用户信息,由依赖项提供 + current_user: 当前用户信息 Returns: - 更新成功的状态信息 + 更新结果 """ - # 简单的权限检查(实际应用中应该有更完善的权限管理) - if current_user["level"] < 2: # 假设SVIP用户有管理权限 + # 检查权限(只有SVIP用户才能更新用户级别) + if current_user.get("level", 0) < 2: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="没有足够的权限执行此操作" ) - try: - # 获取数据库管理器 - db_manager = get_db_manager() - - # 更新用户级别 - success = db_manager.update_user_level(user_id, level) - - if not success: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail=f"用户ID {user_id} 不存在" - ) - - return { - "status": "success", - "message": f"成功更新用户级别为 {level}" - } - - except HTTPException: - raise - except Exception as e: - logger.error(f"更新用户级别失败: {str(e)}") + # 获取数据库管理器 + db_manager = get_db_manager() + + # 更新用户级别 + success = db_manager.update_user_level(user_id, level) + + if not success: raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"更新用户级别失败: {str(e)}" - ) \ No newline at end of file + status_code=status.HTTP_400_BAD_REQUEST, + detail="更新用户级别失败" + ) + + return { + "status": "success", + "message": f"成功更新用户 {user_id} 的级别为 {level}" + } + +@router.get("/points/{user_id}", response_model=Dict[str, Any]) +async def get_user_points( + user_id: int, + current_user: Dict[str, Any] = Depends(get_current_user) +) -> Dict[str, Any]: + """ + 获取用户积分 + + Args: + user_id: 用户ID + current_user: 当前用户信息 + + Returns: + 用户积分信息 + """ + # 只能查看自己的积分,或者SVIP用户可以查看所有人的积分 + if current_user.get("id") != user_id and current_user.get("level", 0) < 2: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="没有权限查看其他用户的积分" + ) + + # 获取数据库管理器 + db_manager = get_db_manager() + + # 获取用户信息 + user = db_manager.get_user_by_id(user_id) + + if not user: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"用户ID {user_id} 不存在" + ) + + return { + "user_id": user_id, + "points": user.get("points", 0), + "nickname": user.get("nickname", ""), + "level": user.get("level", 0) + } + +@router.post("/points/add/{user_id}", response_model=Dict[str, Any]) +async def add_user_points( + user_id: int, + points: int = Query(..., gt=0, description="增加的积分数量(必须大于0)"), + current_user: Dict[str, Any] = Depends(get_current_user) +) -> Dict[str, Any]: + """ + 为用户增加积分(需要管理员权限) + + Args: + user_id: 用户ID + points: 增加的积分数量 + current_user: 当前用户信息 + + Returns: + 操作结果 + """ + # 检查权限(只有SVIP用户才能添加积分) + if current_user.get("level", 0) < 2: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="没有足够的权限执行此操作" + ) + + # 获取数据库管理器 + db_manager = get_db_manager() + + # 添加积分 + success = db_manager.add_user_points(user_id, points) + + if not success: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="添加积分失败" + ) + + # 获取更新后的用户信息 + user = db_manager.get_user_by_id(user_id) + + return { + "status": "success", + "message": f"成功为用户 {user_id} 增加 {points} 积分", + "current_points": user.get("points", 0) + } + +@router.post("/points/consume/{user_id}", response_model=Dict[str, Any]) +async def consume_user_points( + user_id: int, + points: int = Query(..., gt=0, description="消费的积分数量(必须大于0)"), + current_user: Dict[str, Any] = Depends(get_current_user) +) -> Dict[str, Any]: + """ + 用户消费积分 + + Args: + user_id: 用户ID + points: 消费的积分数量 + current_user: 当前用户信息 + + Returns: + 操作结果 + """ + # 只能消费自己的积分,或者SVIP用户可以操作所有人的积分 + if current_user.get("id") != user_id and current_user.get("level", 0) < 2: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="没有权限消费其他用户的积分" + ) + + # 获取数据库管理器 + db_manager = get_db_manager() + + # 消费积分 + success = db_manager.consume_user_points(user_id, points) + + if not success: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="积分消费失败,可能是积分不足" + ) + + # 获取更新后的用户信息 + user = db_manager.get_user_by_id(user_id) + + return { + "status": "success", + "message": f"成功消费 {points} 积分", + "remaining_points": user.get("points", 0) + } \ No newline at end of file diff --git a/cryptoai/utils/db_manager.py b/cryptoai/utils/db_manager.py index a7e6ab8..29864ba 100644 --- a/cryptoai/utils/db_manager.py +++ b/cryptoai/utils/db_manager.py @@ -74,6 +74,7 @@ class User(Base): nickname = Column(String(50), nullable=False, comment='昵称') password = Column(String(100), nullable=False, comment='密码') level = Column(Integer, nullable=False, default=0, comment='用户级别(0=普通用户,1=VIP,2=SVIP)') + points = Column(Integer, nullable=False, default=0, comment='用户积分') create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') # 关系 @@ -300,7 +301,7 @@ class DBManager: pass return False - def register_user(self, mail: str, nickname: str, password: str, level: int = 0) -> bool: + def register_user(self, mail: str, nickname: str, password: str, level: int = 0, points: int = 0) -> bool: """ 注册新用户 @@ -309,6 +310,7 @@ class DBManager: nickname: 昵称 password: 密码 level: 用户级别,默认为0(普通用户) + points: 初始积分,默认为0 Returns: 注册是否成功 @@ -337,6 +339,7 @@ class DBManager: nickname=nickname, password=password, # 实际应用中应该对密码进行哈希处理 level=level, + points=points, create_time=datetime.now() ) @@ -344,7 +347,7 @@ class DBManager: session.add(new_user) session.commit() - logger.info(f"成功注册用户: {mail}") + logger.info(f"成功注册用户: {mail},初始积分: {points}") return True except Exception as e: @@ -364,50 +367,6 @@ class DBManager: pass return False - def get_user_by_mail(self, mail: str) -> Optional[Dict[str, Any]]: - """ - 通过邮箱获取用户信息 - - Args: - mail: 邮箱 - - Returns: - 用户信息,如果用户不存在则返回None - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return None - - try: - # 创建会话 - session = self.Session() - - try: - # 查询用户 - user = session.query(User).filter(User.mail == mail).first() - - if user: - # 转换为字典 - return { - 'id': user.id, - 'mail': user.mail, - 'nickname': user.nickname, - 'level': user.level, - 'create_time': user.create_time - } - else: - return None - - finally: - session.close() - - except Exception as e: - logger.error(f"获取用户信息失败: {e}") - return None - def get_user_by_id(self, user_id: int) -> Optional[Dict[str, Any]]: """ 通过ID获取用户信息 @@ -440,6 +399,52 @@ class DBManager: 'mail': user.mail, 'nickname': user.nickname, 'level': user.level, + 'points': user.points, + 'create_time': user.create_time + } + else: + return None + + finally: + session.close() + + except Exception as e: + logger.error(f"获取用户信息失败: {e}") + return None + + def get_user_by_mail(self, mail: str) -> Optional[Dict[str, Any]]: + """ + 通过邮箱获取用户信息 + + Args: + mail: 邮箱 + + Returns: + 用户信息,如果用户不存在则返回None + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return None + + try: + # 创建会话 + session = self.Session() + + try: + # 查询用户 + user = session.query(User).filter(User.mail == mail).first() + + if user: + # 转换为字典 + return { + 'id': user.id, + 'mail': user.mail, + 'nickname': user.nickname, + 'level': user.level, + 'points': user.points, 'create_time': user.create_time } else: @@ -501,6 +506,117 @@ class DBManager: logger.error(f"创建数据库会话失败: {e}") return False + def add_user_points(self, user_id: int, points: int) -> bool: + """ + 为用户增加积分 + + Args: + user_id: 用户ID + points: 增加的积分数量(正数) + + Returns: + 操作是否成功 + """ + if points <= 0: + logger.warning(f"增加的积分必须是正数: {points}") + return False + + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return False + + try: + # 创建会话 + session = self.Session() + + try: + # 查询用户 + user = session.query(User).filter(User.id == user_id).first() + + if not user: + logger.warning(f"用户ID {user_id} 不存在") + return False + + # 增加积分 + user.points += points + session.commit() + + logger.info(f"成功为用户 {user.mail} 增加 {points} 积分,当前积分: {user.points}") + return True + + except Exception as e: + session.rollback() + logger.error(f"增加用户积分失败: {e}") + return False + + finally: + session.close() + + except Exception as e: + logger.error(f"创建数据库会话失败: {e}") + return False + + def consume_user_points(self, user_id: int, points: int) -> bool: + """ + 用户消费积分 + + Args: + user_id: 用户ID + points: 消费的积分数量(正数) + + Returns: + 操作是否成功 + """ + if points <= 0: + logger.warning(f"消费的积分必须是正数: {points}") + return False + + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return False + + try: + # 创建会话 + session = self.Session() + + try: + # 查询用户 + user = session.query(User).filter(User.id == user_id).first() + + if not user: + logger.warning(f"用户ID {user_id} 不存在") + return False + + # 检查积分是否足够 + if user.points < points: + logger.warning(f"用户 {user.mail} 积分不足,当前积分: {user.points},需要消费: {points}") + return False + + # 消费积分 + user.points -= points + session.commit() + + logger.info(f"成功从用户 {user.mail} 消费 {points} 积分,剩余积分: {user.points}") + return True + + except Exception as e: + session.rollback() + logger.error(f"消费用户积分失败: {e}") + return False + + finally: + session.close() + + except Exception as e: + logger.error(f"创建数据库会话失败: {e}") + return False + def get_agent_feeds(self, agent_name: Optional[str] = None, limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]: """ 获取AI Agent信息流 diff --git a/docker-compose.yml b/docker-compose.yml index 44ab693..499db9f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -29,7 +29,7 @@ services: cryptoai-api: build: . container_name: cryptoai-api - image: cryptoai-api:0.0.11 + image: cryptoai-api:0.0.12 restart: always ports: - "8000:8000"