增加一个 logger 中间件。
This commit is contained in:
parent
2f2b54f092
commit
e4fbe74f8c
@ -112,7 +112,7 @@ async def login(
|
|||||||
|
|
||||||
# 创建访问令牌
|
# 创建访问令牌
|
||||||
access_token = create_access_token(
|
access_token = create_access_token(
|
||||||
data={"sub": user.phone}
|
data={"phone": user.phone,"userid":user.userid}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 设置JWT cookie
|
# 设置JWT cookie
|
||||||
@ -168,7 +168,7 @@ async def mock_login(
|
|||||||
|
|
||||||
# 创建访问令牌
|
# 创建访问令牌
|
||||||
access_token = create_access_token(
|
access_token = create_access_token(
|
||||||
data={"sub": user.phone}
|
data={"phone": user.phone,"userid":user.userid}
|
||||||
)
|
)
|
||||||
|
|
||||||
# 设置JWT cookie
|
# 设置JWT cookie
|
||||||
@ -273,7 +273,7 @@ async def password_login(
|
|||||||
return error_response(code=401, message="密码错误")
|
return error_response(code=401, message="密码错误")
|
||||||
|
|
||||||
# 生成访问令牌
|
# 生成访问令牌
|
||||||
access_token = create_access_token(data={"sub": user.phone})
|
access_token = create_access_token(data={"phone": user.phone,"userid":user.userid})
|
||||||
|
|
||||||
return success_response(
|
return success_response(
|
||||||
data={
|
data={
|
||||||
|
|||||||
22
app/core/logger.py
Normal file
22
app/core/logger.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from typing import Dict, Any
|
||||||
|
from app.models.database import SessionLocal
|
||||||
|
from app.models.request_log import RequestLogDB
|
||||||
|
import json
|
||||||
|
from threading import Thread
|
||||||
|
|
||||||
|
def save_request_log(log_data: Dict[str, Any]):
|
||||||
|
"""保存请求日志到数据库"""
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
log = RequestLogDB(**log_data)
|
||||||
|
db.add(log)
|
||||||
|
db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
print(f"保存日志失败: {str(e)}")
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
def log_request_async(log_data: Dict[str, Any]):
|
||||||
|
"""在新线程中异步处理日志"""
|
||||||
|
Thread(target=save_request_log, args=(log_data,), daemon=True).start()
|
||||||
@ -10,9 +10,12 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|||||||
|
|
||||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str:
|
||||||
to_encode = data.copy()
|
to_encode = data.copy()
|
||||||
if expires_delta is not None:
|
|
||||||
expire = datetime.now(timezone.utc) + expires_delta
|
# 更新 token 数据,不设置过期时间
|
||||||
to_encode.update({"exp": expire})
|
to_encode.update({
|
||||||
|
"userid": data.get("userid"),
|
||||||
|
"phone": data.get("phone")
|
||||||
|
})
|
||||||
|
|
||||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256")
|
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256")
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
@ -41,12 +44,12 @@ def clear_jwt_cookie(response: Response):
|
|||||||
def verify_token(token: str) -> Optional[str]:
|
def verify_token(token: str) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
||||||
phone: str = payload.get("sub")
|
phone: str = payload.get("phone")
|
||||||
if phone is None:
|
if phone is None:
|
||||||
return None
|
return None
|
||||||
return phone
|
return phone
|
||||||
except JWTError:
|
except JWTError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_password_hash(password: str) -> str:
|
def get_password_hash(password: str) -> str:
|
||||||
"""获取密码哈希值"""
|
"""获取密码哈希值"""
|
||||||
@ -54,4 +57,15 @@ def get_password_hash(password: str) -> str:
|
|||||||
|
|
||||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||||
"""验证密码"""
|
"""验证密码"""
|
||||||
return pwd_context.verify(plain_password, hashed_password)
|
return pwd_context.verify(plain_password, hashed_password)
|
||||||
|
|
||||||
|
def decode_jwt(token: str) -> dict:
|
||||||
|
"""解码 JWT token 获取完整信息"""
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
||||||
|
return {
|
||||||
|
"userid": payload.get("userid"),
|
||||||
|
"phone": payload.get("phone")
|
||||||
|
}
|
||||||
|
except:
|
||||||
|
return None
|
||||||
@ -6,6 +6,7 @@ from fastapi.exceptions import RequestValidationError
|
|||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from app.core.response import error_response
|
from app.core.response import error_response
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
from app.middleware.request_logger import RequestLoggerMiddleware
|
||||||
|
|
||||||
# 创建数据库表
|
# 创建数据库表
|
||||||
Base.metadata.create_all(bind=engine)
|
Base.metadata.create_all(bind=engine)
|
||||||
@ -25,6 +26,9 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 添加请求日志中间件
|
||||||
|
app.add_middleware(RequestLoggerMiddleware)
|
||||||
|
|
||||||
# 添加用户路由
|
# 添加用户路由
|
||||||
app.include_router(user.router, prefix="/api/user", tags=["用户"])
|
app.include_router(user.router, prefix="/api/user", tags=["用户"])
|
||||||
app.include_router(address.router, prefix="/api/address", tags=["配送地址"])
|
app.include_router(address.router, prefix="/api/address", tags=["配送地址"])
|
||||||
|
|||||||
86
app/middleware/request_logger.py
Normal file
86
app/middleware/request_logger.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from fastapi import Request
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
import time
|
||||||
|
from app.core.logger import log_request_async
|
||||||
|
from app.core.security import verify_token, decode_jwt
|
||||||
|
import json
|
||||||
|
import copy
|
||||||
|
|
||||||
|
class RequestLoggerMiddleware(BaseHTTPMiddleware):
|
||||||
|
LOGGED_METHODS = {"POST", "PUT", "DELETE"}
|
||||||
|
SENSITIVE_PATHS = {
|
||||||
|
"/api/user/login",
|
||||||
|
"/api/user/password-login",
|
||||||
|
"/api/user/reset-password"
|
||||||
|
}
|
||||||
|
SENSITIVE_FIELDS = {"password", "verify_code", "old_password", "new_password"}
|
||||||
|
|
||||||
|
def filter_sensitive_data(self, data: dict, path: str) -> dict:
|
||||||
|
"""过滤敏感数据"""
|
||||||
|
if not data or path not in self.SENSITIVE_PATHS:
|
||||||
|
return data
|
||||||
|
|
||||||
|
filtered_data = copy.deepcopy(data)
|
||||||
|
for field in self.SENSITIVE_FIELDS:
|
||||||
|
if field in filtered_data:
|
||||||
|
filtered_data[field] = "***"
|
||||||
|
return filtered_data
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
method = request.method
|
||||||
|
|
||||||
|
if method not in self.LOGGED_METHODS:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
path = request.url.path
|
||||||
|
headers = dict(request.headers)
|
||||||
|
query_params = dict(request.query_params)
|
||||||
|
|
||||||
|
# 获取并过滤请求体
|
||||||
|
body = None
|
||||||
|
try:
|
||||||
|
body_bytes = await request.body()
|
||||||
|
if body_bytes:
|
||||||
|
body = json.loads(body_bytes)
|
||||||
|
body = self.filter_sensitive_data(body, path)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 从 Authorization 头获取 token
|
||||||
|
token = None
|
||||||
|
auth_header = headers.get('authorization')
|
||||||
|
if auth_header and auth_header.startswith('Bearer '):
|
||||||
|
token = auth_header.split(' ')[1]
|
||||||
|
|
||||||
|
# 从 token 获取用户信息
|
||||||
|
user_id = None
|
||||||
|
if token:
|
||||||
|
try:
|
||||||
|
payload = decode_jwt(token)
|
||||||
|
if payload:
|
||||||
|
user_id = payload.get("userid")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 处理请求
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
# 计算响应时间
|
||||||
|
response_time = int((time.time() - start_time) * 1000)
|
||||||
|
|
||||||
|
# 异步记录日志
|
||||||
|
log_data = {
|
||||||
|
"path": path,
|
||||||
|
"method": method,
|
||||||
|
"headers": headers,
|
||||||
|
"query_params": query_params,
|
||||||
|
"body": body,
|
||||||
|
"user_id": user_id,
|
||||||
|
"ip_address": request.client.host,
|
||||||
|
"status_code": response.status_code,
|
||||||
|
"response_time": response_time
|
||||||
|
}
|
||||||
|
log_request_async(log_data)
|
||||||
|
|
||||||
|
return response
|
||||||
19
app/models/request_log.py
Normal file
19
app/models/request_log.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from sqlalchemy import Column, String, Integer, DateTime, JSON
|
||||||
|
from sqlalchemy.sql import func
|
||||||
|
from .database import Base
|
||||||
|
|
||||||
|
class RequestLogDB(Base):
|
||||||
|
__tablename__ = "request_logs"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
path = Column(String(200), nullable=False)
|
||||||
|
method = Column(String(10), nullable=False)
|
||||||
|
headers = Column(JSON)
|
||||||
|
query_params = Column(JSON)
|
||||||
|
body = Column(JSON, nullable=True)
|
||||||
|
user_id = Column(Integer, nullable=True)
|
||||||
|
user_info = Column(JSON, nullable=True) # 添加用户详细信息
|
||||||
|
ip_address = Column(String(50))
|
||||||
|
status_code = Column(Integer)
|
||||||
|
response_time = Column(Integer) # 毫秒
|
||||||
|
create_time = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
Loading…
Reference in New Issue
Block a user