This commit is contained in:
aaron 2025-03-10 20:18:52 +08:00
parent fbe860a014
commit 6c2333da4c
6 changed files with 327 additions and 12 deletions

View File

@ -70,14 +70,17 @@ async def get_communities(
): ):
"""获取社区列表""" """获取社区列表"""
# 构建查询, 关联社区分润 # 构建查询, 关联社区分润
# 使用一次查询获取所有需要的数据,减少数据库连接使用时间
query = db.query(CommunityDB).options(joinedload(CommunityDB.community_profit_sharing)) query = db.query(CommunityDB).options(joinedload(CommunityDB.community_profit_sharing))
# 状态过滤 # 状态过滤
if status: if status:
query = query.filter(CommunityDB.status == status) query = query.filter(CommunityDB.status == status)
# 获取总数 # 获取总数(使用子查询优化计数操作)
total = query.count() from sqlalchemy import func
count_query = query.statement.with_only_columns([func.count()]).order_by(None)
total = db.execute(count_query).scalar()
# 查询数据 # 查询数据
communities = query.offset(skip).limit(limit).all() communities = query.offset(skip).limit(limit).all()

View File

@ -0,0 +1,63 @@
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from sqlalchemy import text
from app.models.database import get_db, get_active_sessions_count, get_long_running_sessions
from app.core.response import success_response, ResponseModel
from app.core.db_monitor import DBConnectionMonitor
router = APIRouter()
@router.get("/health", response_model=ResponseModel)
async def health_check(db: Session = Depends(get_db)):
"""健康检查端点检查API和数据库连接状态"""
# 尝试执行简单查询以验证数据库连接
try:
db.execute(text("SELECT 1")).scalar()
db_status = "healthy"
except Exception as e:
db_status = f"unhealthy: {str(e)}"
# 获取连接池状态和性能统计
all_stats = DBConnectionMonitor.get_all_stats()
# 获取活跃会话信息
active_sessions_count = get_active_sessions_count()
long_running_sessions = get_long_running_sessions(threshold_seconds=30)
return success_response(data={
"status": "ok",
"database": {
"status": db_status,
"connection_pool": all_stats["connection_pool"],
"active_sessions": active_sessions_count,
"long_running_sessions": len(long_running_sessions)
}
})
@router.get("/stats", response_model=ResponseModel)
async def performance_stats(
include_slow_queries: bool = Query(False, description="是否包含慢查询记录"),
include_long_sessions: bool = Query(False, description="是否包含长时间运行的会话详情"),
db: Session = Depends(get_db)
):
"""获取详细的性能统计信息"""
# 获取所有统计信息
all_stats = DBConnectionMonitor.get_all_stats()
# 如果不需要包含慢查询记录,则移除它们以减少响应大小
if not include_slow_queries and "slow_queries" in all_stats["performance_stats"]:
all_stats["performance_stats"]["slow_queries_count"] = len(all_stats["performance_stats"]["slow_queries"])
del all_stats["performance_stats"]["slow_queries"]
# 添加会话信息
all_stats["sessions"] = {
"active_count": get_active_sessions_count()
}
# 如果需要包含长时间运行的会话详情
if include_long_sessions:
all_stats["sessions"]["long_running"] = get_long_running_sessions(threshold_seconds=30)
else:
all_stats["sessions"]["long_running_count"] = len(get_long_running_sessions(threshold_seconds=30))
return success_response(data=all_stats)

View File

@ -155,7 +155,7 @@ def calculate_delivery_share(order: ShippingOrderDB, db: Session) -> float:
CommunityProfitSharing.community_id == order.address_community_id CommunityProfitSharing.community_id == order.address_community_id
).first() ).first()
if sharing: if sharing:
return round(order.original_amount_with_additional_fee * (sharing.delivery_rate / 100.0), 2) return round(order.original_amount_with_additional_fee * (float(sharing.delivery_rate) / 100.0), 2)
else: else:
return 0 return 0

149
app/core/db_monitor.py Normal file
View File

