This commit is contained in:
aaron 2025-03-16 00:06:45 +08:00
parent ab7fa6b895
commit f1b6f56f3e
6 changed files with 112 additions and 52 deletions

View File

@ -197,7 +197,9 @@ async def password_login(
request: Request = None request: Request = None
): ):
"""密码登录""" """密码登录"""
print(f"login_data: {login_data}")
user = db.query(UserDB).filter(UserDB.phone == login_data.phone).first() user = db.query(UserDB).filter(UserDB.phone == login_data.phone).first()
print(f"user: {user}")
if not user: if not user:
return error_response(code=401, message="用户不存在") return error_response(code=401, message="用户不存在")

View File

@ -3,20 +3,51 @@ from app.models.database import SessionLocal
from app.models.request_log import RequestLogDB from app.models.request_log import RequestLogDB
import json import json
from threading import Thread from threading import Thread
import logging
import traceback
def save_request_log(log_data: Dict[str, Any]): def save_request_log(log_data: Dict[str, Any]):
"""保存请求日志到数据库""" """保存请求日志到数据库"""
db = SessionLocal() db = None
try: try:
db = SessionLocal()
# 确保headers和body可以序列化为JSON
if 'headers' in log_data and log_data['headers']:
try:
# 尝试将headers转换为JSON字符串再解析回来确保可序列化
json.dumps(log_data['headers'])
except (TypeError, OverflowError):
# 如果无法序列化,则转换为字符串
log_data['headers'] = str(log_data['headers'])
if 'body' in log_data and log_data['body']:
try:
# 尝试将body转换为JSON字符串再解析回来确保可序列化
json.dumps(log_data['body'])
except (TypeError, OverflowError):
# 如果无法序列化,则转换为字符串
log_data['body'] = str(log_data['body'])
# 创建日志记录
log = RequestLogDB(**log_data) log = RequestLogDB(**log_data)
db.add(log) db.add(log)
db.commit() db.commit()
print(f"请求日志已保存: {log_data['path']}")
except Exception as e: except Exception as e:
db.rollback()
print(f"保存日志失败: {str(e)}") print(f"保存日志失败: {str(e)}")
print(traceback.format_exc())
if db:
db.rollback()
finally: finally:
db.close() if db:
db.close()
def log_request_async(log_data: Dict[str, Any]): def log_request_async(log_data: Dict[str, Any]):
"""在新线程中异步处理日志""" """在新线程中异步处理日志"""
Thread(target=save_request_log, args=(log_data,), daemon=True).start() try:
Thread(target=save_request_log, args=(log_data,), daemon=True).start()
except Exception as e:
print(f"启动日志线程失败: {str(e)}")
print(traceback.format_exc())

View File

@ -26,7 +26,6 @@ def verify_token(token: str) -> Optional[str]:
try: try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
sub: str = payload.get("sub") sub: str = payload.get("sub")
print(f"payload: {payload}")
return sub return sub
except JWTError: except JWTError:
return None return None

View File

