增加一个 logger 中间件。
This commit is contained in:
parent
2f2b54f092
commit
e4fbe74f8c
@ -112,7 +112,7 @@ async def login(
|
||||
|
||||
# 创建访问令牌
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.phone}
|
||||
data={"phone": user.phone,"userid":user.userid}
|
||||
)
|
||||
|
||||
# 设置JWT cookie
|
||||
@ -168,7 +168,7 @@ async def mock_login(
|
||||
|
||||
# 创建访问令牌
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.phone}
|
||||
data={"phone": user.phone,"userid":user.userid}
|
||||
)
|
||||
|
||||
# 设置JWT cookie
|
||||
@ -273,7 +273,7 @@ async def password_login(
|
||||
return error_response(code=401, message="密码错误")
|
||||
|
||||
# 生成访问令牌
|
||||
access_token = create_access_token(data={"sub": user.phone})
|
||||
access_token = create_access_token(data={"phone": user.phone,"userid":user.userid})
|
||||
|
||||
return success_response(
|
||||
data={
|
||||
|
||||
22
app/core/logger.py
Normal file
22
app/core/logger.py
Normal file
@ -0,0 +1,22 @@
|
||||
from typing import Dict, Any
|
||||
from app.models.database import SessionLocal
|
||||
from app.models.request_log import RequestLogDB
|
||||
import json
|
||||
from threading import Thread
|
||||
|
||||
def save_request_log(log_data: Dict[str, Any]):
|
||||
"""保存请求日志到数据库"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
log = RequestLogDB(**log_data)
|
||||
db.add(log)
|
||||
db.commit()
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"保存日志失败: {str(e)}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def log_request_async(log_data: Dict[str, Any]):
|
||||
"""在新线程中异步处理日志"""
|
||||
Thread(target=save_request_log, args=(log_data,), daemon=True).start()
|
||||
@ -10,9 +10,12 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||
to_encode = data.copy()
|
||||
if expires_delta is not None:
|
||||
expire = datetime.now(timezone.utc) + expires_delta
|
||||
to_encode.update({"exp": expire})
|
||||
|
||||
# 更新 token 数据,不设置过期时间
|
||||
to_encode.update({
|
||||
"userid": data.get("userid"),
|
||||
"phone": data.get("phone")
|
||||
})
|
||||
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256")
|
||||
return encoded_jwt
|
||||
@ -41,12 +44,12 @@ def clear_jwt_cookie(response: Response):
|
||||
def verify_token(token: str) -> Optional[str]:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
||||
phone: str = payload.get("sub")
|
||||
phone: str = payload.get("phone")
|
||||
if phone is None:
|
||||
return None
|
||||
return phone
|
||||
except JWTError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""获取密码哈希值"""
|
||||
@ -54,4 +57,15 @@ def get_password_hash(password: str) -> str:
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证密码"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
def decode_jwt(token: str) -> dict:
|
||||
"""解码 JWT token 获取完整信息"""
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
||||
return {
|
||||
"userid": payload.get("userid"),
|
||||
"phone": payload.get("phone")
|
||||
}
|
||||
except:
|
||||
return None
|
||||
@ -6,6 +6,7 @@ from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
from app.core.response import error_response
|
||||
from fastapi import HTTPException
|
||||
from app.middleware.request_logger import RequestLoggerMiddleware
|
||||
|
||||
# 创建数据库表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@ -25,6 +26,9 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 添加请求日志中间件
|
||||
app.add_middleware(RequestLoggerMiddleware)
|
||||
|
||||
# 添加用户路由
|
||||
app.include_router(user.router, prefix="/api/user", tags=["用户"])
|
||||
app.include_router(address.router, prefix="/api/address", tags=["配送地址"])
|
||||
|
||||
86
app/middleware/request_logger.py
Normal file
86
app/middleware/request_logger.py
Normal file
@ -0,0 +1,86 @@
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import time
|
||||
from app.core.logger import log_request_async
|
||||
from app.core.security import verify_token, decode_jwt
|
||||
import json
|
||||
import copy
|
||||
|
||||
class RequestLoggerMiddleware(BaseHTTPMiddleware):
|
||||
LOGGED_METHODS = {"POST", "PUT", "DELETE"}
|
||||
SENSITIVE_PATHS = {
|
||||
"/api/user/login",
|
||||
"/api/user/password-login",
|
||||
"/api/user/reset-password"
|
||||
}
|
||||
SENSITIVE_FIELDS = {"password", "verify_code", "old_password", "new_password"}
|
||||
|
||||
def filter_sensitive_data(self, data: dict, path: str) -> dict:
|
||||
"""过滤敏感数据"""
|
||||
if not data or path not in self.SENSITIVE_PATHS:
|
||||
return data
|
||||
|
||||
filtered_data = copy.deepcopy(data)
|
||||
for field in self.SENSITIVE_FIELDS:
|
||||
if field in filtered_data:
|
||||
filtered_data[field] = "***"
|
||||
return filtered_data
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
method = request.method
|
||||
|
||||
if method not in self.LOGGED_METHODS:
|
||||
return await call_next(request)
|
||||
|
||||
start_time = time.time()
|
||||
path = request.url.path
|
||||
headers = dict(request.headers)
|
||||
query_params = dict(request.query_params)
|
||||
|
||||
# 获取并过滤请求体
|
||||
body = None
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
if body_bytes:
|
||||
body = json.loads(body_bytes)
|
||||
body = self.filter_sensitive_data(body, path)
|
||||
except:
|
||||
pass
|
||||
|
||||
# 从 Authorization 头获取 token
|
||||
token = None
|
||||
auth_header = headers.get('authorization')
|
||||
if auth_header and auth_header.startswith('Bearer '):
|
||||
token = auth_header.split(' ')[1]
|
||||
|
||||
# 从 token 获取用户信息
|
||||
user_id = None
|
||||
if token:
|
||||
try:
|
||||
payload = decode_jwt(token)
|
||||
if payload:
|
||||
user_id = payload.get("userid")
|
||||
except:
|
||||
pass
|
||||
|
||||
# 处理请求
|
||||
response = await call_next(request)
|
||||
|
||||
# 计算响应时间
|
||||
response_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# 异步记录日志
|
||||
log_data = {
|
||||
"path": path,
|
||||
"method": method,
|
||||
"headers": headers,
|
||||
"query_params": query_params,
|
||||
"body": body,
|
||||
"user_id": user_id,
|
||||
"ip_address": request.client.host,
|
||||
"status_code": response.status_code,
|
||||
"response_time": response_time
|
||||
}
|
||||
log_request_async(log_data)
|
||||
|
||||
return response
|
||||
19
app/models/request_log.py
Normal file
19
app/models/request_log.py
Normal file
@ -0,0 +1,19 @@
|
||||
from sqlalchemy import Column, String, Integer, DateTime, JSON
|
||||
from sqlalchemy.sql import func
|
||||
from .database import Base
|
||||
|
||||
class RequestLogDB(Base):
|
||||
__tablename__ = "request_logs"
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
path = Column(String(200), nullable=False)
|
||||
method = Column(String(10), nullable=False)
|
||||
headers = Column(JSON)
|
||||
query_params = Column(JSON)
|
||||
body = Column(JSON, nullable=True)
|
||||
user_id = Column(Integer, nullable=True)
|
||||
user_info = Column(JSON, nullable=True) # 添加用户详细信息
|
||||
ip_address = Column(String(50))
|
||||
status_code = Column(Integer)
|
||||
response_time = Column(Integer) # 毫秒
|
||||
create_time = Column(DateTime(timezone=True), server_default=func.now())
|
||||
Loading…
Reference in New Issue
Block a user