crypto.ai/cryptoai/utils/db_manager.py
2025-04-28 14:53:05 +08:00

263 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import json
import logging
from typing import Dict, Any, List, Optional, Union
from datetime import datetime
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Index, text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.dialects.mysql import JSON
from sqlalchemy.pool import QueuePool
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('db_manager')
# 创建模型基类
Base = declarative_base()
# 定义分析结果模型
class AnalysisResult(Base):
"""分析结果表模型"""
__tablename__ = 'analysis_results'
id = Column(Integer, primary_key=True, autoincrement=True)
agent = Column(String(50), nullable=False, comment='智能体类型(crypto, gold)')
symbol = Column(String(50), nullable=False, comment='交易对符号')
time_interval = Column(String(20), nullable=False, comment='时间间隔')
completion_result = Column(JSON, nullable=False, comment='分析结果JSON')
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
updated_at = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
# 索引
__table_args__ = (
Index('idx_agent', 'agent'),
Index('idx_symbol', 'symbol'),
Index('idx_time_interval', 'time_interval'),
Index('idx_created_at', 'created_at'),
)
class DBManager:
"""数据库管理工具用于连接MySQL数据库并保存智能体分析结果"""
def __init__(self, host: str, port: int, user: str, password: str, db_name: str):
"""
初始化数据库管理器
Args:
host: 数据库主机地址
port: 数据库端口
user: 用户名
password: 密码
db_name: 数据库名
"""
self.host = host
self.port = port
self.user = user
self.password = password
self.db_name = db_name
self.engine = None
self.Session = None
# 初始化数据库连接
self._init_db()
def _init_db(self) -> None:
"""初始化数据库连接和表"""
try:
# 创建数据库连接
connection_string = f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}?charset=utf8mb4"
# 创建引擎,设置连接池
self.engine = create_engine(
connection_string,
echo=False, # 设置为True可以输出SQL语句调试用
pool_size=5, # 连接池大小
max_overflow=10, # 最大溢出连接数
pool_timeout=30, # 连接超时时间
pool_recycle=1800, # 连接回收时间(秒)
pool_pre_ping=True # 在使用连接前先ping一下确保连接有效
)
# 创建会话工厂
self.Session = sessionmaker(bind=self.engine)
# 创建表(如果不存在)
Base.metadata.create_all(self.engine)
logger.info(f"成功连接到数据库 {self.db_name}")
except Exception as e:
logger.error(f"数据库初始化失败: {e}")
self.engine = None
def save_analysis_result(self, agent: str, symbol: str, time_interval: str,
analysis_result: Dict[str, Any]) -> bool:
"""
保存分析结果到数据库
Args:
agent: 智能体类型,例如 'crypto''gold'
symbol: 交易对符号,例如 'BTCUSDT'
time_interval: 时间间隔,例如 '1h', '4h', '1d'
analysis_result: 分析结果数据
Returns:
保存是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 创建新记录
new_result = AnalysisResult(
agent=agent,
symbol=symbol,
time_interval=time_interval,
completion_result=analysis_result,
created_at=datetime.now(),
updated_at=datetime.now()
)
# 添加并提交
session.add(new_result)
session.commit()
logger.info(f"成功保存 {agent} 分析结果,交易对: {symbol}, 时间间隔: {time_interval}")
return True
except Exception as e:
session.rollback()
logger.error(f"保存分析结果失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
# 如果是连接错误,尝试重新初始化
try:
self._init_db()
except:
pass
return False
def get_latest_result(self, agent: str, symbol: str, time_interval: str) -> Optional[Dict[str, Any]]:
"""
获取最新的分析结果
Args:
agent: 智能体类型,例如 'crypto''gold'
symbol: 交易对符号,例如 'BTCUSDT'
time_interval: 时间间隔,例如 '1h', '4h', '1d'
Returns:
最新分析结果如果查询失败则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 查询最新的结果
result = session.query(AnalysisResult).filter(
AnalysisResult.agent == agent,
AnalysisResult.symbol == symbol,
AnalysisResult.time_interval == time_interval
).order_by(AnalysisResult.created_at.desc()).first()
if result:
# 转换为字典
return {
'id': result.id,
'agent': result.agent,
'symbol': result.symbol,
'time_interval': result.time_interval,
'completion_result': result.completion_result,
'created_at': result.created_at
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取最新分析结果失败: {e}")
return None
def close(self) -> None:
"""关闭数据库连接"""
if self.engine:
self.engine.dispose()
self.engine = None
logger.info("数据库连接已关闭")
# 单例模式
_db_instance = None
def get_db_manager(host: Optional[str] = None,
port: Optional[int] = None,
user: Optional[str] = None,
password: Optional[str] = None,
db_name: Optional[str] = None) -> DBManager:
"""
获取数据库管理器实例(单例模式)
Args:
host: 数据库主机地址
port: 数据库端口
user: 用户名
password: 密码
db_name: 数据库名
Returns:
数据库管理器实例
"""
global _db_instance
# 如果已经初始化过,直接返回
if _db_instance is not None:
return _db_instance
# 从环境变量获取配置
db_host = host or os.environ.get('DB_HOST', 'gz-cynosdbmysql-grp-2j1cnopr.sql.tencentcdb.com')
db_port = port or int(os.environ.get('DB_PORT', '27469'))
db_user = user or os.environ.get('DB_USER', 'root')
db_password = password or os.environ.get('DB_PASSWORD', 'Aa#223388')
db_name = db_name or os.environ.get('DB_NAME', 'cryptoai')
# 创建实例
_db_instance = DBManager(
host=db_host,
port=db_port,
user=db_user,
password=db_password,
db_name=db_name
)
return _db_instance