@ -0,0 +1,149 @@
from app.models.database import engine
import logging
from fastapi import FastAPI
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
import time
import threading
from collections import defaultdict, deque
import datetime
logger = logging.getLogger(__name__)
class DBStats:
"""数据库统计信息"""
def __init__(self):
self.slow_queries = deque(maxlen=100) # 最多保存100条慢查询记录
self.endpoint_stats = defaultdict(lambda: {"count": 0, "total_time": 0, "max_time": 0})
self.hourly_requests = defaultdict(int)
self.lock = threading.Lock()
def record_slow_query(self, method, path, duration):
"""记录慢查询"""
with self.lock:
self.slow_queries.append({
"method": method,
"path": path,
"duration": duration,
"timestamp": datetime.datetime.now().isoformat()
})
def record_request(self, method, path, duration):
"""记录请求统计"""
with self.lock:
key = f"{method} {path}"
self.endpoint_stats[key]["count"] += 1
self.endpoint_stats[key]["total_time"] += duration
self.endpoint_stats[key]["max_time"] = max(self.endpoint_stats[key]["max_time"], duration)
# 记录每小时请求数
hour = datetime.datetime.now().strftime("%Y-%m-%d %H:00")
self.hourly_requests[hour] += 1
def get_stats(self):
"""获取统计信息"""
with self.lock:
# 计算平均响应时间
avg_times = {}
for endpoint, stats in self.endpoint_stats.items():
if stats["count"] > 0:
avg_times[endpoint] = stats["total_time"] / stats["count"]
# 按平均响应时间排序的前10个端点
sorted_endpoints = sorted(
avg_times.items(),
key=lambda x: x[1],
reverse=True
)[:10]
return {
"slow_queries": list(self.slow_queries),
"top_slow_endpoints": sorted_endpoints,
"hourly_requests": dict(self.hourly_requests)
}
class DBConnectionMonitor:
"""数据库连接监控工具"""
# 初始化统计对象
stats = DBStats()
@staticmethod
def get_connection_stats():
"""获取当前连接池状态"""
pool = engine.pool
return {
"pool_size": pool.size(),
"checkedin": pool.checkedin(),
"checkedout": pool.checkedout(),
"overflow": pool.overflow(),
}
@staticmethod
def log_connection_stats():
"""记录当前连接池状态"""
stats = DBConnectionMonitor.get_connection_stats()
logger.info(f"DB Connection Pool Stats: {stats}")
return stats
@staticmethod
def get_all_stats():
"""获取所有统计信息"""
return {
"connection_pool": DBConnectionMonitor.get_connection_stats(),
"performance_stats": DBConnectionMonitor.stats.get_stats()
}
class DBMonitorMiddleware(BaseHTTPMiddleware):
"""数据库连接监控中间件,定期记录连接池状态"""
def __init__(self, app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0):
super().__init__(app)
self.log_interval = log_interval
self.request_count = 0
self.slow_query_threshold = slow_query_threshold
async def dispatch(self, request: Request, call_next):
self.request_count += 1
# 每处理N个请求记录一次连接池状态
if self.request_count % self.log_interval == 0:
DBConnectionMonitor.log_connection_stats()
# 记录请求处理时间
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
# 记录请求统计
method = request.method
path = request.url.path
DBConnectionMonitor.stats.record_request(method, path, process_time)
# 如果请求处理时间超过阈值,记录为慢查询
if process_time > self.slow_query_threshold:
logger.warning(f"Slow request: {method} {path} took {process_time:.2f}s")
DBConnectionMonitor.stats.record_slow_query(method, path, process_time)
return response
def setup_db_monitor(app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0):
"""设置数据库监控"""
app.add_middleware(
DBMonitorMiddleware,
log_interval=log_interval,
slow_query_threshold=slow_query_threshold
)
@app.on_event("startup")
async def startup_db_monitor():
logger.info("Starting DB connection monitoring")
DBConnectionMonitor.log_connection_stats()
@app.on_event("shutdown")
async def shutdown_db_monitor():
logger.info("Final DB connection stats before shutdown")
DBConnectionMonitor.log_connection_stats()

View File

@ -1,6 +1,6 @@
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from app.api.endpoints import wechat,user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_order, point, config, merchant_category, log, account,merchant_pay_order, message, bank_card, withdraw, mp, point_product, point_product_order, coupon_activity, dashboard, wecom, feedback, timeperiod, community_timeperiod, order_additional_fee, ai, community_set, community_set_mapping, community_profit_sharing, partner from app.api.endpoints import wechat,user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_order, point, config, merchant_category, log, account,merchant_pay_order, message, bank_card, withdraw, mp, point_product, point_product_order, coupon_activity, dashboard, wecom, feedback, timeperiod, community_timeperiod, order_additional_fee, ai, community_set, community_set_mapping, community_profit_sharing, partner, health
from app.models.database import Base, engine from app.models.database import Base, engine
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
@ -15,6 +15,7 @@ from app.api.endpoints import wecom
from app.api.endpoints import feedback from app.api.endpoints import feedback
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
import os import os
from app.core.db_monitor import setup_db_monitor
# 创建数据库表 # 创建数据库表
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
@ -26,6 +27,9 @@ app = FastAPI(
docs_url="/docs" if settings.DEBUG else None docs_url="/docs" if settings.DEBUG else None
) )
# 设置数据库连接监控
setup_db_monitor(app)
app.default_response_class = CustomJSONResponse app.default_response_class = CustomJSONResponse
# 配置 CORS # 配置 CORS
@ -80,6 +84,7 @@ app.include_router(upload.router, prefix="/api/upload", tags=["文件上传"])
app.include_router(config.router, prefix="/api/config", tags=["系统配置"]) app.include_router(config.router, prefix="/api/config", tags=["系统配置"])
app.include_router(log.router, prefix="/api/logs", tags=["系统日志"]) app.include_router(log.router, prefix="/api/logs", tags=["系统日志"])
app.include_router(feedback.router, prefix="/api/feedback", tags=["反馈"]) app.include_router(feedback.router, prefix="/api/feedback", tags=["反馈"])
app.include_router(health.router, prefix="/api/health", tags=["系统健康检查"])
@app.get("/") @app.get("/")

