crypto.ai/cryptoai/utils/db_manager.py
2025-05-23 14:45:20 +08:00

1919 lines
63 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import json
import logging
from typing import Dict, Any, List, Optional, Union
from datetime import datetime
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Index, text, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.dialects.mysql import JSON
from sqlalchemy.pool import QueuePool
from sqlalchemy import or_
from cryptoai.utils.config_loader import ConfigLoader
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('db_manager')
# 创建模型基类
Base = declarative_base()
# 定义Token模型
class Token(Base):
"""Token信息表模型"""
__tablename__ = 'tokens'
id = Column(Integer, primary_key=True, autoincrement=True)
symbol = Column(String(50), nullable=False, unique=True, comment='交易对符号')
base_asset = Column(String(20), nullable=False, comment='基础资产')
quote_asset = Column(String(20), nullable=False, comment='计价资产')
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 索引
__table_args__ = (
Index('idx_symbol', 'symbol'),
Index('idx_base_asset', 'base_asset'),
Index('idx_quote_asset', 'quote_asset'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
# 定义分析结果模型
class AnalysisResult(Base):
"""分析结果表模型"""
__tablename__ = 'analysis_results'
id = Column(Integer, primary_key=True, autoincrement=True)
agent = Column(String(50), nullable=False, comment='智能体类型(crypto, gold)')
symbol = Column(String(50), nullable=False, comment='交易对符号')
time_interval = Column(String(20), nullable=False, comment='时间间隔')
completion_result = Column(JSON, nullable=False, comment='分析结果JSON')
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
updated_at = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
# 索引
__table_args__ = (
Index('idx_agent', 'agent'),
Index('idx_symbol', 'symbol'),
Index('idx_time_interval', 'time_interval'),
Index('idx_created_at', 'created_at'),
)
# 定义AI Agent信息流模型
class AgentFeed(Base):
"""AI Agent信息流表模型"""
__tablename__ = 'agent_feeds'
id = Column(Integer, primary_key=True, autoincrement=True)
agent_name = Column(String(50), nullable=False, comment='AI Agent名称')
avatar_url = Column(String(255), nullable=True, comment='头像URL')
content = Column(Text, nullable=False, comment='内容')
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 索引和表属性
__table_args__ = (
Index('idx_agent_name', 'agent_name'),
Index('idx_create_time', 'create_time'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
# 定义用户数据模型
class User(Base):
"""用户数据表模型"""
__tablename__ = 'users'
id = Column(Integer, primary_key=True, autoincrement=True)
mail = Column(String(100), nullable=False, unique=True, comment='邮箱')
nickname = Column(String(50), nullable=False, comment='昵称')
password = Column(String(100), nullable=False, comment='密码')
level = Column(Integer, nullable=False, default=0, comment='用户级别(0=普通用户,1=VIP,2=SVIP)')
points = Column(Integer, nullable=False, default=0, comment='用户积分')
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 关系
questions = relationship("UserQuestion", back_populates="user")
# 索引和表属性
__table_args__ = (
Index('idx_mail', 'mail'),
Index('idx_level', 'level'),
Index('idx_create_time', 'create_time'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
# 定义用户提问数据模型
class UserQuestion(Base):
"""用户提问数据表模型"""
__tablename__ = 'user_questions'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, comment='用户ID')
agent_id = Column(String(50), nullable=False, comment='AI Agent ID')
question = Column(Text, nullable=False, comment='提问内容')
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 关系
user = relationship("User", back_populates="questions")
# 索引和表属性
__table_args__ = (
Index('idx_user_id', 'user_id'),
Index('idx_agent_id', 'agent_id'),
Index('idx_create_time', 'create_time'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
# 定义Agent数据模型
class Agent(Base):
"""Agent数据表模型"""
__tablename__ = 'agents'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(100), nullable=False, unique=True, comment='Agent名称')
hello_prompt = Column(Text, nullable=True, comment='欢迎提示语')
description = Column(Text, nullable=True, comment='Agent描述')
dify_token = Column(String(255), nullable=True, comment='Dify API令牌')
inputs = Column(JSON, nullable=True, comment='输入参数JSON')
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
update_time = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
# 索引和表属性
__table_args__ = (
Index('idx_name', 'name'),
Index('idx_create_time', 'create_time'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
# 定义 A 股数据模型
class AStock(Base):
"""A股股票基本信息表模型"""
__tablename__ = 'astock'
stock_code = Column(String(10), primary_key=True, comment='股票代码')
short_name = Column(String(50), nullable=False, comment='股票简称')
exchange = Column(String(20), nullable=True, comment='交易所')
list_date = Column(DateTime, nullable=True, comment='上市日期')
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
# 索引
__table_args__ = (
Index('idx_stock_code', 'stock_code', unique=True),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
def _convert_timestamp_to_datetime(timestamp: Union[int, float, str, None]) -> Optional[datetime]:
"""
将时间戳转换为datetime对象
Args:
timestamp: Unix时间戳毫秒级
Returns:
datetime对象或None
"""
if timestamp is None:
return None
try:
# 转换为整数
if isinstance(timestamp, str):
timestamp = int(timestamp)
# 如果是毫秒级时间戳,转换为秒级
if timestamp > 1e11: # 判断是否为毫秒级时间戳
timestamp = timestamp / 1000
return datetime.fromtimestamp(timestamp)
except (ValueError, TypeError) as e:
logger.error(f"时间戳转换失败: {timestamp}, 错误: {e}")
return None
class DBManager:
"""数据库管理工具用于连接MySQL数据库并保存智能体分析结果"""
def __init__(self, host: str, port: int, user: str, password: str, db_name: str):
"""
初始化数据库管理器
Args:
host: 数据库主机地址
port: 数据库端口
user: 用户名
password: 密码
db_name: 数据库名
"""
self.host = host
self.port = port
self.user = user
self.password = password
self.db_name = db_name
self.engine = None
self.Session = None
# 初始化数据库连接
self._init_db()
def _init_db(self) -> None:
"""初始化数据库连接和表"""
try:
# 创建数据库连接
connection_string = f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}?charset=utf8mb4"
# 创建引擎,设置连接池
self.engine = create_engine(
connection_string,
echo=False, # 设置为True可以输出SQL语句调试用
pool_size=5, # 连接池大小
max_overflow=10, # 最大溢出连接数
pool_timeout=30, # 连接超时时间
pool_recycle=1800, # 连接回收时间(秒)
pool_pre_ping=True, # 在使用连接前先ping一下确保连接有效
connect_args={'charset': 'utf8mb4'}
)
# 创建会话工厂
self.Session = sessionmaker(bind=self.engine)
# 创建表(如果不存在)
Base.metadata.create_all(self.engine)
logger.info(f"成功连接到数据库 {self.db_name}")
except Exception as e:
logger.error(f"数据库初始化失败: {e}")
self.engine = None
def save_analysis_result(self, agent: str, symbol: str, time_interval: str,
analysis_result: Dict[str, Any]) -> bool:
"""
保存分析结果到数据库
Args:
agent: 智能体类型,例如 'crypto''gold'
symbol: 交易对符号,例如 'BTCUSDT'
time_interval: 时间间隔,例如 '1h', '4h', '1d'
analysis_result: 分析结果数据
Returns:
保存是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 创建新记录
new_result = AnalysisResult(
agent=agent,
symbol=symbol,
time_interval=time_interval,
completion_result=analysis_result,
created_at=datetime.now(),
updated_at=datetime.now()
)
# 添加并提交
session.add(new_result)
session.commit()
logger.info(f"成功保存 {agent} 分析结果,交易对: {symbol}, 时间间隔: {time_interval}")
return True
except Exception as e:
session.rollback()
logger.error(f"保存分析结果失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
# 如果是连接错误,尝试重新初始化
try:
self._init_db()
except:
pass
return False
def save_agent_feed(self, agent_name: str, content: str, avatar_url: Optional[str] = None) -> bool:
"""
保存AI Agent信息流到数据库
Args:
agent_name: AI Agent名称
content: 内容
avatar_url: 头像URL可选
Returns:
保存是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 创建新记录
new_feed = AgentFeed(
agent_name=agent_name,
content=content,
avatar_url=avatar_url
)
# 添加并提交
session.add(new_feed)
session.commit()
logger.info(f"成功保存 {agent_name} 的信息流")
return True
except Exception as e:
session.rollback()
logger.error(f"保存信息流失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
# 如果是连接错误,尝试重新初始化
try:
self._init_db()
except:
pass
return False
def register_user(self, mail: str, nickname: str, password: str, level: int = 0, points: int = 0) -> bool:
"""
注册新用户
Args:
mail: 邮箱
nickname: 昵称
password: 密码
level: 用户级别默认为0普通用户
points: 初始积分默认为0
Returns:
注册是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 检查邮箱是否已存在
existing_user = session.query(User).filter(User.mail == mail).first()
if existing_user:
logger.warning(f"邮箱 {mail} 已被注册")
return False
# 创建新用户
new_user = User(
mail=mail,
nickname=nickname,
password=password, # 实际应用中应该对密码进行哈希处理
level=level,
points=points,
create_time=datetime.now()
)
# 添加并提交
session.add(new_user)
session.commit()
logger.info(f"成功注册用户: {mail},初始积分: {points}")
return True
except Exception as e:
session.rollback()
logger.error(f"注册用户失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
# 如果是连接错误,尝试重新初始化
try:
self._init_db()
except:
pass
return False
def get_user_count(self) -> int:
"""
获取用户数量
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return 0
try:
# 创建会话
session = self.Session()
try:
# 查询用户数量
user_count = session.query(User).count()
return user_count
except Exception as e:
logger.error(f"获取用户数量失败: {e}")
return 0
finally:
session.close()
except Exception as e:
logger.error(f"获取用户数量失败: {e}")
return 0
def get_user_by_id(self, user_id: int) -> Optional[Dict[str, Any]]:
"""
通过ID获取用户信息
Args:
user_id: 用户ID
Returns:
用户信息如果用户不存在则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 查询用户
user = session.query(User).filter(User.id == user_id).first()
if user:
# 转换为字典
return {
'id': user.id,
'mail': user.mail,
'nickname': user.nickname,
'level': user.level,
'points': user.points,
'create_time': user.create_time
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
return None
def get_user_by_mail(self, mail: str) -> Optional[Dict[str, Any]]:
"""
通过邮箱获取用户信息
Args:
mail: 邮箱
Returns:
用户信息如果用户不存在则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 查询用户
user = session.query(User).filter(User.mail == mail).first()
if user:
# 转换为字典
return {
'id': user.id,
'mail': user.mail,
'nickname': user.nickname,
'level': user.level,
'points': user.points,
'create_time': user.create_time
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
return None
def update_user_level(self, user_id: int, level: int) -> bool:
"""
更新用户级别
Args:
user_id: 用户ID
level: 新的用户级别
Returns:
更新是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 查询用户
user = session.query(User).filter(User.id == user_id).first()
if not user:
logger.warning(f"用户ID {user_id} 不存在")
return False
# 更新级别
user.level = level
session.commit()
logger.info(f"成功更新用户 {user.mail} 的级别为 {level}")
return True
except Exception as e:
session.rollback()
logger.error(f"更新用户级别失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def add_user_points(self, user_id: int, points: int) -> bool:
"""
为用户增加积分
Args:
user_id: 用户ID
points: 增加的积分数量(正数)
Returns:
操作是否成功
"""
if points <= 0:
logger.warning(f"增加的积分必须是正数: {points}")
return False
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 查询用户
user = session.query(User).filter(User.id == user_id).first()
if not user:
logger.warning(f"用户ID {user_id} 不存在")
return False
# 增加积分
user.points += points
session.commit()
logger.info(f"成功为用户 {user.mail} 增加 {points} 积分,当前积分: {user.points}")
return True
except Exception as e:
session.rollback()
logger.error(f"增加用户积分失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def consume_user_points(self, user_id: int, points: int) -> bool:
"""
用户消费积分
Args:
user_id: 用户ID
points: 消费的积分数量(正数)
Returns:
操作是否成功
"""
if points <= 0:
logger.warning(f"消费的积分必须是正数: {points}")
return False
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 查询用户
user = session.query(User).filter(User.id == user_id).first()
if not user:
logger.warning(f"用户ID {user_id} 不存在")
return False
# 检查积分是否足够
if user.points < points:
logger.warning(f"用户 {user.mail} 积分不足,当前积分: {user.points},需要消费: {points}")
return False
# 消费积分
user.points -= points
session.commit()
logger.info(f"成功从用户 {user.mail} 消费 {points} 积分,剩余积分: {user.points}")
return True
except Exception as e:
session.rollback()
logger.error(f"消费用户积分失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def get_agent_feeds(self, agent_name: Optional[str] = None, limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
"""
获取AI Agent信息流
Args:
agent_name: 可选指定获取特定Agent的信息流
limit: 返回的最大记录数默认20条
Returns:
信息流列表,如果查询失败则返回空列表
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return []
try:
# 创建会话
session = self.Session()
try:
# 构建查询
query = session.query(AgentFeed)
# 如果指定了agent_name则筛选
if agent_name:
query = query.filter(AgentFeed.agent_name == agent_name)
# 按创建时间降序排序并限制数量
results = query.order_by(AgentFeed.create_time.desc()).offset(skip).limit(limit).all()
# 转换为字典列表
feeds = []
for result in results:
feeds.append({
'id': result.id,
'agent_name': result.agent_name,
'avatar_url': result.avatar_url,
'content': result.content,
'create_time': result.create_time
})
return feeds
finally:
session.close()
except Exception as e:
logger.error(f"获取信息流失败: {e}")
return []
def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool:
"""
保存用户提问数据
Args:
user_id: 用户ID
agent_id: AI Agent ID
question: 提问内容
Returns:
保存是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 创建新记录
new_question = UserQuestion(
user_id=user_id,
agent_id=agent_id,
question=question,
create_time=datetime.now()
)
# 添加并提交
session.add(new_question)
session.commit()
logger.info(f"成功保存用户 {user_id} 对 Agent {agent_id} 的提问")
return True
except Exception as e:
session.rollback()
logger.error(f"保存用户提问失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
# 如果是连接错误,尝试重新初始化
try:
self._init_db()
except:
pass
return False
def get_user_question_count(self) -> int:
"""
获取用户提问数量
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return 0
try:
# 创建会话
session = self.Session()
try:
# 查询用户提问数量
question_count = session.query(UserQuestion).count()
return question_count
except Exception as e:
logger.error(f"获取用户提问数量失败: {e}")
return 0
finally:
session.close()
except Exception as e:
logger.error(f"获取用户提问数量失败: {e}")
return 0
def get_user_questions(self, user_id: Optional[int] = None, agent_id: Optional[str] = None,
limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
"""
获取用户提问数据
Args:
user_id: 可选,指定获取特定用户的提问
agent_id: 可选指定获取特定Agent的提问
limit: 返回的最大记录数默认20条
skip: 跳过的记录数默认0条
Returns:
提问数据列表,如果查询失败则返回空列表
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return []
try:
# 创建会话
session = self.Session()
try:
# 构建查询
query = session.query(UserQuestion)
# 如果指定了user_id则筛选
if user_id:
query = query.filter(UserQuestion.user_id == user_id)
# 如果指定了agent_id则筛选
if agent_id:
query = query.filter(UserQuestion.agent_id == agent_id)
# 按创建时间降序排序并限制数量
results = query.order_by(UserQuestion.create_time.desc()).offset(skip).limit(limit).all()
# 转换为字典列表
questions = []
for result in results:
questions.append({
'id': result.id,
'user_id': result.user_id,
'agent_id': result.agent_id,
'question': result.question,
'create_time': result.create_time
})
return questions
finally:
session.close()
except Exception as e:
logger.error(f"获取用户提问失败: {e}")
return []
def get_user_question_by_id(self, question_id: int) -> Optional[Dict[str, Any]]:
"""
通过ID获取用户提问数据
Args:
question_id: 提问ID
Returns:
提问数据如果不存在则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 查询提问
result = session.query(UserQuestion).filter(UserQuestion.id == question_id).first()
if result:
# 转换为字典
return {
'id': result.id,
'user_id': result.user_id,
'agent_id': result.agent_id,
'question': result.question,
'create_time': result.create_time
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取用户提问失败: {e}")
return None
def get_latest_result(self, agent: str, symbol: str, time_interval: str) -> Optional[Dict[str, Any]]:
"""
获取最新的分析结果
Args:
agent: 智能体类型,例如 'crypto''gold'
symbol: 交易对符号,例如 'BTCUSDT'
time_interval: 时间间隔,例如 '1h', '4h', '1d'
Returns:
最新分析结果如果查询失败则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 查询最新的结果
result = session.query(AnalysisResult).filter(
AnalysisResult.agent == agent,
AnalysisResult.symbol == symbol,
AnalysisResult.time_interval == time_interval
).order_by(AnalysisResult.created_at.desc()).first()
if result:
# 转换为字典
return {
'id': result.id,
'agent': result.agent,
'symbol': result.symbol,
'time_interval': result.time_interval,
'completion_result': result.completion_result,
'created_at': result.created_at
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取最新分析结果失败: {e}")
return None
def close(self) -> None:
"""关闭数据库连接(如果存在)"""
if self.engine:
self.engine.dispose()
logger.info("数据库连接已关闭")
self.engine = None
self.Session = None
def create_agent(self, name: str, description: Optional[str] = None,
hello_prompt: Optional[str] = None, dify_token: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None) -> Optional[int]:
"""
创建新的Agent
Args:
name: Agent名称
description: Agent描述
hello_prompt: 欢迎提示语
dify_token: Dify API令牌
inputs: 输入参数JSON
Returns:
新Agent的ID如果创建失败则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 检查名称是否已存在
existing_agent = session.query(Agent).filter(Agent.name == name).first()
if existing_agent:
logger.warning(f"Agent名称 {name} 已存在")
return None
# 创建新Agent
new_agent = Agent(
name=name,
description=description,
hello_prompt=hello_prompt,
dify_token=dify_token,
inputs=inputs,
create_time=datetime.now(),
update_time=datetime.now()
)
# 添加并提交
session.add(new_agent)
session.commit()
logger.info(f"成功创建Agent: {name}")
return new_agent.id
except Exception as e:
session.rollback()
logger.error(f"创建Agent失败: {e}")
return None
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return None
def update_agent(self, agent_id: int, name: Optional[str] = None,
description: Optional[str] = None, hello_prompt: Optional[str] = None,
dify_token: Optional[str] = None, inputs: Optional[Dict[str, Any]] = None) -> bool:
"""
更新Agent信息
Args:
agent_id: Agent ID
name: Agent名称
description: Agent描述
hello_prompt: 欢迎提示语
dify_token: Dify API令牌
inputs: 输入参数JSON
Returns:
更新是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 查询Agent
agent = session.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
logger.warning(f"Agent ID {agent_id} 不存在")
return False
# 更新字段
if name is not None:
# 检查名称是否与其他Agent冲突
if name != agent.name:
existing = session.query(Agent).filter(Agent.name == name).first()
if existing:
logger.warning(f"Agent名称 {name} 已被其他Agent使用")
return False
agent.name = name
if description is not None:
agent.description = description
if hello_prompt is not None:
agent.hello_prompt = hello_prompt
if dify_token is not None:
agent.dify_token = dify_token
if inputs is not None:
agent.inputs = inputs
agent.update_time = datetime.now()
# 提交更改
session.commit()
logger.info(f"成功更新Agent ID: {agent_id}")
return True
except Exception as e:
session.rollback()
logger.error(f"更新Agent失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def get_agent_by_id(self, agent_id: int) -> Optional[Dict[str, Any]]:
"""
通过ID获取Agent信息
Args:
agent_id: Agent ID
Returns:
Agent信息如果不存在则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 查询Agent
agent = session.query(Agent).filter(Agent.id == agent_id).first()
if agent:
# 转换为字典
return {
'id': agent.id,
'name': agent.name,
'description': agent.description,
'hello_prompt': agent.hello_prompt,
'dify_token': agent.dify_token,
'inputs': agent.inputs,
'create_time': agent.create_time,
'update_time': agent.update_time
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取Agent信息失败: {e}")
return None
def get_agent_by_name(self, name: str) -> Optional[Dict[str, Any]]:
"""
通过名称获取Agent信息
Args:
name: Agent名称
Returns:
Agent信息如果不存在则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 查询Agent
agent = session.query(Agent).filter(Agent.name == name).first()
if agent:
# 转换为字典
return {
'id': agent.id,
'name': agent.name,
'description': agent.description,
'hello_prompt': agent.hello_prompt,
'dify_token': agent.dify_token,
'inputs': agent.inputs,
'create_time': agent.create_time,
'update_time': agent.update_time
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取Agent信息失败: {e}")
return None
def list_agents(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
"""
获取Agent列表
Args:
limit: 返回的最大数量
skip: 跳过的数量
Returns:
Agent列表
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return []
try:
# 创建会话
session = self.Session()
try:
# 查询Agent列表
agents = session.query(Agent).order_by(Agent.create_time.asc()).offset(skip).limit(limit).all()
# 转换为字典列表
return [{
'id': agent.id,
'name': agent.name,
'description': agent.description,
'hello_prompt': agent.hello_prompt,
'dify_token': agent.dify_token,
'inputs': agent.inputs,
'create_time': agent.create_time,
'update_time': agent.update_time
} for agent in agents]
finally:
session.close()
except Exception as e:
logger.error(f"获取Agent列表失败: {e}")
return []
def delete_agent(self, agent_id: int) -> bool:
"""
删除Agent
Args:
agent_id: Agent ID
Returns:
删除是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 查询Agent
agent = session.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
logger.warning(f"Agent ID {agent_id} 不存在")
return False
# 删除Agent
session.delete(agent)
session.commit()
logger.info(f"成功删除Agent ID: {agent_id}")
return True
except Exception as e:
session.rollback()
logger.error(f"删除Agent失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool:
"""
批量创建股票信息
Args:
stocks: 股票信息列表,每个元素为包含以下键的字典:
- stock_code: 股票代码
- short_name: 股票简称
- exchange: 交易所(可选)
- list_date: 上市日期可选Unix时间戳毫秒级
Returns:
创建是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
session = self.Session()
try:
for stock_data in stocks:
# 检查股票代码是否已存在
existing_stock = session.query(AStock).filter(
AStock.stock_code == stock_data['stock_code']
).first()
if existing_stock:
logger.warning(f"股票代码 {stock_data['stock_code']} 已存在,跳过")
continue
# 创建新股票记录
new_stock = AStock(
stock_code=stock_data['stock_code'],
short_name=stock_data['short_name'],
exchange=stock_data.get('exchange')
)
# 处理上市日期
if 'list_date' in stock_data and stock_data['list_date']:
list_date = _convert_timestamp_to_datetime(stock_data['list_date'])
if list_date:
new_stock.list_date = list_date
session.add(new_stock)
# 批量提交
session.commit()
logger.info(f"成功批量创建股票信息,共 {len(stocks)} 条记录")
return True
except Exception as e:
session.rollback()
logger.error(f"批量创建股票信息失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def create_stock(self, stock_code: str, short_name: str,
exchange: Optional[str] = None,
list_date: Optional[Union[int, float, str]] = None) -> bool:
"""
创建新的股票信息
Args:
stock_code: 股票代码
short_name: 股票简称
exchange: 交易所(可选)
list_date: 上市日期可选Unix时间戳毫秒级
Returns:
创建是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
session = self.Session()
try:
# 检查股票代码是否已存在
existing_stock = session.query(AStock).filter(AStock.stock_code == stock_code).first()
if existing_stock:
logger.warning(f"股票代码 {stock_code} 已存在")
return False
# 创建新股票记录
new_stock = AStock(
stock_code=stock_code,
short_name=short_name,
exchange=exchange
)
# 处理上市日期
if list_date:
converted_date = _convert_timestamp_to_datetime(list_date)
if converted_date:
new_stock.list_date = converted_date
session.add(new_stock)
session.commit()
logger.info(f"成功创建股票信息: {stock_code} - {short_name}")
return True
except Exception as e:
session.rollback()
logger.error(f"创建股票信息失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def update_stock(self, stock_code: str, short_name: Optional[str] = None,
exchange: Optional[str] = None, list_date: Optional[datetime] = None) -> bool:
"""
更新股票信息
Args:
stock_code: 股票代码
short_name: 股票简称
exchange: 交易所
list_date: 上市日期
Returns:
更新是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
session = self.Session()
try:
# 查询股票
stock = session.query(AStock).filter(AStock.stock_code == stock_code).first()
if not stock:
logger.warning(f"股票代码 {stock_code} 不存在")
return False
# 更新字段
if short_name is not None:
stock.short_name = short_name
if exchange is not None:
stock.exchange = exchange
if list_date is not None:
stock.list_date = list_date
stock.updated_at = datetime.now()
session.commit()
logger.info(f"成功更新股票信息: {stock_code}")
return True
except Exception as e:
session.rollback()
logger.error(f"更新股票信息失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def delete_stock(self, stock_code: str) -> bool:
"""
删除股票信息
Args:
stock_code: 股票代码
Returns:
删除是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
session = self.Session()
try:
# 查询股票
stock = session.query(AStock).filter(AStock.stock_code == stock_code).first()
if not stock:
logger.warning(f"股票代码 {stock_code} 不存在")
return False
# 删除股票
session.delete(stock)
session.commit()
logger.info(f"成功删除股票信息: {stock_code}")
return True
except Exception as e:
session.rollback()
logger.error(f"删除股票信息失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def get_stock_by_code(self, stock_code: str) -> Optional[Dict[str, Any]]:
"""
通过股票代码获取股票信息
Args:
stock_code: 股票代码
Returns:
股票信息如果不存在则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
session = self.Session()
try:
# 查询股票
stock = session.query(AStock).filter(AStock.stock_code == stock_code).first()
if stock:
return {
'stock_code': stock.stock_code,
'short_name': stock.short_name,
'exchange': stock.exchange,
'list_date': stock.list_date,
'created_at': stock.created_at,
'updated_at': stock.updated_at
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取股票信息失败: {e}")
return None
def search_stock(self, key: str, limit: int = 10) -> List[Dict[str, Any]]:
"""
搜索股票
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return []
try:
session = self.Session()
try:
# 查询股票
stocks = session.query(AStock).filter(AStock.short_name.like(f"{key}%") | AStock.stock_code.like(f"{key}%")).limit(limit).all()
return [{
'stock_code': stock.stock_code,
'short_name': stock.short_name,
'exchange': stock.exchange,
'list_date': stock.list_date,
'created_at': stock.created_at
} for stock in stocks]
finally:
session.close()
except Exception as e:
logger.error(f"获取股票信息失败: {e}")
return []
def get_stock_by_name(self, short_name: str) -> List[Dict[str, Any]]:
"""
通过股票简称获取股票信息(可能有多个结果)
Args:
short_name: 股票简称
Returns:
股票信息列表
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return []
try:
session = self.Session()
try:
# 查询股票
stocks = session.query(AStock).filter(AStock.short_name == short_name).all()
return [{
'stock_code': stock.stock_code,
'short_name': stock.short_name,
'exchange': stock.exchange,
'list_date': stock.list_date,
'created_at': stock.created_at
} for stock in stocks]
finally:
session.close()
except Exception as e:
logger.error(f"获取股票信息失败: {e}")
return []
def list_stocks(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
"""
获取股票列表
Args:
limit: 返回的最大数量
skip: 跳过的数量
Returns:
股票信息列表
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return []
try:
session = self.Session()
try:
# 查询股票列表
stocks = session.query(AStock).order_by(AStock.stock_code).offset(skip).limit(limit).all()
return [{
'stock_code': stock.stock_code,
'short_name': stock.short_name,
'exchange': stock.exchange,
'list_date': stock.list_date,
'created_at': stock.created_at
} for stock in stocks]
finally:
session.close()
except Exception as e:
logger.error(f"获取股票列表失败: {e}")
return []
def create_token(self, symbol: str, base_asset: str, quote_asset: str) -> bool:
"""
创建新的Token信息
Args:
symbol: 交易对符号
base_asset: 基础资产
quote_asset: 计价资产
Returns:
创建是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
session = self.Session()
try:
# 检查交易对是否已存在
existing_token = session.query(Token).filter(Token.symbol == symbol).first()
if existing_token:
logger.warning(f"交易对 {symbol} 已存在")
return False
# 创建新Token记录
new_token = Token(
symbol=symbol,
base_asset=base_asset,
quote_asset=quote_asset
)
session.add(new_token)
session.commit()
logger.info(f"成功创建Token信息: {symbol}")
return True
except Exception as e:
session.rollback()
logger.error(f"创建Token信息失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def delete_token(self, symbol: str) -> bool:
"""
删除Token信息
Args:
symbol: 交易对符号
Returns:
删除是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
session = self.Session()
try:
# 查询Token
token = session.query(Token).filter(Token.symbol == symbol).first()
if not token:
logger.warning(f"交易对 {symbol} 不存在")
return False
# 删除Token
session.delete(token)
session.commit()
logger.info(f"成功删除Token信息: {symbol}")
return True
except Exception as e:
session.rollback()
logger.error(f"删除Token信息失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def search_token(self, key: str, limit: int = 10) -> List[Dict[str, Any]]:
"""
搜索Token
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return []
try:
session = self.Session()
# 使用 SQLAlchemy 的 ORM 查询
tokens = session.query(Token).filter(Token.symbol.like(f"{key}%") | Token.base_asset.like(f"{key}%")).limit(limit).all()
return [{
'symbol': token.symbol,
'base_asset': token.base_asset,
'quote_asset': token.quote_asset
} for token in tokens]
except Exception as e:
logger.error(f"获取Token信息失败: {e}")
return []
finally:
session.close()
# 单例模式
_db_instance = None
def get_db_manager(host: Optional[str] = None,
port: Optional[int] = None,
user: Optional[str] = None,
password: Optional[str] = None,
db_name: Optional[str] = None) -> DBManager:
"""
获取数据库管理器实例(单例模式)
Args:
host: 数据库主机地址
port: 数据库端口
user: 用户名
password: 密码
db_name: 数据库名
Returns:
数据库管理器实例
"""
global _db_instance
# 如果已经初始化过,直接返回
if _db_instance is not None:
return _db_instance
# 如果未指定参数,从配置加载器获取数据库配置
if host is None or port is None or user is None or password is None or db_name is None:
config_loader = ConfigLoader()
db_config = config_loader.get_database_config()
# 使用配置中的值或默认值
db_host = host or db_config.get('host')
db_port = port or db_config.get('port')
db_user = user or db_config.get('user')
db_password = password or db_config.get('password')
db_name = db_name or db_config.get('db_name')
logger.info(f"从配置加载数据库连接信息: {db_host}:{db_port}/{db_name}")
else:
db_host = host
db_port = port
db_user = user
db_password = password
db_name = db_name
# 创建实例
_db_instance = DBManager(
host=db_host,
port=db_port,
user=db_user,
password=db_password,
db_name=db_name
)
return _db_instance