140 lines
4.6 KiB
Python
140 lines
4.6 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import os
|
||
import yaml
|
||
from typing import Dict, Any, Optional
|
||
|
||
|
||
class ConfigLoader:
|
||
"""配置加载器,用于读取和获取配置信息"""
|
||
|
||
def __init__(self, config_path: Optional[str] = None):
|
||
"""
|
||
初始化配置加载器
|
||
|
||
Args:
|
||
config_path: 配置文件路径,如果为None,则使用默认路径
|
||
"""
|
||
if config_path is None:
|
||
# 默认配置文件路径
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
config_dir = os.path.join(os.path.dirname(current_dir), 'config')
|
||
config_path = os.path.join(config_dir, 'config.yaml')
|
||
|
||
# 如果默认配置文件不存在,则使用示例配置文件
|
||
if not os.path.exists(config_path):
|
||
config_path = os.path.join(config_dir, 'config.example.yaml')
|
||
print(f"配置文件 config.yaml 不存在,使用示例配置文件: {config_path}")
|
||
|
||
self.config_path = config_path
|
||
|
||
# 加载配置
|
||
self.config = self._load_config()
|
||
|
||
# 确保database配置存在
|
||
if 'database' not in self.config:
|
||
self.config['database'] = {
|
||
'host': 'gz-cynosdbmysql-grp-2j1cnopr.sql.tencentcdb.com',
|
||
'port': 27469,
|
||
'user': 'root',
|
||
'password': 'Aa#223388',
|
||
'db_name': 'cryptoai'
|
||
}
|
||
|
||
def _load_config(self) -> Dict[str, Any]:
|
||
"""
|
||
加载配置文件
|
||
|
||
Returns:
|
||
配置字典
|
||
"""
|
||
try:
|
||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||
config = yaml.safe_load(f)
|
||
print(f"已加载配置文件: {self.config_path}")
|
||
return config
|
||
except Exception as e:
|
||
print(f"加载配置文件失败: {e}")
|
||
return {}
|
||
|
||
def get_config(self, section: str) -> Dict[str, Any]:
|
||
"""
|
||
获取指定部分的配置
|
||
|
||
Args:
|
||
section: 配置部分名称
|
||
|
||
Returns:
|
||
配置字典
|
||
"""
|
||
return self.config.get(section, {})
|
||
|
||
def get_binance_config(self) -> Dict[str, Any]:
|
||
"""获取Binance配置"""
|
||
return self.get_config('binance')
|
||
|
||
def get_okx_config(self) -> Dict[str, Any]:
|
||
"""获取OKX配置"""
|
||
return self.get_config('okx')
|
||
|
||
def get_deepseek_config(self) -> Dict[str, Any]:
|
||
"""获取DeepSeek配置"""
|
||
return self.get_config('deepseek')
|
||
|
||
def get_alltick_config(self) -> Dict[str, Any]:
|
||
"""获取AllTick配置"""
|
||
return self.get_config('alltick')
|
||
|
||
def get_crypto_config(self) -> Dict[str, Any]:
|
||
"""获取加密货币配置"""
|
||
return self.get_config('crypto')
|
||
|
||
def get_data_config(self) -> Dict[str, Any]:
|
||
"""获取数据配置"""
|
||
return self.get_config('data')
|
||
|
||
def get_agent_config(self) -> Dict[str, Any]:
|
||
"""获取Agent配置"""
|
||
return self.get_config('agent')
|
||
|
||
def get_logging_config(self) -> Dict[str, Any]:
|
||
"""获取日志配置"""
|
||
return self.get_config('logging')
|
||
|
||
def get_dingtalk_config(self) -> Dict[str, Any]:
|
||
"""获取钉钉机器人配置"""
|
||
return self.get_config('dingtalk')
|
||
|
||
def get_database_config(self) -> Dict[str, Any]:
|
||
"""获取数据库配置"""
|
||
# 首先从配置文件获取
|
||
db_config = self.get_config('database')
|
||
|
||
# 使用环境变量覆盖(如果存在)
|
||
if os.environ.get('DB_HOST'):
|
||
db_config['host'] = os.environ.get('DB_HOST')
|
||
if os.environ.get('DB_PORT'):
|
||
db_config['port'] = int(os.environ.get('DB_PORT'))
|
||
if os.environ.get('DB_USER'):
|
||
db_config['user'] = os.environ.get('DB_USER')
|
||
if os.environ.get('DB_PASSWORD'):
|
||
db_config['password'] = os.environ.get('DB_PASSWORD')
|
||
if os.environ.get('DB_NAME'):
|
||
db_config['db_name'] = os.environ.get('DB_NAME')
|
||
|
||
# 确保返回默认值(如果配置不存在)
|
||
default_config = {
|
||
'host': 'gz-cynosdbmysql-grp-2j1cnopr.sql.tencentcdb.com',
|
||
'port': 27469,
|
||
'user': 'root',
|
||
'password': 'Aa#223388',
|
||
'db_name': 'cryptoai'
|
||
}
|
||
|
||
# 合并默认配置和实际配置
|
||
for key, value in default_config.items():
|
||
if key not in db_config:
|
||
db_config[key] = value
|
||
|
||
return db_config |