stock-ai-agent/backend/app/middleware/auth_middleware.py
2026-02-04 14:56:03 +08:00

104 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
JWT认证中间件
"""
from fastapi import Request, HTTPException, Depends
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.orm import Session
from app.services.jwt_service import jwt_service
from app.models.database import User
from app.services.db_service import db_service
from app.utils.logger import logger
security = HTTPBearer()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security)
) -> User:
"""
获取当前登录用户
从JWT token中解析用户信息并返回User对象
Args:
credentials: HTTP Bearer认证凭据
Returns:
User对象
Raises:
HTTPException: 认证失败时抛出401异常
"""
try:
token = credentials.credentials
logger.info(f"收到认证请求token前10位: {token[:10] if token else 'None'}")
# 验证token
payload = jwt_service.verify_token(token)
user_id = int(payload.get("sub"))
# 从数据库查询用户
db = db_service.get_session()
try:
user = db.query(User).filter(
User.id == user_id,
User.is_active == True
).first()
if not user:
logger.warning(f"用户不存在或已禁用: user_id={user_id}")
raise HTTPException(
status_code=401,
detail="用户不存在或已禁用",
headers={"WWW-Authenticate": "Bearer"}
)
logger.info(f"认证成功: user_id={user.id}, phone={user.phone}")
return user
finally:
db.close()
except ValueError as e:
logger.warning(f"Token验证失败: {e}")
raise HTTPException(
status_code=401,
detail=str(e),
headers={"WWW-Authenticate": "Bearer"}
)
except Exception as e:
logger.error(f"认证异常: {e}")
raise HTTPException(
status_code=401,
detail="认证失败",
headers={"WWW-Authenticate": "Bearer"}
)
def get_client_ip(request: Request) -> str:
"""
获取客户端IP地址
Args:
request: FastAPI请求对象
Returns:
客户端IP地址
"""
# 优先从X-Forwarded-For获取代理/负载均衡场景)
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
# X-Forwarded-For可能包含多个IP取第一个
return forwarded.split(",")[0].strip()
# 从X-Real-IP获取
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# 直接从client获取
if request.client:
return request.client.host
return "unknown"