update
This commit is contained in:
parent
fbe860a014
commit
6c2333da4c
@ -70,14 +70,17 @@ async def get_communities(
|
|||||||
):
|
):
|
||||||
"""获取社区列表"""
|
"""获取社区列表"""
|
||||||
# 构建查询, 关联社区分润
|
# 构建查询, 关联社区分润
|
||||||
|
# 使用一次查询获取所有需要的数据,减少数据库连接使用时间
|
||||||
query = db.query(CommunityDB).options(joinedload(CommunityDB.community_profit_sharing))
|
query = db.query(CommunityDB).options(joinedload(CommunityDB.community_profit_sharing))
|
||||||
|
|
||||||
# 状态过滤
|
# 状态过滤
|
||||||
if status:
|
if status:
|
||||||
query = query.filter(CommunityDB.status == status)
|
query = query.filter(CommunityDB.status == status)
|
||||||
|
|
||||||
# 获取总数
|
# 获取总数(使用子查询优化计数操作)
|
||||||
total = query.count()
|
from sqlalchemy import func
|
||||||
|
count_query = query.statement.with_only_columns([func.count()]).order_by(None)
|
||||||
|
total = db.execute(count_query).scalar()
|
||||||
|
|
||||||
# 查询数据
|
# 查询数据
|
||||||
communities = query.offset(skip).limit(limit).all()
|
communities = query.offset(skip).limit(limit).all()
|
||||||
|
|||||||
63
app/api/endpoints/health.py
Normal file
63
app/api/endpoints/health.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
from fastapi import APIRouter, Depends, Query
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy import text
|
||||||
|
from app.models.database import get_db, get_active_sessions_count, get_long_running_sessions
|
||||||
|
from app.core.response import success_response, ResponseModel
|
||||||
|
from app.core.db_monitor import DBConnectionMonitor
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get("/health", response_model=ResponseModel)
|
||||||
|
async def health_check(db: Session = Depends(get_db)):
|
||||||
|
"""健康检查端点,检查API和数据库连接状态"""
|
||||||
|
# 尝试执行简单查询以验证数据库连接
|
||||||
|
try:
|
||||||
|
db.execute(text("SELECT 1")).scalar()
|
||||||
|
db_status = "healthy"
|
||||||
|
except Exception as e:
|
||||||
|
db_status = f"unhealthy: {str(e)}"
|
||||||
|
|
||||||
|
# 获取连接池状态和性能统计
|
||||||
|
all_stats = DBConnectionMonitor.get_all_stats()
|
||||||
|
|
||||||
|
# 获取活跃会话信息
|
||||||
|
active_sessions_count = get_active_sessions_count()
|
||||||
|
long_running_sessions = get_long_running_sessions(threshold_seconds=30)
|
||||||
|
|
||||||
|
return success_response(data={
|
||||||
|
"status": "ok",
|
||||||
|
"database": {
|
||||||
|
"status": db_status,
|
||||||
|
"connection_pool": all_stats["connection_pool"],
|
||||||
|
"active_sessions": active_sessions_count,
|
||||||
|
"long_running_sessions": len(long_running_sessions)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
@router.get("/stats", response_model=ResponseModel)
|
||||||
|
async def performance_stats(
|
||||||
|
include_slow_queries: bool = Query(False, description="是否包含慢查询记录"),
|
||||||
|
include_long_sessions: bool = Query(False, description="是否包含长时间运行的会话详情"),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""获取详细的性能统计信息"""
|
||||||
|
# 获取所有统计信息
|
||||||
|
all_stats = DBConnectionMonitor.get_all_stats()
|
||||||
|
|
||||||
|
# 如果不需要包含慢查询记录,则移除它们以减少响应大小
|
||||||
|
if not include_slow_queries and "slow_queries" in all_stats["performance_stats"]:
|
||||||
|
all_stats["performance_stats"]["slow_queries_count"] = len(all_stats["performance_stats"]["slow_queries"])
|
||||||
|
del all_stats["performance_stats"]["slow_queries"]
|
||||||
|
|
||||||
|
# 添加会话信息
|
||||||
|
all_stats["sessions"] = {
|
||||||
|
"active_count": get_active_sessions_count()
|
||||||
|
}
|
||||||
|
|
||||||
|
# 如果需要包含长时间运行的会话详情
|
||||||
|
if include_long_sessions:
|
||||||
|
all_stats["sessions"]["long_running"] = get_long_running_sessions(threshold_seconds=30)
|
||||||
|
else:
|
||||||
|
all_stats["sessions"]["long_running_count"] = len(get_long_running_sessions(threshold_seconds=30))
|
||||||
|
|
||||||
|
return success_response(data=all_stats)
|
||||||
@ -155,7 +155,7 @@ def calculate_delivery_share(order: ShippingOrderDB, db: Session) -> float:
|
|||||||
CommunityProfitSharing.community_id == order.address_community_id
|
CommunityProfitSharing.community_id == order.address_community_id
|
||||||
).first()
|
).first()
|
||||||
if sharing:
|
if sharing:
|
||||||
return round(order.original_amount_with_additional_fee * (sharing.delivery_rate / 100.0), 2)
|
return round(order.original_amount_with_additional_fee * (float(sharing.delivery_rate) / 100.0), 2)
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|||||||
149
app/core/db_monitor.py
Normal file
149
app/core/db_monitor.py
Normal file
@ -0,0 +1,149 @@
|
|||||||
|
from app.models.database import engine
|
||||||
|
import logging
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class DBStats:
|
||||||
|
"""数据库统计信息"""
|
||||||
|
def __init__(self):
|
||||||
|
self.slow_queries = deque(maxlen=100) # 最多保存100条慢查询记录
|
||||||
|
self.endpoint_stats = defaultdict(lambda: {"count": 0, "total_time": 0, "max_time": 0})
|
||||||
|
self.hourly_requests = defaultdict(int)
|
||||||
|
self.lock = threading.Lock()
|
||||||
|
|
||||||
|
def record_slow_query(self, method, path, duration):
|
||||||
|
"""记录慢查询"""
|
||||||
|
with self.lock:
|
||||||
|
self.slow_queries.append({
|
||||||
|
"method": method,
|
||||||
|
"path": path,
|
||||||
|
"duration": duration,
|
||||||
|
"timestamp": datetime.datetime.now().isoformat()
|
||||||
|
})
|
||||||
|
|
||||||
|
def record_request(self, method, path, duration):
|
||||||
|
"""记录请求统计"""
|
||||||
|
with self.lock:
|
||||||
|
key = f"{method} {path}"
|
||||||
|
self.endpoint_stats[key]["count"] += 1
|
||||||
|
self.endpoint_stats[key]["total_time"] += duration
|
||||||
|
self.endpoint_stats[key]["max_time"] = max(self.endpoint_stats[key]["max_time"], duration)
|
||||||
|
|
||||||
|
# 记录每小时请求数
|
||||||
|
hour = datetime.datetime.now().strftime("%Y-%m-%d %H:00")
|
||||||
|
self.hourly_requests[hour] += 1
|
||||||
|
|
||||||
|
def get_stats(self):
|
||||||
|
"""获取统计信息"""
|
||||||
|
with self.lock:
|
||||||
|
# 计算平均响应时间
|
||||||
|
avg_times = {}
|
||||||
|
for endpoint, stats in self.endpoint_stats.items():
|
||||||
|
if stats["count"] > 0:
|
||||||
|
avg_times[endpoint] = stats["total_time"] / stats["count"]
|
||||||
|
|
||||||
|
# 按平均响应时间排序的前10个端点
|
||||||
|
sorted_endpoints = sorted(
|
||||||
|
avg_times.items(),
|
||||||
|
key=lambda x: x[1],
|
||||||
|
reverse=True
|
||||||
|
)[:10]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"slow_queries": list(self.slow_queries),
|
||||||
|
"top_slow_endpoints": sorted_endpoints,
|
||||||
|
"hourly_requests": dict(self.hourly_requests)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DBConnectionMonitor:
|
||||||
|
"""数据库连接监控工具"""
|
||||||
|
|
||||||
|
# 初始化统计对象
|
||||||
|
stats = DBStats()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_connection_stats():
|
||||||
|
"""获取当前连接池状态"""
|
||||||
|
pool = engine.pool
|
||||||
|
return {
|
||||||
|
"pool_size": pool.size(),
|
||||||
|
"checkedin": pool.checkedin(),
|
||||||
|
"checkedout": pool.checkedout(),
|
||||||
|
"overflow": pool.overflow(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def log_connection_stats():
|
||||||
|
"""记录当前连接池状态"""
|
||||||
|
stats = DBConnectionMonitor.get_connection_stats()
|
||||||
|
logger.info(f"DB Connection Pool Stats: {stats}")
|
||||||
|
return stats
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_stats():
|
||||||
|
"""获取所有统计信息"""
|
||||||
|
return {
|
||||||
|
"connection_pool": DBConnectionMonitor.get_connection_stats(),
|
||||||
|
"performance_stats": DBConnectionMonitor.stats.get_stats()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DBMonitorMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""数据库连接监控中间件,定期记录连接池状态"""
|
||||||
|
|
||||||
|
def __init__(self, app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0):
|
||||||
|
super().__init__(app)
|
||||||
|
self.log_interval = log_interval
|
||||||
|
self.request_count = 0
|
||||||
|
self.slow_query_threshold = slow_query_threshold
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
self.request_count += 1
|
||||||
|
|
||||||
|
# 每处理N个请求记录一次连接池状态
|
||||||
|
if self.request_count % self.log_interval == 0:
|
||||||
|
DBConnectionMonitor.log_connection_stats()
|
||||||
|
|
||||||
|
# 记录请求处理时间
|
||||||
|
start_time = time.time()
|
||||||
|
response = await call_next(request)
|
||||||
|
process_time = time.time() - start_time
|
||||||
|
|
||||||
|
# 记录请求统计
|
||||||
|
method = request.method
|
||||||
|
path = request.url.path
|
||||||
|
DBConnectionMonitor.stats.record_request(method, path, process_time)
|
||||||
|
|
||||||
|
# 如果请求处理时间超过阈值,记录为慢查询
|
||||||
|
if process_time > self.slow_query_threshold:
|
||||||
|
logger.warning(f"Slow request: {method} {path} took {process_time:.2f}s")
|
||||||
|
DBConnectionMonitor.stats.record_slow_query(method, path, process_time)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def setup_db_monitor(app: FastAPI, log_interval: int = 100, slow_query_threshold: float = 1.0):
|
||||||
|
"""设置数据库监控"""
|
||||||
|
app.add_middleware(
|
||||||
|
DBMonitorMiddleware,
|
||||||
|
log_interval=log_interval,
|
||||||
|
slow_query_threshold=slow_query_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_db_monitor():
|
||||||
|
logger.info("Starting DB connection monitoring")
|
||||||
|
DBConnectionMonitor.log_connection_stats()
|
||||||
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown_db_monitor():
|
||||||
|
logger.info("Final DB connection stats before shutdown")
|
||||||
|
DBConnectionMonitor.log_connection_stats()
|
||||||
@ -1,6 +1,6 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from app.api.endpoints import wechat,user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_order, point, config, merchant_category, log, account,merchant_pay_order, message, bank_card, withdraw, mp, point_product, point_product_order, coupon_activity, dashboard, wecom, feedback, timeperiod, community_timeperiod, order_additional_fee, ai, community_set, community_set_mapping, community_profit_sharing, partner
|
from app.api.endpoints import wechat,user, address, community, station, order, coupon, community_building, upload, merchant, merchant_product, merchant_order, point, config, merchant_category, log, account,merchant_pay_order, message, bank_card, withdraw, mp, point_product, point_product_order, coupon_activity, dashboard, wecom, feedback, timeperiod, community_timeperiod, order_additional_fee, ai, community_set, community_set_mapping, community_profit_sharing, partner, health
|
||||||
from app.models.database import Base, engine
|
from app.models.database import Base, engine
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
@ -15,6 +15,7 @@ from app.api.endpoints import wecom
|
|||||||
from app.api.endpoints import feedback
|
from app.api.endpoints import feedback
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
import os
|
import os
|
||||||
|
from app.core.db_monitor import setup_db_monitor
|
||||||
|
|
||||||
# 创建数据库表
|
# 创建数据库表
|
||||||
Base.metadata.create_all(bind=engine)
|
Base.metadata.create_all(bind=engine)
|
||||||
@ -26,6 +27,9 @@ app = FastAPI(
|
|||||||
docs_url="/docs" if settings.DEBUG else None
|
docs_url="/docs" if settings.DEBUG else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 设置数据库连接监控
|
||||||
|
setup_db_monitor(app)
|
||||||
|
|
||||||
app.default_response_class = CustomJSONResponse
|
app.default_response_class = CustomJSONResponse
|
||||||
|
|
||||||
# 配置 CORS
|
# 配置 CORS
|
||||||
@ -80,6 +84,7 @@ app.include_router(upload.router, prefix="/api/upload", tags=["文件上传"])
|
|||||||
app.include_router(config.router, prefix="/api/config", tags=["系统配置"])
|
app.include_router(config.router, prefix="/api/config", tags=["系统配置"])
|
||||||
app.include_router(log.router, prefix="/api/logs", tags=["系统日志"])
|
app.include_router(log.router, prefix="/api/logs", tags=["系统日志"])
|
||||||
app.include_router(feedback.router, prefix="/api/feedback", tags=["反馈"])
|
app.include_router(feedback.router, prefix="/api/feedback", tags=["反馈"])
|
||||||
|
app.include_router(health.router, prefix="/api/health", tags=["系统健康检查"])
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
|
|||||||
@ -1,25 +1,120 @@
|
|||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine, event
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker, scoped_session
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
import pymysql
|
import pymysql
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import threading
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
# 设置日志记录器
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# 注册 MySQL Python SQL Driver
|
# 注册 MySQL Python SQL Driver
|
||||||
pymysql.install_as_MySQLdb()
|
pymysql.install_as_MySQLdb()
|
||||||
|
|
||||||
|
# 创建数据库引擎
|
||||||
engine = create_engine(
|
engine = create_engine(
|
||||||
settings.SQLALCHEMY_DATABASE_URL,
|
settings.SQLALCHEMY_DATABASE_URL,
|
||||||
pool_pre_ping=True, # 自动处理断开的连接
|
pool_pre_ping=True, # 自动处理断开的连接
|
||||||
pool_recycle=3600, # 连接回收时间
|
pool_recycle=3600, # 连接回收时间(1小时)
|
||||||
|
pool_size=10, # 连接池大小
|
||||||
|
max_overflow=20, # 允许的最大连接数超出pool_size的数量
|
||||||
|
pool_timeout=30, # 获取连接的超时时间(秒)
|
||||||
|
echo=settings.DEBUG, # 在调试模式下打印SQL语句
|
||||||
)
|
)
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
||||||
|
|
||||||
|
# 创建线程安全的会话工厂
|
||||||
|
SessionLocal = scoped_session(
|
||||||
|
sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 声明基类
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
# 会话跟踪
|
||||||
|
active_sessions = {}
|
||||||
|
session_lock = threading.Lock()
|
||||||
|
|
||||||
|
# 添加事件监听器,记录长时间运行的查询
|
||||||
|
@event.listens_for(engine, "before_cursor_execute")
|
||||||
|
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||||
|
"""在执行SQL查询前记录时间"""
|
||||||
|
conn.info.setdefault('query_start_time', []).append(time.time())
|
||||||
|
if settings.DEBUG:
|
||||||
|
# 安全地记录SQL语句,避免敏感信息泄露
|
||||||
|
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
|
||||||
|
logger.debug(f"开始执行查询: {safe_statement}...")
|
||||||
|
|
||||||
|
@event.listens_for(engine, "after_cursor_execute")
|
||||||
|
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
|
||||||
|
"""在执行SQL查询后计算耗时并记录慢查询"""
|
||||||
|
total = time.time() - conn.info['query_start_time'].pop(-1)
|
||||||
|
|
||||||
|
# 记录慢查询
|
||||||
|
if total > 0.5: # 记录超过0.5秒的查询
|
||||||
|
# 安全地记录SQL语句,避免敏感信息泄露
|
||||||
|
safe_statement = ' '.join(statement.split())[:200] # 截断并移除多余空白
|
||||||
|
logger.warning(f"慢查询 ({total:.2f}s): {safe_statement}...")
|
||||||
|
|
||||||
# 依赖项
|
# 依赖项
|
||||||
def get_db():
|
def get_db():
|
||||||
db = SessionLocal()
|
"""获取数据库会话"""
|
||||||
|
session_id = threading.get_ident()
|
||||||
|
session = SessionLocal()
|
||||||
|
|
||||||
|
# 记录活跃会话
|
||||||
|
with session_lock:
|
||||||
|
active_sessions[session_id] = {
|
||||||
|
"created_at": time.time(),
|
||||||
|
"thread_id": session_id
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield db
|
yield session
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
# 关闭会话并从活跃会话中移除
|
||||||
|
session.close()
|
||||||
|
with session_lock:
|
||||||
|
if session_id in active_sessions:
|
||||||
|
del active_sessions[session_id]
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def get_db_context():
|
||||||
|
"""上下文管理器版本的get_db,用于非依赖项场景"""
|
||||||
|
session = SessionLocal()
|
||||||
|
session_id = threading.get_ident()
|
||||||
|
|
||||||
|
# 记录活跃会话
|
||||||
|
with session_lock:
|
||||||
|
active_sessions[session_id] = {
|
||||||
|
"created_at": time.time(),
|
||||||
|
"thread_id": session_id
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
session.rollback()
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
with session_lock:
|
||||||
|
if session_id in active_sessions:
|
||||||
|
del active_sessions[session_id]
|
||||||
|
|
||||||
|
def get_active_sessions_count():
|
||||||
|
"""获取当前活跃会话数量"""
|
||||||
|
with session_lock:
|
||||||
|
return len(active_sessions)
|
||||||
|
|
||||||
|
def get_long_running_sessions(threshold_seconds=60):
|
||||||
|
"""获取长时间运行的会话"""
|
||||||
|
current_time = time.time()
|
||||||
|
with session_lock:
|
||||||
|
return {
|
||||||
|
thread_id: info for thread_id, info in active_sessions.items()
|
||||||
|
if current_time - info["created_at"] > threshold_seconds
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user