#!/usr/bin/env python # -*- coding: utf-8 -*- import os import json import logging from typing import Dict, Any, List, Optional, Union from datetime import datetime from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Index, text from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from sqlalchemy.dialects.mysql import JSON from sqlalchemy.pool import QueuePool from utils.config_loader import ConfigLoader # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger('db_manager') # 创建模型基类 Base = declarative_base() # 定义分析结果模型 class AnalysisResult(Base): """分析结果表模型""" __tablename__ = 'analysis_results' id = Column(Integer, primary_key=True, autoincrement=True) agent = Column(String(50), nullable=False, comment='智能体类型(crypto, gold)') symbol = Column(String(50), nullable=False, comment='交易对符号') time_interval = Column(String(20), nullable=False, comment='时间间隔') completion_result = Column(JSON, nullable=False, comment='分析结果JSON') created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') updated_at = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间') # 索引 __table_args__ = ( Index('idx_agent', 'agent'), Index('idx_symbol', 'symbol'), Index('idx_time_interval', 'time_interval'), Index('idx_created_at', 'created_at'), ) class DBManager: """数据库管理工具,用于连接MySQL数据库并保存智能体分析结果""" def __init__(self, host: str, port: int, user: str, password: str, db_name: str): """ 初始化数据库管理器 Args: host: 数据库主机地址 port: 数据库端口 user: 用户名 password: 密码 db_name: 数据库名 """ self.host = host self.port = port self.user = user self.password = password self.db_name = db_name self.engine = None self.Session = None # 初始化数据库连接 self._init_db() def _init_db(self) -> None: """初始化数据库连接和表""" try: # 创建数据库连接 connection_string = f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}?charset=utf8mb4" # 创建引擎,设置连接池 self.engine = create_engine( connection_string, echo=False, # 设置为True可以输出SQL语句(调试用) pool_size=5, # 连接池大小 max_overflow=10, # 最大溢出连接数 pool_timeout=30, # 连接超时时间 pool_recycle=1800, # 连接回收时间(秒) pool_pre_ping=True # 在使用连接前先ping一下,确保连接有效 ) # 创建会话工厂 self.Session = sessionmaker(bind=self.engine) # 创建表(如果不存在) Base.metadata.create_all(self.engine) logger.info(f"成功连接到数据库 {self.db_name}") except Exception as e: logger.error(f"数据库初始化失败: {e}") self.engine = None def save_analysis_result(self, agent: str, symbol: str, time_interval: str, analysis_result: Dict[str, Any]) -> bool: """ 保存分析结果到数据库 Args: agent: 智能体类型,例如 'crypto' 或 'gold' symbol: 交易对符号,例如 'BTCUSDT' time_interval: 时间间隔,例如 '1h', '4h', '1d' analysis_result: 分析结果数据 Returns: 保存是否成功 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return False try: # 创建会话 session = self.Session() try: # 创建新记录 new_result = AnalysisResult( agent=agent, symbol=symbol, time_interval=time_interval, completion_result=analysis_result, created_at=datetime.now(), updated_at=datetime.now() ) # 添加并提交 session.add(new_result) session.commit() logger.info(f"成功保存 {agent} 分析结果,交易对: {symbol}, 时间间隔: {time_interval}") return True except Exception as e: session.rollback() logger.error(f"保存分析结果失败: {e}") return False finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") # 如果是连接错误,尝试重新初始化 try: self._init_db() except: pass return False def get_latest_result(self, agent: str, symbol: str, time_interval: str) -> Optional[Dict[str, Any]]: """ 获取最新的分析结果 Args: agent: 智能体类型,例如 'crypto' 或 'gold' symbol: 交易对符号,例如 'BTCUSDT' time_interval: 时间间隔,例如 '1h', '4h', '1d' Returns: 最新分析结果,如果查询失败则返回None """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return None try: # 创建会话 session = self.Session() try: # 查询最新的结果 result = session.query(AnalysisResult).filter( AnalysisResult.agent == agent, AnalysisResult.symbol == symbol, AnalysisResult.time_interval == time_interval ).order_by(AnalysisResult.created_at.desc()).first() if result: # 转换为字典 return { 'id': result.id, 'agent': result.agent, 'symbol': result.symbol, 'time_interval': result.time_interval, 'completion_result': result.completion_result, 'created_at': result.created_at } else: return None finally: session.close() except Exception as e: logger.error(f"获取最新分析结果失败: {e}") return None def close(self) -> None: """关闭数据库连接""" if self.engine: self.engine.dispose() self.engine = None logger.info("数据库连接已关闭") # 单例模式 _db_instance = None def get_db_manager(host: Optional[str] = None, port: Optional[int] = None, user: Optional[str] = None, password: Optional[str] = None, db_name: Optional[str] = None) -> DBManager: """ 获取数据库管理器实例(单例模式) Args: host: 数据库主机地址 port: 数据库端口 user: 用户名 password: 密码 db_name: 数据库名 Returns: 数据库管理器实例 """ global _db_instance # 如果已经初始化过,直接返回 if _db_instance is not None: return _db_instance # 如果未指定参数,从配置加载器获取数据库配置 if host is None or port is None or user is None or password is None or db_name is None: config_loader = ConfigLoader() db_config = config_loader.get_database_config() # 使用配置中的值或默认值 db_host = host or db_config.get('host') db_port = port or db_config.get('port') db_user = user or db_config.get('user') db_password = password or db_config.get('password') db_name = db_name or db_config.get('db_name') logger.info(f"从配置加载数据库连接信息: {db_host}:{db_port}/{db_name}") else: db_host = host db_port = port db_user = user db_password = password db_name = db_name # 创建实例 _db_instance = DBManager( host=db_host, port=db_port, user=db_user, password=db_password, db_name=db_name ) return _db_instance