update
This commit is contained in:
parent
fbe860a014
commit
6c2333da4c
@ -70,14 +70,17 @@ async def get_communities(
|
||||
):
|
||||
"""获取社区列表"""
|
||||
# 构建查询, 关联社区分润
|
||||
# 使用一次查询获取所有需要的数据,减少数据库连接使用时间
|
||||
query = db.query(CommunityDB).options(joinedload(CommunityDB.community_profit_sharing))
|
||||
|
||||
# 状态过滤
|
||||
if status:
|
||||
query = query.filter(CommunityDB.status == status)
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
# 获取总数(使用子查询优化计数操作)
|
||||
from sqlalchemy import func
|
||||
count_query = query.statement.with_only_columns([func.count()]).order_by(None)
|
||||
total = db.execute(count_query).scalar()
|
||||
|
||||
# 查询数据
|
||||
communities = query.offset(skip).limit(limit).all()
|
||||
|
||||
63
app/api/endpoints/health.py
Normal file
63
app/api/endpoints/health.py
Normal file
@ -0,0 +1,63 @@
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
from app.models.database import get_db, get_active_sessions_count, get_long_running_sessions
|
||||
from app.core.response import success_response, ResponseModel
|
||||
from app.core.db_monitor import DBConnectionMonitor
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/health", response_model=ResponseModel)
|
||||
async def health_check(db: Session = Depends(get_db)):
|
||||
"""健康检查端点,检查API和数据库连接状态"""
|
||||
# 尝试执行简单查询以验证数据库连接
|
||||
try:
|
||||
db.execute(text("SELECT 1")).scalar()
|
||||
db_status = "healthy"
|
||||
except Exception as e:
|
||||
db_status = f"unhealthy: {str(e)}"
|
||||
|
||||
# 获取连接池状态和性能统计
|
||||
all_stats = DBConnectionMonitor.get_all_stats()
|
||||
|
||||
# 获取活跃会话信息
|
||||
active_sessions_count = get_active_sessions_count()
|
||||
long_running_sessions = get_long_running_sessions(threshold_seconds=30)
|
||||
|
||||
return success_response(data={
|
||||
"status": "ok",
|
||||
"database": {
|
||||
"status": db_status,
|
||||
"connection_pool": all_stats["connection_pool"],
|
||||
"active_sessions": active_sessions_count,
|
||||
"long_running_sessions": len(long_running_sessions)
|
||||
}
|
||||
})
|
||||
|
||||
@router.get("/stats", response_model=ResponseModel)
|
||||
async def performance_stats(
|
||||
include_slow_queries: bool = Query(False, description="是否包含慢查询记录"),
|
||||
include_long_sessions: bool = Query(False, description="是否包含长时间运行的会话详情"),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""获取详细的性能统计信息"""
|
||||
# 获取所有统计信息
|
||||
all_stats = DBConnectionMonitor.get_all_stats()
|
||||
|
||||
# 如果不需要包含慢查询记录,则移除它们以减少响应大小
|
||||
if not include_slow_queries and "slow_queries" in all_stats["performance_stats"]:
|
||||
all_stats["performance_stats"]["slow_queries_count"] = len(all_stats["performance_stats"]["slow_queries"])
|
||||
del all_stats["performance_stats"]["slow_queries"]
|
||||
|
||||
# 添加会话信息
|
||||
all_stats["sessions"] = {
|
||||
"active_count": get_active_sessions_count()
|
||||
}
|
||||
|
||||
# 如果需要包含长时间运行的会话详情
|
||||
if include_long_sessions:
|
||||
all_stats["sessions"]["long_running"] = get_long_running_sessions(threshold_seconds=30)
|
||||
else:
|
||||
all_stats["sessions"]["long_running_count"] = len(get_long_running_sessions(threshold_seconds=30))
|
||||
|
||||
return success_response(data=all_stats)
|
||||
@ -155,7 +155,7 @@ def calculate_delivery_share(order: ShippingOrderDB, db: Session) -> float:
|
||||
CommunityProfitSharing.community_id == order.address_community_id
|
||||
).first()
|
||||
if sharing:
|
||||
return round(order.original_amount_with_additional_fee * (sharing.delivery_rate / 100.0), 2)
|
||||
return round(order.original_amount_with_additional_fee * (float(sharing.delivery_rate) / 100.0), 2)
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
149
app/core/db_monitor.py
Normal file
149
app/core/db_monitor.py
Normal file
@ -0,0 +1,149 @@
|
||||
from app.models.database import engine
|
||||
import logging
|
||||
from fastapi import FastAPI
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
import time
|
||||
import threading
|
||||
from collections import defaultdict, deque
|
||||
import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DBStats:
|
||||
"""数据库统计信息"""
|
||||
def __init__(self):
|
||||
self.slow_queries = deque(maxlen=100) # 最多保存100条慢查询记录
|
||||
self.endpoint_stats = defaultdict(lambda: {"count": 0, "total_time": 0, "max_time": 0})
|
||||
self.hourly_requests = defaultdict(int)
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def record_slow_query(self, method, path, duration):
|
||||
"""记录慢查询"""
|
||||
with self.lock:
|
||||
self.slow_queries.append({
|
||||
"method": method,
|
||||
"path": path,
|
||||
"duration": duration,
|
||||
"timestamp": datetime.datetime.now().isoformat()
|
||||
})
|
||||
|
||||
def record_request(self, method, path, duration):
|
||||
"""记录请求统计"""
|
||||
with self.lock:
|
||||
key = f"{method} {path}"
|
||||
self.endpoint_stats[key]["count"] += 1
|
||||
self.endpoint_stats[key]["total_time"] += duration
|
||||
self.endpoint_stats[key]["max_time"] = max(self.endpoint_stats[key]["max_time"], duration)
|
||||
|
||||
# 记录每小时请求数
|
||||
hour = datetime.datetime.now().strftime("%Y-%m-%d %H:00")
|
||||
self.hourly_requests[hour] += 1
|
||||
|
||||
def get_stats(self):
|
||||
"""获取统计信息"""
|
||||
with self.lock:
|
||||
# 计算平均响应时间
|
||||
avg_times = {}
|
||||
for endpoint, stats in self.endpoint_stats.items():
|
||||
if stats["count"] > 0:
|
||||
avg_times[endpoint] = stats["total_time"] / stats["count"]
|
||||
|
||||
# 按平均响应时间排序的前10个端点
|
||||
sorted_endpoints = sorted(
|
||||
avg_times.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)[:10]
|
||||
|
||||
return {
|
||||
"slow_queries": list(self.slow_queries),
|
||||
"top_slow_endpoints": sorted_endpoints,
|
||||
"hourly_requests": dict(self.hourly_requests)
|
||||
}
|
||||
|
||||
|
||||
class DBConnectionMonitor:
|
||||
"""数据库连接监控工具"""
|
||||
|
||||
# 初始化统计对象
|
||||
stats = DBStats()
|
||||
|
||||
@staticmethod
|
||||
def get_connection_stats():
|
||||
"""获取当前连接池状态"""
|
||||
pool = engine.pool
|
||||
return {
|
||||
"pool_size": pool.size(),
|
||||
"checkedin": pool.checkedin(),
|
||||
"checkedout": pool.checkedout(),
|
||||
"overflow": pool.overflow(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def log_connection_stats():
|
||||
"""记录当前连接池状态"""
|
||||
stats = DBConnectionMonitor.get_connection_stats()
|
||||
logger.info(f"DB Connection Pool Stats: {stats}")
|
||||
return stats
|
||||
|
||||
@staticmethod
|
||||
def get_all_stats():
|
||||
"""获取所有统计信息"""
|
||||
return {
|
||||
"connection_pool": DBConnectionMonitor.get_connection_stats(),
|
||||
"performance_stats": DBConnectionMonitor.stats.get_stats()
|
||||
}
|
||||
|
||||
|
||||
class DBMonitorMiddleware(BaseHTTPMiddleware):
|
||||
"""数据库连接监控中间件,定期记录连接池状态"""
|
||||
|
||||
def __init__(self, app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0):
|
||||
super().__init__(app)
|
||||
self.log_interval = log_interval
|
||||
self.request_count = 0
|
||||
self.slow_query_threshold = slow_query_threshold
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
self.request_count += 1
|
||||
|
||||
# 每处理N个请求记录一次连接池状态
|
||||
if self.request_count % self.log_interval == 0:
|
||||
DBConnectionMonitor.log_connection_stats()
|
||||
|
||||
# 记录请求处理时间
|
||||
start_time = time.time()
|
||||
response = await call_next(request)
|
||||
process_time = time.time() - start_time
|
||||
|
||||
# 记录请求统计
|
||||
method = request.method
|
||||
path = request.url.path
|
||||
DBConnectionMonitor.stats.record_request(method, path, process_time)
|
||||
|
||||
# 如果请求处理时间超过阈值,记录为慢查询
|
||||
if process_time > self.slow_query_threshold:
|
||||
logger.warning(f"Slow request: {method} {path} took {process_time:.2f}s")
|
||||
DBConnectionMonitor.stats.record_slow_query(method, path, process_time)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def setup_db_monitor(app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0):
|
||||
"""设置数据库监控"""
|
||||
app.add_middleware(
|
||||
DBMonitorMiddleware,
|
||||
log_interval=log_interval,
|
||||
slow_query_threshold=slow_query_threshold
|
||||
)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_db_monitor():
|
||||
logger.info("Starting DB connection monitoring")
|
||||
DBConnectionMonitor.log_connection_stats()
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_db_monitor():
|
||||
logger.info("Final DB connection stats before shutdown")
|
||||
DBConnectionMonitor.log_connection_stats()
|
||||
@ -1,6 +1,6 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.api.endpoints import wechat,user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_order, point, config, merchant_category, log, account,merchant_pay_order, message, bank_card, withdraw, mp, point_product, point_product_order, coupon_activity, dashboard, wecom, feedback, timeperiod, community_timeperiod, order_additional_fee, ai, community_set, community_set_mapping, community_profit_sharing, partner
|
||||
from app.api.endpoints import wechat,user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_order, point, config, merchant_category, log, account,merchant_pay_order, message, bank_card, withdraw, mp, point_product, point_product_order, coupon_activity, dashboard, wecom, feedback, timeperiod, community_timeperiod, order_additional_fee, ai, community_set, community_set_mapping, community_profit_sharing, partner, health
|
||||
from app.models.database import Base, engine
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
@ -15,6 +15,7 @@ from app.api.endpoints import wecom
|
||||
from app.api.endpoints import feedback
|
||||
from starlette.middleware.sessions import SessionMiddleware
|
||||
import os
|
||||
from app.core.db_monitor import setup_db_monitor
|
||||
|
||||
# 创建数据库表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@ -26,6 +27,9 @@ app = FastAPI(
|
||||
docs_url="/docs" if settings.DEBUG else None
|
||||
)
|
||||
|
||||
# 设置数据库连接监控
|
||||
setup_db_monitor(app)
|
||||
|
||||
app.default_response_class = CustomJSONResponse
|
||||
|
||||
# 配置 CORS
|
||||
@ -80,6 +84,7 @@ app.include_router(upload.router, prefix="/api/upload", tags=["文件上传"])
|
||||
app.include_router(config.router, prefix="/api/config", tags=["系统配置"])
|
||||
app.include_router(log.router, prefix="/api/logs", tags=["系统日志"])
|
||||
app.include_router(feedback.router, prefix="/api/feedback", tags=["反馈"])
|
||||
app.include_router(health.router, prefix="/api/health", tags=["系统健康检查"])
|
||||
|
||||
|
||||
@app.get("/")
|
||||
|
||||
@ -1,25 +1,120 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||
from app.core.config import settings
|
||||
import pymysql
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
# 设置日志记录器
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 注册 MySQL Python SQL Driver
|
||||
pymysql.install_as_MySQLdb()
|
||||
|
||||
# 创建数据库引擎
|
||||
engine = create_engine(
|
||||
settings.SQLALCHEMY_DATABASE_URL,
|
||||
pool_pre_ping=True, # 自动处理断开的连接
|
||||
pool_recycle=3600, # 连接回收时间
|
||||
pool_pre_ping=True, # 自动处理断开的连接
|
||||
pool_recycle=3600, # 连接回收时间(1小时)
|
||||
pool_size=10, # 连接池大小
|
||||
max_overflow=20, # 允许的最大连接数超出pool_size的数量
|
||||
pool_timeout=30, # 获取连接的超时时间(秒)
|
||||
echo=settings.DEBUG, # 在调试模式下打印SQL语句
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# 创建线程安全的会话工厂
|
||||
SessionLocal = scoped_session(
|
||||
sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
)
|
||||
|
||||
# 声明基类
|
||||
Base = declarative_base()
|
||||
|
||||
# 会话跟踪
|
||||
active_sessions = {}
|
||||
session_lock = threading.Lock()
|
||||
|
||||
# 添加事件监听器,记录长时间运行的查询
|
||||
@event.listens_for(engine, "before_cursor_execute")
|
||||
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
"""在执行SQL查询前记录时间"""
|
||||
conn.info.setdefault('query_start_time', []).append(time.time())
|
||||
if settings.DEBUG:
|
||||
# 安全地记录SQL语句,避免敏感信息泄露
|
||||
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
|
||||
logger.debug(f"开始执行查询: {safe_statement}...")
|
||||
|
||||
@event.listens_for(engine, "after_cursor_execute")
|
||||
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||
"""在执行SQL查询后计算耗时并记录慢查询"""
|
||||
total = time.time() - conn.info['query_start_time'].pop(-1)
|
||||
|
||||
# 记录慢查询
|
||||
if total > 0.5: # 记录超过0.5秒的查询
|
||||
# 安全地记录SQL语句,避免敏感信息泄露
|
||||
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
|
||||
logger.warning(f"慢查询 ({total:.2f}s): {safe_statement}...")
|
||||
|
||||
# 依赖项
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
"""获取数据库会话"""
|
||||
session_id = threading.get_ident()
|
||||
session = SessionLocal()
|
||||
|
||||
# 记录活跃会话
|
||||
with session_lock:
|
||||
active_sessions[session_id] = {
|
||||
"created_at": time.time(),
|
||||
"thread_id": session_id
|
||||
}
|
||||
|
||||
try:
|
||||
yield db
|
||||
yield session
|
||||
finally:
|
||||
db.close()
|
||||
# 关闭会话并从活跃会话中移除
|
||||
session.close()
|
||||
with session_lock:
|
||||
if session_id in active_sessions:
|
||||
del active_sessions[session_id]
|
||||
|
||||
@contextmanager
|
||||
def get_db_context():
|
||||
"""上下文管理器版本的get_db,用于非依赖项场景"""
|
||||
session = SessionLocal()
|
||||
session_id = threading.get_ident()
|
||||
|
||||
# 记录活跃会话
|
||||
with session_lock:
|
||||
active_sessions[session_id] = {
|
||||
"created_at": time.time(),
|
||||
"thread_id": session_id
|
||||
}
|
||||
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
raise e
|
||||
finally:
|
||||
session.close()
|
||||
with session_lock:
|
||||
if session_id in active_sessions:
|
||||
del active_sessions[session_id]
|
||||
|
||||
def get_active_sessions_count():
|
||||
"""获取当前活跃会话数量"""
|
||||
with session_lock:
|
||||
return len(active_sessions)
|
||||
|
||||
def get_long_running_sessions(threshold_seconds=60):
|
||||
"""获取长时间运行的会话"""
|
||||
current_time = time.time()
|
||||
with session_lock:
|
||||
return {
|
||||
thread_id: info for thread_id, info in active_sessions.items()
|
||||
if current_time - info["created_at"] > threshold_seconds
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user