crypto.ai/cryptoai/routes/user.py
2025-05-06 16:37:49 +08:00

354 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
用户API路由模块提供用户注册、登录和信息获取功能
"""
import logging
import hashlib
import secrets
from fastapi import APIRouter, HTTPException, status, Depends, Query
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from pydantic import BaseModel, EmailStr
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
import jwt
from jwt.exceptions import PyJWTError
from fastapi import Request
from cryptoai.utils.db_manager import get_db_manager
from cryptoai.utils.email_service import get_email_service
# 配置日志
logger = logging.getLogger("user_router")
# 创建路由
router = APIRouter()
# JWT配置
JWT_SECRET_KEY = "FX9Rf7YpvTRdXWmj0Osx8P9smSkUh6fW" # 实际应用中应该从环境变量或配置文件中获取
JWT_ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7天过期
# 请求模型
class UserRegister(BaseModel):
"""用户注册请求模型"""
mail: EmailStr
nickname: str
password: str
verification_code: str
class UserLogin(BaseModel):
"""用户登录请求模型"""
mail: EmailStr
password: str
class SendVerificationCodeRequest(BaseModel):
"""发送验证码请求模型"""
mail: EmailStr
# 响应模型
class UserResponse(BaseModel):
"""用户信息响应模型"""
id: int
mail: str
nickname: str
level: int
create_time: datetime
class TokenResponse(BaseModel):
"""令牌响应模型"""
access_token: str
token_type: str
expires_in: int
user_info: UserResponse
# 工具函数
def hash_password(password: str) -> str:
"""对密码进行哈希处理"""
return hashlib.sha256(password.encode()).hexdigest()
def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None) -> str:
"""创建访问令牌"""
to_encode = data.copy()
if expires_delta:
expire = datetime.now() + expires_delta
else:
expire = datetime.now() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
return encoded_jwt
async def get_current_user(request: Request) -> Dict[str, Any]:
"""获取当前用户"""
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的身份验证凭据",
headers={"WWW-Authenticate": "Bearer"},
)
try:
token = request.headers.get("Authorization")
if not token:
raise credentials_exception
token = token.split(" ")[1]
print(f"token:{token}")
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
mail = payload.get("sub")
print(f"mail:{mail}")
if mail is None:
raise credentials_exception
except PyJWTError as e:
print(f"PyJWTError: {e}")
raise credentials_exception
db_manager = get_db_manager()
user = db_manager.get_user_by_mail(mail)
if user is None:
raise credentials_exception
return user
@router.post("/send-verification-code", response_model=Dict[str, Any])
async def send_verification_code(request: SendVerificationCodeRequest) -> Dict[str, Any]:
"""
发送邮箱验证码
Args:
request: 发送验证码请求
Returns:
发送结果
"""
try:
# 获取数据库管理器
db_manager = get_db_manager()
# 检查邮箱是否已被注册
user = db_manager.get_user_by_mail(request.mail)
if user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="该邮箱已被注册"
)
# 获取邮件服务
email_service = get_email_service()
# 发送验证码邮件
result = email_service.send_verification_email(request.mail)
if not result['success']:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=result['message']
)
return {
"status": "success",
"message": "验证码已发送到您的邮箱"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"发送验证码失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"发送验证码失败: {str(e)}"
)
@router.post("/register", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
async def register_user(user: UserRegister) -> Dict[str, Any]:
"""
注册新用户
Args:
user: 用户注册信息
Returns:
注册成功的状态信息
"""
try:
# 获取邮件服务
email_service = get_email_service()
# 验证验证码
if not email_service.verify_code(user.mail, user.verification_code):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="验证码错误或已过期"
)
# 获取数据库管理器
db_manager = get_db_manager()
# 对密码进行哈希处理
hashed_password = hash_password(user.password)
# 注册用户
success = db_manager.register_user(
mail=user.mail,
nickname=user.nickname,
password=hashed_password,
level=0 # 默认为普通用户
)
if not success:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="注册失败,邮箱可能已被使用"
)
return {
"status": "success",
"message": "用户注册成功"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"注册用户失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"注册用户失败: {str(e)}"
)
@router.post("/login", response_model=TokenResponse)
async def login(loginData: UserLogin) -> TokenResponse:
"""
用户登录
Args:
form_data: 表单数据,包含用户名(邮箱)和密码
Returns:
访问令牌和用户信息
"""
try:
# 获取数据库管理器
db_manager = get_db_manager()
# 获取用户信息
user = db_manager.get_user_by_mail(loginData.mail)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="邮箱或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
# 验证密码
hashed_password = hash_password(loginData.password)
# 查询用户的密码哈希
session = db_manager.Session()
try:
from cryptoai.utils.db_manager import User
db_user = session.query(User).filter(User.mail == loginData.mail).first()
if not db_user or db_user.password != hashed_password:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="邮箱或密码错误",
headers={"WWW-Authenticate": "Bearer"},
)
finally:
session.close()
# 创建访问令牌,不过期
access_token_expires = None
access_token = create_access_token(
data={"sub": user["mail"]}, expires_delta=access_token_expires
)
return TokenResponse(
access_token=access_token,
token_type="bearer",
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
user_info=UserResponse(
id=user["id"],
mail=user["mail"],
nickname=user["nickname"],
level=user["level"],
create_time=user["create_time"]
)
)
except HTTPException:
raise
except Exception as e:
logger.error(f"用户登录失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"用户登录失败: {str(e)}"
)
@router.get("/me", response_model=UserResponse)
async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user)) -> UserResponse:
"""
获取当前登录用户信息
Args:
current_user: 当前用户信息,由依赖项提供
Returns:
用户信息
"""
return UserResponse(
id=current_user["id"],
mail=current_user["mail"],
nickname=current_user["nickname"],
level=current_user["level"],
create_time=current_user["create_time"]
)
@router.put("/level/{user_id}", response_model=Dict[str, Any])
async def update_user_level(
user_id: int,
level: int = Query(..., description="用户级别(0=普通用户,1=VIP,2=SVIP)"),
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
更新用户级别(需要管理员权限)
Args:
user_id: 用户ID
level: 新的用户级别
current_user: 当前用户信息,由依赖项提供
Returns:
更新成功的状态信息
"""
# 简单的权限检查(实际应用中应该有更完善的权限管理)
if current_user["level"] < 2: # 假设SVIP用户有管理权限
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="没有足够的权限执行此操作"
)
try:
# 获取数据库管理器
db_manager = get_db_manager()
# 更新用户级别
success = db_manager.update_user_level(user_id, level)
if not success:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"用户ID {user_id} 不存在"
)
return {
"status": "success",
"message": f"成功更新用户级别为 {level}"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新用户级别失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"更新用户级别失败: {str(e)}"
)