#!/usr/bin/env python # -*- coding: utf-8 -*- """ 用户API路由模块,提供用户注册、登录和信息获取功能 """ import logging import hashlib import secrets from fastapi import APIRouter, HTTPException, status, Depends, Query from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from pydantic import BaseModel, EmailStr from typing import Dict, Any, List, Optional from datetime import datetime, timedelta import jwt from jwt.exceptions import PyJWTError from fastapi import Request from sqlalchemy.orm import Session from cryptoai.utils.db_manager import get_db from cryptoai.utils.email_service import get_email_service from cryptoai.models.user import UserManager from cryptoai.models.user_subscription import UserSubscriptionManager # 配置日志 logger = logging.getLogger("user_router") # 创建路由 router = APIRouter() # JWT配置 JWT_SECRET_KEY = "FX9Rf7YpvTRdXWmj0Osx8P9smSkUh6fW" # 实际应用中应该从环境变量或配置文件中获取 JWT_ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 180 * 24 * 60 * 60 # 180天 # 请求模型 class UserRegister(BaseModel): """用户注册请求模型""" mail: EmailStr nickname: str password: str verification_code: str class UserLogin(BaseModel): """用户登录请求模型""" mail: EmailStr password: str class SendVerificationCodeRequest(BaseModel): """发送验证码请求模型""" mail: EmailStr class ResetPasswordRequest(BaseModel): """重置密码请求模型""" mail: EmailStr verification_code: str new_password: str # 响应模型 class UserResponse(BaseModel): """用户信息响应模型""" id: int mail: str nickname: str level: int points: int create_time: datetime is_subscribed: bool member_name: str expire_time: datetime = None class TokenResponse(BaseModel): """令牌响应模型""" access_token: str token_type: str expires_in: int user_info: UserResponse # 工具函数 def hash_password(password: str) -> str: """对密码进行哈希处理""" return hashlib.sha256(password.encode()).hexdigest() def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None) -> str: """创建访问令牌""" to_encode = data.copy() if expires_delta: expire = datetime.now() + expires_delta else: expire = datetime.now() + timedelta(days=180) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM) return encoded_jwt async def get_current_user(request: Request, session: Session = Depends(get_db)) -> Dict[str, Any]: """获取当前用户""" credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="无效的身份验证凭据", headers={"WWW-Authenticate": "Bearer"}, ) try: token = request.headers.get("Authorization") if not token: raise credentials_exception token = token.split(" ")[1] print(f"token:{token}") payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM]) mail = payload.get("sub") print(f"mail:{mail}") if mail is None: raise credentials_exception except PyJWTError as e: print(f"PyJWTError: {e}") raise credentials_exception manager = UserManager(session) user = manager.get_user_by_mail(mail) if user is None: raise credentials_exception return user @router.post("/send-verification-code", response_model=Dict[str, Any]) async def send_verification_code(request: SendVerificationCodeRequest) -> Dict[str, Any]: """ 发送邮箱验证码 Args: request: 发送验证码请求 Returns: 发送结果 """ try: # 获取邮件服务 email_service = get_email_service() # 发送验证码邮件 result = email_service.send_verification_email(request.mail) if not result['success']: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=result['message'] ) return { "status": "success", "message": "验证码已发送到您的邮箱" } except HTTPException: raise except Exception as e: logger.error(f"发送验证码失败: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"发送验证码失败: {str(e)}" ) @router.put("/reset_password", response_model=Dict[str, Any]) async def reset_password(request: ResetPasswordRequest, session: Session = Depends(get_db)) -> Dict[str, Any]: """ 修改密码 """ try: # 获取数据库管理器 # 验证验证码 email_service = get_email_service() if not email_service.verify_code(request.mail, request.verification_code): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="验证码错误或已过期" ) # 更新密码 hashed_password = hash_password(request.new_password) manager = UserManager(session) success = manager.update_password(request.mail, hashed_password) if not success: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="密码修改失败" ) return { "status": "success", "message": "密码修改成功" } except Exception as e: logger.error(f"修改密码失败: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"修改密码失败: {str(e)}" ) @router.post("/register", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED) async def register_user(user: UserRegister, session: Session = Depends(get_db)) -> Dict[str, Any]: """ 注册新用户 Args: user: 用户注册信息 Returns: 注册成功的状态信息 """ try: # 获取邮件服务 email_service = get_email_service() # 验证验证码 if not email_service.verify_code(user.mail, user.verification_code): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="验证码错误或已过期" ) # 获取数据库管理器 manager = UserManager(session) # 对密码进行哈希处理 hashed_password = hash_password(user.password) # 注册用户 success = manager.register_user( mail=user.mail, nickname=user.nickname, password=hashed_password, level=0, # 默认为普通用户 points=1 # 默认初始积分为100 ) if not success: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="注册失败,邮箱可能已被使用" ) return { "status": "success", "message": "用户注册成功" } except HTTPException: raise except Exception as e: logger.error(f"注册用户失败: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"注册用户失败: {str(e)}" ) @router.post("/login", response_model=TokenResponse) async def login(loginData: UserLogin, session: Session = Depends(get_db)) -> TokenResponse: """ 用户登录 Args: form_data: 表单数据,包含用户名(邮箱)和密码 Returns: 访问令牌和用户信息 """ try: # 获取数据库管理器 manager = UserManager(session) # 获取用户信息 user = manager.get_user_by_mail(loginData.mail) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="邮箱或密码错误", headers={"WWW-Authenticate": "Bearer"}, ) # 验证密码 hashed_password = hash_password(loginData.password) # 查询用户的密码哈希 user = manager.get_user_by_mail_and_password(loginData.mail, hashed_password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="邮箱或密码错误", headers={"WWW-Authenticate": "Bearer"}, ) # 创建访问令牌,不过期 access_token = create_access_token(data={"sub": user["mail"]}) user_subscription_manager = UserSubscriptionManager(session) user_subscription = user_subscription_manager.get_subscription_by_user_id(user["id"]) is_subscribed = False expire_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") if user_subscription and user_subscription["expire_time"] > datetime.now(): member_name = "SVIP会员" if user_subscription["time_type"] == 2 else "VIP会员" is_subscribed = True expire_time = user_subscription["expire_time"].strftime("%Y-%m-%d %H:%M:%S") else: member_name = "普通会员" return TokenResponse( access_token=access_token, token_type="bearer", expires_in=ACCESS_TOKEN_EXPIRE_MINUTES, user_info=UserResponse( id=user["id"], mail=user["mail"], nickname=user["nickname"], level=user["level"], points=user["points"], create_time=user["create_time"], is_subscribed=is_subscribed, member_name=member_name, expire_time=expire_time ) ) except HTTPException: raise except Exception as e: logger.error(f"用户登录失败: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"用户登录失败: {str(e)}" ) @router.get("/me", response_model=UserResponse) async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user), session: Session = Depends(get_db)) -> UserResponse: """ 获取当前登录用户信息 Args: current_user: 当前用户信息,由依赖项提供 Returns: 用户信息 """ user_subscription_manager = UserSubscriptionManager(session) user_subscription = user_subscription_manager.get_subscription_by_user_id(current_user["id"]) is_subscribed = False expire_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") if user_subscription and user_subscription["expire_time"] > datetime.now(): member_name = "SVIP会员" if user_subscription["time_type"] == 2 else "VIP会员" is_subscribed = True expire_time = user_subscription["expire_time"].strftime("%Y-%m-%d %H:%M:%S") else: member_name = "普通会员" user = UserResponse( id=current_user["id"], mail=current_user["mail"], nickname=current_user["nickname"], level=current_user["level"], points=current_user["points"], create_time=current_user["create_time"], is_subscribed=is_subscribed, member_name=member_name ) if is_subscribed and expire_time: user.expire_time = expire_time return user @router.put("/level/{user_id}", response_model=Dict[str, Any]) async def update_user_level( user_id: int, level: int = Query(..., description="用户级别(0=普通用户,1=VIP,2=SVIP)"), current_user: Dict[str, Any] = Depends(get_current_user), session: Session = Depends(get_db) ) -> Dict[str, Any]: """ 更新用户级别(需要管理员权限) Args: user_id: 用户ID level: 新的用户级别 current_user: 当前用户信息 Returns: 更新结果 """ # 检查权限(只有SVIP用户才能更新用户级别) if current_user.get("level", 0) < 2: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="没有足够的权限执行此操作" ) # 获取数据库管理器 manager = UserManager(session) # 更新用户级别 success = manager.update_user_level(user_id, level) if not success: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="更新用户级别失败" ) return { "status": "success", "message": f"成功更新用户 {user_id} 的级别为 {level}" } @router.get("/points/{user_id}", response_model=Dict[str, Any]) async def get_user_points( user_id: int, current_user: Dict[str, Any] = Depends(get_current_user), session: Session = Depends(get_db) ) -> Dict[str, Any]: """ 获取用户积分 Args: user_id: 用户ID current_user: 当前用户信息 Returns: 用户积分信息 """ # 只能查看自己的积分,或者SVIP用户可以查看所有人的积分 if current_user.get("id") != user_id and current_user.get("level", 0) < 2: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="没有权限查看其他用户的积分" ) # 获取数据库管理器 manager = UserManager(session) # 获取用户信息 user = manager.get_user_by_id(user_id) if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"用户ID {user_id} 不存在" ) return { "user_id": user_id, "points": user.get("points", 0), "nickname": user.get("nickname", ""), "level": user.get("level", 0) } @router.post("/points/add/{user_id}", response_model=Dict[str, Any]) async def add_user_points( user_id: int, points: int = Query(..., gt=0, description="增加的积分数量(必须大于0)"), current_user: Dict[str, Any] = Depends(get_current_user), session: Session = Depends(get_db) ) -> Dict[str, Any]: """ 为用户增加积分(需要管理员权限) Args: user_id: 用户ID points: 增加的积分数量 current_user: 当前用户信息 Returns: 操作结果 """ # 检查权限(只有SVIP用户才能添加积分) if current_user.get("level", 0) < 2: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="没有足够的权限执行此操作" ) # 获取数据库管理器 manager = UserManager(session) # 添加积分 success = manager.add_user_points(user_id, points) if not success: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="添加积分失败" ) # 获取更新后的用户信息 user = manager.get_user_by_id(user_id) return { "status": "success", "message": f"成功为用户 {user_id} 增加 {points} 积分", "current_points": user.get("points", 0) } @router.post("/points/consume/{user_id}", response_model=Dict[str, Any]) async def consume_user_points( user_id: int, points: int = Query(..., gt=0, description="消费的积分数量(必须大于0)"), current_user: Dict[str, Any] = Depends(get_current_user), session: Session = Depends(get_db) ) -> Dict[str, Any]: """ 用户消费积分 Args: user_id: 用户ID points: 消费的积分数量 current_user: 当前用户信息 Returns: 操作结果 """ # 只能消费自己的积分,或者SVIP用户可以操作所有人的积分 if current_user.get("id") != user_id and current_user.get("level", 0) < 2: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="没有权限消费其他用户的积分" ) # 获取数据库管理器 manager = UserManager(session) # 消费积分 success = manager.consume_user_points(user_id, points) if not success: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="积分消费失败,可能是积分不足" ) # 获取更新后的用户信息 user = manager.get_user_by_id(user_id) return { "status": "success", "message": f"成功消费 {points} 积分", "remaining_points": user.get("points", 0) }