236 lines
7.6 KiB
Python
236 lines
7.6 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import os
|
||
import logging
|
||
from typing import Dict, Any, List, Optional, Union
|
||
from datetime import datetime
|
||
|
||
from sqlalchemy import create_engine
|
||
from sqlalchemy.orm import sessionmaker
|
||
from sqlalchemy.pool import QueuePool
|
||
|
||
from cryptoai.utils.config_loader import ConfigLoader
|
||
from cryptoai.models.base import Base
|
||
from cryptoai.models.token import TokenManager
|
||
from cryptoai.models.analysis_result import AnalysisResultManager
|
||
from cryptoai.models.agent_feed import AgentFeedManager
|
||
from cryptoai.models.user import UserManager
|
||
from cryptoai.models.user_question import UserQuestionManager
|
||
from cryptoai.models.agent import AgentManager
|
||
from cryptoai.models.astock import AStockManager
|
||
from cryptoai.models.analysis_history import AnalysisHistoryManager
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger('db_manager')
|
||
|
||
class DBManager:
|
||
"""
|
||
数据库管理工具,用于连接MySQL数据库并提供各个模型的管理器
|
||
|
||
使用方法:
|
||
- 调用 get_db_manager() 获取数据库管理器实例
|
||
- 使用 db_manager.token_manager.xxx() 直接访问各个模型管理器的方法
|
||
- 使用 db_manager.get_session() 可以获取一个新的数据库会话
|
||
"""
|
||
|
||
def __init__(self, host: str, port: int, user: str, password: str, db_name: str):
|
||
"""
|
||
初始化数据库管理器
|
||
|
||
Args:
|
||
host: 数据库主机地址
|
||
port: 数据库端口
|
||
user: 用户名
|
||
password: 密码
|
||
db_name: 数据库名
|
||
"""
|
||
self.host = host
|
||
self.port = port
|
||
self.user = user
|
||
self.password = password
|
||
self.db_name = db_name
|
||
self.engine = None
|
||
self.Session = None
|
||
|
||
# 初始化数据库连接
|
||
self._init_db()
|
||
|
||
# 初始化各个管理器
|
||
self._init_managers()
|
||
|
||
def _init_db(self) -> None:
|
||
"""初始化数据库连接和表"""
|
||
try:
|
||
# 创建数据库连接
|
||
connection_string = f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}?charset=utf8mb4"
|
||
|
||
# 创建引擎,设置连接池
|
||
self.engine = create_engine(
|
||
connection_string,
|
||
echo=False, # 设置为True可以输出SQL语句(调试用)
|
||
pool_size=5, # 连接池大小
|
||
max_overflow=10, # 最大溢出连接数
|
||
pool_timeout=30, # 连接超时时间
|
||
pool_recycle=1800, # 连接回收时间(秒)
|
||
pool_pre_ping=True, # 在使用连接前先ping一下,确保连接有效
|
||
connect_args={'charset': 'utf8mb4'}
|
||
)
|
||
|
||
# 创建会话工厂
|
||
self.Session = sessionmaker(bind=self.engine)
|
||
|
||
# 创建表(如果不存在)
|
||
Base.metadata.create_all(self.engine)
|
||
|
||
logger.info(f"成功连接到数据库 {self.db_name}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"数据库初始化失败: {e}")
|
||
self.engine = None
|
||
|
||
def _init_managers(self) -> None:
|
||
"""初始化各个模型的管理器"""
|
||
if not self.engine:
|
||
logger.error("引擎未初始化,无法创建管理器")
|
||
return
|
||
|
||
try:
|
||
session = self.Session()
|
||
|
||
# 初始化各个模型的管理器
|
||
self.token_manager = TokenManager(session)
|
||
self.analysis_result_manager = AnalysisResultManager(session)
|
||
self.agent_feed_manager = AgentFeedManager(session)
|
||
self.user_manager = UserManager(session)
|
||
self.user_question_manager = UserQuestionManager(session)
|
||
self.agent_manager = AgentManager(session)
|
||
self.astock_manager = AStockManager(session)
|
||
self.analysis_history_manager = AnalysisHistoryManager(session)
|
||
|
||
logger.info("成功初始化所有模型管理器")
|
||
|
||
except Exception as e:
|
||
logger.error(f"管理器初始化失败: {e}")
|
||
if session:
|
||
session.close()
|
||
|
||
def get_session(self):
|
||
"""
|
||
获取新的数据库会话
|
||
|
||
Returns:
|
||
SQLAlchemy session对象,如果初始化失败则返回None
|
||
"""
|
||
if not self.Session:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新初始化数据库失败: {e}")
|
||
return None
|
||
|
||
return self.Session()
|
||
|
||
def refresh_managers(self) -> bool:
|
||
"""
|
||
刷新所有管理器,重新建立会话
|
||
|
||
Returns:
|
||
刷新是否成功
|
||
"""
|
||
try:
|
||
# 关闭旧会话(如果有)
|
||
self.close()
|
||
|
||
# 重新初始化数据库连接
|
||
self._init_db()
|
||
|
||
# 重新初始化管理器
|
||
self._init_managers()
|
||
|
||
return self.engine is not None
|
||
except Exception as e:
|
||
logger.error(f"刷新管理器失败: {e}")
|
||
return False
|
||
|
||
def close(self) -> None:
|
||
"""关闭数据库连接(如果存在)"""
|
||
if self.engine:
|
||
self.engine.dispose()
|
||
logger.info("数据库连接已关闭")
|
||
self.engine = None
|
||
self.Session = None
|
||
|
||
# 单例模式
|
||
_db_instance = None
|
||
|
||
def get_db_manager(host: Optional[str] = None,
|
||
port: Optional[int] = None,
|
||
user: Optional[str] = None,
|
||
password: Optional[str] = None,
|
||
db_name: Optional[str] = None) -> DBManager:
|
||
"""
|
||
获取数据库管理器实例(单例模式)
|
||
|
||
Args:
|
||
host: 数据库主机地址
|
||
port: 数据库端口
|
||
user: 用户名
|
||
password: 密码
|
||
db_name: 数据库名
|
||
|
||
Returns:
|
||
数据库管理器实例
|
||
|
||
使用示例:
|
||
```python
|
||
# 获取数据库管理器
|
||
db_manager = get_db_manager()
|
||
|
||
# 使用Token管理器
|
||
tokens = db_manager.token_manager.search_token("BTC")
|
||
|
||
# 使用用户管理器
|
||
user = db_manager.user_manager.get_user_by_mail("example@test.com")
|
||
```
|
||
"""
|
||
global _db_instance
|
||
|
||
# 如果已经初始化过,直接返回
|
||
if _db_instance is not None:
|
||
return _db_instance
|
||
|
||
# 如果未指定参数,从配置加载器获取数据库配置
|
||
if host is None or port is None or user is None or password is None or db_name is None:
|
||
config_loader = ConfigLoader()
|
||
db_config = config_loader.get_database_config()
|
||
|
||
# 使用配置中的值或默认值
|
||
db_host = host or db_config.get('host')
|
||
db_port = port or db_config.get('port')
|
||
db_user = user or db_config.get('user')
|
||
db_password = password or db_config.get('password')
|
||
db_name = db_name or db_config.get('db_name')
|
||
|
||
logger.info(f"从配置加载数据库连接信息: {db_host}:{db_port}/{db_name}")
|
||
else:
|
||
db_host = host
|
||
db_port = port
|
||
db_user = user
|
||
db_password = password
|
||
db_name = db_name
|
||
|
||
# 创建实例
|
||
_db_instance = DBManager(
|
||
host=db_host,
|
||
port=db_port,
|
||
user=db_user,
|
||
password=db_password,
|
||
db_name=db_name
|
||
)
|
||
|
||
return _db_instance |