#!/usr/bin/env python # -*- coding: utf-8 -*- import os import json import logging from typing import Dict, Any, List, Optional, Union from datetime import datetime from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Index, text, ForeignKey from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, relationship from sqlalchemy.dialects.mysql import JSON from sqlalchemy.pool import QueuePool from sqlalchemy import or_ from cryptoai.utils.config_loader import ConfigLoader # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger('db_manager') # 创建模型基类 Base = declarative_base() # 定义Token模型 class Token(Base): """Token信息表模型""" __tablename__ = 'tokens' id = Column(Integer, primary_key=True, autoincrement=True) symbol = Column(String(50), nullable=False, unique=True, comment='交易对符号') base_asset = Column(String(20), nullable=False, comment='基础资产') quote_asset = Column(String(20), nullable=False, comment='计价资产') created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') # 索引 __table_args__ = ( Index('idx_symbol', 'symbol'), Index('idx_base_asset', 'base_asset'), Index('idx_quote_asset', 'quote_asset'), {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} ) # 定义分析结果模型 class AnalysisResult(Base): """分析结果表模型""" __tablename__ = 'analysis_results' id = Column(Integer, primary_key=True, autoincrement=True) agent = Column(String(50), nullable=False, comment='智能体类型(crypto, gold)') symbol = Column(String(50), nullable=False, comment='交易对符号') time_interval = Column(String(20), nullable=False, comment='时间间隔') completion_result = Column(JSON, nullable=False, comment='分析结果JSON') created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') updated_at = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间') # 索引 __table_args__ = ( Index('idx_agent', 'agent'), Index('idx_symbol', 'symbol'), Index('idx_time_interval', 'time_interval'), Index('idx_created_at', 'created_at'), ) # 定义AI Agent信息流模型 class AgentFeed(Base): """AI Agent信息流表模型""" __tablename__ = 'agent_feeds' id = Column(Integer, primary_key=True, autoincrement=True) agent_name = Column(String(50), nullable=False, comment='AI Agent名称') avatar_url = Column(String(255), nullable=True, comment='头像URL') content = Column(Text, nullable=False, comment='内容') create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') # 索引和表属性 __table_args__ = ( Index('idx_agent_name', 'agent_name'), Index('idx_create_time', 'create_time'), {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} ) # 定义用户数据模型 class User(Base): """用户数据表模型""" __tablename__ = 'users' id = Column(Integer, primary_key=True, autoincrement=True) mail = Column(String(100), nullable=False, unique=True, comment='邮箱') nickname = Column(String(50), nullable=False, comment='昵称') password = Column(String(100), nullable=False, comment='密码') level = Column(Integer, nullable=False, default=0, comment='用户级别(0=普通用户,1=VIP,2=SVIP)') points = Column(Integer, nullable=False, default=0, comment='用户积分') create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') # 关系 questions = relationship("UserQuestion", back_populates="user") # 索引和表属性 __table_args__ = ( Index('idx_mail', 'mail'), Index('idx_level', 'level'), Index('idx_create_time', 'create_time'), {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} ) # 定义用户提问数据模型 class UserQuestion(Base): """用户提问数据表模型""" __tablename__ = 'user_questions' id = Column(Integer, primary_key=True, autoincrement=True) user_id = Column(Integer, ForeignKey('users.id'), nullable=False, comment='用户ID') agent_id = Column(String(50), nullable=False, comment='AI Agent ID') question = Column(Text, nullable=False, comment='提问内容') create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') # 关系 user = relationship("User", back_populates="questions") # 索引和表属性 __table_args__ = ( Index('idx_user_id', 'user_id'), Index('idx_agent_id', 'agent_id'), Index('idx_create_time', 'create_time'), {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} ) # 定义Agent数据模型 class Agent(Base): """Agent数据表模型""" __tablename__ = 'agents' id = Column(Integer, primary_key=True, autoincrement=True) name = Column(String(100), nullable=False, unique=True, comment='Agent名称') hello_prompt = Column(Text, nullable=True, comment='欢迎提示语') description = Column(Text, nullable=True, comment='Agent描述') dify_token = Column(String(255), nullable=True, comment='Dify API令牌') inputs = Column(JSON, nullable=True, comment='输入参数JSON') create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') update_time = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间') # 索引和表属性 __table_args__ = ( Index('idx_name', 'name'), Index('idx_create_time', 'create_time'), {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} ) # 定义 A 股数据模型 class AStock(Base): """A股股票基本信息表模型""" __tablename__ = 'astock' stock_code = Column(String(10), primary_key=True, comment='股票代码') short_name = Column(String(50), nullable=False, comment='股票简称') exchange = Column(String(20), nullable=True, comment='交易所') list_date = Column(DateTime, nullable=True, comment='上市日期') created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') # 索引 __table_args__ = ( Index('idx_stock_code', 'stock_code', unique=True), {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} ) def _convert_timestamp_to_datetime(timestamp: Union[int, float, str, None]) -> Optional[datetime]: """ 将时间戳转换为datetime对象 Args: timestamp: Unix时间戳(毫秒级) Returns: datetime对象或None """ if timestamp is None: return None try: # 转换为整数 if isinstance(timestamp, str): timestamp = int(timestamp) # 如果是毫秒级时间戳,转换为秒级 if timestamp > 1e11: # 判断是否为毫秒级时间戳 timestamp = timestamp / 1000 return datetime.fromtimestamp(timestamp) except (ValueError, TypeError) as e: logger.error(f"时间戳转换失败: {timestamp}, 错误: {e}") return None class DBManager: """数据库管理工具,用于连接MySQL数据库并保存智能体分析结果""" def __init__(self, host: str, port: int, user: str, password: str, db_name: str): """ 初始化数据库管理器 Args: host: 数据库主机地址 port: 数据库端口 user: 用户名 password: 密码 db_name: 数据库名 """ self.host = host self.port = port self.user = user self.password = password self.db_name = db_name self.engine = None self.Session = None # 初始化数据库连接 self._init_db() def _init_db(self) -> None: """初始化数据库连接和表""" try: # 创建数据库连接 connection_string = f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}?charset=utf8mb4" # 创建引擎,设置连接池 self.engine = create_engine( connection_string, echo=False, # 设置为True可以输出SQL语句(调试用) pool_size=5, # 连接池大小 max_overflow=10, # 最大溢出连接数 pool_timeout=30, # 连接超时时间 pool_recycle=1800, # 连接回收时间(秒) pool_pre_ping=True, # 在使用连接前先ping一下,确保连接有效 connect_args={'charset': 'utf8mb4'} ) # 创建会话工厂 self.Session = sessionmaker(bind=self.engine) # 创建表(如果不存在) Base.metadata.create_all(self.engine) logger.info(f"成功连接到数据库 {self.db_name}") except Exception as e: logger.error(f"数据库初始化失败: {e}") self.engine = None def save_analysis_result(self, agent: str, symbol: str, time_interval: str, analysis_result: Dict[str, Any]) -> bool: """ 保存分析结果到数据库 Args: agent: 智能体类型,例如 'crypto' 或 'gold' symbol: 交易对符号,例如 'BTCUSDT' time_interval: 时间间隔,例如 '1h', '4h', '1d' analysis_result: 分析结果数据 Returns: 保存是否成功 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return False try: # 创建会话 session = self.Session() try: # 创建新记录 new_result = AnalysisResult( agent=agent, symbol=symbol, time_interval=time_interval, completion_result=analysis_result, created_at=datetime.now(), updated_at=datetime.now() ) # 添加并提交 session.add(new_result) session.commit() logger.info(f"成功保存 {agent} 分析结果,交易对: {symbol}, 时间间隔: {time_interval}") return True except Exception as e: session.rollback() logger.error(f"保存分析结果失败: {e}") return False finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") # 如果是连接错误,尝试重新初始化 try: self._init_db() except: pass return False def save_agent_feed(self, agent_name: str, content: str, avatar_url: Optional[str] = None) -> bool: """ 保存AI Agent信息流到数据库 Args: agent_name: AI Agent名称 content: 内容 avatar_url: 头像URL,可选 Returns: 保存是否成功 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return False try: # 创建会话 session = self.Session() try: # 创建新记录 new_feed = AgentFeed( agent_name=agent_name, content=content, avatar_url=avatar_url ) # 添加并提交 session.add(new_feed) session.commit() logger.info(f"成功保存 {agent_name} 的信息流") return True except Exception as e: session.rollback() logger.error(f"保存信息流失败: {e}") return False finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") # 如果是连接错误,尝试重新初始化 try: self._init_db() except: pass return False def register_user(self, mail: str, nickname: str, password: str, level: int = 0, points: int = 0) -> bool: """ 注册新用户 Args: mail: 邮箱 nickname: 昵称 password: 密码 level: 用户级别,默认为0(普通用户) points: 初始积分,默认为0 Returns: 注册是否成功 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return False try: # 创建会话 session = self.Session() try: # 检查邮箱是否已存在 existing_user = session.query(User).filter(User.mail == mail).first() if existing_user: logger.warning(f"邮箱 {mail} 已被注册") return False # 创建新用户 new_user = User( mail=mail, nickname=nickname, password=password, # 实际应用中应该对密码进行哈希处理 level=level, points=points, create_time=datetime.now() ) # 添加并提交 session.add(new_user) session.commit() logger.info(f"成功注册用户: {mail},初始积分: {points}") return True except Exception as e: session.rollback() logger.error(f"注册用户失败: {e}") return False finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") # 如果是连接错误,尝试重新初始化 try: self._init_db() except: pass return False def get_user_by_id(self, user_id: int) -> Optional[Dict[str, Any]]: """ 通过ID获取用户信息 Args: user_id: 用户ID Returns: 用户信息,如果用户不存在则返回None """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return None try: # 创建会话 session = self.Session() try: # 查询用户 user = session.query(User).filter(User.id == user_id).first() if user: # 转换为字典 return { 'id': user.id, 'mail': user.mail, 'nickname': user.nickname, 'level': user.level, 'points': user.points, 'create_time': user.create_time } else: return None finally: session.close() except Exception as e: logger.error(f"获取用户信息失败: {e}") return None def get_user_by_mail(self, mail: str) -> Optional[Dict[str, Any]]: """ 通过邮箱获取用户信息 Args: mail: 邮箱 Returns: 用户信息,如果用户不存在则返回None """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return None try: # 创建会话 session = self.Session() try: # 查询用户 user = session.query(User).filter(User.mail == mail).first() if user: # 转换为字典 return { 'id': user.id, 'mail': user.mail, 'nickname': user.nickname, 'level': user.level, 'points': user.points, 'create_time': user.create_time } else: return None finally: session.close() except Exception as e: logger.error(f"获取用户信息失败: {e}") return None def update_user_level(self, user_id: int, level: int) -> bool: """ 更新用户级别 Args: user_id: 用户ID level: 新的用户级别 Returns: 更新是否成功 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return False try: # 创建会话 session = self.Session() try: # 查询用户 user = session.query(User).filter(User.id == user_id).first() if not user: logger.warning(f"用户ID {user_id} 不存在") return False # 更新级别 user.level = level session.commit() logger.info(f"成功更新用户 {user.mail} 的级别为 {level}") return True except Exception as e: session.rollback() logger.error(f"更新用户级别失败: {e}") return False finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") return False def add_user_points(self, user_id: int, points: int) -> bool: """ 为用户增加积分 Args: user_id: 用户ID points: 增加的积分数量(正数) Returns: 操作是否成功 """ if points <= 0: logger.warning(f"增加的积分必须是正数: {points}") return False if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return False try: # 创建会话 session = self.Session() try: # 查询用户 user = session.query(User).filter(User.id == user_id).first() if not user: logger.warning(f"用户ID {user_id} 不存在") return False # 增加积分 user.points += points session.commit() logger.info(f"成功为用户 {user.mail} 增加 {points} 积分,当前积分: {user.points}") return True except Exception as e: session.rollback() logger.error(f"增加用户积分失败: {e}") return False finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") return False def consume_user_points(self, user_id: int, points: int) -> bool: """ 用户消费积分 Args: user_id: 用户ID points: 消费的积分数量(正数) Returns: 操作是否成功 """ if points <= 0: logger.warning(f"消费的积分必须是正数: {points}") return False if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return False try: # 创建会话 session = self.Session() try: # 查询用户 user = session.query(User).filter(User.id == user_id).first() if not user: logger.warning(f"用户ID {user_id} 不存在") return False # 检查积分是否足够 if user.points < points: logger.warning(f"用户 {user.mail} 积分不足,当前积分: {user.points},需要消费: {points}") return False # 消费积分 user.points -= points session.commit() logger.info(f"成功从用户 {user.mail} 消费 {points} 积分,剩余积分: {user.points}") return True except Exception as e: session.rollback() logger.error(f"消费用户积分失败: {e}") return False finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") return False def get_agent_feeds(self, agent_name: Optional[str] = None, limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]: """ 获取AI Agent信息流 Args: agent_name: 可选,指定获取特定Agent的信息流 limit: 返回的最大记录数,默认20条 Returns: 信息流列表,如果查询失败则返回空列表 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return [] try: # 创建会话 session = self.Session() try: # 构建查询 query = session.query(AgentFeed) # 如果指定了agent_name,则筛选 if agent_name: query = query.filter(AgentFeed.agent_name == agent_name) # 按创建时间降序排序并限制数量 results = query.order_by(AgentFeed.create_time.desc()).offset(skip).limit(limit).all() # 转换为字典列表 feeds = [] for result in results: feeds.append({ 'id': result.id, 'agent_name': result.agent_name, 'avatar_url': result.avatar_url, 'content': result.content, 'create_time': result.create_time }) return feeds finally: session.close() except Exception as e: logger.error(f"获取信息流失败: {e}") return [] def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool: """ 保存用户提问数据 Args: user_id: 用户ID agent_id: AI Agent ID question: 提问内容 Returns: 保存是否成功 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return False try: # 创建会话 session = self.Session() try: # 创建新记录 new_question = UserQuestion( user_id=user_id, agent_id=agent_id, question=question, create_time=datetime.now() ) # 添加并提交 session.add(new_question) session.commit() logger.info(f"成功保存用户 {user_id} 对 Agent {agent_id} 的提问") return True except Exception as e: session.rollback() logger.error(f"保存用户提问失败: {e}") return False finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") # 如果是连接错误,尝试重新初始化 try: self._init_db() except: pass return False def get_user_questions(self, user_id: Optional[int] = None, agent_id: Optional[str] = None, limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]: """ 获取用户提问数据 Args: user_id: 可选,指定获取特定用户的提问 agent_id: 可选,指定获取特定Agent的提问 limit: 返回的最大记录数,默认20条 skip: 跳过的记录数,默认0条 Returns: 提问数据列表,如果查询失败则返回空列表 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return [] try: # 创建会话 session = self.Session() try: # 构建查询 query = session.query(UserQuestion) # 如果指定了user_id,则筛选 if user_id: query = query.filter(UserQuestion.user_id == user_id) # 如果指定了agent_id,则筛选 if agent_id: query = query.filter(UserQuestion.agent_id == agent_id) # 按创建时间降序排序并限制数量 results = query.order_by(UserQuestion.create_time.desc()).offset(skip).limit(limit).all() # 转换为字典列表 questions = [] for result in results: questions.append({ 'id': result.id, 'user_id': result.user_id, 'agent_id': result.agent_id, 'question': result.question, 'create_time': result.create_time }) return questions finally: session.close() except Exception as e: logger.error(f"获取用户提问失败: {e}") return [] def get_user_question_by_id(self, question_id: int) -> Optional[Dict[str, Any]]: """ 通过ID获取用户提问数据 Args: question_id: 提问ID Returns: 提问数据,如果不存在则返回None """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return None try: # 创建会话 session = self.Session() try: # 查询提问 result = session.query(UserQuestion).filter(UserQuestion.id == question_id).first() if result: # 转换为字典 return { 'id': result.id, 'user_id': result.user_id, 'agent_id': result.agent_id, 'question': result.question, 'create_time': result.create_time } else: return None finally: session.close() except Exception as e: logger.error(f"获取用户提问失败: {e}") return None def get_latest_result(self, agent: str, symbol: str, time_interval: str) -> Optional[Dict[str, Any]]: """ 获取最新的分析结果 Args: agent: 智能体类型,例如 'crypto' 或 'gold' symbol: 交易对符号,例如 'BTCUSDT' time_interval: 时间间隔,例如 '1h', '4h', '1d' Returns: 最新分析结果,如果查询失败则返回None """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return None try: # 创建会话 session = self.Session() try: # 查询最新的结果 result = session.query(AnalysisResult).filter( AnalysisResult.agent == agent, AnalysisResult.symbol == symbol, AnalysisResult.time_interval == time_interval ).order_by(AnalysisResult.created_at.desc()).first() if result: # 转换为字典 return { 'id': result.id, 'agent': result.agent, 'symbol': result.symbol, 'time_interval': result.time_interval, 'completion_result': result.completion_result, 'created_at': result.created_at } else: return None finally: session.close() except Exception as e: logger.error(f"获取最新分析结果失败: {e}") return None def close(self) -> None: """关闭数据库连接(如果存在)""" if self.engine: self.engine.dispose() logger.info("数据库连接已关闭") self.engine = None self.Session = None def create_agent(self, name: str, description: Optional[str] = None, hello_prompt: Optional[str] = None, dify_token: Optional[str] = None, inputs: Optional[Dict[str, Any]] = None) -> Optional[int]: """ 创建新的Agent Args: name: Agent名称 description: Agent描述 hello_prompt: 欢迎提示语 dify_token: Dify API令牌 inputs: 输入参数JSON Returns: 新Agent的ID,如果创建失败则返回None """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return None try: # 创建会话 session = self.Session() try: # 检查名称是否已存在 existing_agent = session.query(Agent).filter(Agent.name == name).first() if existing_agent: logger.warning(f"Agent名称 {name} 已存在") return None # 创建新Agent new_agent = Agent( name=name, description=description, hello_prompt=hello_prompt, dify_token=dify_token, inputs=inputs, create_time=datetime.now(), update_time=datetime.now() ) # 添加并提交 session.add(new_agent) session.commit() logger.info(f"成功创建Agent: {name}") return new_agent.id except Exception as e: session.rollback() logger.error(f"创建Agent失败: {e}") return None finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") return None def update_agent(self, agent_id: int, name: Optional[str] = None, description: Optional[str] = None, hello_prompt: Optional[str] = None, dify_token: Optional[str] = None, inputs: Optional[Dict[str, Any]] = None) -> bool: """ 更新Agent信息 Args: agent_id: Agent ID name: Agent名称 description: Agent描述 hello_prompt: 欢迎提示语 dify_token: Dify API令牌 inputs: 输入参数JSON Returns: 更新是否成功 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return False try: # 创建会话 session = self.Session() try: # 查询Agent agent = session.query(Agent).filter(Agent.id == agent_id).first() if not agent: logger.warning(f"Agent ID {agent_id} 不存在") return False # 更新字段 if name is not None: # 检查名称是否与其他Agent冲突 if name != agent.name: existing = session.query(Agent).filter(Agent.name == name).first() if existing: logger.warning(f"Agent名称 {name} 已被其他Agent使用") return False agent.name = name if description is not None: agent.description = description if hello_prompt is not None: agent.hello_prompt = hello_prompt if dify_token is not None: agent.dify_token = dify_token if inputs is not None: agent.inputs = inputs agent.update_time = datetime.now() # 提交更改 session.commit() logger.info(f"成功更新Agent ID: {agent_id}") return True except Exception as e: session.rollback() logger.error(f"更新Agent失败: {e}") return False finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") return False def get_agent_by_id(self, agent_id: int) -> Optional[Dict[str, Any]]: """ 通过ID获取Agent信息 Args: agent_id: Agent ID Returns: Agent信息,如果不存在则返回None """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return None try: # 创建会话 session = self.Session() try: # 查询Agent agent = session.query(Agent).filter(Agent.id == agent_id).first() if agent: # 转换为字典 return { 'id': agent.id, 'name': agent.name, 'description': agent.description, 'hello_prompt': agent.hello_prompt, 'dify_token': agent.dify_token, 'inputs': agent.inputs, 'create_time': agent.create_time, 'update_time': agent.update_time } else: return None finally: session.close() except Exception as e: logger.error(f"获取Agent信息失败: {e}") return None def get_agent_by_name(self, name: str) -> Optional[Dict[str, Any]]: """ 通过名称获取Agent信息 Args: name: Agent名称 Returns: Agent信息,如果不存在则返回None """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return None try: # 创建会话 session = self.Session() try: # 查询Agent agent = session.query(Agent).filter(Agent.name == name).first() if agent: # 转换为字典 return { 'id': agent.id, 'name': agent.name, 'description': agent.description, 'hello_prompt': agent.hello_prompt, 'dify_token': agent.dify_token, 'inputs': agent.inputs, 'create_time': agent.create_time, 'update_time': agent.update_time } else: return None finally: session.close() except Exception as e: logger.error(f"获取Agent信息失败: {e}") return None def list_agents(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]: """ 获取Agent列表 Args: limit: 返回的最大数量 skip: 跳过的数量 Returns: Agent列表 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return [] try: # 创建会话 session = self.Session() try: # 查询Agent列表 agents = session.query(Agent).order_by(Agent.create_time.asc()).offset(skip).limit(limit).all() # 转换为字典列表 return [{ 'id': agent.id, 'name': agent.name, 'description': agent.description, 'hello_prompt': agent.hello_prompt, 'dify_token': agent.dify_token, 'inputs': agent.inputs, 'create_time': agent.create_time, 'update_time': agent.update_time } for agent in agents] finally: session.close() except Exception as e: logger.error(f"获取Agent列表失败: {e}") return [] def delete_agent(self, agent_id: int) -> bool: """ 删除Agent Args: agent_id: Agent ID Returns: 删除是否成功 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return False try: # 创建会话 session = self.Session() try: # 查询Agent agent = session.query(Agent).filter(Agent.id == agent_id).first() if not agent: logger.warning(f"Agent ID {agent_id} 不存在") return False # 删除Agent session.delete(agent) session.commit() logger.info(f"成功删除Agent ID: {agent_id}") return True except Exception as e: session.rollback() logger.error(f"删除Agent失败: {e}") return False finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") return False def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool: """ 批量创建股票信息 Args: stocks: 股票信息列表,每个元素为包含以下键的字典: - stock_code: 股票代码 - short_name: 股票简称 - exchange: 交易所(可选) - list_date: 上市日期(可选,Unix时间戳,毫秒级) Returns: 创建是否成功 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return False try: session = self.Session() try: for stock_data in stocks: # 检查股票代码是否已存在 existing_stock = session.query(AStock).filter( AStock.stock_code == stock_data['stock_code'] ).first() if existing_stock: logger.warning(f"股票代码 {stock_data['stock_code']} 已存在,跳过") continue # 创建新股票记录 new_stock = AStock( stock_code=stock_data['stock_code'], short_name=stock_data['short_name'], exchange=stock_data.get('exchange') ) # 处理上市日期 if 'list_date' in stock_data and stock_data['list_date']: list_date = _convert_timestamp_to_datetime(stock_data['list_date']) if list_date: new_stock.list_date = list_date session.add(new_stock) # 批量提交 session.commit() logger.info(f"成功批量创建股票信息,共 {len(stocks)} 条记录") return True except Exception as e: session.rollback() logger.error(f"批量创建股票信息失败: {e}") return False finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") return False def create_stock(self, stock_code: str, short_name: str, exchange: Optional[str] = None, list_date: Optional[Union[int, float, str]] = None) -> bool: """ 创建新的股票信息 Args: stock_code: 股票代码 short_name: 股票简称 exchange: 交易所(可选) list_date: 上市日期(可选,Unix时间戳,毫秒级) Returns: 创建是否成功 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return False try: session = self.Session() try: # 检查股票代码是否已存在 existing_stock = session.query(AStock).filter(AStock.stock_code == stock_code).first() if existing_stock: logger.warning(f"股票代码 {stock_code} 已存在") return False # 创建新股票记录 new_stock = AStock( stock_code=stock_code, short_name=short_name, exchange=exchange ) # 处理上市日期 if list_date: converted_date = _convert_timestamp_to_datetime(list_date) if converted_date: new_stock.list_date = converted_date session.add(new_stock) session.commit() logger.info(f"成功创建股票信息: {stock_code} - {short_name}") return True except Exception as e: session.rollback() logger.error(f"创建股票信息失败: {e}") return False finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") return False def update_stock(self, stock_code: str, short_name: Optional[str] = None, exchange: Optional[str] = None, list_date: Optional[datetime] = None) -> bool: """ 更新股票信息 Args: stock_code: 股票代码 short_name: 股票简称 exchange: 交易所 list_date: 上市日期 Returns: 更新是否成功 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return False try: session = self.Session() try: # 查询股票 stock = session.query(AStock).filter(AStock.stock_code == stock_code).first() if not stock: logger.warning(f"股票代码 {stock_code} 不存在") return False # 更新字段 if short_name is not None: stock.short_name = short_name if exchange is not None: stock.exchange = exchange if list_date is not None: stock.list_date = list_date stock.updated_at = datetime.now() session.commit() logger.info(f"成功更新股票信息: {stock_code}") return True except Exception as e: session.rollback() logger.error(f"更新股票信息失败: {e}") return False finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") return False def delete_stock(self, stock_code: str) -> bool: """ 删除股票信息 Args: stock_code: 股票代码 Returns: 删除是否成功 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return False try: session = self.Session() try: # 查询股票 stock = session.query(AStock).filter(AStock.stock_code == stock_code).first() if not stock: logger.warning(f"股票代码 {stock_code} 不存在") return False # 删除股票 session.delete(stock) session.commit() logger.info(f"成功删除股票信息: {stock_code}") return True except Exception as e: session.rollback() logger.error(f"删除股票信息失败: {e}") return False finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") return False def get_stock_by_code(self, stock_code: str) -> Optional[Dict[str, Any]]: """ 通过股票代码获取股票信息 Args: stock_code: 股票代码 Returns: 股票信息,如果不存在则返回None """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return None try: session = self.Session() try: # 查询股票 stock = session.query(AStock).filter(AStock.stock_code == stock_code).first() if stock: return { 'stock_code': stock.stock_code, 'short_name': stock.short_name, 'exchange': stock.exchange, 'list_date': stock.list_date, 'created_at': stock.created_at, 'updated_at': stock.updated_at } else: return None finally: session.close() except Exception as e: logger.error(f"获取股票信息失败: {e}") return None def search_stock(self, key: str, limit: int = 10) -> List[Dict[str, Any]]: """ 搜索股票 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return [] try: session = self.Session() try: # 查询股票 stocks = session.query(AStock).filter(AStock.short_name.like(f"{key}%") | AStock.stock_code.like(f"{key}%")).limit(limit).all() return [{ 'stock_code': stock.stock_code, 'short_name': stock.short_name, 'exchange': stock.exchange, 'list_date': stock.list_date, 'created_at': stock.created_at } for stock in stocks] finally: session.close() except Exception as e: logger.error(f"获取股票信息失败: {e}") return [] def get_stock_by_name(self, short_name: str) -> List[Dict[str, Any]]: """ 通过股票简称获取股票信息(可能有多个结果) Args: short_name: 股票简称 Returns: 股票信息列表 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return [] try: session = self.Session() try: # 查询股票 stocks = session.query(AStock).filter(AStock.short_name == short_name).all() return [{ 'stock_code': stock.stock_code, 'short_name': stock.short_name, 'exchange': stock.exchange, 'list_date': stock.list_date, 'created_at': stock.created_at } for stock in stocks] finally: session.close() except Exception as e: logger.error(f"获取股票信息失败: {e}") return [] def list_stocks(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]: """ 获取股票列表 Args: limit: 返回的最大数量 skip: 跳过的数量 Returns: 股票信息列表 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return [] try: session = self.Session() try: # 查询股票列表 stocks = session.query(AStock).order_by(AStock.stock_code).offset(skip).limit(limit).all() return [{ 'stock_code': stock.stock_code, 'short_name': stock.short_name, 'exchange': stock.exchange, 'list_date': stock.list_date, 'created_at': stock.created_at } for stock in stocks] finally: session.close() except Exception as e: logger.error(f"获取股票列表失败: {e}") return [] def create_token(self, symbol: str, base_asset: str, quote_asset: str) -> bool: """ 创建新的Token信息 Args: symbol: 交易对符号 base_asset: 基础资产 quote_asset: 计价资产 Returns: 创建是否成功 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return False try: session = self.Session() try: # 检查交易对是否已存在 existing_token = session.query(Token).filter(Token.symbol == symbol).first() if existing_token: logger.warning(f"交易对 {symbol} 已存在") return False # 创建新Token记录 new_token = Token( symbol=symbol, base_asset=base_asset, quote_asset=quote_asset ) session.add(new_token) session.commit() logger.info(f"成功创建Token信息: {symbol}") return True except Exception as e: session.rollback() logger.error(f"创建Token信息失败: {e}") return False finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") return False def delete_token(self, symbol: str) -> bool: """ 删除Token信息 Args: symbol: 交易对符号 Returns: 删除是否成功 """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return False try: session = self.Session() try: # 查询Token token = session.query(Token).filter(Token.symbol == symbol).first() if not token: logger.warning(f"交易对 {symbol} 不存在") return False # 删除Token session.delete(token) session.commit() logger.info(f"成功删除Token信息: {symbol}") return True except Exception as e: session.rollback() logger.error(f"删除Token信息失败: {e}") return False finally: session.close() except Exception as e: logger.error(f"创建数据库会话失败: {e}") return False def search_token(self, key: str, limit: int = 10) -> List[Dict[str, Any]]: """ 搜索Token """ if not self.engine: try: self._init_db() except Exception as e: logger.error(f"重新连接数据库失败: {e}") return [] try: session = self.Session() # 使用 SQLAlchemy 的 ORM 查询 tokens = session.query(Token).filter(Token.symbol.like(f"{key}%")).limit(limit).all() return [{ 'symbol': token.symbol, 'base_asset': token.base_asset, 'quote_asset': token.quote_asset } for token in tokens] except Exception as e: logger.error(f"获取Token信息失败: {e}") return [] finally: session.close() # 单例模式 _db_instance = None def get_db_manager(host: Optional[str] = None, port: Optional[int] = None, user: Optional[str] = None, password: Optional[str] = None, db_name: Optional[str] = None) -> DBManager: """ 获取数据库管理器实例(单例模式) Args: host: 数据库主机地址 port: 数据库端口 user: 用户名 password: 密码 db_name: 数据库名 Returns: 数据库管理器实例 """ global _db_instance # 如果已经初始化过,直接返回 if _db_instance is not None: return _db_instance # 如果未指定参数,从配置加载器获取数据库配置 if host is None or port is None or user is None or password is None or db_name is None: config_loader = ConfigLoader() db_config = config_loader.get_database_config() # 使用配置中的值或默认值 db_host = host or db_config.get('host') db_port = port or db_config.get('port') db_user = user or db_config.get('user') db_password = password or db_config.get('password') db_name = db_name or db_config.get('db_name') logger.info(f"从配置加载数据库连接信息: {db_host}:{db_port}/{db_name}") else: db_host = host db_port = port db_user = user db_password = password db_name = db_name # 创建实例 _db_instance = DBManager( host=db_host, port=db_port, user=db_user, password=db_password, db_name=db_name ) return _db_instance