354 lines
10 KiB
Python
354 lines
10 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
用户API路由模块,提供用户注册、登录和信息获取功能
|
||
"""
|
||
|
||
import logging
|
||
import hashlib
|
||
import secrets
|
||
from fastapi import APIRouter, HTTPException, status, Depends, Query
|
||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||
from pydantic import BaseModel, EmailStr
|
||
from typing import Dict, Any, List, Optional
|
||
from datetime import datetime, timedelta
|
||
import jwt
|
||
from jwt.exceptions import PyJWTError
|
||
from fastapi import Request
|
||
|
||
from cryptoai.utils.db_manager import get_db_manager
|
||
from cryptoai.utils.email_service import get_email_service
|
||
|
||
# 配置日志
|
||
logger = logging.getLogger("user_router")
|
||
|
||
# 创建路由
|
||
router = APIRouter()
|
||
|
||
# JWT配置
|
||
JWT_SECRET_KEY = "FX9Rf7YpvTRdXWmj0Osx8P9smSkUh6fW" # 实际应用中应该从环境变量或配置文件中获取
|
||
JWT_ALGORITHM = "HS256"
|
||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7天过期
|
||
|
||
# 请求模型
|
||
class UserRegister(BaseModel):
|
||
"""用户注册请求模型"""
|
||
mail: EmailStr
|
||
nickname: str
|
||
password: str
|
||
verification_code: str
|
||
|
||
class UserLogin(BaseModel):
|
||
"""用户登录请求模型"""
|
||
mail: EmailStr
|
||
password: str
|
||
|
||
class SendVerificationCodeRequest(BaseModel):
|
||
"""发送验证码请求模型"""
|
||
mail: EmailStr
|
||
|
||
# 响应模型
|
||
class UserResponse(BaseModel):
|
||
"""用户信息响应模型"""
|
||
id: int
|
||
mail: str
|
||
nickname: str
|
||
level: int
|
||
create_time: datetime
|
||
|
||
class TokenResponse(BaseModel):
|
||
"""令牌响应模型"""
|
||
access_token: str
|
||
token_type: str
|
||
expires_in: int
|
||
user_info: UserResponse
|
||
|
||
# 工具函数
|
||
def hash_password(password: str) -> str:
|
||
"""对密码进行哈希处理"""
|
||
return hashlib.sha256(password.encode()).hexdigest()
|
||
|
||
def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None) -> str:
|
||
"""创建访问令牌"""
|
||
to_encode = data.copy()
|
||
if expires_delta:
|
||
expire = datetime.now() + expires_delta
|
||
else:
|
||
expire = datetime.now() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||
to_encode.update({"exp": expire})
|
||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||
return encoded_jwt
|
||
|
||
async def get_current_user(request: Request) -> Dict[str, Any]:
|
||
"""获取当前用户"""
|
||
credentials_exception = HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="无效的身份验证凭据",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
try:
|
||
token = request.headers.get("Authorization")
|
||
if not token:
|
||
raise credentials_exception
|
||
token = token.split(" ")[1]
|
||
print(f"token:{token}")
|
||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||
mail = payload.get("sub")
|
||
print(f"mail:{mail}")
|
||
if mail is None:
|
||
raise credentials_exception
|
||
except PyJWTError as e:
|
||
print(f"PyJWTError: {e}")
|
||
raise credentials_exception
|
||
|
||
db_manager = get_db_manager()
|
||
user = db_manager.get_user_by_mail(mail)
|
||
if user is None:
|
||
raise credentials_exception
|
||
return user
|
||
|
||
@router.post("/send-verification-code", response_model=Dict[str, Any])
|
||
async def send_verification_code(request: SendVerificationCodeRequest) -> Dict[str, Any]:
|
||
"""
|
||
发送邮箱验证码
|
||
|
||
Args:
|
||
request: 发送验证码请求
|
||
|
||
Returns:
|
||
发送结果
|
||
"""
|
||
try:
|
||
# 获取数据库管理器
|
||
db_manager = get_db_manager()
|
||
|
||
# 检查邮箱是否已被注册
|
||
user = db_manager.get_user_by_mail(request.mail)
|
||
if user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="该邮箱已被注册"
|
||
)
|
||
|
||
# 获取邮件服务
|
||
email_service = get_email_service()
|
||
|
||
# 发送验证码邮件
|
||
result = email_service.send_verification_email(request.mail)
|
||
|
||
if not result['success']:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=result['message']
|
||
)
|
||
|
||
return {
|
||
"status": "success",
|
||
"message": "验证码已发送到您的邮箱"
|
||
}
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"发送验证码失败: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"发送验证码失败: {str(e)}"
|
||
)
|
||
|
||
@router.post("/register", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
|
||
async def register_user(user: UserRegister) -> Dict[str, Any]:
|
||
"""
|
||
注册新用户
|
||
|
||
Args:
|
||
user: 用户注册信息
|
||
|
||
Returns:
|
||
注册成功的状态信息
|
||
"""
|
||
try:
|
||
# 获取邮件服务
|
||
email_service = get_email_service()
|
||
|
||
# 验证验证码
|
||
if not email_service.verify_code(user.mail, user.verification_code):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="验证码错误或已过期"
|
||
)
|
||
|
||
# 获取数据库管理器
|
||
db_manager = get_db_manager()
|
||
|
||
# 对密码进行哈希处理
|
||
hashed_password = hash_password(user.password)
|
||
|
||
# 注册用户
|
||
success = db_manager.register_user(
|
||
mail=user.mail,
|
||
nickname=user.nickname,
|
||
password=hashed_password,
|
||
level=0 # 默认为普通用户
|
||
)
|
||
|
||
if not success:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="注册失败,邮箱可能已被使用"
|
||
)
|
||
|
||
return {
|
||
"status": "success",
|
||
"message": "用户注册成功"
|
||
}
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"注册用户失败: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"注册用户失败: {str(e)}"
|
||
)
|
||
|
||
@router.post("/login", response_model=TokenResponse)
|
||
async def login(loginData: UserLogin) -> TokenResponse:
|
||
"""
|
||
用户登录
|
||
|
||
Args:
|
||
form_data: 表单数据,包含用户名(邮箱)和密码
|
||
|
||
Returns:
|
||
访问令牌和用户信息
|
||
"""
|
||
try:
|
||
# 获取数据库管理器
|
||
db_manager = get_db_manager()
|
||
|
||
# 获取用户信息
|
||
user = db_manager.get_user_by_mail(loginData.mail)
|
||
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="邮箱或密码错误",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
# 验证密码
|
||
hashed_password = hash_password(loginData.password)
|
||
|
||
# 查询用户的密码哈希
|
||
session = db_manager.Session()
|
||
try:
|
||
from cryptoai.utils.db_manager import User
|
||
db_user = session.query(User).filter(User.mail == loginData.mail).first()
|
||
if not db_user or db_user.password != hashed_password:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="邮箱或密码错误",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
finally:
|
||
session.close()
|
||
|
||
# 创建访问令牌,不过期
|
||
access_token_expires = None
|
||
access_token = create_access_token(
|
||
data={"sub": user["mail"]}, expires_delta=access_token_expires
|
||
)
|
||
|
||
return TokenResponse(
|
||
access_token=access_token,
|
||
token_type="bearer",
|
||
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||
user_info=UserResponse(
|
||
id=user["id"],
|
||
mail=user["mail"],
|
||
nickname=user["nickname"],
|
||
level=user["level"],
|
||
create_time=user["create_time"]
|
||
)
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"用户登录失败: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"用户登录失败: {str(e)}"
|
||
)
|
||
|
||
@router.get("/me", response_model=UserResponse)
|
||
async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user)) -> UserResponse:
|
||
"""
|
||
获取当前登录用户信息
|
||
|
||
Args:
|
||
current_user: 当前用户信息,由依赖项提供
|
||
|
||
Returns:
|
||
用户信息
|
||
"""
|
||
return UserResponse(
|
||
id=current_user["id"],
|
||
mail=current_user["mail"],
|
||
nickname=current_user["nickname"],
|
||
level=current_user["level"],
|
||
create_time=current_user["create_time"]
|
||
)
|
||
|
||
@router.put("/level/{user_id}", response_model=Dict[str, Any])
|
||
async def update_user_level(
|
||
user_id: int,
|
||
level: int = Query(..., description="用户级别(0=普通用户,1=VIP,2=SVIP)"),
|
||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
更新用户级别(需要管理员权限)
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
level: 新的用户级别
|
||
current_user: 当前用户信息,由依赖项提供
|
||
|
||
Returns:
|
||
更新成功的状态信息
|
||
"""
|
||
# 简单的权限检查(实际应用中应该有更完善的权限管理)
|
||
if current_user["level"] < 2: # 假设SVIP用户有管理权限
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="没有足够的权限执行此操作"
|
||
)
|
||
|
||
try:
|
||
# 获取数据库管理器
|
||
db_manager = get_db_manager()
|
||
|
||
# 更新用户级别
|
||
success = db_manager.update_user_level(user_id, level)
|
||
|
||
if not success:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"用户ID {user_id} 不存在"
|
||
)
|
||
|
||
return {
|
||
"status": "success",
|
||
"message": f"成功更新用户级别为 {level}"
|
||
}
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"更新用户级别失败: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"更新用户级别失败: {str(e)}"
|
||
) |