#!/usr/bin/env python # -*- coding: utf-8 -*- import os import logging from typing import Dict, Any, List, Optional, Union from datetime import datetime from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import QueuePool from cryptoai.utils.config_loader import ConfigLoader from cryptoai.models.base import Base from cryptoai.models.token import TokenManager from cryptoai.models.analysis_result import AnalysisResultManager from cryptoai.models.user import UserManager from cryptoai.models.user_question import UserQuestionManager from cryptoai.models.astock import AStockManager from cryptoai.models.analysis_history import AnalysisHistoryManager # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger('db_manager') class DBManager: """ 数据库管理工具,用于连接MySQL数据库并提供各个模型的管理器 使用方法: - 调用 get_db_manager() 获取数据库管理器实例 - 使用 db_manager.token_manager.xxx() 直接访问各个模型管理器的方法 - 使用 db_manager.get_session() 可以获取一个新的数据库会话 """ def __init__(self, host: str, port: int, user: str, password: str, db_name: str): """ 初始化数据库管理器 Args: host: 数据库主机地址 port: 数据库端口 user: 用户名 password: 密码 db_name: 数据库名 """ self.host = host self.port = port self.user = user self.password = password self.db_name = db_name self.engine = None self.Session = None # 初始化数据库连接 self._init_db() # 初始化各个管理器 self._init_managers() def _init_db(self) -> None: """初始化数据库连接和表""" try: # 创建数据库连接 connection_string = f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}?charset=utf8mb4" # 创建引擎,设置连接池 self.engine = create_engine( connection_string, echo=False, # 设置为True可以输出SQL语句(调试用) pool_size=5, # 连接池大小 max_overflow=10, # 最大溢出连接数 pool_timeout=30, # 连接超时时间 pool_recycle=1800, # 连接回收时间(秒) pool_pre_ping=True, # 在使用连接前先ping一下,确保连接有效 connect_args={'charset': 'utf8mb4'} ) # 创建会话工厂 self.Session = sessionmaker(bind=self.engine) # 创建表(如果不存在) Base.metadata.create_all(self.engine) logger.info(f"成功连接到数据库 {self.db_name}") except Exception as e: logger.error(f"数据库初始化失败: {e}") self.engine = None def _init_managers(self) -> None: """初始化各个模型的管理器""" if not self.engine: logger.error("引擎未初始化,无法创建管理器") return try: session = self.Session() # 初始化各个模型的管理器 self.token_manager = TokenManager(session) self.analysis_result_manager = AnalysisResultManager(session) self.user_manager = UserManager(session) self.user_question_manager = UserQuestionManager(session) self.astock_manager = AStockManager(session) self.analysis_history_manager = AnalysisHistoryManager(session) logger.info("成功初始化所有模型管理器") except Exception as e: logger.error(f"管理器初始化失败: {e}") if session: session.close() def get_session(self): """ 获取新的数据库会话 Returns: SQLAlchemy session对象,如果初始化失败则返回None """ if not self.Session: try: self._init_db() except Exception as e: logger.error(f"重新初始化数据库失败: {e}") return None return self.Session() def refresh_managers(self) -> bool: """ 刷新所有管理器,重新建立会话 Returns: 刷新是否成功 """ try: # 关闭旧会话(如果有) self.close() # 重新初始化数据库连接 self._init_db() # 重新初始化管理器 self._init_managers() return self.engine is not None except Exception as e: logger.error(f"刷新管理器失败: {e}") return False def close(self) -> None: """关闭数据库连接(如果存在)""" if self.engine: self.engine.dispose() logger.info("数据库连接已关闭") self.engine = None self.Session = None # 单例模式 _db_instance = None def get_db_manager(host: Optional[str] = None, port: Optional[int] = None, user: Optional[str] = None, password: Optional[str] = None, db_name: Optional[str] = None) -> DBManager: """ 获取数据库管理器实例(单例模式) Args: host: 数据库主机地址 port: 数据库端口 user: 用户名 password: 密码 db_name: 数据库名 Returns: 数据库管理器实例 使用示例: ```python # 获取数据库管理器 db_manager = get_db_manager() # 使用Token管理器 tokens = db_manager.token_manager.search_token("BTC") # 使用用户管理器 user = db_manager.user_manager.get_user_by_mail("example@test.com") ``` """ global _db_instance # 如果已经初始化过,直接返回 if _db_instance is not None: return _db_instance # 如果未指定参数,从配置加载器获取数据库配置 if host is None or port is None or user is None or password is None or db_name is None: config_loader = ConfigLoader() db_config = config_loader.get_database_config() # 使用配置中的值或默认值 db_host = host or db_config.get('host') db_port = port or db_config.get('port') db_user = user or db_config.get('user') db_password = password or db_config.get('password') db_name = db_name or db_config.get('db_name') logger.info(f"从配置加载数据库连接信息: {db_host}:{db_port}/{db_name}") else: db_host = host db_port = port db_user = user db_password = password db_name = db_name # 创建实例 _db_instance = DBManager( host=db_host, port=db_port, user=db_user, password=db_password, db_name=db_name ) return _db_instance