554 lines
17 KiB
Python
554 lines
17 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
用户API路由模块,提供用户注册、登录和信息获取功能
|
||
"""
|
||
|
||
import logging
|
||
import hashlib
|
||
import secrets
|
||
from fastapi import APIRouter, HTTPException, status, Depends, Query
|
||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||
from pydantic import BaseModel, EmailStr
|
||
from typing import Dict, Any, List, Optional
|
||
from datetime import datetime, timedelta
|
||
import jwt
|
||
from jwt.exceptions import PyJWTError
|
||
from fastapi import Request
|
||
from sqlalchemy.orm import Session
|
||
from cryptoai.utils.db_manager import get_db
|
||
from cryptoai.utils.email_service import get_email_service
|
||
from cryptoai.models.user import UserManager
|
||
from cryptoai.models.user_subscription import UserSubscriptionManager
|
||
|
||
# 配置日志
|
||
logger = logging.getLogger("user_router")
|
||
|
||
# 创建路由
|
||
router = APIRouter()
|
||
|
||
# JWT配置
|
||
JWT_SECRET_KEY = "FX9Rf7YpvTRdXWmj0Osx8P9smSkUh6fW" # 实际应用中应该从环境变量或配置文件中获取
|
||
JWT_ALGORITHM = "HS256"
|
||
ACCESS_TOKEN_EXPIRE_MINUTES = 180 * 24 * 60 * 60 # 180天
|
||
|
||
# 请求模型
|
||
class UserRegister(BaseModel):
|
||
"""用户注册请求模型"""
|
||
mail: EmailStr
|
||
nickname: str
|
||
password: str
|
||
verification_code: str
|
||
|
||
class UserLogin(BaseModel):
|
||
"""用户登录请求模型"""
|
||
mail: EmailStr
|
||
password: str
|
||
|
||
class SendVerificationCodeRequest(BaseModel):
|
||
"""发送验证码请求模型"""
|
||
mail: EmailStr
|
||
|
||
class ResetPasswordRequest(BaseModel):
|
||
"""重置密码请求模型"""
|
||
mail: EmailStr
|
||
verification_code: str
|
||
new_password: str
|
||
|
||
# 响应模型
|
||
class UserResponse(BaseModel):
|
||
"""用户信息响应模型"""
|
||
id: int
|
||
mail: str
|
||
nickname: str
|
||
level: int
|
||
points: int
|
||
create_time: datetime
|
||
is_subscribed: bool
|
||
member_name: str
|
||
expire_time: datetime = None
|
||
|
||
class TokenResponse(BaseModel):
|
||
"""令牌响应模型"""
|
||
access_token: str
|
||
token_type: str
|
||
expires_in: int
|
||
user_info: UserResponse
|
||
|
||
# 工具函数
|
||
def hash_password(password: str) -> str:
|
||
"""对密码进行哈希处理"""
|
||
return hashlib.sha256(password.encode()).hexdigest()
|
||
|
||
def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None) -> str:
|
||
"""创建访问令牌"""
|
||
to_encode = data.copy()
|
||
if expires_delta:
|
||
expire = datetime.now() + expires_delta
|
||
else:
|
||
expire = datetime.now() + timedelta(days=180)
|
||
to_encode.update({"exp": expire})
|
||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||
return encoded_jwt
|
||
|
||
async def get_current_user(request: Request, session: Session = Depends(get_db)) -> Dict[str, Any]:
|
||
"""获取当前用户"""
|
||
credentials_exception = HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="无效的身份验证凭据",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
try:
|
||
token = request.headers.get("Authorization")
|
||
if not token:
|
||
raise credentials_exception
|
||
token = token.split(" ")[1]
|
||
print(f"token:{token}")
|
||
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=[JWT_ALGORITHM])
|
||
mail = payload.get("sub")
|
||
print(f"mail:{mail}")
|
||
if mail is None:
|
||
raise credentials_exception
|
||
except PyJWTError as e:
|
||
print(f"PyJWTError: {e}")
|
||
raise credentials_exception
|
||
|
||
manager = UserManager(session)
|
||
user = manager.get_user_by_mail(mail)
|
||
if user is None:
|
||
raise credentials_exception
|
||
return user
|
||
|
||
@router.post("/send-verification-code", response_model=Dict[str, Any])
|
||
async def send_verification_code(request: SendVerificationCodeRequest) -> Dict[str, Any]:
|
||
"""
|
||
发送邮箱验证码
|
||
|
||
Args:
|
||
request: 发送验证码请求
|
||
|
||
Returns:
|
||
发送结果
|
||
"""
|
||
try:
|
||
# 获取邮件服务
|
||
email_service = get_email_service()
|
||
|
||
# 发送验证码邮件
|
||
result = email_service.send_verification_email(request.mail)
|
||
|
||
if not result['success']:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=result['message']
|
||
)
|
||
|
||
return {
|
||
"status": "success",
|
||
"message": "验证码已发送到您的邮箱"
|
||
}
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"发送验证码失败: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"发送验证码失败: {str(e)}"
|
||
)
|
||
|
||
@router.put("/reset_password", response_model=Dict[str, Any])
|
||
async def reset_password(request: ResetPasswordRequest, session: Session = Depends(get_db)) -> Dict[str, Any]:
|
||
"""
|
||
修改密码
|
||
"""
|
||
try:
|
||
# 获取数据库管理器
|
||
|
||
|
||
# 验证验证码
|
||
email_service = get_email_service()
|
||
if not email_service.verify_code(request.mail, request.verification_code):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="验证码错误或已过期"
|
||
)
|
||
|
||
# 更新密码
|
||
hashed_password = hash_password(request.new_password)
|
||
|
||
manager = UserManager(session)
|
||
success = manager.update_password(request.mail, hashed_password)
|
||
if not success:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="密码修改失败"
|
||
)
|
||
|
||
return {
|
||
"status": "success",
|
||
"message": "密码修改成功"
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"修改密码失败: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"修改密码失败: {str(e)}"
|
||
)
|
||
|
||
|
||
@router.post("/register", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
|
||
async def register_user(user: UserRegister, session: Session = Depends(get_db)) -> Dict[str, Any]:
|
||
"""
|
||
注册新用户
|
||
|
||
Args:
|
||
user: 用户注册信息
|
||
|
||
Returns:
|
||
注册成功的状态信息
|
||
"""
|
||
try:
|
||
# 获取邮件服务
|
||
email_service = get_email_service()
|
||
|
||
# 验证验证码
|
||
if not email_service.verify_code(user.mail, user.verification_code):
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="验证码错误或已过期"
|
||
)
|
||
|
||
# 获取数据库管理器
|
||
manager = UserManager(session)
|
||
|
||
# 对密码进行哈希处理
|
||
hashed_password = hash_password(user.password)
|
||
|
||
# 注册用户
|
||
success = manager.register_user(
|
||
mail=user.mail,
|
||
nickname=user.nickname,
|
||
password=hashed_password,
|
||
level=0, # 默认为普通用户
|
||
points=1 # 默认初始积分为100
|
||
)
|
||
|
||
if not success:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="注册失败,邮箱可能已被使用"
|
||
)
|
||
|
||
return {
|
||
"status": "success",
|
||
"message": "用户注册成功"
|
||
}
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"注册用户失败: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"注册用户失败: {str(e)}"
|
||
)
|
||
|
||
@router.post("/login", response_model=TokenResponse)
|
||
async def login(loginData: UserLogin, session: Session = Depends(get_db)) -> TokenResponse:
|
||
"""
|
||
用户登录
|
||
|
||
Args:
|
||
form_data: 表单数据,包含用户名(邮箱)和密码
|
||
|
||
Returns:
|
||
访问令牌和用户信息
|
||
"""
|
||
try:
|
||
# 获取数据库管理器
|
||
manager = UserManager(session)
|
||
|
||
# 获取用户信息
|
||
user = manager.get_user_by_mail(loginData.mail)
|
||
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="邮箱或密码错误",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
# 验证密码
|
||
hashed_password = hash_password(loginData.password)
|
||
|
||
# 查询用户的密码哈希
|
||
user = manager.get_user_by_mail_and_password(loginData.mail, hashed_password)
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="邮箱或密码错误",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
|
||
# 创建访问令牌,不过期
|
||
access_token = create_access_token(data={"sub": user["mail"]})
|
||
|
||
user_subscription_manager = UserSubscriptionManager(session)
|
||
user_subscription = user_subscription_manager.get_subscription_by_user_id(user["id"])
|
||
|
||
is_subscribed = False
|
||
expire_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
if user_subscription and user_subscription["expire_time"] > datetime.now():
|
||
member_name = "SVIP会员" if user_subscription["time_type"] == 2 else "VIP会员"
|
||
is_subscribed = True
|
||
expire_time = user_subscription["expire_time"].strftime("%Y-%m-%d %H:%M:%S")
|
||
else:
|
||
member_name = "普通会员"
|
||
|
||
return TokenResponse(
|
||
access_token=access_token,
|
||
token_type="bearer",
|
||
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES,
|
||
user_info=UserResponse(
|
||
id=user["id"],
|
||
mail=user["mail"],
|
||
nickname=user["nickname"],
|
||
level=user["level"],
|
||
points=user["points"],
|
||
create_time=user["create_time"],
|
||
is_subscribed=is_subscribed,
|
||
member_name=member_name,
|
||
expire_time=expire_time
|
||
)
|
||
)
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"用户登录失败: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"用户登录失败: {str(e)}"
|
||
)
|
||
|
||
@router.get("/me", response_model=UserResponse)
|
||
async def get_user_info(current_user: Dict[str, Any] = Depends(get_current_user), session: Session = Depends(get_db)) -> UserResponse:
|
||
"""
|
||
获取当前登录用户信息
|
||
|
||
Args:
|
||
current_user: 当前用户信息,由依赖项提供
|
||
|
||
Returns:
|
||
用户信息
|
||
"""
|
||
|
||
user_subscription_manager = UserSubscriptionManager(session)
|
||
user_subscription = user_subscription_manager.get_subscription_by_user_id(current_user["id"])
|
||
|
||
is_subscribed = False
|
||
expire_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
if user_subscription and user_subscription["expire_time"] > datetime.now():
|
||
member_name = "SVIP会员" if user_subscription["time_type"] == 2 else "VIP会员"
|
||
is_subscribed = True
|
||
expire_time = user_subscription["expire_time"].strftime("%Y-%m-%d %H:%M:%S")
|
||
else:
|
||
member_name = "普通会员"
|
||
|
||
|
||
user = UserResponse(
|
||
id=current_user["id"],
|
||
mail=current_user["mail"],
|
||
nickname=current_user["nickname"],
|
||
level=current_user["level"],
|
||
points=current_user["points"],
|
||
create_time=current_user["create_time"],
|
||
is_subscribed=is_subscribed,
|
||
member_name=member_name
|
||
)
|
||
|
||
if is_subscribed and expire_time:
|
||
user.expire_time = expire_time
|
||
|
||
return user
|
||
|
||
@router.put("/level/{user_id}", response_model=Dict[str, Any])
|
||
async def update_user_level(
|
||
user_id: int,
|
||
level: int = Query(..., description="用户级别(0=普通用户,1=VIP,2=SVIP)"),
|
||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||
session: Session = Depends(get_db)
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
更新用户级别(需要管理员权限)
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
level: 新的用户级别
|
||
current_user: 当前用户信息
|
||
|
||
Returns:
|
||
更新结果
|
||
"""
|
||
# 检查权限(只有SVIP用户才能更新用户级别)
|
||
if current_user.get("level", 0) < 2:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="没有足够的权限执行此操作"
|
||
)
|
||
|
||
# 获取数据库管理器
|
||
manager = UserManager(session)
|
||
|
||
# 更新用户级别
|
||
success = manager.update_user_level(user_id, level)
|
||
|
||
if not success:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="更新用户级别失败"
|
||
)
|
||
|
||
return {
|
||
"status": "success",
|
||
"message": f"成功更新用户 {user_id} 的级别为 {level}"
|
||
}
|
||
|
||
@router.get("/points/{user_id}", response_model=Dict[str, Any])
|
||
async def get_user_points(
|
||
user_id: int,
|
||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||
session: Session = Depends(get_db)
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
获取用户积分
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
current_user: 当前用户信息
|
||
|
||
Returns:
|
||
用户积分信息
|
||
"""
|
||
# 只能查看自己的积分,或者SVIP用户可以查看所有人的积分
|
||
if current_user.get("id") != user_id and current_user.get("level", 0) < 2:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="没有权限查看其他用户的积分"
|
||
)
|
||
|
||
# 获取数据库管理器
|
||
manager = UserManager(session)
|
||
|
||
# 获取用户信息
|
||
user = manager.get_user_by_id(user_id)
|
||
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"用户ID {user_id} 不存在"
|
||
)
|
||
|
||
return {
|
||
"user_id": user_id,
|
||
"points": user.get("points", 0),
|
||
"nickname": user.get("nickname", ""),
|
||
"level": user.get("level", 0)
|
||
}
|
||
|
||
@router.post("/points/add/{user_id}", response_model=Dict[str, Any])
|
||
async def add_user_points(
|
||
user_id: int,
|
||
points: int = Query(..., gt=0, description="增加的积分数量(必须大于0)"),
|
||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||
session: Session = Depends(get_db)
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
为用户增加积分(需要管理员权限)
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
points: 增加的积分数量
|
||
current_user: 当前用户信息
|
||
|
||
Returns:
|
||
操作结果
|
||
"""
|
||
# 检查权限(只有SVIP用户才能添加积分)
|
||
if current_user.get("level", 0) < 2:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="没有足够的权限执行此操作"
|
||
)
|
||
|
||
# 获取数据库管理器
|
||
manager = UserManager(session)
|
||
|
||
# 添加积分
|
||
success = manager.add_user_points(user_id, points)
|
||
|
||
if not success:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="添加积分失败"
|
||
)
|
||
|
||
# 获取更新后的用户信息
|
||
user = manager.get_user_by_id(user_id)
|
||
|
||
return {
|
||
"status": "success",
|
||
"message": f"成功为用户 {user_id} 增加 {points} 积分",
|
||
"current_points": user.get("points", 0)
|
||
}
|
||
|
||
@router.post("/points/consume/{user_id}", response_model=Dict[str, Any])
|
||
async def consume_user_points(
|
||
user_id: int,
|
||
points: int = Query(..., gt=0, description="消费的积分数量(必须大于0)"),
|
||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||
session: Session = Depends(get_db)
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
用户消费积分
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
points: 消费的积分数量
|
||
current_user: 当前用户信息
|
||
|
||
Returns:
|
||
操作结果
|
||
"""
|
||
# 只能消费自己的积分,或者SVIP用户可以操作所有人的积分
|
||
if current_user.get("id") != user_id and current_user.get("level", 0) < 2:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="没有权限消费其他用户的积分"
|
||
)
|
||
|
||
# 获取数据库管理器
|
||
manager = UserManager(session)
|
||
|
||
# 消费积分
|
||
success = manager.consume_user_points(user_id, points)
|
||
|
||
if not success:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail="积分消费失败,可能是积分不足"
|
||
)
|
||
|
||
# 获取更新后的用户信息
|
||
user = manager.get_user_by_id(user_id)
|
||
|
||
return {
|
||
"status": "success",
|
||
"message": f"成功消费 {points} 积分",
|
||
"remaining_points": user.get("points", 0)
|
||
} |