257 lines
7.5 KiB
Python
257 lines
7.5 KiB
Python
from fastapi import APIRouter, HTTPException, Depends, Response, Body
|
||
from sqlalchemy.orm import Session
|
||
from app.models.user import UserLogin, UserInfo, VerifyCodeRequest, UserDB, UserUpdate, UserRole, UserPasswordLogin
|
||
from app.api.deps import get_current_user, get_admin_user
|
||
from app.models.database import get_db
|
||
import random
|
||
import string
|
||
import redis
|
||
from app.core.config import settings
|
||
from unisdk.sms import UniSMS
|
||
from unisdk.exception import UniException
|
||
from datetime import timedelta
|
||
from app.core.security import create_access_token, set_jwt_cookie, clear_jwt_cookie, get_password_hash, verify_password
|
||
from app.core.response import success_response, error_response, ResponseModel
|
||
from pydantic import BaseModel, Field
|
||
from typing import List
|
||
|
||
router = APIRouter()
|
||
|
||
# Redis 连接
|
||
redis_client = redis.Redis(
|
||
host=settings.REDIS_HOST,
|
||
port=settings.REDIS_PORT,
|
||
db=settings.REDIS_DB,
|
||
password=settings.REDIS_PASSWORD,
|
||
decode_responses=True
|
||
)
|
||
|
||
# 初始化短信客户端
|
||
client = UniSMS(settings.UNI_APP_ID)
|
||
|
||
# 添加 Mock 登录请求模型
|
||
class MockLoginRequest(BaseModel):
|
||
phone: str = Field(..., pattern="^1[3-9]\d{9}$")
|
||
|
||
@router.post("/send-code")
|
||
async def send_verify_code(request: VerifyCodeRequest):
|
||
"""发送验证码"""
|
||
phone = request.phone
|
||
code = ''.join(random.choices(string.digits, k=6))
|
||
|
||
try:
|
||
# 发送短信
|
||
res = client.send({
|
||
"to": phone,
|
||
"signature": settings.UNI_SMS_SIGN,
|
||
"templateId": settings.UNI_SMS_TEMPLATE_ID,
|
||
"templateData": {
|
||
"code": code
|
||
}
|
||
})
|
||
|
||
if res.code != "0":
|
||
return error_response(message=f"短信发送失败: {res.message}")
|
||
|
||
# 存储验证码到 Redis
|
||
redis_client.setex(
|
||
f"verify_code:{phone}",
|
||
settings.VERIFICATION_CODE_EXPIRE_SECONDS,
|
||
code
|
||
)
|
||
|
||
return success_response(message="验证码已发送")
|
||
|
||
except UniException as e:
|
||
return error_response(message=f"发送验证码失败: {str(e)}")
|
||
|
||
@router.post("/login")
|
||
async def login(
|
||
user_login: UserLogin,
|
||
db: Session = Depends(get_db),
|
||
response: Response = None
|
||
):
|
||
"""用户登录"""
|
||
phone = user_login.phone
|
||
verify_code = user_login.verify_code
|
||
|
||
# 验证验证码
|
||
stored_code = redis_client.get(f"verify_code:{phone}")
|
||
if not stored_code or stored_code != verify_code:
|
||
return error_response(message="验证码错误或已过期")
|
||
|
||
redis_client.delete(f"verify_code:{phone}")
|
||
|
||
# 查找或创建用户
|
||
user = db.query(UserDB).filter(UserDB.phone == phone).first()
|
||
if not user:
|
||
user = UserDB(
|
||
username=f"user_{phone[-4:]}",
|
||
phone=phone,
|
||
roles=[UserRole.USER]
|
||
)
|
||
db.add(user)
|
||
db.commit()
|
||
db.refresh(user)
|
||
|
||
# 创建访问令牌
|
||
access_token = create_access_token(
|
||
data={"sub": user.phone}
|
||
)
|
||
|
||
# 设置JWT cookie
|
||
if response:
|
||
set_jwt_cookie(response, access_token)
|
||
|
||
return success_response(
|
||
message="登录成功",
|
||
data={
|
||
"user": UserInfo.model_validate(user),
|
||
"access_token": access_token,
|
||
"token_type": "bearer"
|
||
}
|
||
)
|
||
|
||
@router.get("/info", response_model=ResponseModel)
|
||
async def get_user_info(
|
||
current_user: UserDB = Depends(get_current_user)
|
||
):
|
||
"""获取用户信息"""
|
||
return success_response(data=UserInfo.model_validate(current_user))
|
||
|
||
@router.post("/mock-login", response_model=ResponseModel)
|
||
async def mock_login(
|
||
request: MockLoginRequest,
|
||
db: Session = Depends(get_db),
|
||
response: Response = None
|
||
):
|
||
"""Mock登录接口(仅用于开发测试)"""
|
||
if not settings.DEBUG:
|
||
return error_response(code=403, message="该接口仅在开发环境可用")
|
||
|
||
# 查找或创建用户
|
||
user = db.query(UserDB).filter(UserDB.phone == request.phone).first()
|
||
if not user:
|
||
user = UserDB(
|
||
username=f"user_{request.phone[-4:]}",
|
||
phone=request.phone
|
||
)
|
||
db.add(user)
|
||
db.commit()
|
||
db.refresh(user)
|
||
|
||
# 创建访问令牌
|
||
access_token = create_access_token(
|
||
data={"sub": user.phone}
|
||
)
|
||
|
||
# 设置JWT cookie
|
||
if response:
|
||
set_jwt_cookie(response, access_token)
|
||
|
||
return success_response(
|
||
message="登录成功",
|
||
data={
|
||
"user": UserInfo.model_validate(user),
|
||
"access_token": access_token,
|
||
"token_type": "bearer"
|
||
}
|
||
)
|
||
|
||
@router.post("/logout", response_model=ResponseModel)
|
||
async def logout(
|
||
response: Response,
|
||
current_user: UserDB = Depends(get_current_user)
|
||
):
|
||
"""退出登录"""
|
||
clear_jwt_cookie(response)
|
||
return success_response(message="退出登录成功")
|
||
|
||
@router.put("/update", response_model=ResponseModel)
|
||
async def update_user_info(
|
||
update_data: UserUpdate,
|
||
db: Session = Depends(get_db),
|
||
current_user: UserDB = Depends(get_current_user)
|
||
):
|
||
"""更新用户信息"""
|
||
# 获取非空的更新字段
|
||
update_fields = update_data.model_dump(exclude_unset=True)
|
||
|
||
if not update_fields:
|
||
return error_response(code=400, message="没有提供要更新的字段")
|
||
|
||
# 更新字段
|
||
for field, value in update_fields.items():
|
||
setattr(current_user, field, value)
|
||
|
||
try:
|
||
db.commit()
|
||
db.refresh(current_user)
|
||
return success_response(
|
||
message="用户信息更新成功",
|
||
data=UserInfo.model_validate(current_user)
|
||
)
|
||
except Exception as e:
|
||
db.rollback()
|
||
return error_response(code=500, message=f"更新失败: {str(e)}")
|
||
|
||
@router.put("/roles", response_model=ResponseModel)
|
||
async def update_user_roles(
|
||
user_id: int,
|
||
roles: List[UserRole],
|
||
db: Session = Depends(get_db),
|
||
admin: UserDB = Depends(get_admin_user)
|
||
):
|
||
"""更新用户角色(管理员)"""
|
||
user = db.query(UserDB).filter(UserDB.userid == user_id).first()
|
||
if not user:
|
||
return error_response(code=404, message="用户不存在")
|
||
|
||
# 确保至少有一个角色
|
||
if not roles:
|
||
return error_response(code=400, message="用户必须至少有一个角色")
|
||
|
||
# 确保普通用户角色始终存在
|
||
if UserRole.USER not in roles:
|
||
roles.append(UserRole.USER)
|
||
|
||
# 更新角色
|
||
user.roles = list(set(roles)) # 去重
|
||
|
||
try:
|
||
db.commit()
|
||
db.refresh(user)
|
||
return success_response(
|
||
message="用户角色更新成功",
|
||
data=UserInfo.model_validate(user)
|
||
)
|
||
except Exception as e:
|
||
db.rollback()
|
||
return error_response(code=500, message=f"更新失败: {str(e)}")
|
||
|
||
@router.post("/password_login", response_model=ResponseModel)
|
||
async def password_login(
|
||
login_data: UserPasswordLogin,
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""密码登录"""
|
||
user = db.query(UserDB).filter(UserDB.phone == login_data.phone).first()
|
||
|
||
if not user:
|
||
return error_response(code=401, message="用户不存在")
|
||
|
||
if not user.password:
|
||
return error_response(code=401, message="请先设置密码")
|
||
|
||
if not verify_password(login_data.password, user.password):
|
||
return error_response(code=401, message="密码错误")
|
||
|
||
# 生成访问令牌
|
||
access_token = create_access_token(user.phone)
|
||
|
||
return success_response(
|
||
data={
|
||
"access_token": f"Bearer {access_token}",
|
||
"user": UserInfo.model_validate(user)
|
||
}
|
||
) |