This commit is contained in:
aaron 2025-03-06 15:04:10 +08:00
parent 37b255f524
commit 7bea1fabc3
4 changed files with 27 additions and 26 deletions

View File

@ -26,11 +26,11 @@ async def get_current_user(
print(f"token: {token}")
sub, phone = verify_token(token)
sub = verify_token(token)
if not sub:
raise HTTPException(status_code=401, detail="Token已过期或无效")
user = db.query(UserDB).filter(UserDB.phone == phone).first()
user = db.query(UserDB).filter(UserDB.phone == sub).first()
if not user:
raise HTTPException(status_code=401, detail="用户未登录")
return user

View File

@ -103,7 +103,7 @@ async def login(
# 创建访问令牌
access_token = create_access_token(
data={"sub": user.userid, "phone": user.phone}
data={"phone": user.phone}
)
# 设置JWT cookie
@ -203,7 +203,7 @@ async def phone_login(
# 创建访问令牌
access_token = create_access_token(
data={"phone": user.phone,"sub":user.userid}
data={"phone": user.phone}
)
# 设置JWT cookie
@ -343,7 +343,7 @@ async def password_login(
return error_response(code=401, message="配送员账户,请先设置归属小区")
# 生成访问令牌
access_token = create_access_token(data={"phone": user.phone,"sub":user.userid})
access_token = create_access_token(data={"phone": user.phone})
# 设置JWT cookie
if response:

View File

@ -10,7 +10,10 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
to_encode = data.copy()
to_encode = {
"sub": data.get("phone")
}
if expires_delta:
to_encode.update({"exp": datetime.now(timezone.utc) + expires_delta})
else:
@ -45,12 +48,11 @@ def clear_jwt_cookie(response: Response):
def verify_token(token: str) -> Optional[str]:
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
sub: str = payload.get("userid")
phone: str = payload.get("phone")
sub: str = payload.get("sub")
print(f"payload: {payload}")
return sub, phone
return sub
except JWTError:
return None, None
return None
def get_password_hash(password: str) -> str:
"""获取密码哈希值"""
@ -65,8 +67,7 @@ def decode_jwt(token: str) -> dict:
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
return {
"userid": payload.get("userid"),
"phone": payload.get("phone")
"phone": payload.get("sub"),
}
except:
return None

View File

@ -48,20 +48,20 @@ class RequestLoggerMiddleware(BaseHTTPMiddleware):
pass
# 从 Authorization 头获取 token
token = None
auth_header = headers.get('authorization')
if auth_header and auth_header.startswith('Bearer '):
token = auth_header.split(' ')[1]
# token = None
# auth_header = headers.get('authorization')
# if auth_header and auth_header.startswith('Bearer '):
# token = auth_header.split(' ')[1]
# 从 token 获取用户信息
user_id = None
if token:
try:
payload = decode_jwt(token)
if payload:
user_id = payload.get("userid")
except:
pass
# # 从 token 获取用户信息
# user_id = None
# if token:
# try:
# payload = decode_jwt(token)
# if payload:
# user_id = payload.get("phone")
# except:
# pass
# 处理请求
response = await call_next(request)
@ -76,7 +76,7 @@ class RequestLoggerMiddleware(BaseHTTPMiddleware):
"headers": headers,
"query_params": query_params,
"body": body,
"user_id": user_id,
# "user_id": user_id,
"ip_address": request.client.host,
"status_code": response.status_code,
"response_time": response_time