crypto.ai/cryptoai/utils/db_manager.py
2025-05-09 23:37:28 +08:00

1160 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import json
import logging
from typing import Dict, Any, List, Optional, Union
from datetime import datetime
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Index, text, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.dialects.mysql import JSON
from sqlalchemy.pool import QueuePool
from cryptoai.utils.config_loader import ConfigLoader
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('db_manager')
# 创建模型基类
Base = declarative_base()
# 定义分析结果模型
class AnalysisResult(Base):
"""分析结果表模型"""
__tablename__ = 'analysis_results'
id = Column(Integer, primary_key=True, autoincrement=True)
agent = Column(String(50), nullable=False, comment='智能体类型(crypto, gold)')
symbol = Column(String(50), nullable=False, comment='交易对符号')
time_interval = Column(String(20), nullable=False, comment='时间间隔')
completion_result = Column(JSON, nullable=False, comment='分析结果JSON')
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
updated_at = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
# 索引
__table_args__ = (
Index('idx_agent', 'agent'),
Index('idx_symbol', 'symbol'),
Index('idx_time_interval', 'time_interval'),
Index('idx_created_at', 'created_at'),
)
# 定义AI Agent信息流模型
class AgentFeed(Base):
"""AI Agent信息流表模型"""
__tablename__ = 'agent_feeds'
id = Column(Integer, primary_key=True, autoincrement=True)
agent_name = Column(String(50), nullable=False, comment='AI Agent名称')
avatar_url = Column(String(255), nullable=True, comment='头像URL')
content = Column(Text, nullable=False, comment='内容')
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 索引和表属性
__table_args__ = (
Index('idx_agent_name', 'agent_name'),
Index('idx_create_time', 'create_time'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
# 定义用户数据模型
class User(Base):
"""用户数据表模型"""
__tablename__ = 'users'
id = Column(Integer, primary_key=True, autoincrement=True)
mail = Column(String(100), nullable=False, unique=True, comment='邮箱')
nickname = Column(String(50), nullable=False, comment='昵称')
password = Column(String(100), nullable=False, comment='密码')
level = Column(Integer, nullable=False, default=0, comment='用户级别(0=普通用户,1=VIP,2=SVIP)')
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 关系
questions = relationship("UserQuestion", back_populates="user")
# 索引和表属性
__table_args__ = (
Index('idx_mail', 'mail'),
Index('idx_level', 'level'),
Index('idx_create_time', 'create_time'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
# 定义用户提问数据模型
class UserQuestion(Base):
"""用户提问数据表模型"""
__tablename__ = 'user_questions'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, comment='用户ID')
agent_id = Column(String(50), nullable=False, comment='AI Agent ID')
question = Column(Text, nullable=False, comment='提问内容')
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 关系
user = relationship("User", back_populates="questions")
# 索引和表属性
__table_args__ = (
Index('idx_user_id', 'user_id'),
Index('idx_agent_id', 'agent_id'),
Index('idx_create_time', 'create_time'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
# 定义Agent数据模型
class Agent(Base):
"""Agent数据表模型"""
__tablename__ = 'agents'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(100), nullable=False, unique=True, comment='Agent名称')
hello_prompt = Column(Text, nullable=True, comment='欢迎提示语')
description = Column(Text, nullable=True, comment='Agent描述')
dify_token = Column(String(255), nullable=True, comment='Dify API令牌')
inputs = Column(JSON, nullable=True, comment='输入参数JSON')
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
update_time = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
# 索引和表属性
__table_args__ = (
Index('idx_name', 'name'),
Index('idx_create_time', 'create_time'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
class DBManager:
"""数据库管理工具用于连接MySQL数据库并保存智能体分析结果"""
def __init__(self, host: str, port: int, user: str, password: str, db_name: str):
"""
初始化数据库管理器
Args:
host: 数据库主机地址
port: 数据库端口
user: 用户名
password: 密码
db_name: 数据库名
"""
self.host = host
self.port = port
self.user = user
self.password = password
self.db_name = db_name
self.engine = None
self.Session = None
# 初始化数据库连接
self._init_db()
def _init_db(self) -> None:
"""初始化数据库连接和表"""
try:
# 创建数据库连接
connection_string = f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}?charset=utf8mb4"
# 创建引擎,设置连接池
self.engine = create_engine(
connection_string,
echo=False, # 设置为True可以输出SQL语句调试用
pool_size=5, # 连接池大小
max_overflow=10, # 最大溢出连接数
pool_timeout=30, # 连接超时时间
pool_recycle=1800, # 连接回收时间(秒)
pool_pre_ping=True, # 在使用连接前先ping一下确保连接有效
connect_args={'charset': 'utf8mb4'}
)
# 创建会话工厂
self.Session = sessionmaker(bind=self.engine)
# 创建表(如果不存在)
Base.metadata.create_all(self.engine)
logger.info(f"成功连接到数据库 {self.db_name}")
except Exception as e:
logger.error(f"数据库初始化失败: {e}")
self.engine = None
def save_analysis_result(self, agent: str, symbol: str, time_interval: str,
analysis_result: Dict[str, Any]) -> bool:
"""
保存分析结果到数据库
Args:
agent: 智能体类型,例如 'crypto''gold'
symbol: 交易对符号,例如 'BTCUSDT'
time_interval: 时间间隔,例如 '1h', '4h', '1d'
analysis_result: 分析结果数据
Returns:
保存是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 创建新记录
new_result = AnalysisResult(
agent=agent,
symbol=symbol,
time_interval=time_interval,
completion_result=analysis_result,
created_at=datetime.now(),
updated_at=datetime.now()
)
# 添加并提交
session.add(new_result)
session.commit()
logger.info(f"成功保存 {agent} 分析结果,交易对: {symbol}, 时间间隔: {time_interval}")
return True
except Exception as e:
session.rollback()
logger.error(f"保存分析结果失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
# 如果是连接错误,尝试重新初始化
try:
self._init_db()
except:
pass
return False
def save_agent_feed(self, agent_name: str, content: str, avatar_url: Optional[str] = None) -> bool:
"""
保存AI Agent信息流到数据库
Args:
agent_name: AI Agent名称
content: 内容
avatar_url: 头像URL可选
Returns:
保存是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 创建新记录
new_feed = AgentFeed(
agent_name=agent_name,
content=content,
avatar_url=avatar_url
)
# 添加并提交
session.add(new_feed)
session.commit()
logger.info(f"成功保存 {agent_name} 的信息流")
return True
except Exception as e:
session.rollback()
logger.error(f"保存信息流失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
# 如果是连接错误,尝试重新初始化
try:
self._init_db()
except:
pass
return False
def register_user(self, mail: str, nickname: str, password: str, level: int = 0) -> bool:
"""
注册新用户
Args:
mail: 邮箱
nickname: 昵称
password: 密码
level: 用户级别默认为0普通用户
Returns:
注册是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 检查邮箱是否已存在
existing_user = session.query(User).filter(User.mail == mail).first()
if existing_user:
logger.warning(f"邮箱 {mail} 已被注册")
return False
# 创建新用户
new_user = User(
mail=mail,
nickname=nickname,
password=password, # 实际应用中应该对密码进行哈希处理
level=level,
create_time=datetime.now()
)
# 添加并提交
session.add(new_user)
session.commit()
logger.info(f"成功注册用户: {mail}")
return True
except Exception as e:
session.rollback()
logger.error(f"注册用户失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
# 如果是连接错误,尝试重新初始化
try:
self._init_db()
except:
pass
return False
def get_user_by_mail(self, mail: str) -> Optional[Dict[str, Any]]:
"""
通过邮箱获取用户信息
Args:
mail: 邮箱
Returns:
用户信息如果用户不存在则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 查询用户
user = session.query(User).filter(User.mail == mail).first()
if user:
# 转换为字典
return {
'id': user.id,
'mail': user.mail,
'nickname': user.nickname,
'level': user.level,
'create_time': user.create_time
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
return None
def get_user_by_id(self, user_id: int) -> Optional[Dict[str, Any]]:
"""
通过ID获取用户信息
Args:
user_id: 用户ID
Returns:
用户信息如果用户不存在则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 查询用户
user = session.query(User).filter(User.id == user_id).first()
if user:
# 转换为字典
return {
'id': user.id,
'mail': user.mail,
'nickname': user.nickname,
'level': user.level,
'create_time': user.create_time
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
return None
def update_user_level(self, user_id: int, level: int) -> bool:
"""
更新用户级别
Args:
user_id: 用户ID
level: 新的用户级别
Returns:
更新是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 查询用户
user = session.query(User).filter(User.id == user_id).first()
if not user:
logger.warning(f"用户ID {user_id} 不存在")
return False
# 更新级别
user.level = level
session.commit()
logger.info(f"成功更新用户 {user.mail} 的级别为 {level}")
return True
except Exception as e:
session.rollback()
logger.error(f"更新用户级别失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def get_agent_feeds(self, agent_name: Optional[str] = None, limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
"""
获取AI Agent信息流
Args:
agent_name: 可选指定获取特定Agent的信息流
limit: 返回的最大记录数默认20条
Returns:
信息流列表,如果查询失败则返回空列表
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return []
try:
# 创建会话
session = self.Session()
try:
# 构建查询
query = session.query(AgentFeed)
# 如果指定了agent_name则筛选
if agent_name:
query = query.filter(AgentFeed.agent_name == agent_name)
# 按创建时间降序排序并限制数量
results = query.order_by(AgentFeed.create_time.desc()).offset(skip).limit(limit).all()
# 转换为字典列表
feeds = []
for result in results:
feeds.append({
'id': result.id,
'agent_name': result.agent_name,
'avatar_url': result.avatar_url,
'content': result.content,
'create_time': result.create_time
})
return feeds
finally:
session.close()
except Exception as e:
logger.error(f"获取信息流失败: {e}")
return []
def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool:
"""
保存用户提问数据
Args:
user_id: 用户ID
agent_id: AI Agent ID
question: 提问内容
Returns:
保存是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 创建新记录
new_question = UserQuestion(
user_id=user_id,
agent_id=agent_id,
question=question,
create_time=datetime.now()
)
# 添加并提交
session.add(new_question)
session.commit()
logger.info(f"成功保存用户 {user_id} 对 Agent {agent_id} 的提问")
return True
except Exception as e:
session.rollback()
logger.error(f"保存用户提问失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
# 如果是连接错误,尝试重新初始化
try:
self._init_db()
except:
pass
return False
def get_user_questions(self, user_id: Optional[int] = None, agent_id: Optional[str] = None,
limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
"""
获取用户提问数据
Args:
user_id: 可选,指定获取特定用户的提问
agent_id: 可选指定获取特定Agent的提问
limit: 返回的最大记录数默认20条
skip: 跳过的记录数默认0条
Returns:
提问数据列表,如果查询失败则返回空列表
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return []
try:
# 创建会话
session = self.Session()
try:
# 构建查询
query = session.query(UserQuestion)
# 如果指定了user_id则筛选
if user_id:
query = query.filter(UserQuestion.user_id == user_id)
# 如果指定了agent_id则筛选
if agent_id:
query = query.filter(UserQuestion.agent_id == agent_id)
# 按创建时间降序排序并限制数量
results = query.order_by(UserQuestion.create_time.desc()).offset(skip).limit(limit).all()
# 转换为字典列表
questions = []
for result in results:
questions.append({
'id': result.id,
'user_id': result.user_id,
'agent_id': result.agent_id,
'question': result.question,
'create_time': result.create_time
})
return questions
finally:
session.close()
except Exception as e:
logger.error(f"获取用户提问失败: {e}")
return []
def get_user_question_by_id(self, question_id: int) -> Optional[Dict[str, Any]]:
"""
通过ID获取用户提问数据
Args:
question_id: 提问ID
Returns:
提问数据如果不存在则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 查询提问
result = session.query(UserQuestion).filter(UserQuestion.id == question_id).first()
if result:
# 转换为字典
return {
'id': result.id,
'user_id': result.user_id,
'agent_id': result.agent_id,
'question': result.question,
'create_time': result.create_time
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取用户提问失败: {e}")
return None
def get_latest_result(self, agent: str, symbol: str, time_interval: str) -> Optional[Dict[str, Any]]:
"""
获取最新的分析结果
Args:
agent: 智能体类型,例如 'crypto''gold'
symbol: 交易对符号,例如 'BTCUSDT'
time_interval: 时间间隔,例如 '1h', '4h', '1d'
Returns:
最新分析结果如果查询失败则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 查询最新的结果
result = session.query(AnalysisResult).filter(
AnalysisResult.agent == agent,
AnalysisResult.symbol == symbol,
AnalysisResult.time_interval == time_interval
).order_by(AnalysisResult.created_at.desc()).first()
if result:
# 转换为字典
return {
'id': result.id,
'agent': result.agent,
'symbol': result.symbol,
'time_interval': result.time_interval,
'completion_result': result.completion_result,
'created_at': result.created_at
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取最新分析结果失败: {e}")
return None
def close(self) -> None:
"""关闭数据库连接(如果存在)"""
if self.engine:
self.engine.dispose()
logger.info("数据库连接已关闭")
self.engine = None
self.Session = None
def create_agent(self, name: str, description: Optional[str] = None,
hello_prompt: Optional[str] = None, dify_token: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None) -> Optional[int]:
"""
创建新的Agent
Args:
name: Agent名称
description: Agent描述
hello_prompt: 欢迎提示语
dify_token: Dify API令牌
inputs: 输入参数JSON
Returns:
新Agent的ID如果创建失败则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 检查名称是否已存在
existing_agent = session.query(Agent).filter(Agent.name == name).first()
if existing_agent:
logger.warning(f"Agent名称 {name} 已存在")
return None
# 创建新Agent
new_agent = Agent(
name=name,
description=description,
hello_prompt=hello_prompt,
dify_token=dify_token,
inputs=inputs,
create_time=datetime.now(),
update_time=datetime.now()
)
# 添加并提交
session.add(new_agent)
session.commit()
logger.info(f"成功创建Agent: {name}")
return new_agent.id
except Exception as e:
session.rollback()
logger.error(f"创建Agent失败: {e}")
return None
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return None
def update_agent(self, agent_id: int, name: Optional[str] = None,
description: Optional[str] = None, hello_prompt: Optional[str] = None,
dify_token: Optional[str] = None, inputs: Optional[Dict[str, Any]] = None) -> bool:
"""
更新Agent信息
Args:
agent_id: Agent ID
name: Agent名称
description: Agent描述
hello_prompt: 欢迎提示语
dify_token: Dify API令牌
inputs: 输入参数JSON
Returns:
更新是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 查询Agent
agent = session.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
logger.warning(f"Agent ID {agent_id} 不存在")
return False
# 更新字段
if name is not None:
# 检查名称是否与其他Agent冲突
if name != agent.name:
existing = session.query(Agent).filter(Agent.name == name).first()
if existing:
logger.warning(f"Agent名称 {name} 已被其他Agent使用")
return False
agent.name = name
if description is not None:
agent.description = description
if hello_prompt is not None:
agent.hello_prompt = hello_prompt
if dify_token is not None:
agent.dify_token = dify_token
if inputs is not None:
agent.inputs = inputs
agent.update_time = datetime.now()
# 提交更改
session.commit()
logger.info(f"成功更新Agent ID: {agent_id}")
return True
except Exception as e:
session.rollback()
logger.error(f"更新Agent失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def get_agent_by_id(self, agent_id: int) -> Optional[Dict[str, Any]]:
"""
通过ID获取Agent信息
Args:
agent_id: Agent ID
Returns:
Agent信息如果不存在则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 查询Agent
agent = session.query(Agent).filter(Agent.id == agent_id).first()
if agent:
# 转换为字典
return {
'id': agent.id,
'name': agent.name,
'description': agent.description,
'hello_prompt': agent.hello_prompt,
'dify_token': agent.dify_token,
'inputs': agent.inputs,
'create_time': agent.create_time,
'update_time': agent.update_time
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取Agent信息失败: {e}")
return None
def get_agent_by_name(self, name: str) -> Optional[Dict[str, Any]]:
"""
通过名称获取Agent信息
Args:
name: Agent名称
Returns:
Agent信息如果不存在则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 查询Agent
agent = session.query(Agent).filter(Agent.name == name).first()
if agent:
# 转换为字典
return {
'id': agent.id,
'name': agent.name,
'description': agent.description,
'hello_prompt': agent.hello_prompt,
'dify_token': agent.dify_token,
'inputs': agent.inputs,
'create_time': agent.create_time,
'update_time': agent.update_time
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取Agent信息失败: {e}")
return None
def list_agents(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
"""
获取Agent列表
Args:
limit: 返回的最大数量
skip: 跳过的数量
Returns:
Agent列表
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return []
try:
# 创建会话
session = self.Session()
try:
# 查询Agent列表
agents = session.query(Agent).order_by(Agent.create_time.asc()).offset(skip).limit(limit).all()
# 转换为字典列表
return [{
'id': agent.id,
'name': agent.name,
'description': agent.description,
'hello_prompt': agent.hello_prompt,
'dify_token': agent.dify_token,
'inputs': agent.inputs,
'create_time': agent.create_time,
'update_time': agent.update_time
} for agent in agents]
finally:
session.close()
except Exception as e:
logger.error(f"获取Agent列表失败: {e}")
return []
def delete_agent(self, agent_id: int) -> bool:
"""
删除Agent
Args:
agent_id: Agent ID
Returns:
删除是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 查询Agent
agent = session.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
logger.warning(f"Agent ID {agent_id} 不存在")
return False
# 删除Agent
session.delete(agent)
session.commit()
logger.info(f"成功删除Agent ID: {agent_id}")
return True
except Exception as e:
session.rollback()
logger.error(f"删除Agent失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
# 单例模式
_db_instance = None
def get_db_manager(host: Optional[str] = None,
port: Optional[int] = None,
user: Optional[str] = None,
password: Optional[str] = None,
db_name: Optional[str] = None) -> DBManager:
"""
获取数据库管理器实例(单例模式)
Args:
host: 数据库主机地址
port: 数据库端口
user: 用户名
password: 密码
db_name: 数据库名
Returns:
数据库管理器实例
"""
global _db_instance
# 如果已经初始化过,直接返回
if _db_instance is not None:
return _db_instance
# 如果未指定参数,从配置加载器获取数据库配置
if host is None or port is None or user is None or password is None or db_name is None:
config_loader = ConfigLoader()
db_config = config_loader.get_database_config()
# 使用配置中的值或默认值
db_host = host or db_config.get('host')
db_port = port or db_config.get('port')
db_user = user or db_config.get('user')
db_password = password or db_config.get('password')
db_name = db_name or db_config.get('db_name')
logger.info(f"从配置加载数据库连接信息: {db_host}:{db_port}/{db_name}")
else:
db_host = host
db_port = port
db_user = user
db_password = password
db_name = db_name
# 创建实例
_db_instance = DBManager(
host=db_host,
port=db_port,
user=db_user,
password=db_password,
db_name=db_name
)
return _db_instance