crypto.ai/cryptoai/routes/user.py
2025-05-24 15:33:38 +08:00

512 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
用户API路由模块提供用户注册、登录和信息获取功能
"""
import logging
import hashlib
import secrets
from fastapi import APIRouter, HTTPException, status, Depends, Query
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel, EmailStr
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
import jwt
from jwt.exceptions import PyJWTError
from fastapi import Request
from cryptoai.utils.db_manager import get_db_manager
from cryptoai.utils.email_service import get_email_service
# 配置日志
logger = logging.getLogger("user_router")
# 创建路由
router = APIRouter()
# JWT配置
JWT_SECRET_KEY = "FX9Rf7YpvTRdXWmj0Osx8P9smSkUh6fW" # 实际应用中应该从环境变量或配置文件中获取
JWT_ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 0 # 用户登录后不过期
# 请求模型
class UserRegister(BaseModel):
"""用户注册请求模型"""
mail: EmailStr
nickname: str
password: str
verification_code: str
class UserLogin(BaseModel):
"""用户登录请求模型"""
mail: EmailStr
password: str
class SendVerificationCodeRequest(BaseModel):
"""发送验证码请求模型"""
mail: EmailStr
class ResetPasswordRequest(BaseModel):
"""重置密码请求模型"""
mail: EmailStr
verification_code: str
new_password: str
# 响应模型
class UserResponse(BaseModel):
"""用户信息响应模型"""
id: int
mail: str
nickname: str
level: int
points: int
create_time: datetime
class TokenResponse(BaseModel):
"""令牌响应模型"""
access_token: str
token_type: str
expires_in: int
user_info: UserResponse
# 工具函数
def hash_password(password: str) -> str:
"""对密码进行哈希处理"""
return hashlib.sha256(password.encode()).hexdigest()
def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None) -> str:
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.now() + expires_delta
else:
expire = datetime.now() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
return encoded_jwt
async def get_current_user(request: Request) -> Dict[str, Any]:
"""获取当前用户"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的身份验证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
try:
token = request.headers.get("Authorization")
if not token:
raise credentials_exception
token = token.split(" ")[1]
print(f"token:{token}")
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
mail = payload.get("sub")
print(f"mail:{mail}")
if mail is None:
raise credentials_exception
except PyJWTError as e:
print(f"PyJWTError: {e}")
raise credentials_exception
db_manager = get_db_manager()
user = db_manager.user_manager.get_user_by_mail(mail)
if user is None:
raise credentials_exception
return user
@router.post("/send-verification-code", response_model=Dict[str, Any])
async def send_verification_code(request: SendVerificationCodeRequest) -> Dict[str, Any]:
"""
发送邮箱验证码
Args:
request: 发送验证码请求
Returns:
发送结果
"""
try:
# 获取邮件服务
email_service = get_email_service()
# 发送验证码邮件
result = email_service.send_verification_email(request.mail)
if not result['success']:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=result['message']
)
return {
"status": "success",
"message": "验证码已发送到您的邮箱"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"发送验证码失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"发送验证码失败: {str(e)}"
)
@router.put("/reset_password", response_model=Dict[str, Any])
async def reset_password(request: ResetPasswordRequest) -> Dict[str, Any]:
"""
修改密码
"""
try:
# 获取数据库管理器
db_manager = get_db_manager()
# 验证验证码
email_service = get_email_service()
if not email_service.verify_code(request.mail, request.verification_code):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="验证码错误或已过期"
)
# 更新密码
hashed_password = hash_password(request.new_password)
success = db_manager.user_manager.update_password(request.mail, hashed_password)
if not success:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="密码修改失败"
)
return {
"status": "success",
"message": "密码修改成功"
}
except Exception as e:
logger.error(f"修改密码失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"修改密码失败: {str(e)}"
)
@router.post("/register", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
async def register_user(user: UserRegister) -> Dict[str, Any]:
"""
注册新用户
Args:
user: 用户注册信息
Returns:
注册成功的状态信息
"""
try:
# 获取邮件服务
email_service = get_email_service()
# 验证验证码
if not email_service.verify_code(user.mail, user.verification_code):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="验证码错误或已过期"
)
# 获取数据库管理器
db_manager = get_db_manager()
# 对密码进行哈希处理
hashed_password = hash_password(user.password)
# 注册用户
success = db_manager.user_manager.register_user(
mail=user.mail,
nickname=user.nickname,
password=hashed_password,
level=0, # 默认为普通用户
points=100 # 默认初始积分为100
)
if not success:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="注册失败,邮箱可能已被使用"
)
return {
"status": "success",
"message": "用户注册成功"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"注册用户失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"注册用户失败: {str(e)}"
)
@router.post("/login", response_model=TokenResponse)
async def login(loginData: UserLogin) -> TokenResponse:
"""
用户登录
Args:
form_data: 表单数据,包含用户名(邮箱)和密码
Returns:
访问令牌和用户信息
"""
try:
# 获取数据库管理器
db_manager = get_db_manager()
# 获取用户信息
user = db_manager.user_manager.get_user_by_mail(loginData.mail)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="邮箱或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
# 验证密码
hashed_password = hash_password(loginData.password)
# 查询用户的密码哈希
session = db_manager.Session()
try:
user = db_manager.user_manager.get_user_by_mail_and_password(loginData.mail, hashed_password)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="邮箱或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
finally:
session.close()
# 创建访问令牌,不过期
access_token_expires = None
access_token = create_access_token(
data={"sub": user["mail"]}, expires_delta=access_token_expires
)
return TokenResponse(
access_token=access_token,
token_type="bearer",
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
user_info=UserResponse(
id=user["id"],
mail=user["mail"],
nickname=user["nickname"],
level=user["level"],
points=user["points"],
create_time=user["create_time"]
)
)
except HTTPException:
raise
except Exception as e:
logger.error(f"用户登录失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"用户登录失败: {str(e)}"
)
@router.get("/me", response_model=UserResponse)
async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user)) -> UserResponse:
"""
获取当前登录用户信息
Args:
current_user: 当前用户信息,由依赖项提供
Returns:
用户信息
"""
return UserResponse(
id=current_user["id"],
mail=current_user["mail"],
nickname=current_user["nickname"],
level=current_user["level"],
points=current_user["points"],
create_time=current_user["create_time"]
)
@router.put("/level/{user_id}", response_model=Dict[str, Any])
async def update_user_level(
user_id: int,
level: int = Query(..., description="用户级别(0=普通用户,1=VIP,2=SVIP)"),
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
更新用户级别(需要管理员权限)
Args:
user_id: 用户ID
level: 新的用户级别
current_user: 当前用户信息
Returns:
更新结果
"""
# 检查权限只有SVIP用户才能更新用户级别
if current_user.get("level", 0) < 2:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="没有足够的权限执行此操作"
)
# 获取数据库管理器
db_manager = get_db_manager()
# 更新用户级别
success = db_manager.user_manager.update_user_level(user_id, level)
if not success:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="更新用户级别失败"
)
return {
"status": "success",
"message": f"成功更新用户 {user_id} 的级别为 {level}"
}
@router.get("/points/{user_id}", response_model=Dict[str, Any])
async def get_user_points(
user_id: int,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
获取用户积分
Args:
user_id: 用户ID
current_user: 当前用户信息
Returns:
用户积分信息
"""
# 只能查看自己的积分或者SVIP用户可以查看所有人的积分
if current_user.get("id") != user_id and current_user.get("level", 0) < 2:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="没有权限查看其他用户的积分"
)
# 获取数据库管理器
db_manager = get_db_manager()
# 获取用户信息
user = db_manager.user_manager.get_user_by_id(user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"用户ID {user_id} 不存在"
)
return {
"user_id": user_id,
"points": user.get("points", 0),
"nickname": user.get("nickname", ""),
"level": user.get("level", 0)
}
@router.post("/points/add/{user_id}", response_model=Dict[str, Any])
async def add_user_points(
user_id: int,
points: int = Query(..., gt=0, description="增加的积分数量必须大于0"),
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
为用户增加积分(需要管理员权限)
Args:
user_id: 用户ID
points: 增加的积分数量
current_user: 当前用户信息
Returns:
操作结果
"""
# 检查权限只有SVIP用户才能添加积分
if current_user.get("level", 0) < 2:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="没有足够的权限执行此操作"
)
# 获取数据库管理器
db_manager = get_db_manager()
# 添加积分
success = db_manager.user_manager.add_user_points(user_id, points)
if not success:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="添加积分失败"
)
# 获取更新后的用户信息
user = db_manager.user_manager.get_user_by_id(user_id)
return {
"status": "success",
"message": f"成功为用户 {user_id} 增加 {points} 积分",
"current_points": user.get("points", 0)
}
@router.post("/points/consume/{user_id}", response_model=Dict[str, Any])
async def consume_user_points(
user_id: int,
points: int = Query(..., gt=0, description="消费的积分数量必须大于0"),
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
用户消费积分
Args:
user_id: 用户ID
points: 消费的积分数量
current_user: 当前用户信息
Returns:
操作结果
"""
# 只能消费自己的积分或者SVIP用户可以操作所有人的积分
if current_user.get("id") != user_id and current_user.get("level", 0) < 2:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="没有权限消费其他用户的积分"
)
# 获取数据库管理器
db_manager = get_db_manager()
# 消费积分
success = db_manager.user_manager.consume_user_points(user_id, points)
if not success:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="积分消费失败,可能是积分不足"
)
# 获取更新后的用户信息
user = db_manager.user_manager.get_user_by_id(user_id)
return {
"status": "success",
"message": f"成功消费 {points} 积分",
"remaining_points": user.get("points", 0)
}