diff --git a/app/api/endpoints/user.py b/app/api/endpoints/user.py index 69c9c74..4155a13 100644 --- a/app/api/endpoints/user.py +++ b/app/api/endpoints/user.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, HTTPException, Depends, Response, Body from sqlalchemy.orm import Session -from app.models.user import UserLogin, UserInfo, VerifyCodeRequest, UserDB, UserUpdate, UserRole, UserPasswordLogin +from app.models.user import UserLogin, UserInfo, VerifyCodeRequest, UserDB, UserUpdate, UserRole, UserPasswordLogin, ReferralUserInfo from app.api.deps import get_current_user, get_admin_user from app.models.database import get_db import random @@ -254,4 +254,28 @@ async def password_login( "access_token": f"Bearer {access_token}", "user": UserInfo.model_validate(user) } - ) \ No newline at end of file + ) + +@router.get("/referrals", response_model=ResponseModel) +async def get_referral_users( + db: Session = Depends(get_db), + current_user: UserDB = Depends(get_current_user) +): + """获取我邀请的用户列表""" + referral_users = db.query(UserDB).filter( + UserDB.referral_code == current_user.user_code + ).order_by( + UserDB.create_time.desc() + ).all() + + # 处理手机号脱敏 + def mask_phone(phone: str) -> str: + return f"{phone[:3]}****{phone[7:]}" + + return success_response(data=[ + ReferralUserInfo( + username=user.username, + phone=mask_phone(user.phone), + create_time=user.create_time + ) for user in referral_users + ]) \ No newline at end of file diff --git a/app/models/user.py b/app/models/user.py index c671025..76f470e 100644 --- a/app/models/user.py +++ b/app/models/user.py @@ -1,10 +1,13 @@ -from sqlalchemy import Column, String, DateTime, Integer, Boolean, Enum +from sqlalchemy import Column, String, DateTime, Integer, Boolean, Enum, ForeignKey from sqlalchemy.sql import func from sqlalchemy.dialects.mysql import JSON from pydantic import BaseModel, Field from .database import Base -from typing import Optional, List +from typing import Optional, List, Union +from datetime import datetime import enum +import random +import string class UserRole(str, enum.Enum): USER = "user" @@ -12,9 +15,9 @@ class UserRole(str, enum.Enum): ADMIN = "admin" class Gender(str, enum.Enum): - MALE = "male" - FEMALE = "female" - UNKNOWN = "unknown" + MALE = "MALE" + FEMALE = "FEMALE" + UNKNOWN = "UNKNOWN" # 数据库模型 class UserDB(Base): @@ -23,6 +26,8 @@ class UserDB(Base): userid = Column(Integer, primary_key=True,autoincrement=True, index=True) username = Column(String(50)) phone = Column(String(11), unique=True, index=True) + user_code = Column(String(6), unique=True, nullable=False) + referral_code = Column(String(6), ForeignKey("users.user_code"), nullable=True) password = Column(String(128), nullable=True) # 加密后的密码 avatar = Column(String(200), nullable=True) # 头像URL地址 gender = Column(Enum(Gender), nullable=False, default=Gender.UNKNOWN) @@ -34,11 +39,14 @@ class UserDB(Base): class UserLogin(BaseModel): phone: str = Field(..., pattern="^1[3-9]\d{9}$") verify_code: str = Field(..., min_length=6, max_length=6) + referral_code: Optional[str] = Field(None, min_length=6, max_length=6) class UserInfo(BaseModel): userid: int username: str phone: str + user_code: str + referral_code: Optional[str] avatar: Optional[str] = None gender: Gender roles: List[UserRole] @@ -59,4 +67,22 @@ class UserUpdate(BaseModel): class UserPasswordLogin(BaseModel): phone: str = Field(..., pattern="^1[3-9]\d{9}$") - password: str = Field(..., min_length=6, max_length=20) \ No newline at end of file + password: str = Field(..., min_length=6, max_length=20) + +def generate_user_code() -> str: + """生成6位大写字母+数字的用户编码""" + chars = string.ascii_uppercase + string.digits + while True: + code = ''.join(random.choices(chars, k=6)) + # 检查是否已存在 + exists = UserDB.query.filter_by(user_code=code).first() + if not exists: + return code + +class ReferralUserInfo(BaseModel): + username: str + phone: str # 会在API中处理脱敏 + create_time: datetime + + class Config: + from_attributes = True \ No newline at end of file