增加一个 logger 中间件。

This commit is contained in:
aaron 2025-01-13 19:01:39 +08:00
parent 2f2b54f092
commit e4fbe74f8c
6 changed files with 154 additions and 9 deletions

View File

@ -112,7 +112,7 @@ async def login(
# 创建访问令牌
access_token = create_access_token(
data={"sub": user.phone}
data={"phone": user.phone,"userid":user.userid}
)
# 设置JWT cookie
@ -168,7 +168,7 @@ async def mock_login(
# 创建访问令牌
access_token = create_access_token(
data={"sub": user.phone}
data={"phone": user.phone,"userid":user.userid}
)
# 设置JWT cookie
@ -273,7 +273,7 @@ async def password_login(
return error_response(code=401, message="密码错误")
# 生成访问令牌
access_token = create_access_token(data={"sub": user.phone})
access_token = create_access_token(data={"phone": user.phone,"userid":user.userid})
return success_response(
data={

22
app/core/logger.py Normal file
View File

@ -0,0 +1,22 @@
from typing import Dict, Any
from app.models.database import SessionLocal
from app.models.request_log import RequestLogDB
import json
from threading import Thread
def save_request_log(log_data: Dict[str, Any]):
"""保存请求日志到数据库"""
db = SessionLocal()
try:
log = RequestLogDB(**log_data)
db.add(log)
db.commit()
except Exception as e:
db.rollback()
print(f"保存日志失败: {str(e)}")
finally:
db.close()
def log_request_async(log_data: Dict[str, Any]):
"""在新线程中异步处理日志"""
Thread(target=save_request_log, args=(log_data,), daemon=True).start()

View File

@ -10,9 +10,12 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
to_encode = data.copy()
if expires_delta is not None:
expire = datetime.now(timezone.utc) + expires_delta
to_encode.update({"exp": expire})
# 更新 token 数据,不设置过期时间
to_encode.update({
"userid": data.get("userid"),
"phone": data.get("phone")
})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256")
return encoded_jwt
@ -41,12 +44,12 @@ def clear_jwt_cookie(response: Response):
def verify_token(token: str) -> Optional[str]:
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
phone: str = payload.get("sub")
phone: str = payload.get("phone")
if phone is None:
return None
return phone
except JWTError:
return None
return None
def get_password_hash(password: str) -> str:
"""获取密码哈希值"""
@ -54,4 +57,15 @@ def get_password_hash(password: str) -> str:
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
return pwd_context.verify(plain_password, hashed_password)
return pwd_context.verify(plain_password, hashed_password)
def decode_jwt(token: str) -> dict:
"""解码 JWT token 获取完整信息"""
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
return {
"userid": payload.get("userid"),
"phone": payload.get("phone")
}
except:
return None

View File

@ -6,6 +6,7 @@ from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.core.response import error_response
from fastapi import HTTPException
from app.middleware.request_logger import RequestLoggerMiddleware
# 创建数据库表
Base.metadata.create_all(bind=engine)
@ -25,6 +26,9 @@ app.add_middleware(
allow_headers=["*"],
)
# 添加请求日志中间件
app.add_middleware(RequestLoggerMiddleware)
# 添加用户路由
app.include_router(user.router, prefix="/api/user", tags=["用户"])
app.include_router(address.router, prefix="/api/address", tags=["配送地址"])

View File

@ -0,0 +1,86 @@
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
import time
from app.core.logger import log_request_async
from app.core.security import verify_token, decode_jwt
import json
import copy
class RequestLoggerMiddleware(BaseHTTPMiddleware):
LOGGED_METHODS = {"POST", "PUT", "DELETE"}
SENSITIVE_PATHS = {
"/api/user/login",
"/api/user/password-login",
"/api/user/reset-password"
}
SENSITIVE_FIELDS = {"password", "verify_code", "old_password", "new_password"}
def filter_sensitive_data(self, data: dict, path: str) -> dict:
"""过滤敏感数据"""
if not data or path not in self.SENSITIVE_PATHS:
return data
filtered_data = copy.deepcopy(data)
for field in self.SENSITIVE_FIELDS:
if field in filtered_data:
filtered_data[field] = "***"
return filtered_data
async def dispatch(self, request: Request, call_next):
method = request.method
if method not in self.LOGGED_METHODS:
return await call_next(request)
start_time = time.time()
path = request.url.path
headers = dict(request.headers)
query_params = dict(request.query_params)
# 获取并过滤请求体
body = None
try:
body_bytes = await request.body()
if body_bytes:
body = json.loads(body_bytes)
body = self.filter_sensitive_data(body, path)
except:
pass
# 从 Authorization 头获取 token
token = None
auth_header = headers.get('authorization')
if auth_header and auth_header.startswith('Bearer '):
token = auth_header.split(' ')[1]
# 从 token 获取用户信息
user_id = None
if token:
try:
payload = decode_jwt(token)
if payload:
user_id = payload.get("userid")
except:
pass
# 处理请求
response = await call_next(request)
# 计算响应时间
response_time = int((time.time() - start_time) * 1000)
# 异步记录日志
log_data = {
"path": path,
"method": method,
"headers": headers,
"query_params": query_params,
"body": body,
"user_id": user_id,
"ip_address": request.client.host,
"status_code": response.status_code,
"response_time": response_time
}
log_request_async(log_data)
return response

19
app/models/request_log.py Normal file
View File

@ -0,0 +1,19 @@
from sqlalchemy import Column, String, Integer, DateTime, JSON
from sqlalchemy.sql import func
from .database import Base
class RequestLogDB(Base):
__tablename__ = "request_logs"
id = Column(Integer, primary_key=True, autoincrement=True)
path = Column(String(200), nullable=False)
method = Column(String(10), nullable=False)
headers = Column(JSON)
query_params = Column(JSON)
body = Column(JSON, nullable=True)
user_id = Column(Integer, nullable=True)
user_info = Column(JSON, nullable=True) # 添加用户详细信息
ip_address = Column(String(50))
status_code = Column(Integer)
response_time = Column(Integer) # 毫秒
create_time = Column(DateTime(timezone=True), server_default=func.now())