View File

@ -1,25 +1,120 @@
from sqlalchemy import create_engine from sqlalchemy import create_engine, event
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker, scoped_session
from app.core.config import settings from app.core.config import settings
import pymysql import pymysql
import logging
import time
import threading
from contextlib import contextmanager
# 设置日志记录器
logger = logging.getLogger(__name__)
# 注册 MySQL Python SQL Driver # 注册 MySQL Python SQL Driver
pymysql.install_as_MySQLdb() pymysql.install_as_MySQLdb()
# 创建数据库引擎
engine = create_engine( engine = create_engine(
settings.SQLALCHEMY_DATABASE_URL, settings.SQLALCHEMY_DATABASE_URL,
pool_pre_ping=True, # 自动处理断开的连接 pool_pre_ping=True, # 自动处理断开的连接
pool_recycle=3600, # 连接回收时间 pool_recycle=3600, # 连接回收时间1小时
pool_size=10, # 连接池大小
max_overflow=20, # 允许的最大连接数超出pool_size的数量
pool_timeout=30, # 获取连接的超时时间(秒)
echo=settings.DEBUG, # 在调试模式下打印SQL语句
) )
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建线程安全的会话工厂
SessionLocal = scoped_session(
sessionmaker(autocommit=False, autoflush=False, bind=engine)
)
# 声明基类
Base = declarative_base() Base = declarative_base()
# 会话跟踪
active_sessions = {}
session_lock = threading.Lock()
# 添加事件监听器,记录长时间运行的查询
@event.listens_for(engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
"""在执行SQL查询前记录时间"""
conn.info.setdefault('query_start_time', []).append(time.time())
if settings.DEBUG:
# 安全地记录SQL语句避免敏感信息泄露
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
logger.debug(f"开始执行查询: {safe_statement}...")
@event.listens_for(engine, "after_cursor_execute")
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
"""在执行SQL查询后计算耗时并记录慢查询"""
total = time.time() - conn.info['query_start_time'].pop(-1)
# 记录慢查询
if total > 0.5: # 记录超过0.5秒的查询
# 安全地记录SQL语句避免敏感信息泄露
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
logger.warning(f"慢查询 ({total:.2f}s): {safe_statement}...")
# 依赖项 # 依赖项
def get_db(): def get_db():
db = SessionLocal() """获取数据库会话"""
session_id = threading.get_ident()
session = SessionLocal()
# 记录活跃会话
with session_lock:
active_sessions[session_id] = {
"created_at": time.time(),
"thread_id": session_id
}
try: try:
yield db yield session
finally: finally:
db.close() # 关闭会话并从活跃会话中移除
session.close()
with session_lock:
if session_id in active_sessions:
del active_sessions[session_id]
@contextmanager
def get_db_context():
"""上下文管理器版本的get_db用于非依赖项场景"""
session = SessionLocal()
session_id = threading.get_ident()
# 记录活跃会话
with session_lock:
active_sessions[session_id] = {
"created_at": time.time(),
"thread_id": session_id
}
try:
yield session
session.commit()
except Exception as e:
session.rollback()
raise e
finally:
session.close()
with session_lock:
if session_id in active_sessions:
del active_sessions[session_id]
def get_active_sessions_count():
"""获取当前活跃会话数量"""
with session_lock:
return len(active_sessions)
def get_long_running_sessions(threshold_seconds=60):
"""获取长时间运行的会话"""
current_time = time.time()
with session_lock:
return {
thread_id: info for thread_id, info in active_sessions.items()
if current_time - info["created_at"] > threshold_seconds
}