From e4fbe74f8c86a6fff9cee048b8c2783a9632fe2f Mon Sep 17 00:00:00 2001 From: aaron <> Date: Mon, 13 Jan 2025 19:01:39 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=B8=80=E4=B8=AA=20logger?= =?UTF-8?q?=20=E4=B8=AD=E9=97=B4=E4=BB=B6=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/endpoints/user.py | 6 +-- app/core/logger.py | 22 ++++++++ app/core/security.py | 26 +++++++--- app/main.py | 4 ++ app/middleware/request_logger.py | 86 ++++++++++++++++++++++++++++++++ app/models/request_log.py | 19 +++++++ 6 files changed, 154 insertions(+), 9 deletions(-) create mode 100644 app/core/logger.py create mode 100644 app/middleware/request_logger.py create mode 100644 app/models/request_log.py diff --git a/app/api/endpoints/user.py b/app/api/endpoints/user.py index 81ba298..811f7d7 100644 --- a/app/api/endpoints/user.py +++ b/app/api/endpoints/user.py @@ -112,7 +112,7 @@ async def login( # 创建访问令牌 access_token = create_access_token( - data={"sub": user.phone} + data={"phone": user.phone,"userid":user.userid} ) # 设置JWT cookie @@ -168,7 +168,7 @@ async def mock_login( # 创建访问令牌 access_token = create_access_token( - data={"sub": user.phone} + data={"phone": user.phone,"userid":user.userid} ) # 设置JWT cookie @@ -273,7 +273,7 @@ async def password_login( return error_response(code=401, message="密码错误") # 生成访问令牌 - access_token = create_access_token(data={"sub": user.phone}) + access_token = create_access_token(data={"phone": user.phone,"userid":user.userid}) return success_response( data={ diff --git a/app/core/logger.py b/app/core/logger.py new file mode 100644 index 0000000..a56fc7a --- /dev/null +++ b/app/core/logger.py @@ -0,0 +1,22 @@ +from typing import Dict, Any +from app.models.database import SessionLocal +from app.models.request_log import RequestLogDB +import json +from threading import Thread + +def save_request_log(log_data: Dict[str, Any]): + """保存请求日志到数据库""" + db = SessionLocal() + try: + log = RequestLogDB(**log_data) + db.add(log) + db.commit() + except Exception as e: + db.rollback() + print(f"保存日志失败: {str(e)}") + finally: + db.close() + +def log_request_async(log_data: Dict[str, Any]): + """在新线程中异步处理日志""" + Thread(target=save_request_log, args=(log_data,), daemon=True).start() \ No newline at end of file diff --git a/app/core/security.py b/app/core/security.py index 4778b68..f67a814 100644 --- a/app/core/security.py +++ b/app/core/security.py @@ -10,9 +10,12 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: to_encode = data.copy() - if expires_delta is not None: - expire = datetime.now(timezone.utc) + expires_delta - to_encode.update({"exp": expire}) + + # 更新 token 数据,不设置过期时间 + to_encode.update({ + "userid": data.get("userid"), + "phone": data.get("phone") + }) encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm="HS256") return encoded_jwt @@ -41,12 +44,12 @@ def clear_jwt_cookie(response: Response): def verify_token(token: str) -> Optional[str]: try: payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) - phone: str = payload.get("sub") + phone: str = payload.get("phone") if phone is None: return None return phone except JWTError: - return None + return None def get_password_hash(password: str) -> str: """获取密码哈希值""" @@ -54,4 +57,15 @@ def get_password_hash(password: str) -> str: def verify_password(plain_password: str, hashed_password: str) -> bool: """验证密码""" - return pwd_context.verify(plain_password, hashed_password) \ No newline at end of file + return pwd_context.verify(plain_password, hashed_password) + +def decode_jwt(token: str) -> dict: + """解码 JWT token 获取完整信息""" + try: + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) + return { + "userid": payload.get("userid"), + "phone": payload.get("phone") + } + except: + return None \ No newline at end of file diff --git a/app/main.py b/app/main.py index 98c7c3d..f6ee2bf 100644 --- a/app/main.py +++ b/app/main.py @@ -6,6 +6,7 @@ from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse from app.core.response import error_response from fastapi import HTTPException +from app.middleware.request_logger import RequestLoggerMiddleware # 创建数据库表 Base.metadata.create_all(bind=engine) @@ -25,6 +26,9 @@ app.add_middleware( allow_headers=["*"], ) +# 添加请求日志中间件 +app.add_middleware(RequestLoggerMiddleware) + # 添加用户路由 app.include_router(user.router, prefix="/api/user", tags=["用户"]) app.include_router(address.router, prefix="/api/address", tags=["配送地址"]) diff --git a/app/middleware/request_logger.py b/app/middleware/request_logger.py new file mode 100644 index 0000000..d36c469 --- /dev/null +++ b/app/middleware/request_logger.py @@ -0,0 +1,86 @@ +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware +import time +from app.core.logger import log_request_async +from app.core.security import verify_token, decode_jwt +import json +import copy + +class RequestLoggerMiddleware(BaseHTTPMiddleware): + LOGGED_METHODS = {"POST", "PUT", "DELETE"} + SENSITIVE_PATHS = { + "/api/user/login", + "/api/user/password-login", + "/api/user/reset-password" + } + SENSITIVE_FIELDS = {"password", "verify_code", "old_password", "new_password"} + + def filter_sensitive_data(self, data: dict, path: str) -> dict: + """过滤敏感数据""" + if not data or path not in self.SENSITIVE_PATHS: + return data + + filtered_data = copy.deepcopy(data) + for field in self.SENSITIVE_FIELDS: + if field in filtered_data: + filtered_data[field] = "***" + return filtered_data + + async def dispatch(self, request: Request, call_next): + method = request.method + + if method not in self.LOGGED_METHODS: + return await call_next(request) + + start_time = time.time() + path = request.url.path + headers = dict(request.headers) + query_params = dict(request.query_params) + + # 获取并过滤请求体 + body = None + try: + body_bytes = await request.body() + if body_bytes: + body = json.loads(body_bytes) + body = self.filter_sensitive_data(body, path) + except: + pass + + # 从 Authorization 头获取 token + token = None + auth_header = headers.get('authorization') + if auth_header and auth_header.startswith('Bearer '): + token = auth_header.split(' ')[1] + + # 从 token 获取用户信息 + user_id = None + if token: + try: + payload = decode_jwt(token) + if payload: + user_id = payload.get("userid") + except: + pass + + # 处理请求 + response = await call_next(request) + + # 计算响应时间 + response_time = int((time.time() - start_time) * 1000) + + # 异步记录日志 + log_data = { + "path": path, + "method": method, + "headers": headers, + "query_params": query_params, + "body": body, + "user_id": user_id, + "ip_address": request.client.host, + "status_code": response.status_code, + "response_time": response_time + } + log_request_async(log_data) + + return response \ No newline at end of file diff --git a/app/models/request_log.py b/app/models/request_log.py new file mode 100644 index 0000000..1b919f1 --- /dev/null +++ b/app/models/request_log.py @@ -0,0 +1,19 @@ +from sqlalchemy import Column, String, Integer, DateTime, JSON +from sqlalchemy.sql import func +from .database import Base + +class RequestLogDB(Base): + __tablename__ = "request_logs" + + id = Column(Integer, primary_key=True, autoincrement=True) + path = Column(String(200), nullable=False) + method = Column(String(10), nullable=False) + headers = Column(JSON) + query_params = Column(JSON) + body = Column(JSON, nullable=True) + user_id = Column(Integer, nullable=True) + user_info = Column(JSON, nullable=True) # 添加用户详细信息 + ip_address = Column(String(50)) + status_code = Column(Integer) + response_time = Column(Integer) # 毫秒 + create_time = Column(DateTime(timezone=True), server_default=func.now()) \ No newline at end of file