@ -5,6 +5,12 @@ from app.core.logger import log_request_async
from app.core.security import verify_token, decode_jwt from app.core.security import verify_token, decode_jwt
import json import json
import copy import copy
import asyncio
from starlette.responses import JSONResponse
import logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
class RequestLoggerMiddleware(BaseHTTPMiddleware): class RequestLoggerMiddleware(BaseHTTPMiddleware):
LOGGED_METHODS = {"POST", "PUT", "DELETE"} LOGGED_METHODS = {"POST", "PUT", "DELETE"}
@ -26,6 +32,17 @@ class RequestLoggerMiddleware(BaseHTTPMiddleware):
filtered_data[field] = "***" filtered_data[field] = "***"
return filtered_data return filtered_data
def filter_headers(self, headers: dict) -> dict:
"""过滤敏感请求头"""
filtered_headers = copy.deepcopy(headers)
sensitive_headers = ["authorization", "cookie"]
for header in sensitive_headers:
if header in filtered_headers:
filtered_headers[header] = "***"
return filtered_headers
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
method = request.method method = request.method
@ -34,54 +51,65 @@ class RequestLoggerMiddleware(BaseHTTPMiddleware):
start_time = time.time() start_time = time.time()
path = request.url.path path = request.url.path
headers = dict(request.headers)
# 不记录健康检查请求
if path.endswith("/health") or path.endswith("/ping"):
return await call_next(request)
# 过滤敏感请求头
headers = self.filter_headers(dict(request.headers))
query_params = dict(request.query_params) query_params = dict(request.query_params)
# 获取并过滤请求体 # 创建请求体的副本,而不消费原始请求体
body = None body = None
try: try:
# 保存原始请求体内容
body_bytes = await request.body() body_bytes = await request.body()
# 重要设置_body属性使FastAPI可以再次读取请求体
request._body = body_bytes
if body_bytes: if body_bytes:
body = json.loads(body_bytes) try:
body = self.filter_sensitive_data(body, path) body_str = body_bytes.decode()
print(f"请求体: {body}") body = json.loads(body_str)
except: body = self.filter_sensitive_data(body, path)
pass except json.JSONDecodeError:
body = {"raw": "无法解析的JSON数据"}
except UnicodeDecodeError:
body = {"raw": "无法解码的二进制数据"}
except Exception as e:
logger.error(f"读取请求体异常: {str(e)}")
body = {"error": "读取请求体异常"}
# 从 Authorization 头获取 token # 处理请求,添加超时保护
# token = None try:
# auth_header = headers.get('authorization') response = await call_next(request)
# if auth_header and auth_header.startswith('Bearer '):
# token = auth_header.split(' ')[1] # 计算响应时间
response_time = int((time.time() - start_time) * 1000)
# # 从 token 获取用户信息
# user_id = None # 异步记录日志,捕获可能的异常
# if token: try:
# try: log_data = {
# payload = decode_jwt(token) "path": path,
# if payload: "method": method,
# user_id = payload.get("phone") "headers": headers,
# except: "query_params": query_params,
# pass "body": body,
"ip_address": request.client.host if hasattr(request, "client") and request.client else "unknown",
# 处理请求 "status_code": response.status_code,
response = await call_next(request) "response_time": response_time
}
# 计算响应时间 log_request_async(log_data)
response_time = int((time.time() - start_time) * 1000) except Exception as e:
logger.error(f"记录请求日志异常: {str(e)}")
# 异步记录日志
log_data = { return response
"path": path,
"method": method, except Exception as e:
"headers": headers, logger.error(f"请求处理异常: {str(e)}")
"query_params": query_params, return JSONResponse(
"body": body, status_code=500,
# "user_id": user_id, content={"code": 500, "message": "服务器内部错误"}
"ip_address": request.client.host, )
"status_code": response.status_code,
"response_time": response_time
}
log_request_async(log_data)
return response

View File

@ -112,9 +112,9 @@ class UserUpdate(BaseModel):
extra = "forbid" # 禁止额外字段 extra = "forbid" # 禁止额外字段
class UserPasswordLogin(BaseModel): class UserPasswordLogin(BaseModel):
phone: str = Field(..., pattern="^1[3-9]\d{9}$") phone: str = Field(..., pattern="^1[3-9]\d{9}$", description="手机号")
password: str = Field(..., min_length=6, max_length=20) password: str = Field(..., min_length=6, max_length=20, description="密码")
role: UserRole = Field(default=UserRole.DELIVERYMAN) role: UserRole = Field(default=UserRole.ADMIN, description="角色")
class ChangePasswordRequest(BaseModel): class ChangePasswordRequest(BaseModel):
phone: str = Field(..., pattern="^1[3-9]\d{9}$") phone: str = Field(..., pattern="^1[3-9]\d{9}$")

Binary file not shown.