This commit is contained in:
aaron 2025-05-24 15:33:38 +08:00
parent 1debbc7dce
commit ea94081617
3 changed files with 73 additions and 17 deletions

View File

@ -3,7 +3,7 @@
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from datetime import datetime from datetime import datetime
from fastapi import HTTPException, status
from sqlalchemy import Column, Integer, String, DateTime, Index from sqlalchemy import Column, Integer, String, DateTime, Index
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
@ -127,7 +127,7 @@ class UserManager:
logger.error(f"获取用户信息失败: {e}") logger.error(f"获取用户信息失败: {e}")
return None return None
def login(self, mail: str, password: str) -> Optional[Dict[str, Any]]: def get_user_by_mail_and_password(self, mail: str, password: str) -> Optional[Dict[str, Any]]:
""" """
登录 登录
""" """
@ -172,6 +172,28 @@ class UserManager:
except Exception as e: except Exception as e:
logger.error(f"获取用户信息失败: {e}") logger.error(f"获取用户信息失败: {e}")
return None return None
def update_password(self, mail: str, password: str) -> bool:
"""
更新用户密码
"""
try:
user = self.session.query(User).filter(User.mail == mail).first()
if not user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="用户不存在"
)
user.password = password
self.session.commit()
return True
except Exception as e:
logger.error(f"更新用户密码失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"更新用户密码失败: {e}"
)
def update_user_level(self, user_id: int, level: int) -> bool: def update_user_level(self, user_id: int, level: int) -> bool:
""" """

View File

@ -29,7 +29,7 @@ router = APIRouter()
# JWT配置 # JWT配置
JWT_SECRET_KEY = "FX9Rf7YpvTRdXWmj0Osx8P9smSkUh6fW" # 实际应用中应该从环境变量或配置文件中获取 JWT_SECRET_KEY = "FX9Rf7YpvTRdXWmj0Osx8P9smSkUh6fW" # 实际应用中应该从环境变量或配置文件中获取
JWT_ALGORITHM = "HS256" JWT_ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 * 7 # 7天过期 ACCESS_TOKEN_EXPIRE_MINUTES = 0 # 用户登录后不过期
# 请求模型 # 请求模型
class UserRegister(BaseModel): class UserRegister(BaseModel):
@ -48,6 +48,12 @@ class SendVerificationCodeRequest(BaseModel):
"""发送验证码请求模型""" """发送验证码请求模型"""
mail: EmailStr mail: EmailStr
class ResetPasswordRequest(BaseModel):
"""重置密码请求模型"""
mail: EmailStr
verification_code: str
new_password: str
# 响应模型 # 响应模型
class UserResponse(BaseModel): class UserResponse(BaseModel):
"""用户信息响应模型""" """用户信息响应模型"""
@ -121,17 +127,6 @@ async def send_verification_code(request: SendVerificationCodeRequest) -> Dict[s
发送结果 发送结果
""" """
try: try:
# 获取数据库管理器
db_manager = get_db_manager()
# 检查邮箱是否已被注册
user = db_manager.user_manager.get_user_by_mail(request.mail)
if user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="该邮箱已被注册"
)
# 获取邮件服务 # 获取邮件服务
email_service = get_email_service() email_service = get_email_service()
@ -157,6 +152,45 @@ async def send_verification_code(request: SendVerificationCodeRequest) -> Dict[s
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"发送验证码失败: {str(e)}" detail=f"发送验证码失败: {str(e)}"
) )
@router.put("/reset_password", response_model=Dict[str, Any])
async def reset_password(request: ResetPasswordRequest) -> Dict[str, Any]:
"""
修改密码
"""
try:
# 获取数据库管理器
db_manager = get_db_manager()
# 验证验证码
email_service = get_email_service()
if not email_service.verify_code(request.mail, request.verification_code):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="验证码错误或已过期"
)
# 更新密码
hashed_password = hash_password(request.new_password)
success = db_manager.user_manager.update_password(request.mail, hashed_password)
if not success:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="密码修改失败"
)
return {
"status": "success",
"message": "密码修改成功"
}
except Exception as e:
logger.error(f"修改密码失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"修改密码失败: {str(e)}"
)
@router.post("/register", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED) @router.post("/register", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
async def register_user(user: UserRegister) -> Dict[str, Any]: async def register_user(user: UserRegister) -> Dict[str, Any]:
@ -187,7 +221,7 @@ async def register_user(user: UserRegister) -> Dict[str, Any]:
hashed_password = hash_password(user.password) hashed_password = hash_password(user.password)
# 注册用户 # 注册用户
success = db_manager.register_user( success = db_manager.user_manager.register_user(
mail=user.mail, mail=user.mail,
nickname=user.nickname, nickname=user.nickname,
password=hashed_password, password=hashed_password,
@ -246,7 +280,7 @@ async def login(loginData: UserLogin) -> TokenResponse:
# 查询用户的密码哈希 # 查询用户的密码哈希
session = db_manager.Session() session = db_manager.Session()
try: try:
user = db_manager.user_manager.login(loginData.mail, hashed_password) user = db_manager.user_manager.get_user_by_mail_and_password(loginData.mail, hashed_password)
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,

View File

@ -29,7 +29,7 @@ services:
cryptoai-api: cryptoai-api:
build: . build: .
container_name: cryptoai-api container_name: cryptoai-api
image: cryptoai-api:0.1.26 image: cryptoai-api:0.1.27
restart: always restart: always
ports: ports:
- "8000:8000" - "8000:8000"