diff --git a/app/api/endpoints/community.py b/app/api/endpoints/community.py index eb5aec0..5cbaf4f 100644 --- a/app/api/endpoints/community.py +++ b/app/api/endpoints/community.py @@ -70,14 +70,17 @@ async def get_communities( ): """获取社区列表""" # 构建查询, 关联社区分润 + # 使用一次查询获取所有需要的数据,减少数据库连接使用时间 query = db.query(CommunityDB).options(joinedload(CommunityDB.community_profit_sharing)) # 状态过滤 if status: query = query.filter(CommunityDB.status == status) - # 获取总数 - total = query.count() + # 获取总数(使用子查询优化计数操作) + from sqlalchemy import func + count_query = query.statement.with_only_columns([func.count()]).order_by(None) + total = db.execute(count_query).scalar() # 查询数据 communities = query.offset(skip).limit(limit).all() diff --git a/app/api/endpoints/health.py b/app/api/endpoints/health.py new file mode 100644 index 0000000..d5baeff --- /dev/null +++ b/app/api/endpoints/health.py @@ -0,0 +1,63 @@ +from fastapi import APIRouter, Depends, Query +from sqlalchemy.orm import Session +from sqlalchemy import text +from app.models.database import get_db, get_active_sessions_count, get_long_running_sessions +from app.core.response import success_response, ResponseModel +from app.core.db_monitor import DBConnectionMonitor + +router = APIRouter() + +@router.get("/health", response_model=ResponseModel) +async def health_check(db: Session = Depends(get_db)): + """健康检查端点,检查API和数据库连接状态""" + # 尝试执行简单查询以验证数据库连接 + try: + db.execute(text("SELECT 1")).scalar() + db_status = "healthy" + except Exception as e: + db_status = f"unhealthy: {str(e)}" + + # 获取连接池状态和性能统计 + all_stats = DBConnectionMonitor.get_all_stats() + + # 获取活跃会话信息 + active_sessions_count = get_active_sessions_count() + long_running_sessions = get_long_running_sessions(threshold_seconds=30) + + return success_response(data={ + "status": "ok", + "database": { + "status": db_status, + "connection_pool": all_stats["connection_pool"], + "active_sessions": active_sessions_count, + "long_running_sessions": len(long_running_sessions) + } + }) + +@router.get("/stats", response_model=ResponseModel) +async def performance_stats( + include_slow_queries: bool = Query(False, description="是否包含慢查询记录"), + include_long_sessions: bool = Query(False, description="是否包含长时间运行的会话详情"), + db: Session = Depends(get_db) +): + """获取详细的性能统计信息""" + # 获取所有统计信息 + all_stats = DBConnectionMonitor.get_all_stats() + + # 如果不需要包含慢查询记录,则移除它们以减少响应大小 + if not include_slow_queries and "slow_queries" in all_stats["performance_stats"]: + all_stats["performance_stats"]["slow_queries_count"] = len(all_stats["performance_stats"]["slow_queries"]) + del all_stats["performance_stats"]["slow_queries"] + + # 添加会话信息 + all_stats["sessions"] = { + "active_count": get_active_sessions_count() + } + + # 如果需要包含长时间运行的会话详情 + if include_long_sessions: + all_stats["sessions"]["long_running"] = get_long_running_sessions(threshold_seconds=30) + else: + all_stats["sessions"]["long_running_count"] = len(get_long_running_sessions(threshold_seconds=30)) + + return success_response(data=all_stats) \ No newline at end of file diff --git a/app/api/endpoints/order.py b/app/api/endpoints/order.py index 2a68196..f430c70 100644 --- a/app/api/endpoints/order.py +++ b/app/api/endpoints/order.py @@ -155,7 +155,7 @@ def calculate_delivery_share(order: ShippingOrderDB, db: Session) -> float: CommunityProfitSharing.community_id == order.address_community_id ).first() if sharing: - return round(order.original_amount_with_additional_fee * (sharing.delivery_rate / 100.0), 2) + return round(order.original_amount_with_additional_fee * (float(sharing.delivery_rate) / 100.0), 2) else: return 0 diff --git a/app/core/db_monitor.py b/app/core/db_monitor.py new file mode 100644 index 0000000..978862d --- /dev/null +++ b/app/core/db_monitor.py @@ -0,0 +1,149 @@ +from app.models.database import engine +import logging +from fastapi import FastAPI +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +import time +import threading +from collections import defaultdict, deque +import datetime + +logger = logging.getLogger(__name__) + +class DBStats: + """数据库统计信息""" + def __init__(self): + self.slow_queries = deque(maxlen=100) # 最多保存100条慢查询记录 + self.endpoint_stats = defaultdict(lambda: {"count": 0, "total_time": 0, "max_time": 0}) + self.hourly_requests = defaultdict(int) + self.lock = threading.Lock() + + def record_slow_query(self, method, path, duration): + """记录慢查询""" + with self.lock: + self.slow_queries.append({ + "method": method, + "path": path, + "duration": duration, + "timestamp": datetime.datetime.now().isoformat() + }) + + def record_request(self, method, path, duration): + """记录请求统计""" + with self.lock: + key = f"{method} {path}" + self.endpoint_stats[key]["count"] += 1 + self.endpoint_stats[key]["total_time"] += duration + self.endpoint_stats[key]["max_time"] = max(self.endpoint_stats[key]["max_time"], duration) + + # 记录每小时请求数 + hour = datetime.datetime.now().strftime("%Y-%m-%d %H:00") + self.hourly_requests[hour] += 1 + + def get_stats(self): + """获取统计信息""" + with self.lock: + # 计算平均响应时间 + avg_times = {} + for endpoint, stats in self.endpoint_stats.items(): + if stats["count"] > 0: + avg_times[endpoint] = stats["total_time"] / stats["count"] + + # 按平均响应时间排序的前10个端点 + sorted_endpoints = sorted( + avg_times.items(), + key=lambda x: x[1], + reverse=True + )[:10] + + return { + "slow_queries": list(self.slow_queries), + "top_slow_endpoints": sorted_endpoints, + "hourly_requests": dict(self.hourly_requests) + } + + +class DBConnectionMonitor: + """数据库连接监控工具""" + + # 初始化统计对象 + stats = DBStats() + + @staticmethod + def get_connection_stats(): + """获取当前连接池状态""" + pool = engine.pool + return { + "pool_size": pool.size(), + "checkedin": pool.checkedin(), + "checkedout": pool.checkedout(), + "overflow": pool.overflow(), + } + + @staticmethod + def log_connection_stats(): + """记录当前连接池状态""" + stats = DBConnectionMonitor.get_connection_stats() + logger.info(f"DB Connection Pool Stats: {stats}") + return stats + + @staticmethod + def get_all_stats(): + """获取所有统计信息""" + return { + "connection_pool": DBConnectionMonitor.get_connection_stats(), + "performance_stats": DBConnectionMonitor.stats.get_stats() + } + + +class DBMonitorMiddleware(BaseHTTPMiddleware): + """数据库连接监控中间件,定期记录连接池状态""" + + def __init__(self, app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0): + super().__init__(app) + self.log_interval = log_interval + self.request_count = 0 + self.slow_query_threshold = slow_query_threshold + + async def dispatch(self, request: Request, call_next): + self.request_count += 1 + + # 每处理N个请求记录一次连接池状态 + if self.request_count % self.log_interval == 0: + DBConnectionMonitor.log_connection_stats() + + # 记录请求处理时间 + start_time = time.time() + response = await call_next(request) + process_time = time.time() - start_time + + # 记录请求统计 + method = request.method + path = request.url.path + DBConnectionMonitor.stats.record_request(method, path, process_time) + + # 如果请求处理时间超过阈值,记录为慢查询 + if process_time > self.slow_query_threshold: + logger.warning(f"Slow request: {method} {path} took {process_time:.2f}s") + DBConnectionMonitor.stats.record_slow_query(method, path, process_time) + + return response + + +def setup_db_monitor(app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0): + """设置数据库监控""" + app.add_middleware( + DBMonitorMiddleware, + log_interval=log_interval, + slow_query_threshold=slow_query_threshold + ) + + @app.on_event("startup") + async def startup_db_monitor(): + logger.info("Starting DB connection monitoring") + DBConnectionMonitor.log_connection_stats() + + @app.on_event("shutdown") + async def shutdown_db_monitor(): + logger.info("Final DB connection stats before shutdown") + DBConnectionMonitor.log_connection_stats() \ No newline at end of file diff --git a/app/main.py b/app/main.py index 93bc7ae..79d7b87 100644 --- a/app/main.py +++ b/app/main.py @@ -1,6 +1,6 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from app.api.endpoints import wechat,user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_order, point, config, merchant_category, log, account,merchant_pay_order, message, bank_card, withdraw, mp, point_product, point_product_order, coupon_activity, dashboard, wecom, feedback, timeperiod, community_timeperiod, order_additional_fee, ai, community_set, community_set_mapping, community_profit_sharing, partner +from app.api.endpoints import wechat,user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_order, point, config, merchant_category, log, account,merchant_pay_order, message, bank_card, withdraw, mp, point_product, point_product_order, coupon_activity, dashboard, wecom, feedback, timeperiod, community_timeperiod, order_additional_fee, ai, community_set, community_set_mapping, community_profit_sharing, partner, health from app.models.database import Base, engine from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse @@ -15,6 +15,7 @@ from app.api.endpoints import wecom from app.api.endpoints import feedback from starlette.middleware.sessions import SessionMiddleware import os +from app.core.db_monitor import setup_db_monitor # 创建数据库表 Base.metadata.create_all(bind=engine) @@ -26,6 +27,9 @@ app = FastAPI( docs_url="/docs" if settings.DEBUG else None ) +# 设置数据库连接监控 +setup_db_monitor(app) + app.default_response_class = CustomJSONResponse # 配置 CORS @@ -80,6 +84,7 @@ app.include_router(upload.router, prefix="/api/upload", tags=["文件上传"]) app.include_router(config.router, prefix="/api/config", tags=["系统配置"]) app.include_router(log.router, prefix="/api/logs", tags=["系统日志"]) app.include_router(feedback.router, prefix="/api/feedback", tags=["反馈"]) +app.include_router(health.router, prefix="/api/health", tags=["系统健康检查"]) @app.get("/") diff --git a/app/models/database.py b/app/models/database.py index b53ee26..eac6900 100644 --- a/app/models/database.py +++ b/app/models/database.py @@ -1,25 +1,120 @@ -from sqlalchemy import create_engine +from sqlalchemy import create_engine, event from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, scoped_session from app.core.config import settings import pymysql +import logging +import time +import threading +from contextlib import contextmanager + +# 设置日志记录器 +logger = logging.getLogger(__name__) # 注册 MySQL Python SQL Driver pymysql.install_as_MySQLdb() +# 创建数据库引擎 engine = create_engine( settings.SQLALCHEMY_DATABASE_URL, - pool_pre_ping=True, # 自动处理断开的连接 - pool_recycle=3600, # 连接回收时间 + pool_pre_ping=True, # 自动处理断开的连接 + pool_recycle=3600, # 连接回收时间(1小时) + pool_size=10, # 连接池大小 + max_overflow=20, # 允许的最大连接数超出pool_size的数量 + pool_timeout=30, # 获取连接的超时时间(秒) + echo=settings.DEBUG, # 在调试模式下打印SQL语句 ) -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +# 创建线程安全的会话工厂 +SessionLocal = scoped_session( + sessionmaker(autocommit=False, autoflush=False, bind=engine) +) + +# 声明基类 Base = declarative_base() +# 会话跟踪 +active_sessions = {} +session_lock = threading.Lock() + +# 添加事件监听器,记录长时间运行的查询 +@event.listens_for(engine, "before_cursor_execute") +def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): + """在执行SQL查询前记录时间""" + conn.info.setdefault('query_start_time', []).append(time.time()) + if settings.DEBUG: + # 安全地记录SQL语句,避免敏感信息泄露 + safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白 + logger.debug(f"开始执行查询: {safe_statement}...") + +@event.listens_for(engine, "after_cursor_execute") +def after_cursor_execute(conn, cursor, statement, parameters, context, executemany): + """在执行SQL查询后计算耗时并记录慢查询""" + total = time.time() - conn.info['query_start_time'].pop(-1) + + # 记录慢查询 + if total > 0.5: # 记录超过0.5秒的查询 + # 安全地记录SQL语句,避免敏感信息泄露 + safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白 + logger.warning(f"慢查询 ({total:.2f}s): {safe_statement}...") + # 依赖项 def get_db(): - db = SessionLocal() + """获取数据库会话""" + session_id = threading.get_ident() + session = SessionLocal() + + # 记录活跃会话 + with session_lock: + active_sessions[session_id] = { + "created_at": time.time(), + "thread_id": session_id + } + try: - yield db + yield session finally: - db.close() \ No newline at end of file + # 关闭会话并从活跃会话中移除 + session.close() + with session_lock: + if session_id in active_sessions: + del active_sessions[session_id] + +@contextmanager +def get_db_context(): + """上下文管理器版本的get_db,用于非依赖项场景""" + session = SessionLocal() + session_id = threading.get_ident() + + # 记录活跃会话 + with session_lock: + active_sessions[session_id] = { + "created_at": time.time(), + "thread_id": session_id + } + + try: + yield session + session.commit() + except Exception as e: + session.rollback() + raise e + finally: + session.close() + with session_lock: + if session_id in active_sessions: + del active_sessions[session_id] + +def get_active_sessions_count(): + """获取当前活跃会话数量""" + with session_lock: + return len(active_sessions) + +def get_long_running_sessions(threshold_seconds=60): + """获取长时间运行的会话""" + current_time = time.time() + with session_lock: + return { + thread_id: info for thread_id, info in active_sessions.items() + if current_time - info["created_at"] > threshold_seconds + } \ No newline at end of file