263 lines
8.8 KiB
Python
263 lines
8.8 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import os
|
||
import json
|
||
import logging
|
||
from typing import Dict, Any, List, Optional, Union
|
||
from datetime import datetime
|
||
|
||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Index, text
|
||
from sqlalchemy.ext.declarative import declarative_base
|
||
from sqlalchemy.orm import sessionmaker
|
||
from sqlalchemy.dialects.mysql import JSON
|
||
from sqlalchemy.pool import QueuePool
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger('db_manager')
|
||
|
||
# 创建模型基类
|
||
Base = declarative_base()
|
||
|
||
# 定义分析结果模型
|
||
class AnalysisResult(Base):
|
||
"""分析结果表模型"""
|
||
__tablename__ = 'analysis_results'
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
agent = Column(String(50), nullable=False, comment='智能体类型(crypto, gold)')
|
||
symbol = Column(String(50), nullable=False, comment='交易对符号')
|
||
time_interval = Column(String(20), nullable=False, comment='时间间隔')
|
||
completion_result = Column(JSON, nullable=False, comment='分析结果JSON')
|
||
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||
updated_at = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
|
||
|
||
# 索引
|
||
__table_args__ = (
|
||
Index('idx_agent', 'agent'),
|
||
Index('idx_symbol', 'symbol'),
|
||
Index('idx_time_interval', 'time_interval'),
|
||
Index('idx_created_at', 'created_at'),
|
||
)
|
||
|
||
class DBManager:
|
||
"""数据库管理工具,用于连接MySQL数据库并保存智能体分析结果"""
|
||
|
||
def __init__(self, host: str, port: int, user: str, password: str, db_name: str):
|
||
"""
|
||
初始化数据库管理器
|
||
|
||
Args:
|
||
host: 数据库主机地址
|
||
port: 数据库端口
|
||
user: 用户名
|
||
password: 密码
|
||
db_name: 数据库名
|
||
"""
|
||
self.host = host
|
||
self.port = port
|
||
self.user = user
|
||
self.password = password
|
||
self.db_name = db_name
|
||
self.engine = None
|
||
self.Session = None
|
||
|
||
# 初始化数据库连接
|
||
self._init_db()
|
||
|
||
def _init_db(self) -> None:
|
||
"""初始化数据库连接和表"""
|
||
try:
|
||
# 创建数据库连接
|
||
connection_string = f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}?charset=utf8mb4"
|
||
|
||
# 创建引擎,设置连接池
|
||
self.engine = create_engine(
|
||
connection_string,
|
||
echo=False, # 设置为True可以输出SQL语句(调试用)
|
||
pool_size=5, # 连接池大小
|
||
max_overflow=10, # 最大溢出连接数
|
||
pool_timeout=30, # 连接超时时间
|
||
pool_recycle=1800, # 连接回收时间(秒)
|
||
pool_pre_ping=True # 在使用连接前先ping一下,确保连接有效
|
||
)
|
||
|
||
# 创建会话工厂
|
||
self.Session = sessionmaker(bind=self.engine)
|
||
|
||
# 创建表(如果不存在)
|
||
Base.metadata.create_all(self.engine)
|
||
|
||
logger.info(f"成功连接到数据库 {self.db_name}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"数据库初始化失败: {e}")
|
||
self.engine = None
|
||
|
||
def save_analysis_result(self, agent: str, symbol: str, time_interval: str,
|
||
analysis_result: Dict[str, Any]) -> bool:
|
||
"""
|
||
保存分析结果到数据库
|
||
|
||
Args:
|
||
agent: 智能体类型,例如 'crypto' 或 'gold'
|
||
symbol: 交易对符号,例如 'BTCUSDT'
|
||
time_interval: 时间间隔,例如 '1h', '4h', '1d'
|
||
analysis_result: 分析结果数据
|
||
|
||
Returns:
|
||
保存是否成功
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 创建新记录
|
||
new_result = AnalysisResult(
|
||
agent=agent,
|
||
symbol=symbol,
|
||
time_interval=time_interval,
|
||
completion_result=analysis_result,
|
||
created_at=datetime.now(),
|
||
updated_at=datetime.now()
|
||
)
|
||
|
||
# 添加并提交
|
||
session.add(new_result)
|
||
session.commit()
|
||
|
||
logger.info(f"成功保存 {agent} 分析结果,交易对: {symbol}, 时间间隔: {time_interval}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"保存分析结果失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
# 如果是连接错误,尝试重新初始化
|
||
try:
|
||
self._init_db()
|
||
except:
|
||
pass
|
||
return False
|
||
|
||
def get_latest_result(self, agent: str, symbol: str, time_interval: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取最新的分析结果
|
||
|
||
Args:
|
||
agent: 智能体类型,例如 'crypto' 或 'gold'
|
||
symbol: 交易对符号,例如 'BTCUSDT'
|
||
time_interval: 时间间隔,例如 '1h', '4h', '1d'
|
||
|
||
Returns:
|
||
最新分析结果,如果查询失败则返回None
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return None
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询最新的结果
|
||
result = session.query(AnalysisResult).filter(
|
||
AnalysisResult.agent == agent,
|
||
AnalysisResult.symbol == symbol,
|
||
AnalysisResult.time_interval == time_interval
|
||
).order_by(AnalysisResult.created_at.desc()).first()
|
||
|
||
if result:
|
||
# 转换为字典
|
||
return {
|
||
'id': result.id,
|
||
'agent': result.agent,
|
||
'symbol': result.symbol,
|
||
'time_interval': result.time_interval,
|
||
'completion_result': result.completion_result,
|
||
'created_at': result.created_at
|
||
}
|
||
else:
|
||
return None
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取最新分析结果失败: {e}")
|
||
return None
|
||
|
||
def close(self) -> None:
|
||
"""关闭数据库连接"""
|
||
if self.engine:
|
||
self.engine.dispose()
|
||
self.engine = None
|
||
logger.info("数据库连接已关闭")
|
||
|
||
|
||
# 单例模式
|
||
_db_instance = None
|
||
|
||
def get_db_manager(host: Optional[str] = None,
|
||
port: Optional[int] = None,
|
||
user: Optional[str] = None,
|
||
password: Optional[str] = None,
|
||
db_name: Optional[str] = None) -> DBManager:
|
||
"""
|
||
获取数据库管理器实例(单例模式)
|
||
|
||
Args:
|
||
host: 数据库主机地址
|
||
port: 数据库端口
|
||
user: 用户名
|
||
password: 密码
|
||
db_name: 数据库名
|
||
|
||
Returns:
|
||
数据库管理器实例
|
||
"""
|
||
global _db_instance
|
||
|
||
# 如果已经初始化过,直接返回
|
||
if _db_instance is not None:
|
||
return _db_instance
|
||
|
||
# 从环境变量获取配置
|
||
db_host = host or os.environ.get('DB_HOST', 'gz-cynosdbmysql-grp-2j1cnopr.sql.tencentcdb.com')
|
||
db_port = port or int(os.environ.get('DB_PORT', '27469'))
|
||
db_user = user or os.environ.get('DB_USER', 'root')
|
||
db_password = password or os.environ.get('DB_PASSWORD', 'Aa#223388')
|
||
db_name = db_name or os.environ.get('DB_NAME', 'cryptoai')
|
||
|
||
# 创建实例
|
||
_db_instance = DBManager(
|
||
host=db_host,
|
||
port=db_port,
|
||
user=db_user,
|
||
password=db_password,
|
||
db_name=db_name
|
||
)
|
||
|
||
return _db_instance |