1858 lines
61 KiB
Python
1858 lines
61 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import os
|
||
import json
|
||
import logging
|
||
from typing import Dict, Any, List, Optional, Union
|
||
from datetime import datetime
|
||
|
||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Index, text, ForeignKey
|
||
from sqlalchemy.ext.declarative import declarative_base
|
||
from sqlalchemy.orm import sessionmaker, relationship
|
||
from sqlalchemy.dialects.mysql import JSON
|
||
from sqlalchemy.pool import QueuePool
|
||
from sqlalchemy import or_
|
||
|
||
from cryptoai.utils.config_loader import ConfigLoader
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger('db_manager')
|
||
|
||
# 创建模型基类
|
||
Base = declarative_base()
|
||
|
||
# 定义Token模型
|
||
class Token(Base):
|
||
"""Token信息表模型"""
|
||
__tablename__ = 'tokens'
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
symbol = Column(String(50), nullable=False, unique=True, comment='交易对符号')
|
||
base_asset = Column(String(20), nullable=False, comment='基础资产')
|
||
quote_asset = Column(String(20), nullable=False, comment='计价资产')
|
||
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||
|
||
# 索引
|
||
__table_args__ = (
|
||
Index('idx_symbol', 'symbol'),
|
||
Index('idx_base_asset', 'base_asset'),
|
||
Index('idx_quote_asset', 'quote_asset'),
|
||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||
)
|
||
|
||
# 定义分析结果模型
|
||
class AnalysisResult(Base):
|
||
"""分析结果表模型"""
|
||
__tablename__ = 'analysis_results'
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
agent = Column(String(50), nullable=False, comment='智能体类型(crypto, gold)')
|
||
symbol = Column(String(50), nullable=False, comment='交易对符号')
|
||
time_interval = Column(String(20), nullable=False, comment='时间间隔')
|
||
completion_result = Column(JSON, nullable=False, comment='分析结果JSON')
|
||
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||
updated_at = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
|
||
|
||
# 索引
|
||
__table_args__ = (
|
||
Index('idx_agent', 'agent'),
|
||
Index('idx_symbol', 'symbol'),
|
||
Index('idx_time_interval', 'time_interval'),
|
||
Index('idx_created_at', 'created_at'),
|
||
)
|
||
|
||
# 定义AI Agent信息流模型
|
||
class AgentFeed(Base):
|
||
"""AI Agent信息流表模型"""
|
||
__tablename__ = 'agent_feeds'
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
agent_name = Column(String(50), nullable=False, comment='AI Agent名称')
|
||
avatar_url = Column(String(255), nullable=True, comment='头像URL')
|
||
content = Column(Text, nullable=False, comment='内容')
|
||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||
|
||
# 索引和表属性
|
||
__table_args__ = (
|
||
Index('idx_agent_name', 'agent_name'),
|
||
Index('idx_create_time', 'create_time'),
|
||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||
)
|
||
|
||
# 定义用户数据模型
|
||
class User(Base):
|
||
"""用户数据表模型"""
|
||
__tablename__ = 'users'
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
mail = Column(String(100), nullable=False, unique=True, comment='邮箱')
|
||
nickname = Column(String(50), nullable=False, comment='昵称')
|
||
password = Column(String(100), nullable=False, comment='密码')
|
||
level = Column(Integer, nullable=False, default=0, comment='用户级别(0=普通用户,1=VIP,2=SVIP)')
|
||
points = Column(Integer, nullable=False, default=0, comment='用户积分')
|
||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||
|
||
# 关系
|
||
questions = relationship("UserQuestion", back_populates="user")
|
||
|
||
# 索引和表属性
|
||
__table_args__ = (
|
||
Index('idx_mail', 'mail'),
|
||
Index('idx_level', 'level'),
|
||
Index('idx_create_time', 'create_time'),
|
||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||
)
|
||
|
||
# 定义用户提问数据模型
|
||
class UserQuestion(Base):
|
||
"""用户提问数据表模型"""
|
||
__tablename__ = 'user_questions'
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, comment='用户ID')
|
||
agent_id = Column(String(50), nullable=False, comment='AI Agent ID')
|
||
question = Column(Text, nullable=False, comment='提问内容')
|
||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||
|
||
# 关系
|
||
user = relationship("User", back_populates="questions")
|
||
|
||
# 索引和表属性
|
||
__table_args__ = (
|
||
Index('idx_user_id', 'user_id'),
|
||
Index('idx_agent_id', 'agent_id'),
|
||
Index('idx_create_time', 'create_time'),
|
||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||
)
|
||
|
||
# 定义Agent数据模型
|
||
class Agent(Base):
|
||
"""Agent数据表模型"""
|
||
__tablename__ = 'agents'
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
name = Column(String(100), nullable=False, unique=True, comment='Agent名称')
|
||
hello_prompt = Column(Text, nullable=True, comment='欢迎提示语')
|
||
description = Column(Text, nullable=True, comment='Agent描述')
|
||
dify_token = Column(String(255), nullable=True, comment='Dify API令牌')
|
||
inputs = Column(JSON, nullable=True, comment='输入参数JSON')
|
||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||
update_time = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
|
||
|
||
# 索引和表属性
|
||
__table_args__ = (
|
||
Index('idx_name', 'name'),
|
||
Index('idx_create_time', 'create_time'),
|
||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||
)
|
||
|
||
# 定义 A 股数据模型
|
||
class AStock(Base):
|
||
"""A股股票基本信息表模型"""
|
||
__tablename__ = 'astock'
|
||
|
||
stock_code = Column(String(10), primary_key=True, comment='股票代码')
|
||
short_name = Column(String(50), nullable=False, comment='股票简称')
|
||
exchange = Column(String(20), nullable=True, comment='交易所')
|
||
list_date = Column(DateTime, nullable=True, comment='上市日期')
|
||
created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||
|
||
# 索引
|
||
__table_args__ = (
|
||
Index('idx_stock_code', 'stock_code', unique=True),
|
||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||
)
|
||
|
||
def _convert_timestamp_to_datetime(timestamp: Union[int, float, str, None]) -> Optional[datetime]:
|
||
"""
|
||
将时间戳转换为datetime对象
|
||
|
||
Args:
|
||
timestamp: Unix时间戳(毫秒级)
|
||
|
||
Returns:
|
||
datetime对象或None
|
||
"""
|
||
if timestamp is None:
|
||
return None
|
||
|
||
try:
|
||
# 转换为整数
|
||
if isinstance(timestamp, str):
|
||
timestamp = int(timestamp)
|
||
|
||
# 如果是毫秒级时间戳,转换为秒级
|
||
if timestamp > 1e11: # 判断是否为毫秒级时间戳
|
||
timestamp = timestamp / 1000
|
||
|
||
return datetime.fromtimestamp(timestamp)
|
||
except (ValueError, TypeError) as e:
|
||
logger.error(f"时间戳转换失败: {timestamp}, 错误: {e}")
|
||
return None
|
||
|
||
class DBManager:
|
||
"""数据库管理工具,用于连接MySQL数据库并保存智能体分析结果"""
|
||
|
||
def __init__(self, host: str, port: int, user: str, password: str, db_name: str):
|
||
"""
|
||
初始化数据库管理器
|
||
|
||
Args:
|
||
host: 数据库主机地址
|
||
port: 数据库端口
|
||
user: 用户名
|
||
password: 密码
|
||
db_name: 数据库名
|
||
"""
|
||
self.host = host
|
||
self.port = port
|
||
self.user = user
|
||
self.password = password
|
||
self.db_name = db_name
|
||
self.engine = None
|
||
self.Session = None
|
||
|
||
# 初始化数据库连接
|
||
self._init_db()
|
||
|
||
def _init_db(self) -> None:
|
||
"""初始化数据库连接和表"""
|
||
try:
|
||
# 创建数据库连接
|
||
connection_string = f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.db_name}?charset=utf8mb4"
|
||
|
||
# 创建引擎,设置连接池
|
||
self.engine = create_engine(
|
||
connection_string,
|
||
echo=False, # 设置为True可以输出SQL语句(调试用)
|
||
pool_size=5, # 连接池大小
|
||
max_overflow=10, # 最大溢出连接数
|
||
pool_timeout=30, # 连接超时时间
|
||
pool_recycle=1800, # 连接回收时间(秒)
|
||
pool_pre_ping=True, # 在使用连接前先ping一下,确保连接有效
|
||
connect_args={'charset': 'utf8mb4'}
|
||
)
|
||
|
||
# 创建会话工厂
|
||
self.Session = sessionmaker(bind=self.engine)
|
||
|
||
# 创建表(如果不存在)
|
||
Base.metadata.create_all(self.engine)
|
||
|
||
logger.info(f"成功连接到数据库 {self.db_name}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"数据库初始化失败: {e}")
|
||
self.engine = None
|
||
|
||
def save_analysis_result(self, agent: str, symbol: str, time_interval: str,
|
||
analysis_result: Dict[str, Any]) -> bool:
|
||
"""
|
||
保存分析结果到数据库
|
||
|
||
Args:
|
||
agent: 智能体类型,例如 'crypto' 或 'gold'
|
||
symbol: 交易对符号,例如 'BTCUSDT'
|
||
time_interval: 时间间隔,例如 '1h', '4h', '1d'
|
||
analysis_result: 分析结果数据
|
||
|
||
Returns:
|
||
保存是否成功
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 创建新记录
|
||
new_result = AnalysisResult(
|
||
agent=agent,
|
||
symbol=symbol,
|
||
time_interval=time_interval,
|
||
completion_result=analysis_result,
|
||
created_at=datetime.now(),
|
||
updated_at=datetime.now()
|
||
)
|
||
|
||
# 添加并提交
|
||
session.add(new_result)
|
||
session.commit()
|
||
|
||
logger.info(f"成功保存 {agent} 分析结果,交易对: {symbol}, 时间间隔: {time_interval}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"保存分析结果失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
# 如果是连接错误,尝试重新初始化
|
||
try:
|
||
self._init_db()
|
||
except:
|
||
pass
|
||
return False
|
||
|
||
def save_agent_feed(self, agent_name: str, content: str, avatar_url: Optional[str] = None) -> bool:
|
||
"""
|
||
保存AI Agent信息流到数据库
|
||
|
||
Args:
|
||
agent_name: AI Agent名称
|
||
content: 内容
|
||
avatar_url: 头像URL,可选
|
||
|
||
Returns:
|
||
保存是否成功
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 创建新记录
|
||
new_feed = AgentFeed(
|
||
agent_name=agent_name,
|
||
content=content,
|
||
avatar_url=avatar_url
|
||
)
|
||
|
||
# 添加并提交
|
||
session.add(new_feed)
|
||
session.commit()
|
||
|
||
logger.info(f"成功保存 {agent_name} 的信息流")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"保存信息流失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
# 如果是连接错误,尝试重新初始化
|
||
try:
|
||
self._init_db()
|
||
except:
|
||
pass
|
||
return False
|
||
|
||
def register_user(self, mail: str, nickname: str, password: str, level: int = 0, points: int = 0) -> bool:
|
||
"""
|
||
注册新用户
|
||
|
||
Args:
|
||
mail: 邮箱
|
||
nickname: 昵称
|
||
password: 密码
|
||
level: 用户级别,默认为0(普通用户)
|
||
points: 初始积分,默认为0
|
||
|
||
Returns:
|
||
注册是否成功
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 检查邮箱是否已存在
|
||
existing_user = session.query(User).filter(User.mail == mail).first()
|
||
if existing_user:
|
||
logger.warning(f"邮箱 {mail} 已被注册")
|
||
return False
|
||
|
||
# 创建新用户
|
||
new_user = User(
|
||
mail=mail,
|
||
nickname=nickname,
|
||
password=password, # 实际应用中应该对密码进行哈希处理
|
||
level=level,
|
||
points=points,
|
||
create_time=datetime.now()
|
||
)
|
||
|
||
# 添加并提交
|
||
session.add(new_user)
|
||
session.commit()
|
||
|
||
logger.info(f"成功注册用户: {mail},初始积分: {points}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"注册用户失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
# 如果是连接错误,尝试重新初始化
|
||
try:
|
||
self._init_db()
|
||
except:
|
||
pass
|
||
return False
|
||
|
||
def get_user_by_id(self, user_id: int) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
通过ID获取用户信息
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
用户信息,如果用户不存在则返回None
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return None
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询用户
|
||
user = session.query(User).filter(User.id == user_id).first()
|
||
|
||
if user:
|
||
# 转换为字典
|
||
return {
|
||
'id': user.id,
|
||
'mail': user.mail,
|
||
'nickname': user.nickname,
|
||
'level': user.level,
|
||
'points': user.points,
|
||
'create_time': user.create_time
|
||
}
|
||
else:
|
||
return None
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取用户信息失败: {e}")
|
||
return None
|
||
|
||
def get_user_by_mail(self, mail: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
通过邮箱获取用户信息
|
||
|
||
Args:
|
||
mail: 邮箱
|
||
|
||
Returns:
|
||
用户信息,如果用户不存在则返回None
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return None
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询用户
|
||
user = session.query(User).filter(User.mail == mail).first()
|
||
|
||
if user:
|
||
# 转换为字典
|
||
return {
|
||
'id': user.id,
|
||
'mail': user.mail,
|
||
'nickname': user.nickname,
|
||
'level': user.level,
|
||
'points': user.points,
|
||
'create_time': user.create_time
|
||
}
|
||
else:
|
||
return None
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取用户信息失败: {e}")
|
||
return None
|
||
|
||
def update_user_level(self, user_id: int, level: int) -> bool:
|
||
"""
|
||
更新用户级别
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
level: 新的用户级别
|
||
|
||
Returns:
|
||
更新是否成功
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询用户
|
||
user = session.query(User).filter(User.id == user_id).first()
|
||
|
||
if not user:
|
||
logger.warning(f"用户ID {user_id} 不存在")
|
||
return False
|
||
|
||
# 更新级别
|
||
user.level = level
|
||
session.commit()
|
||
|
||
logger.info(f"成功更新用户 {user.mail} 的级别为 {level}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"更新用户级别失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
return False
|
||
|
||
def add_user_points(self, user_id: int, points: int) -> bool:
|
||
"""
|
||
为用户增加积分
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
points: 增加的积分数量(正数)
|
||
|
||
Returns:
|
||
操作是否成功
|
||
"""
|
||
if points <= 0:
|
||
logger.warning(f"增加的积分必须是正数: {points}")
|
||
return False
|
||
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询用户
|
||
user = session.query(User).filter(User.id == user_id).first()
|
||
|
||
if not user:
|
||
logger.warning(f"用户ID {user_id} 不存在")
|
||
return False
|
||
|
||
# 增加积分
|
||
user.points += points
|
||
session.commit()
|
||
|
||
logger.info(f"成功为用户 {user.mail} 增加 {points} 积分,当前积分: {user.points}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"增加用户积分失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
return False
|
||
|
||
def consume_user_points(self, user_id: int, points: int) -> bool:
|
||
"""
|
||
用户消费积分
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
points: 消费的积分数量(正数)
|
||
|
||
Returns:
|
||
操作是否成功
|
||
"""
|
||
if points <= 0:
|
||
logger.warning(f"消费的积分必须是正数: {points}")
|
||
return False
|
||
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询用户
|
||
user = session.query(User).filter(User.id == user_id).first()
|
||
|
||
if not user:
|
||
logger.warning(f"用户ID {user_id} 不存在")
|
||
return False
|
||
|
||
# 检查积分是否足够
|
||
if user.points < points:
|
||
logger.warning(f"用户 {user.mail} 积分不足,当前积分: {user.points},需要消费: {points}")
|
||
return False
|
||
|
||
# 消费积分
|
||
user.points -= points
|
||
session.commit()
|
||
|
||
logger.info(f"成功从用户 {user.mail} 消费 {points} 积分,剩余积分: {user.points}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"消费用户积分失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
return False
|
||
|
||
def get_agent_feeds(self, agent_name: Optional[str] = None, limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取AI Agent信息流
|
||
|
||
Args:
|
||
agent_name: 可选,指定获取特定Agent的信息流
|
||
limit: 返回的最大记录数,默认20条
|
||
|
||
Returns:
|
||
信息流列表,如果查询失败则返回空列表
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return []
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 构建查询
|
||
query = session.query(AgentFeed)
|
||
|
||
# 如果指定了agent_name,则筛选
|
||
if agent_name:
|
||
query = query.filter(AgentFeed.agent_name == agent_name)
|
||
|
||
# 按创建时间降序排序并限制数量
|
||
results = query.order_by(AgentFeed.create_time.desc()).offset(skip).limit(limit).all()
|
||
|
||
# 转换为字典列表
|
||
feeds = []
|
||
for result in results:
|
||
feeds.append({
|
||
'id': result.id,
|
||
'agent_name': result.agent_name,
|
||
'avatar_url': result.avatar_url,
|
||
'content': result.content,
|
||
'create_time': result.create_time
|
||
})
|
||
|
||
return feeds
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取信息流失败: {e}")
|
||
return []
|
||
|
||
def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool:
|
||
"""
|
||
保存用户提问数据
|
||
|
||
Args:
|
||
user_id: 用户ID
|
||
agent_id: AI Agent ID
|
||
question: 提问内容
|
||
|
||
Returns:
|
||
保存是否成功
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 创建新记录
|
||
new_question = UserQuestion(
|
||
user_id=user_id,
|
||
agent_id=agent_id,
|
||
question=question,
|
||
create_time=datetime.now()
|
||
)
|
||
|
||
# 添加并提交
|
||
session.add(new_question)
|
||
session.commit()
|
||
|
||
logger.info(f"成功保存用户 {user_id} 对 Agent {agent_id} 的提问")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"保存用户提问失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
# 如果是连接错误,尝试重新初始化
|
||
try:
|
||
self._init_db()
|
||
except:
|
||
pass
|
||
return False
|
||
|
||
def get_user_questions(self, user_id: Optional[int] = None, agent_id: Optional[str] = None,
|
||
limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取用户提问数据
|
||
|
||
Args:
|
||
user_id: 可选,指定获取特定用户的提问
|
||
agent_id: 可选,指定获取特定Agent的提问
|
||
limit: 返回的最大记录数,默认20条
|
||
skip: 跳过的记录数,默认0条
|
||
|
||
Returns:
|
||
提问数据列表,如果查询失败则返回空列表
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return []
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 构建查询
|
||
query = session.query(UserQuestion)
|
||
|
||
# 如果指定了user_id,则筛选
|
||
if user_id:
|
||
query = query.filter(UserQuestion.user_id == user_id)
|
||
|
||
# 如果指定了agent_id,则筛选
|
||
if agent_id:
|
||
query = query.filter(UserQuestion.agent_id == agent_id)
|
||
|
||
# 按创建时间降序排序并限制数量
|
||
results = query.order_by(UserQuestion.create_time.desc()).offset(skip).limit(limit).all()
|
||
|
||
# 转换为字典列表
|
||
questions = []
|
||
for result in results:
|
||
questions.append({
|
||
'id': result.id,
|
||
'user_id': result.user_id,
|
||
'agent_id': result.agent_id,
|
||
'question': result.question,
|
||
'create_time': result.create_time
|
||
})
|
||
|
||
return questions
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取用户提问失败: {e}")
|
||
return []
|
||
|
||
def get_user_question_by_id(self, question_id: int) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
通过ID获取用户提问数据
|
||
|
||
Args:
|
||
question_id: 提问ID
|
||
|
||
Returns:
|
||
提问数据,如果不存在则返回None
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return None
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询提问
|
||
result = session.query(UserQuestion).filter(UserQuestion.id == question_id).first()
|
||
|
||
if result:
|
||
# 转换为字典
|
||
return {
|
||
'id': result.id,
|
||
'user_id': result.user_id,
|
||
'agent_id': result.agent_id,
|
||
'question': result.question,
|
||
'create_time': result.create_time
|
||
}
|
||
else:
|
||
return None
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取用户提问失败: {e}")
|
||
return None
|
||
|
||
def get_latest_result(self, agent: str, symbol: str, time_interval: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
获取最新的分析结果
|
||
|
||
Args:
|
||
agent: 智能体类型,例如 'crypto' 或 'gold'
|
||
symbol: 交易对符号,例如 'BTCUSDT'
|
||
time_interval: 时间间隔,例如 '1h', '4h', '1d'
|
||
|
||
Returns:
|
||
最新分析结果,如果查询失败则返回None
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return None
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询最新的结果
|
||
result = session.query(AnalysisResult).filter(
|
||
AnalysisResult.agent == agent,
|
||
AnalysisResult.symbol == symbol,
|
||
AnalysisResult.time_interval == time_interval
|
||
).order_by(AnalysisResult.created_at.desc()).first()
|
||
|
||
if result:
|
||
# 转换为字典
|
||
return {
|
||
'id': result.id,
|
||
'agent': result.agent,
|
||
'symbol': result.symbol,
|
||
'time_interval': result.time_interval,
|
||
'completion_result': result.completion_result,
|
||
'created_at': result.created_at
|
||
}
|
||
else:
|
||
return None
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取最新分析结果失败: {e}")
|
||
return None
|
||
|
||
def close(self) -> None:
|
||
"""关闭数据库连接(如果存在)"""
|
||
if self.engine:
|
||
self.engine.dispose()
|
||
logger.info("数据库连接已关闭")
|
||
self.engine = None
|
||
self.Session = None
|
||
|
||
def create_agent(self, name: str, description: Optional[str] = None,
|
||
hello_prompt: Optional[str] = None, dify_token: Optional[str] = None,
|
||
inputs: Optional[Dict[str, Any]] = None) -> Optional[int]:
|
||
"""
|
||
创建新的Agent
|
||
|
||
Args:
|
||
name: Agent名称
|
||
description: Agent描述
|
||
hello_prompt: 欢迎提示语
|
||
dify_token: Dify API令牌
|
||
inputs: 输入参数JSON
|
||
|
||
Returns:
|
||
新Agent的ID,如果创建失败则返回None
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return None
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 检查名称是否已存在
|
||
existing_agent = session.query(Agent).filter(Agent.name == name).first()
|
||
if existing_agent:
|
||
logger.warning(f"Agent名称 {name} 已存在")
|
||
return None
|
||
|
||
# 创建新Agent
|
||
new_agent = Agent(
|
||
name=name,
|
||
description=description,
|
||
hello_prompt=hello_prompt,
|
||
dify_token=dify_token,
|
||
inputs=inputs,
|
||
create_time=datetime.now(),
|
||
update_time=datetime.now()
|
||
)
|
||
|
||
# 添加并提交
|
||
session.add(new_agent)
|
||
session.commit()
|
||
|
||
logger.info(f"成功创建Agent: {name}")
|
||
return new_agent.id
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"创建Agent失败: {e}")
|
||
return None
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
return None
|
||
|
||
def update_agent(self, agent_id: int, name: Optional[str] = None,
|
||
description: Optional[str] = None, hello_prompt: Optional[str] = None,
|
||
dify_token: Optional[str] = None, inputs: Optional[Dict[str, Any]] = None) -> bool:
|
||
"""
|
||
更新Agent信息
|
||
|
||
Args:
|
||
agent_id: Agent ID
|
||
name: Agent名称
|
||
description: Agent描述
|
||
hello_prompt: 欢迎提示语
|
||
dify_token: Dify API令牌
|
||
inputs: 输入参数JSON
|
||
|
||
Returns:
|
||
更新是否成功
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询Agent
|
||
agent = session.query(Agent).filter(Agent.id == agent_id).first()
|
||
if not agent:
|
||
logger.warning(f"Agent ID {agent_id} 不存在")
|
||
return False
|
||
|
||
# 更新字段
|
||
if name is not None:
|
||
# 检查名称是否与其他Agent冲突
|
||
if name != agent.name:
|
||
existing = session.query(Agent).filter(Agent.name == name).first()
|
||
if existing:
|
||
logger.warning(f"Agent名称 {name} 已被其他Agent使用")
|
||
return False
|
||
agent.name = name
|
||
|
||
if description is not None:
|
||
agent.description = description
|
||
|
||
if hello_prompt is not None:
|
||
agent.hello_prompt = hello_prompt
|
||
|
||
if dify_token is not None:
|
||
agent.dify_token = dify_token
|
||
|
||
if inputs is not None:
|
||
agent.inputs = inputs
|
||
|
||
agent.update_time = datetime.now()
|
||
|
||
# 提交更改
|
||
session.commit()
|
||
|
||
logger.info(f"成功更新Agent ID: {agent_id}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"更新Agent失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
return False
|
||
|
||
def get_agent_by_id(self, agent_id: int) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
通过ID获取Agent信息
|
||
|
||
Args:
|
||
agent_id: Agent ID
|
||
|
||
Returns:
|
||
Agent信息,如果不存在则返回None
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return None
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询Agent
|
||
agent = session.query(Agent).filter(Agent.id == agent_id).first()
|
||
|
||
if agent:
|
||
# 转换为字典
|
||
return {
|
||
'id': agent.id,
|
||
'name': agent.name,
|
||
'description': agent.description,
|
||
'hello_prompt': agent.hello_prompt,
|
||
'dify_token': agent.dify_token,
|
||
'inputs': agent.inputs,
|
||
'create_time': agent.create_time,
|
||
'update_time': agent.update_time
|
||
}
|
||
else:
|
||
return None
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取Agent信息失败: {e}")
|
||
return None
|
||
|
||
def get_agent_by_name(self, name: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
通过名称获取Agent信息
|
||
|
||
Args:
|
||
name: Agent名称
|
||
|
||
Returns:
|
||
Agent信息,如果不存在则返回None
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return None
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询Agent
|
||
agent = session.query(Agent).filter(Agent.name == name).first()
|
||
|
||
if agent:
|
||
# 转换为字典
|
||
return {
|
||
'id': agent.id,
|
||
'name': agent.name,
|
||
'description': agent.description,
|
||
'hello_prompt': agent.hello_prompt,
|
||
'dify_token': agent.dify_token,
|
||
'inputs': agent.inputs,
|
||
'create_time': agent.create_time,
|
||
'update_time': agent.update_time
|
||
}
|
||
else:
|
||
return None
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取Agent信息失败: {e}")
|
||
return None
|
||
|
||
def list_agents(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取Agent列表
|
||
|
||
Args:
|
||
limit: 返回的最大数量
|
||
skip: 跳过的数量
|
||
|
||
Returns:
|
||
Agent列表
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return []
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询Agent列表
|
||
agents = session.query(Agent).order_by(Agent.create_time.asc()).offset(skip).limit(limit).all()
|
||
|
||
# 转换为字典列表
|
||
return [{
|
||
'id': agent.id,
|
||
'name': agent.name,
|
||
'description': agent.description,
|
||
'hello_prompt': agent.hello_prompt,
|
||
'dify_token': agent.dify_token,
|
||
'inputs': agent.inputs,
|
||
'create_time': agent.create_time,
|
||
'update_time': agent.update_time
|
||
} for agent in agents]
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取Agent列表失败: {e}")
|
||
return []
|
||
|
||
def delete_agent(self, agent_id: int) -> bool:
|
||
"""
|
||
删除Agent
|
||
|
||
Args:
|
||
agent_id: Agent ID
|
||
|
||
Returns:
|
||
删除是否成功
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
# 创建会话
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询Agent
|
||
agent = session.query(Agent).filter(Agent.id == agent_id).first()
|
||
if not agent:
|
||
logger.warning(f"Agent ID {agent_id} 不存在")
|
||
return False
|
||
|
||
# 删除Agent
|
||
session.delete(agent)
|
||
session.commit()
|
||
|
||
logger.info(f"成功删除Agent ID: {agent_id}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"删除Agent失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
return False
|
||
|
||
def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool:
|
||
"""
|
||
批量创建股票信息
|
||
|
||
Args:
|
||
stocks: 股票信息列表,每个元素为包含以下键的字典:
|
||
- stock_code: 股票代码
|
||
- short_name: 股票简称
|
||
- exchange: 交易所(可选)
|
||
- list_date: 上市日期(可选,Unix时间戳,毫秒级)
|
||
|
||
Returns:
|
||
创建是否成功
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
session = self.Session()
|
||
|
||
try:
|
||
for stock_data in stocks:
|
||
# 检查股票代码是否已存在
|
||
existing_stock = session.query(AStock).filter(
|
||
AStock.stock_code == stock_data['stock_code']
|
||
).first()
|
||
|
||
if existing_stock:
|
||
logger.warning(f"股票代码 {stock_data['stock_code']} 已存在,跳过")
|
||
continue
|
||
|
||
# 创建新股票记录
|
||
new_stock = AStock(
|
||
stock_code=stock_data['stock_code'],
|
||
short_name=stock_data['short_name'],
|
||
exchange=stock_data.get('exchange')
|
||
)
|
||
|
||
# 处理上市日期
|
||
if 'list_date' in stock_data and stock_data['list_date']:
|
||
list_date = _convert_timestamp_to_datetime(stock_data['list_date'])
|
||
if list_date:
|
||
new_stock.list_date = list_date
|
||
|
||
session.add(new_stock)
|
||
|
||
# 批量提交
|
||
session.commit()
|
||
logger.info(f"成功批量创建股票信息,共 {len(stocks)} 条记录")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"批量创建股票信息失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
return False
|
||
|
||
def create_stock(self, stock_code: str, short_name: str,
|
||
exchange: Optional[str] = None,
|
||
list_date: Optional[Union[int, float, str]] = None) -> bool:
|
||
"""
|
||
创建新的股票信息
|
||
|
||
Args:
|
||
stock_code: 股票代码
|
||
short_name: 股票简称
|
||
exchange: 交易所(可选)
|
||
list_date: 上市日期(可选,Unix时间戳,毫秒级)
|
||
|
||
Returns:
|
||
创建是否成功
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 检查股票代码是否已存在
|
||
existing_stock = session.query(AStock).filter(AStock.stock_code == stock_code).first()
|
||
if existing_stock:
|
||
logger.warning(f"股票代码 {stock_code} 已存在")
|
||
return False
|
||
|
||
# 创建新股票记录
|
||
new_stock = AStock(
|
||
stock_code=stock_code,
|
||
short_name=short_name,
|
||
exchange=exchange
|
||
)
|
||
|
||
# 处理上市日期
|
||
if list_date:
|
||
converted_date = _convert_timestamp_to_datetime(list_date)
|
||
if converted_date:
|
||
new_stock.list_date = converted_date
|
||
|
||
session.add(new_stock)
|
||
session.commit()
|
||
|
||
logger.info(f"成功创建股票信息: {stock_code} - {short_name}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"创建股票信息失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
return False
|
||
|
||
def update_stock(self, stock_code: str, short_name: Optional[str] = None,
|
||
exchange: Optional[str] = None, list_date: Optional[datetime] = None) -> bool:
|
||
"""
|
||
更新股票信息
|
||
|
||
Args:
|
||
stock_code: 股票代码
|
||
short_name: 股票简称
|
||
exchange: 交易所
|
||
list_date: 上市日期
|
||
|
||
Returns:
|
||
更新是否成功
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询股票
|
||
stock = session.query(AStock).filter(AStock.stock_code == stock_code).first()
|
||
if not stock:
|
||
logger.warning(f"股票代码 {stock_code} 不存在")
|
||
return False
|
||
|
||
# 更新字段
|
||
if short_name is not None:
|
||
stock.short_name = short_name
|
||
if exchange is not None:
|
||
stock.exchange = exchange
|
||
if list_date is not None:
|
||
stock.list_date = list_date
|
||
|
||
stock.updated_at = datetime.now()
|
||
|
||
session.commit()
|
||
|
||
logger.info(f"成功更新股票信息: {stock_code}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"更新股票信息失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
return False
|
||
|
||
def delete_stock(self, stock_code: str) -> bool:
|
||
"""
|
||
删除股票信息
|
||
|
||
Args:
|
||
stock_code: 股票代码
|
||
|
||
Returns:
|
||
删除是否成功
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询股票
|
||
stock = session.query(AStock).filter(AStock.stock_code == stock_code).first()
|
||
if not stock:
|
||
logger.warning(f"股票代码 {stock_code} 不存在")
|
||
return False
|
||
|
||
# 删除股票
|
||
session.delete(stock)
|
||
session.commit()
|
||
|
||
logger.info(f"成功删除股票信息: {stock_code}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"删除股票信息失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
return False
|
||
|
||
def get_stock_by_code(self, stock_code: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
通过股票代码获取股票信息
|
||
|
||
Args:
|
||
stock_code: 股票代码
|
||
|
||
Returns:
|
||
股票信息,如果不存在则返回None
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return None
|
||
|
||
try:
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询股票
|
||
stock = session.query(AStock).filter(AStock.stock_code == stock_code).first()
|
||
|
||
if stock:
|
||
return {
|
||
'stock_code': stock.stock_code,
|
||
'short_name': stock.short_name,
|
||
'exchange': stock.exchange,
|
||
'list_date': stock.list_date,
|
||
'created_at': stock.created_at,
|
||
'updated_at': stock.updated_at
|
||
}
|
||
else:
|
||
return None
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取股票信息失败: {e}")
|
||
return None
|
||
|
||
def search_stock(self, key: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||
"""
|
||
搜索股票
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return []
|
||
|
||
try:
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询股票
|
||
stocks = session.query(AStock).filter(AStock.short_name.like(f"{key}%") | AStock.stock_code.like(f"{key}%")).limit(limit).all()
|
||
|
||
return [{
|
||
'stock_code': stock.stock_code,
|
||
'short_name': stock.short_name,
|
||
'exchange': stock.exchange,
|
||
'list_date': stock.list_date,
|
||
'created_at': stock.created_at
|
||
} for stock in stocks]
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取股票信息失败: {e}")
|
||
return []
|
||
|
||
def get_stock_by_name(self, short_name: str) -> List[Dict[str, Any]]:
|
||
"""
|
||
通过股票简称获取股票信息(可能有多个结果)
|
||
|
||
Args:
|
||
short_name: 股票简称
|
||
|
||
Returns:
|
||
股票信息列表
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return []
|
||
|
||
try:
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询股票
|
||
stocks = session.query(AStock).filter(AStock.short_name == short_name).all()
|
||
|
||
return [{
|
||
'stock_code': stock.stock_code,
|
||
'short_name': stock.short_name,
|
||
'exchange': stock.exchange,
|
||
'list_date': stock.list_date,
|
||
'created_at': stock.created_at
|
||
} for stock in stocks]
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取股票信息失败: {e}")
|
||
return []
|
||
|
||
def list_stocks(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取股票列表
|
||
|
||
Args:
|
||
limit: 返回的最大数量
|
||
skip: 跳过的数量
|
||
|
||
Returns:
|
||
股票信息列表
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return []
|
||
|
||
try:
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询股票列表
|
||
stocks = session.query(AStock).order_by(AStock.stock_code).offset(skip).limit(limit).all()
|
||
|
||
return [{
|
||
'stock_code': stock.stock_code,
|
||
'short_name': stock.short_name,
|
||
'exchange': stock.exchange,
|
||
'list_date': stock.list_date,
|
||
'created_at': stock.created_at
|
||
} for stock in stocks]
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取股票列表失败: {e}")
|
||
return []
|
||
|
||
def create_token(self, symbol: str, base_asset: str, quote_asset: str) -> bool:
|
||
"""
|
||
创建新的Token信息
|
||
|
||
Args:
|
||
symbol: 交易对符号
|
||
base_asset: 基础资产
|
||
quote_asset: 计价资产
|
||
|
||
Returns:
|
||
创建是否成功
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 检查交易对是否已存在
|
||
existing_token = session.query(Token).filter(Token.symbol == symbol).first()
|
||
if existing_token:
|
||
logger.warning(f"交易对 {symbol} 已存在")
|
||
return False
|
||
|
||
# 创建新Token记录
|
||
new_token = Token(
|
||
symbol=symbol,
|
||
base_asset=base_asset,
|
||
quote_asset=quote_asset
|
||
)
|
||
|
||
session.add(new_token)
|
||
session.commit()
|
||
|
||
logger.info(f"成功创建Token信息: {symbol}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"创建Token信息失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
return False
|
||
|
||
def delete_token(self, symbol: str) -> bool:
|
||
"""
|
||
删除Token信息
|
||
|
||
Args:
|
||
symbol: 交易对符号
|
||
|
||
Returns:
|
||
删除是否成功
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return False
|
||
|
||
try:
|
||
session = self.Session()
|
||
|
||
try:
|
||
# 查询Token
|
||
token = session.query(Token).filter(Token.symbol == symbol).first()
|
||
if not token:
|
||
logger.warning(f"交易对 {symbol} 不存在")
|
||
return False
|
||
|
||
# 删除Token
|
||
session.delete(token)
|
||
session.commit()
|
||
|
||
logger.info(f"成功删除Token信息: {symbol}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.error(f"删除Token信息失败: {e}")
|
||
return False
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建数据库会话失败: {e}")
|
||
return False
|
||
|
||
def search_token(self, key: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||
"""
|
||
搜索Token
|
||
"""
|
||
if not self.engine:
|
||
try:
|
||
self._init_db()
|
||
except Exception as e:
|
||
logger.error(f"重新连接数据库失败: {e}")
|
||
return []
|
||
|
||
try:
|
||
session = self.Session()
|
||
|
||
# 使用 SQLAlchemy 的 ORM 查询
|
||
tokens = session.query(Token).filter(Token.symbol.like(f"{key}%")).limit(limit).all()
|
||
|
||
return [{
|
||
'symbol': token.symbol,
|
||
'base_asset': token.base_asset,
|
||
'quote_asset': token.quote_asset
|
||
} for token in tokens]
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取Token信息失败: {e}")
|
||
return []
|
||
|
||
finally:
|
||
session.close()
|
||
|
||
# 单例模式
|
||
_db_instance = None
|
||
|
||
def get_db_manager(host: Optional[str] = None,
|
||
port: Optional[int] = None,
|
||
user: Optional[str] = None,
|
||
password: Optional[str] = None,
|
||
db_name: Optional[str] = None) -> DBManager:
|
||
"""
|
||
获取数据库管理器实例(单例模式)
|
||
|
||
Args:
|
||
host: 数据库主机地址
|
||
port: 数据库端口
|
||
user: 用户名
|
||
password: 密码
|
||
db_name: 数据库名
|
||
|
||
Returns:
|
||
数据库管理器实例
|
||
"""
|
||
global _db_instance
|
||
|
||
# 如果已经初始化过,直接返回
|
||
if _db_instance is not None:
|
||
return _db_instance
|
||
|
||
# 如果未指定参数,从配置加载器获取数据库配置
|
||
if host is None or port is None or user is None or password is None or db_name is None:
|
||
config_loader = ConfigLoader()
|
||
db_config = config_loader.get_database_config()
|
||
|
||
# 使用配置中的值或默认值
|
||
db_host = host or db_config.get('host')
|
||
db_port = port or db_config.get('port')
|
||
db_user = user or db_config.get('user')
|
||
db_password = password or db_config.get('password')
|
||
db_name = db_name or db_config.get('db_name')
|
||
|
||
logger.info(f"从配置加载数据库连接信息: {db_host}:{db_port}/{db_name}")
|
||
else:
|
||
db_host = host
|
||
db_port = port
|
||
db_user = user
|
||
db_password = password
|
||
db_name = db_name
|
||
|
||
# 创建实例
|
||
_db_instance = DBManager(
|
||
host=db_host,
|
||
port=db_port,
|
||
user=db_user,
|
||
password=db_password,
|
||
db_name=db_name
|
||
)
|
||
|
||
return _db_instance |