update
This commit is contained in:
parent
ab7fa6b895
commit
f1b6f56f3e
@ -197,7 +197,9 @@ async def password_login(
|
|||||||
request: Request = None
|
request: Request = None
|
||||||
):
|
):
|
||||||
"""密码登录"""
|
"""密码登录"""
|
||||||
|
print(f"login_data: {login_data}")
|
||||||
user = db.query(UserDB).filter(UserDB.phone == login_data.phone).first()
|
user = db.query(UserDB).filter(UserDB.phone == login_data.phone).first()
|
||||||
|
print(f"user: {user}")
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
return error_response(code=401, message="用户不存在")
|
return error_response(code=401, message="用户不存在")
|
||||||
|
|||||||
@ -3,20 +3,51 @@ from app.models.database import SessionLocal
|
|||||||
from app.models.request_log import RequestLogDB
|
from app.models.request_log import RequestLogDB
|
||||||
import json
|
import json
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
def save_request_log(log_data: Dict[str, Any]):
|
def save_request_log(log_data: Dict[str, Any]):
|
||||||
"""保存请求日志到数据库"""
|
"""保存请求日志到数据库"""
|
||||||
db = SessionLocal()
|
db = None
|
||||||
try:
|
try:
|
||||||
|
db = SessionLocal()
|
||||||
|
|
||||||
|
# 确保headers和body可以序列化为JSON
|
||||||
|
if 'headers' in log_data and log_data['headers']:
|
||||||
|
try:
|
||||||
|
# 尝试将headers转换为JSON字符串再解析回来,确保可序列化
|
||||||
|
json.dumps(log_data['headers'])
|
||||||
|
except (TypeError, OverflowError):
|
||||||
|
# 如果无法序列化,则转换为字符串
|
||||||
|
log_data['headers'] = str(log_data['headers'])
|
||||||
|
|
||||||
|
if 'body' in log_data and log_data['body']:
|
||||||
|
try:
|
||||||
|
# 尝试将body转换为JSON字符串再解析回来,确保可序列化
|
||||||
|
json.dumps(log_data['body'])
|
||||||
|
except (TypeError, OverflowError):
|
||||||
|
# 如果无法序列化,则转换为字符串
|
||||||
|
log_data['body'] = str(log_data['body'])
|
||||||
|
|
||||||
|
# 创建日志记录
|
||||||
log = RequestLogDB(**log_data)
|
log = RequestLogDB(**log_data)
|
||||||
db.add(log)
|
db.add(log)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
print(f"请求日志已保存: {log_data['path']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db.rollback()
|
|
||||||
print(f"保存日志失败: {str(e)}")
|
print(f"保存日志失败: {str(e)}")
|
||||||
|
print(traceback.format_exc())
|
||||||
|
if db:
|
||||||
|
db.rollback()
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
if db:
|
||||||
|
db.close()
|
||||||
|
|
||||||
def log_request_async(log_data: Dict[str, Any]):
|
def log_request_async(log_data: Dict[str, Any]):
|
||||||
"""在新线程中异步处理日志"""
|
"""在新线程中异步处理日志"""
|
||||||
Thread(target=save_request_log, args=(log_data,), daemon=True).start()
|
try:
|
||||||
|
Thread(target=save_request_log, args=(log_data,), daemon=True).start()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"启动日志线程失败: {str(e)}")
|
||||||
|
print(traceback.format_exc())
|
||||||
@ -26,7 +26,6 @@ def verify_token(token: str) -> Optional[str]:
|
|||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"])
|
||||||
sub: str = payload.get("sub")
|
sub: str = payload.get("sub")
|
||||||
print(f"payload: {payload}")
|
|
||||||
return sub
|
return sub
|
||||||
except JWTError:
|
except JWTError:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -5,6 +5,12 @@ from app.core.logger import log_request_async
|
|||||||
from app.core.security import verify_token, decode_jwt
|
from app.core.security import verify_token, decode_jwt
|
||||||
import json
|
import json
|
||||||
import copy
|
import copy
|
||||||
|
import asyncio
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
class RequestLoggerMiddleware(BaseHTTPMiddleware):
|
class RequestLoggerMiddleware(BaseHTTPMiddleware):
|
||||||
LOGGED_METHODS = {"POST", "PUT", "DELETE"}
|
LOGGED_METHODS = {"POST", "PUT", "DELETE"}
|
||||||
@ -26,6 +32,17 @@ class RequestLoggerMiddleware(BaseHTTPMiddleware):
|
|||||||
filtered_data[field] = "***"
|
filtered_data[field] = "***"
|
||||||
return filtered_data
|
return filtered_data
|
||||||
|
|
||||||
|
def filter_headers(self, headers: dict) -> dict:
|
||||||
|
"""过滤敏感请求头"""
|
||||||
|
filtered_headers = copy.deepcopy(headers)
|
||||||
|
sensitive_headers = ["authorization", "cookie"]
|
||||||
|
|
||||||
|
for header in sensitive_headers:
|
||||||
|
if header in filtered_headers:
|
||||||
|
filtered_headers[header] = "***"
|
||||||
|
|
||||||
|
return filtered_headers
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
method = request.method
|
method = request.method
|
||||||
|
|
||||||
@ -34,54 +51,65 @@ class RequestLoggerMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
headers = dict(request.headers)
|
|
||||||
|
# 不记录健康检查请求
|
||||||
|
if path.endswith("/health") or path.endswith("/ping"):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# 过滤敏感请求头
|
||||||
|
headers = self.filter_headers(dict(request.headers))
|
||||||
query_params = dict(request.query_params)
|
query_params = dict(request.query_params)
|
||||||
|
|
||||||
# 获取并过滤请求体
|
# 创建请求体的副本,而不消费原始请求体
|
||||||
body = None
|
body = None
|
||||||
try:
|
try:
|
||||||
|
# 保存原始请求体内容
|
||||||
body_bytes = await request.body()
|
body_bytes = await request.body()
|
||||||
|
|
||||||
|
# 重要:设置_body属性,使FastAPI可以再次读取请求体
|
||||||
|
request._body = body_bytes
|
||||||
|
|
||||||
if body_bytes:
|
if body_bytes:
|
||||||
body = json.loads(body_bytes)
|
try:
|
||||||
body = self.filter_sensitive_data(body, path)
|
body_str = body_bytes.decode()
|
||||||
print(f"请求体: {body}")
|
body = json.loads(body_str)
|
||||||
except:
|
body = self.filter_sensitive_data(body, path)
|
||||||
pass
|
except json.JSONDecodeError:
|
||||||
|
body = {"raw": "无法解析的JSON数据"}
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
body = {"raw": "无法解码的二进制数据"}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取请求体异常: {str(e)}")
|
||||||
|
body = {"error": "读取请求体异常"}
|
||||||
|
|
||||||
# 从 Authorization 头获取 token
|
# 处理请求,添加超时保护
|
||||||
# token = None
|
try:
|
||||||
# auth_header = headers.get('authorization')
|
response = await call_next(request)
|
||||||
# if auth_header and auth_header.startswith('Bearer '):
|
|
||||||
# token = auth_header.split(' ')[1]
|
# 计算响应时间
|
||||||
|
response_time = int((time.time() - start_time) * 1000)
|
||||||
# # 从 token 获取用户信息
|
|
||||||
# user_id = None
|
# 异步记录日志,捕获可能的异常
|
||||||
# if token:
|
try:
|
||||||
# try:
|
log_data = {
|
||||||
# payload = decode_jwt(token)
|
"path": path,
|
||||||
# if payload:
|
"method": method,
|
||||||
# user_id = payload.get("phone")
|
"headers": headers,
|
||||||
# except:
|
"query_params": query_params,
|
||||||
# pass
|
"body": body,
|
||||||
|
"ip_address": request.client.host if hasattr(request, "client") and request.client else "unknown",
|
||||||
# 处理请求
|
"status_code": response.status_code,
|
||||||
response = await call_next(request)
|
"response_time": response_time
|
||||||
|
}
|
||||||
# 计算响应时间
|
log_request_async(log_data)
|
||||||
response_time = int((time.time() - start_time) * 1000)
|
except Exception as e:
|
||||||
|
logger.error(f"记录请求日志异常: {str(e)}")
|
||||||
# 异步记录日志
|
|
||||||
log_data = {
|
return response
|
||||||
"path": path,
|
|
||||||
"method": method,
|
except Exception as e:
|
||||||
"headers": headers,
|
logger.error(f"请求处理异常: {str(e)}")
|
||||||
"query_params": query_params,
|
return JSONResponse(
|
||||||
"body": body,
|
status_code=500,
|
||||||
# "user_id": user_id,
|
content={"code": 500, "message": "服务器内部错误"}
|
||||||
"ip_address": request.client.host,
|
)
|
||||||
"status_code": response.status_code,
|
|
||||||
"response_time": response_time
|
|
||||||
}
|
|
||||||
log_request_async(log_data)
|
|
||||||
|
|
||||||
return response
|
|
||||||
@ -112,9 +112,9 @@ class UserUpdate(BaseModel):
|
|||||||
extra = "forbid" # 禁止额外字段
|
extra = "forbid" # 禁止额外字段
|
||||||
|
|
||||||
class UserPasswordLogin(BaseModel):
|
class UserPasswordLogin(BaseModel):
|
||||||
phone: str = Field(..., pattern="^1[3-9]\d{9}$")
|
phone: str = Field(..., pattern="^1[3-9]\d{9}$", description="手机号")
|
||||||
password: str = Field(..., min_length=6, max_length=20)
|
password: str = Field(..., min_length=6, max_length=20, description="密码")
|
||||||
role: UserRole = Field(default=UserRole.DELIVERYMAN)
|
role: UserRole = Field(default=UserRole.ADMIN, description="角色")
|
||||||
|
|
||||||
class ChangePasswordRequest(BaseModel):
|
class ChangePasswordRequest(BaseModel):
|
||||||
phone: str = Field(..., pattern="^1[3-9]\d{9}$")
|
phone: str = Field(..., pattern="^1[3-9]\d{9}$")
|
||||||
|
|||||||
BIN
jobs.sqlite
BIN
jobs.sqlite
Binary file not shown.
Loading…
Reference in New Issue
Block a user