104 lines
2.7 KiB
Python
104 lines
2.7 KiB
Python
"""
|
||
JWT认证中间件
|
||
"""
|
||
from fastapi import Request, HTTPException, Depends
|
||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||
from sqlalchemy.orm import Session
|
||
from app.services.jwt_service import jwt_service
|
||
from app.models.database import User
|
||
from app.services.db_service import db_service
|
||
from app.utils.logger import logger
|
||
|
||
security = HTTPBearer()
|
||
|
||
|
||
async def get_current_user(
|
||
credentials: HTTPAuthorizationCredentials = Depends(security)
|
||
) -> User:
|
||
"""
|
||
获取当前登录用户
|
||
|
||
从JWT token中解析用户信息并返回User对象
|
||
|
||
Args:
|
||
credentials: HTTP Bearer认证凭据
|
||
|
||
Returns:
|
||
User对象
|
||
|
||
Raises:
|
||
HTTPException: 认证失败时抛出401异常
|
||
"""
|
||
try:
|
||
token = credentials.credentials
|
||
logger.info(f"收到认证请求,token前10位: {token[:10] if token else 'None'}")
|
||
|
||
# 验证token
|
||
payload = jwt_service.verify_token(token)
|
||
user_id = int(payload.get("sub"))
|
||
|
||
# 从数据库查询用户
|
||
db = db_service.get_session()
|
||
try:
|
||
user = db.query(User).filter(
|
||
User.id == user_id,
|
||
User.is_active == True
|
||
).first()
|
||
|
||
if not user:
|
||
logger.warning(f"用户不存在或已禁用: user_id={user_id}")
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="用户不存在或已禁用",
|
||
headers={"WWW-Authenticate": "Bearer"}
|
||
)
|
||
|
||
logger.info(f"认证成功: user_id={user.id}, phone={user.phone}")
|
||
return user
|
||
|
||
finally:
|
||
db.close()
|
||
|
||
except ValueError as e:
|
||
logger.warning(f"Token验证失败: {e}")
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail=str(e),
|
||
headers={"WWW-Authenticate": "Bearer"}
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"认证异常: {e}")
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="认证失败",
|
||
headers={"WWW-Authenticate": "Bearer"}
|
||
)
|
||
|
||
|
||
def get_client_ip(request: Request) -> str:
|
||
"""
|
||
获取客户端IP地址
|
||
|
||
Args:
|
||
request: FastAPI请求对象
|
||
|
||
Returns:
|
||
客户端IP地址
|
||
"""
|
||
# 优先从X-Forwarded-For获取(代理/负载均衡场景)
|
||
forwarded = request.headers.get("X-Forwarded-For")
|
||
if forwarded:
|
||
# X-Forwarded-For可能包含多个IP,取第一个
|
||
return forwarded.split(",")[0].strip()
|
||
|
||
# 从X-Real-IP获取
|
||
real_ip = request.headers.get("X-Real-IP")
|
||
if real_ip:
|
||
return real_ip
|
||
|
||
# 直接从client获取
|
||
if request.client:
|
||
return request.client.host
|
||
|
||
return "unknown"
|