390 lines
12 KiB
Python
390 lines
12 KiB
Python
from fastapi import APIRouter, HTTPException, Depends, Response, Body
|
|
from sqlalchemy.orm import Session
|
|
from app.models.user import UserLogin, UserInfo, ResetPasswordRequest,PhoneLoginRequest,VerifyCodeRequest, UserDB, UserUpdate, UserRole, UserPasswordLogin, ReferralUserInfo, generate_user_code
|
|
from app.models.coupon import CouponDB, UserCouponDB
|
|
from app.api.deps import get_current_user, get_admin_user
|
|
from app.models.database import get_db
|
|
import random
|
|
import string
|
|
import redis
|
|
from app.core.config import settings
|
|
from unisdk.sms import UniSMS
|
|
from unisdk.exception import UniException
|
|
from datetime import timedelta
|
|
from app.core.security import create_access_token, set_jwt_cookie, clear_jwt_cookie, get_password_hash, verify_password
|
|
from app.core.response import success_response, error_response, ResponseModel
|
|
from pydantic import BaseModel, Field
|
|
from typing import List
|
|
from typing import Optional
|
|
from datetime import datetime
|
|
|
|
router = APIRouter()
|
|
|
|
# Redis 连接
|
|
redis_client = redis.Redis(
|
|
host=settings.REDIS_HOST,
|
|
port=settings.REDIS_PORT,
|
|
db=settings.REDIS_DB,
|
|
password=settings.REDIS_PASSWORD,
|
|
decode_responses=True
|
|
)
|
|
|
|
# 初始化短信客户端
|
|
client = UniSMS(settings.UNI_APP_ID)
|
|
|
|
@router.post("/send-code")
|
|
async def send_verify_code(request: VerifyCodeRequest):
|
|
"""发送验证码"""
|
|
phone = request.phone
|
|
code = ''.join(random.choices(string.digits, k=6))
|
|
|
|
try:
|
|
# 发送短信
|
|
res = client.send({
|
|
"to": phone,
|
|
"signature": settings.UNI_SMS_SIGN,
|
|
"templateId": settings.UNI_SMS_TEMPLATE_ID,
|
|
"templateData": {
|
|
"code": code
|
|
}
|
|
})
|
|
|
|
if res.code != "0":
|
|
return error_response(message=f"短信发送失败: {res.message}")
|
|
|
|
# 存储验证码到 Redis
|
|
redis_client.setex(
|
|
f"verify_code:{phone}",
|
|
settings.VERIFICATION_CODE_EXPIRE_SECONDS,
|
|
code
|
|
)
|
|
|
|
return success_response(message="验证码已发送")
|
|
|
|
except UniException as e:
|
|
return error_response(message=f"发送验证码失败: {str(e)}")
|
|
|
|
@router.post("/login")
|
|
async def login(
|
|
user_login: UserLogin,
|
|
db: Session = Depends(get_db),
|
|
response: Response = None
|
|
):
|
|
"""用户登录"""
|
|
phone = user_login.phone
|
|
verify_code = user_login.verify_code
|
|
|
|
# 验证验证码
|
|
stored_code = redis_client.get(f"verify_code:{phone}")
|
|
if not stored_code or stored_code != verify_code:
|
|
return error_response(message="验证码错误或已过期")
|
|
|
|
redis_client.delete(f"verify_code:{phone}")
|
|
|
|
# 查找或创建用户
|
|
user = db.query(UserDB).filter(UserDB.phone == phone).first()
|
|
if not user:
|
|
# 生成用户编码
|
|
user_code = generate_user_code(db)
|
|
|
|
user = UserDB(
|
|
username=f"user_{phone[-4:]}",
|
|
phone=phone,
|
|
user_code=user_code,
|
|
referral_code=user_login.referral_code,
|
|
roles=[UserRole.USER]
|
|
)
|
|
db.add(user)
|
|
db.flush() # 获取用户ID
|
|
|
|
# 发放优惠券
|
|
issue_register_coupons(db, user.userid)
|
|
|
|
db.commit()
|
|
db.refresh(user)
|
|
|
|
# 创建访问令牌
|
|
access_token = create_access_token(
|
|
data={"phone": user.phone,"userid":user.userid}
|
|
)
|
|
|
|
# 设置JWT cookie
|
|
if response:
|
|
set_jwt_cookie(response, access_token)
|
|
|
|
return success_response(
|
|
message="登录成功",
|
|
data={
|
|
"user": UserInfo.model_validate(user),
|
|
"access_token": access_token,
|
|
"token_type": "bearer"
|
|
}
|
|
)
|
|
|
|
@router.get("/info", response_model=ResponseModel)
|
|
async def get_user_info(
|
|
current_user: UserDB = Depends(get_current_user)
|
|
):
|
|
"""获取用户信息"""
|
|
return success_response(data=UserInfo.model_validate(current_user))
|
|
|
|
@router.post("/phone-login", response_model=ResponseModel)
|
|
async def phone_login(
|
|
request: PhoneLoginRequest,
|
|
db: Session = Depends(get_db),
|
|
response: Response = None
|
|
):
|
|
""" 手机号登录(测试环境) """
|
|
|
|
# 查找或创建用户
|
|
user = db.query(UserDB).filter(UserDB.phone == request.phone).first()
|
|
if not user:
|
|
# 生成用户编码
|
|
user_code = generate_user_code(db)
|
|
|
|
user = UserDB(
|
|
username=f"user_{request.phone[-4:]}",
|
|
phone=request.phone,
|
|
user_code=user_code,
|
|
referral_code=request.referral_code,
|
|
roles=[UserRole.USER]
|
|
)
|
|
db.add(user)
|
|
db.flush()
|
|
|
|
# 发放优惠券
|
|
issue_register_coupons(db, user.userid)
|
|
|
|
db.commit()
|
|
db.refresh(user)
|
|
|
|
# 创建访问令牌
|
|
access_token = create_access_token(
|
|
data={"phone": user.phone,"userid":user.userid}
|
|
)
|
|
|
|
# 设置JWT cookie
|
|
if response:
|
|
set_jwt_cookie(response, access_token)
|
|
|
|
return success_response(
|
|
message="登录成功",
|
|
data={
|
|
"user": UserInfo.model_validate(user),
|
|
"access_token": access_token,
|
|
"token_type": "bearer"
|
|
}
|
|
)
|
|
|
|
@router.post("/logout", response_model=ResponseModel)
|
|
async def logout(
|
|
response: Response,
|
|
current_user: UserDB = Depends(get_current_user)
|
|
):
|
|
"""退出登录"""
|
|
clear_jwt_cookie(response)
|
|
return success_response(message="退出登录成功")
|
|
|
|
@router.put("/update", response_model=ResponseModel)
|
|
async def update_user_info(
|
|
update_data: UserUpdate,
|
|
db: Session = Depends(get_db),
|
|
current_user: UserDB = Depends(get_current_user)
|
|
):
|
|
"""更新用户信息"""
|
|
# 获取非空的更新字段
|
|
update_fields = update_data.model_dump(exclude_unset=True)
|
|
|
|
if not update_fields:
|
|
return error_response(code=400, message="没有提供要更新的字段")
|
|
|
|
# 更新字段
|
|
for field, value in update_fields.items():
|
|
setattr(current_user, field, value)
|
|
|
|
try:
|
|
db.commit()
|
|
db.refresh(current_user)
|
|
return success_response(
|
|
message="用户信息更新成功",
|
|
data=UserInfo.model_validate(current_user)
|
|
)
|
|
except Exception as e:
|
|
db.rollback()
|
|
return error_response(code=500, message=f"更新失败: {str(e)}")
|
|
|
|
@router.put("/roles", response_model=ResponseModel)
|
|
async def update_user_roles(
|
|
user_id: int,
|
|
roles: List[UserRole],
|
|
db: Session = Depends(get_db),
|
|
admin: UserDB = Depends(get_admin_user)
|
|
):
|
|
"""更新用户角色(管理员)"""
|
|
user = db.query(UserDB).filter(UserDB.userid == user_id).first()
|
|
if not user:
|
|
return error_response(code=404, message="用户不存在")
|
|
|
|
# 确保至少有一个角色
|
|
if not roles:
|
|
return error_response(code=400, message="用户必须至少有一个角色")
|
|
|
|
# 确保普通用户角色始终存在
|
|
if UserRole.USER not in roles:
|
|
roles.append(UserRole.USER)
|
|
|
|
# 更新角色
|
|
user.roles = list(set(roles)) # 去重
|
|
|
|
try:
|
|
db.commit()
|
|
db.refresh(user)
|
|
return success_response(
|
|
message="用户角色更新成功",
|
|
data=UserInfo.model_validate(user)
|
|
)
|
|
except Exception as e:
|
|
db.rollback()
|
|
return error_response(code=500, message=f"更新失败: {str(e)}")
|
|
|
|
@router.post("/password-login", response_model=ResponseModel)
|
|
async def password_login(
|
|
login_data: UserPasswordLogin,
|
|
db: Session = Depends(get_db)
|
|
):
|
|
"""密码登录"""
|
|
user = db.query(UserDB).filter(UserDB.phone == login_data.phone).first()
|
|
|
|
if not user:
|
|
return error_response(code=401, message="用户不存在")
|
|
|
|
if not user.password:
|
|
return error_response(code=401, message="请先设置密码")
|
|
|
|
if not verify_password(login_data.password, user.password):
|
|
return error_response(code=401, message="密码错误")
|
|
|
|
# 生成访问令牌
|
|
access_token = create_access_token(data={"phone": user.phone,"userid":user.userid})
|
|
|
|
return success_response(
|
|
data={
|
|
"access_token": access_token,
|
|
"token_type": "bearer",
|
|
"user": UserInfo.model_validate(user)
|
|
}
|
|
)
|
|
|
|
@router.get("/referrals", response_model=ResponseModel)
|
|
async def get_referral_users(
|
|
db: Session = Depends(get_db),
|
|
current_user: UserDB = Depends(get_current_user)
|
|
):
|
|
"""获取我邀请的用户列表"""
|
|
referral_users = db.query(UserDB).filter(
|
|
UserDB.referral_code == current_user.user_code
|
|
).order_by(
|
|
UserDB.create_time.desc()
|
|
).all()
|
|
|
|
# 处理手机号脱敏
|
|
def mask_phone(phone: str) -> str:
|
|
return f"{phone[:3]}****{phone[7:]}"
|
|
|
|
return success_response(data=[
|
|
ReferralUserInfo(
|
|
username=user.username,
|
|
phone=mask_phone(user.phone),
|
|
create_time=user.create_time
|
|
) for user in referral_users
|
|
])
|
|
|
|
def issue_register_coupons(db: Session, user_id: int):
|
|
"""发放注册优惠券
|
|
Args:
|
|
db: 数据库会话
|
|
user_id: 用户ID
|
|
"""
|
|
register_coupons = settings.REGISTER_COUPONS
|
|
|
|
# 设置过期时间为3个月后
|
|
expire_time = datetime.now() + timedelta(days=90)
|
|
|
|
for config in register_coupons:
|
|
coupon = db.query(CouponDB).filter(
|
|
CouponDB.id == config["coupon_id"]
|
|
).first()
|
|
|
|
if coupon:
|
|
for _ in range(config["count"]):
|
|
user_coupon = UserCouponDB(
|
|
user_id=user_id,
|
|
coupon_id=coupon.id,
|
|
coupon_name=coupon.name,
|
|
coupon_amount=coupon.amount,
|
|
expire_time=expire_time,
|
|
status="unused"
|
|
)
|
|
db.add(user_coupon)
|
|
|
|
@router.get("/list", response_model=ResponseModel)
|
|
async def get_user_list(
|
|
skip: int = 0,
|
|
limit: int = 10,
|
|
db: Session = Depends(get_db),
|
|
admin: UserDB = Depends(get_admin_user) # 仅管理员可访问
|
|
):
|
|
"""获取用户列表(管理员)
|
|
Args:
|
|
skip: 跳过记录数
|
|
limit: 返回记录数
|
|
"""
|
|
total = db.query(UserDB).count()
|
|
users = db.query(UserDB).order_by(
|
|
UserDB.create_time.desc()
|
|
).offset(skip).limit(limit).all()
|
|
|
|
# 处理手机号脱敏
|
|
def mask_phone(phone: str) -> str:
|
|
return f"{phone[:3]}****{phone[7:]}"
|
|
|
|
user_list = []
|
|
for user in users:
|
|
user_info = UserInfo.model_validate(user)
|
|
user_info.phone = mask_phone(user_info.phone) # 手机号脱敏
|
|
user_list.append(user_info)
|
|
|
|
return success_response(data={
|
|
"total": total,
|
|
"items": user_list
|
|
})
|
|
|
|
@router.post("/reset-password", response_model=ResponseModel)
|
|
async def reset_password(
|
|
request: ResetPasswordRequest,
|
|
db: Session = Depends(get_db),
|
|
admin: UserDB = Depends(get_admin_user) # 仅管理员可操作
|
|
):
|
|
"""重置用户密码(管理员)"""
|
|
# 查找用户
|
|
user = db.query(UserDB).filter(UserDB.userid == request.user_id).first()
|
|
if not user:
|
|
return error_response(code=404, message="用户不存在")
|
|
|
|
# 重置密码
|
|
hashed_password = get_password_hash(request.new_password)
|
|
user.password = hashed_password
|
|
|
|
try:
|
|
db.commit()
|
|
return success_response(
|
|
message="密码重置成功",
|
|
data={
|
|
"userid": user.userid,
|
|
"username": user.username,
|
|
"phone": f"{user.phone[:3]}****{user.phone[7:]}" # 手机号脱敏
|
|
}
|
|
)
|
|
except Exception as e:
|
|
db.rollback()
|
|
return error_response(code=500, message=f"密码重置失败: {str(e)}") |