crypto.ai/cryptoai/utils/db_manager.py
2025-05-25 10:23:59 +08:00

232 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import logging
from typing import Dict, Any, List, Optional, Union
from datetime import datetime
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool
from cryptoai.utils.config_loader import ConfigLoader
from cryptoai.models.base import Base
from cryptoai.models.token import TokenManager
from cryptoai.models.analysis_result import AnalysisResultManager
from cryptoai.models.user import UserManager
from cryptoai.models.user_question import UserQuestionManager
from cryptoai.models.astock import AStockManager
from cryptoai.models.analysis_history import AnalysisHistoryManager
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('db_manager')
class DBManager:
"""
数据库管理工具用于连接MySQL数据库并提供各个模型的管理器
使用方法:
- 调用 get_db_manager() 获取数据库管理器实例
- 使用 db_manager.token_manager.xxx() 直接访问各个模型管理器的方法
- 使用 db_manager.get_session() 可以获取一个新的数据库会话
"""
def __init__(self, host: str, port: int, user: str, password: str, db_name: str):
"""
初始化数据库管理器
Args:
host: 数据库主机地址
port: 数据库端口
user: 用户名
password: 密码
db_name: 数据库名
"""
self.host = host
self.port = port
self.user = user
self.password = password
self.db_name = db_name
self.engine = None
self.Session = None
# 初始化数据库连接
self._init_db()
# 初始化各个管理器
self._init_managers()
def _init_db(self) -> None:
"""初始化数据库连接和表"""
try:
# 创建数据库连接
connection_string = f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}?charset=utf8mb4"
# 创建引擎,设置连接池
self.engine = create_engine(
connection_string,
echo=False, # 设置为True可以输出SQL语句调试用
pool_size=5, # 连接池大小
max_overflow=10, # 最大溢出连接数
pool_timeout=30, # 连接超时时间
pool_recycle=1800, # 连接回收时间(秒)
pool_pre_ping=True, # 在使用连接前先ping一下确保连接有效
connect_args={'charset': 'utf8mb4'}
)
# 创建会话工厂
self.Session = sessionmaker(bind=self.engine)
# 创建表(如果不存在)
Base.metadata.create_all(self.engine)
logger.info(f"成功连接到数据库 {self.db_name}")
except Exception as e:
logger.error(f"数据库初始化失败: {e}")
self.engine = None
def _init_managers(self) -> None:
"""初始化各个模型的管理器"""
if not self.engine:
logger.error("引擎未初始化,无法创建管理器")
return
try:
session = self.Session()
# 初始化各个模型的管理器
self.token_manager = TokenManager(session)
self.analysis_result_manager = AnalysisResultManager(session)
self.user_manager = UserManager(session)
self.user_question_manager = UserQuestionManager(session)
self.astock_manager = AStockManager(session)
self.analysis_history_manager = AnalysisHistoryManager(session)
logger.info("成功初始化所有模型管理器")
except Exception as e:
logger.error(f"管理器初始化失败: {e}")
if session:
session.close()
def get_session(self):
"""
获取新的数据库会话
Returns:
SQLAlchemy session对象如果初始化失败则返回None
"""
if not self.Session:
try:
self._init_db()
except Exception as e:
logger.error(f"重新初始化数据库失败: {e}")
return None
return self.Session()
def refresh_managers(self) -> bool:
"""
刷新所有管理器,重新建立会话
Returns:
刷新是否成功
"""
try:
# 关闭旧会话(如果有)
self.close()
# 重新初始化数据库连接
self._init_db()
# 重新初始化管理器
self._init_managers()
return self.engine is not None
except Exception as e:
logger.error(f"刷新管理器失败: {e}")
return False
def close(self) -> None:
"""关闭数据库连接(如果存在)"""
if self.engine:
self.engine.dispose()
logger.info("数据库连接已关闭")
self.engine = None
self.Session = None
# 单例模式
_db_instance = None
def get_db_manager(host: Optional[str] = None,
port: Optional[int] = None,
user: Optional[str] = None,
password: Optional[str] = None,
db_name: Optional[str] = None) -> DBManager:
"""
获取数据库管理器实例(单例模式)
Args:
host: 数据库主机地址
port: 数据库端口
user: 用户名
password: 密码
db_name: 数据库名
Returns:
数据库管理器实例
使用示例:
```python
# 获取数据库管理器
db_manager = get_db_manager()
# 使用Token管理器
tokens = db_manager.token_manager.search_token("BTC")
# 使用用户管理器
user = db_manager.user_manager.get_user_by_mail("example@test.com")
```
"""
global _db_instance
# 如果已经初始化过,直接返回
if _db_instance is not None:
return _db_instance
# 如果未指定参数,从配置加载器获取数据库配置
if host is None or port is None or user is None or password is None or db_name is None:
config_loader = ConfigLoader()
db_config = config_loader.get_database_config()
# 使用配置中的值或默认值
db_host = host or db_config.get('host')
db_port = port or db_config.get('port')
db_user = user or db_config.get('user')
db_password = password or db_config.get('password')
db_name = db_name or db_config.get('db_name')
logger.info(f"从配置加载数据库连接信息: {db_host}:{db_port}/{db_name}")
else:
db_host = host
db_port = port
db_user = user
db_password = password
db_name = db_name
# 创建实例
_db_instance = DBManager(
host=db_host,
port=db_port,
user=db_user,
password=db_password,
db_name=db_name
)
return _db_instance