diff --git a/cryptoai/agents/crypto_agent.py b/cryptoai/agents/crypto_agent.py index b93f774..a45f8a3 100644 --- a/cryptoai/agents/crypto_agent.py +++ b/cryptoai/agents/crypto_agent.py @@ -285,7 +285,7 @@ class CryptoAgent: # 保存分析结果到数据库 try: # 保存到数据库 - saved = self.db_manager.save_analysis_result( + saved = self.db_manager.analysis_result_manager.save_analysis_result( agent='crypto', symbol=symbol, time_interval=self.time_interval, @@ -363,7 +363,7 @@ class CryptoAgent: # 保存交易建议到数据库 try: - saved = self.db_manager.save_agent_feed( + saved = self.db_manager.agent_feed_manager.save_agent_feed( agent_name="Crypto Agent", content=message ) diff --git a/cryptoai/models/__init__.py b/cryptoai/models/__init__.py index 61b498c..cd47491 100644 --- a/cryptoai/models/__init__.py +++ b/cryptoai/models/__init__.py @@ -1,4 +1,14 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -"""模型模块,提供数据处理和模型定义。""" \ No newline at end of file +"""模型模块,提供数据处理和模型定义。""" + +from cryptoai.models.base import Base +from cryptoai.models.token import Token, TokenManager +from cryptoai.models.analysis_result import AnalysisResult, AnalysisResultManager +from cryptoai.models.agent_feed import AgentFeed, AgentFeedManager +from cryptoai.models.user import User, UserManager +from cryptoai.models.user_question import UserQuestion, UserQuestionManager +from cryptoai.models.agent import Agent, AgentManager +from cryptoai.models.astock import AStock, AStockManager +from cryptoai.models.analysis_history import AnalysisHistory, AnalysisHistoryManager \ No newline at end of file diff --git a/cryptoai/models/__pycache__/__init__.cpython-313.pyc b/cryptoai/models/__pycache__/__init__.cpython-313.pyc index 8c58016..b7f3263 100644 Binary files a/cryptoai/models/__pycache__/__init__.cpython-313.pyc and b/cryptoai/models/__pycache__/__init__.cpython-313.pyc differ diff --git a/cryptoai/models/agent.py b/cryptoai/models/agent.py new file mode 100644 index 0000000..44cd211 --- /dev/null +++ b/cryptoai/models/agent.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Dict, Any, List, Optional +from datetime import datetime + +from sqlalchemy import Column, Integer, String, Text, DateTime, Index +from sqlalchemy.dialects.mysql import JSON + +from cryptoai.models.base import Base, logger + +# 定义Agent数据模型 +class Agent(Base): + """Agent数据表模型""" + __tablename__ = 'agents' + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(100), nullable=False, unique=True, comment='Agent名称') + hello_prompt = Column(Text, nullable=True, comment='欢迎提示语') + description = Column(Text, nullable=True, comment='Agent描述') + dify_token = Column(String(255), nullable=True, comment='Dify API令牌') + inputs = Column(JSON, nullable=True, comment='输入参数JSON') + create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') + update_time = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间') + + # 索引和表属性 + __table_args__ = ( + Index('idx_name', 'name'), + Index('idx_create_time', 'create_time'), + {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} + ) + +class AgentManager: + """Agent管理类""" + + def __init__(self, db_session): + self.session = db_session + + def create_agent(self, name: str, description: Optional[str] = None, + hello_prompt: Optional[str] = None, dify_token: Optional[str] = None, + inputs: Optional[Dict[str, Any]] = None) -> Optional[int]: + """ + 创建新的Agent + + Args: + name: Agent名称 + description: Agent描述 + hello_prompt: 欢迎提示语 + dify_token: Dify API令牌 + inputs: 输入参数JSON + + Returns: + 新Agent的ID,如果创建失败则返回None + """ + try: + # 检查名称是否已存在 + existing_agent = self.session.query(Agent).filter(Agent.name == name).first() + if existing_agent: + logger.warning(f"Agent名称 {name} 已存在") + return None + + # 创建新Agent + new_agent = Agent( + name=name, + description=description, + hello_prompt=hello_prompt, + dify_token=dify_token, + inputs=inputs, + create_time=datetime.now(), + update_time=datetime.now() + ) + + # 添加并提交 + self.session.add(new_agent) + self.session.commit() + + logger.info(f"成功创建Agent: {name}") + return new_agent.id + + except Exception as e: + self.session.rollback() + logger.error(f"创建Agent失败: {e}") + return None + + def update_agent(self, agent_id: int, name: Optional[str] = None, + description: Optional[str] = None, hello_prompt: Optional[str] = None, + dify_token: Optional[str] = None, inputs: Optional[Dict[str, Any]] = None) -> bool: + """ + 更新Agent信息 + + Args: + agent_id: Agent ID + name: Agent名称 + description: Agent描述 + hello_prompt: 欢迎提示语 + dify_token: Dify API令牌 + inputs: 输入参数JSON + + Returns: + 更新是否成功 + """ + try: + # 查询Agent + agent = self.session.query(Agent).filter(Agent.id == agent_id).first() + if not agent: + logger.warning(f"Agent ID {agent_id} 不存在") + return False + + # 更新字段 + if name is not None: + # 检查名称是否与其他Agent冲突 + if name != agent.name: + existing = self.session.query(Agent).filter(Agent.name == name).first() + if existing: + logger.warning(f"Agent名称 {name} 已被其他Agent使用") + return False + agent.name = name + + if description is not None: + agent.description = description + + if hello_prompt is not None: + agent.hello_prompt = hello_prompt + + if dify_token is not None: + agent.dify_token = dify_token + + if inputs is not None: + agent.inputs = inputs + + agent.update_time = datetime.now() + + # 提交更改 + self.session.commit() + + logger.info(f"成功更新Agent ID: {agent_id}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"更新Agent失败: {e}") + return False + + def get_agent_by_id(self, agent_id: int) -> Optional[Dict[str, Any]]: + """ + 通过ID获取Agent信息 + + Args: + agent_id: Agent ID + + Returns: + Agent信息,如果不存在则返回None + """ + try: + # 查询Agent + agent = self.session.query(Agent).filter(Agent.id == agent_id).first() + + if agent: + # 转换为字典 + return { + 'id': agent.id, + 'name': agent.name, + 'description': agent.description, + 'hello_prompt': agent.hello_prompt, + 'dify_token': agent.dify_token, + 'inputs': agent.inputs, + 'create_time': agent.create_time, + 'update_time': agent.update_time + } + else: + return None + + except Exception as e: + logger.error(f"获取Agent信息失败: {e}") + return None + + def get_agent_by_name(self, name: str) -> Optional[Dict[str, Any]]: + """ + 通过名称获取Agent信息 + + Args: + name: Agent名称 + + Returns: + Agent信息,如果不存在则返回None + """ + try: + # 查询Agent + agent = self.session.query(Agent).filter(Agent.name == name).first() + + if agent: + # 转换为字典 + return { + 'id': agent.id, + 'name': agent.name, + 'description': agent.description, + 'hello_prompt': agent.hello_prompt, + 'dify_token': agent.dify_token, + 'inputs': agent.inputs, + 'create_time': agent.create_time, + 'update_time': agent.update_time + } + else: + return None + + except Exception as e: + logger.error(f"获取Agent信息失败: {e}") + return None + + def list_agents(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]: + """ + 获取Agent列表 + + Args: + limit: 返回的最大数量 + skip: 跳过的数量 + + Returns: + Agent列表 + """ + try: + # 查询Agent列表 + agents = self.session.query(Agent).order_by(Agent.create_time.asc()).offset(skip).limit(limit).all() + + # 转换为字典列表 + return [{ + 'id': agent.id, + 'name': agent.name, + 'description': agent.description, + 'hello_prompt': agent.hello_prompt, + 'dify_token': agent.dify_token, + 'inputs': agent.inputs, + 'create_time': agent.create_time, + 'update_time': agent.update_time + } for agent in agents] + + except Exception as e: + logger.error(f"获取Agent列表失败: {e}") + return [] + + def delete_agent(self, agent_id: int) -> bool: + """ + 删除Agent + + Args: + agent_id: Agent ID + + Returns: + 删除是否成功 + """ + try: + # 查询Agent + agent = self.session.query(Agent).filter(Agent.id == agent_id).first() + if not agent: + logger.warning(f"Agent ID {agent_id} 不存在") + return False + + # 删除Agent + self.session.delete(agent) + self.session.commit() + + logger.info(f"成功删除Agent ID: {agent_id}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"删除Agent失败: {e}") + return False \ No newline at end of file diff --git a/cryptoai/models/agent_feed.py b/cryptoai/models/agent_feed.py new file mode 100644 index 0000000..9a7197e --- /dev/null +++ b/cryptoai/models/agent_feed.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Dict, Any, List, Optional +from datetime import datetime + +from sqlalchemy import Column, Integer, String, Text, DateTime, Index + +from cryptoai.models.base import Base, logger + +# 定义AI Agent信息流模型 +class AgentFeed(Base): + """AI Agent信息流表模型""" + __tablename__ = 'agent_feeds' + + id = Column(Integer, primary_key=True, autoincrement=True) + agent_name = Column(String(50), nullable=False, comment='AI Agent名称') + avatar_url = Column(String(255), nullable=True, comment='头像URL') + content = Column(Text, nullable=False, comment='内容') + create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') + + # 索引和表属性 + __table_args__ = ( + Index('idx_agent_name', 'agent_name'), + Index('idx_create_time', 'create_time'), + {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} + ) + +class AgentFeedManager: + """Agent信息流管理类""" + + def __init__(self, db_session): + self.session = db_session + + def save_agent_feed(self, agent_name: str, content: str, avatar_url: Optional[str] = None) -> bool: + """ + 保存AI Agent信息流到数据库 + + Args: + agent_name: AI Agent名称 + content: 内容 + avatar_url: 头像URL,可选 + + Returns: + 保存是否成功 + """ + try: + # 创建新记录 + new_feed = AgentFeed( + agent_name=agent_name, + content=content, + avatar_url=avatar_url + ) + + # 添加并提交 + self.session.add(new_feed) + self.session.commit() + + logger.info(f"成功保存 {agent_name} 的信息流") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"保存信息流失败: {e}") + return False + + def get_agent_feeds(self, agent_name: Optional[str] = None, limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]: + """ + 获取AI Agent信息流 + + Args: + agent_name: 可选,指定获取特定Agent的信息流 + limit: 返回的最大记录数,默认20条 + skip: 跳过的记录数,默认0条 + + Returns: + 信息流列表,如果查询失败则返回空列表 + """ + try: + # 构建查询 + query = self.session.query(AgentFeed) + + # 如果指定了agent_name,则筛选 + if agent_name: + query = query.filter(AgentFeed.agent_name == agent_name) + + # 按创建时间降序排序并限制数量 + results = query.order_by(AgentFeed.create_time.desc()).offset(skip).limit(limit).all() + + # 转换为字典列表 + feeds = [] + for result in results: + feeds.append({ + 'id': result.id, + 'agent_name': result.agent_name, + 'avatar_url': result.avatar_url, + 'content': result.content, + 'create_time': result.create_time + }) + + return feeds + + except Exception as e: + logger.error(f"获取信息流失败: {e}") + return [] \ No newline at end of file diff --git a/cryptoai/models/analysis_history.py b/cryptoai/models/analysis_history.py new file mode 100644 index 0000000..57cc1d0 --- /dev/null +++ b/cryptoai/models/analysis_history.py @@ -0,0 +1,189 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Dict, Any, List, Optional +from datetime import datetime + +from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey +from sqlalchemy.orm import relationship + +from cryptoai.models.base import Base, logger + +# 定义分析历史模型 +class AnalysisHistory(Base): + """分析历史表模型""" + __tablename__ = 'analysis_history' + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, ForeignKey('users.id'), nullable=False, comment='用户ID') + type = Column(String(20), nullable=False, comment='分析类型(crypto, astock)') + symbol = Column(String(50), nullable=False, comment='交易符号') + timeframe = Column(String(20), nullable=True, comment='时间框架') + content = Column(Text, nullable=False, comment='分析内容') + create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') + + # 关系 + user = relationship("User", back_populates="analysis_histories") + + # 索引和表属性 + __table_args__ = ( + Index('idx_user_id', 'user_id'), + Index('idx_type', 'type'), + Index('idx_symbol', 'symbol'), + Index('idx_create_time', 'create_time'), + {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} + ) + +class AnalysisHistoryManager: + """分析历史管理类""" + + def __init__(self, db_session): + self.session = db_session + + def add_analysis_history(self, user_id: int, type: str, symbol: str, + content: str, timeframe: str = None) -> bool: + """ + 添加分析历史记录 + + Args: + user_id: 用户ID + type: 分析类型(crypto, astock) + symbol: 交易符号 + content: 分析内容 + timeframe: 时间框架(可选) + + Returns: + 添加是否成功 + """ + try: + # 创建新记录 + new_history = AnalysisHistory( + user_id=user_id, + type=type, + symbol=symbol, + timeframe=timeframe, + content=content, + create_time=datetime.now() + ) + + # 添加并提交 + self.session.add(new_history) + self.session.commit() + + logger.info(f"成功添加分析历史记录,用户ID: {user_id}, 类型: {type}, 交易符号: {symbol}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"添加分析历史记录失败: {e}") + return False + + def delete_analysis_history(self, history_id: int) -> bool: + """ + 删除分析历史记录 + + Args: + history_id: 历史记录ID + + Returns: + 删除是否成功 + """ + try: + # 查询记录 + history = self.session.query(AnalysisHistory).filter(AnalysisHistory.id == history_id).first() + + if not history: + logger.warning(f"分析历史记录ID {history_id} 不存在") + return False + + # 删除记录 + self.session.delete(history) + self.session.commit() + + logger.info(f"成功删除分析历史记录,ID: {history_id}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"删除分析历史记录失败: {e}") + return False + + def get_analysis_history_by_id(self, history_id: int) -> Optional[Dict[str, Any]]: + """ + 根据ID获取分析历史记录 + + Args: + history_id: 历史记录ID + + Returns: + 分析历史记录,如果不存在则返回None + """ + try: + # 查询记录 + history = self.session.query(AnalysisHistory).filter(AnalysisHistory.id == history_id).first() + + if history: + # 转换为字典 + return { + 'id': history.id, + 'user_id': history.user_id, + 'type': history.type, + 'symbol': history.symbol, + 'timeframe': history.timeframe, + 'content': history.content, + 'create_time': history.create_time + } + else: + return None + + except Exception as e: + logger.error(f"获取分析历史记录失败: {e}") + return None + + def get_user_analysis_history(self, user_id: int, type: str = None, + symbol: str = None, limit: int = 10, + offset: int = 0) -> List[Dict[str, Any]]: + """ + 获取用户的分析历史记录 + + Args: + user_id: 用户ID + type: 分析类型筛选(可选) + symbol: 交易符号筛选(可选) + limit: 返回记录数量限制 + offset: 分页偏移量 + + Returns: + 分析历史记录列表 + """ + try: + # 构建查询 + query = self.session.query(AnalysisHistory).filter(AnalysisHistory.user_id == user_id) + + # 添加可选过滤条件 + if type: + query = query.filter(AnalysisHistory.type == type) + if symbol: + query = query.filter(AnalysisHistory.symbol == symbol) + + # 按创建时间降序排序,并应用分页 + results = query.order_by(AnalysisHistory.create_time.desc()).limit(limit).offset(offset).all() + + # 转换为字典列表 + history_list = [] + for history in results: + history_list.append({ + 'id': history.id, + 'user_id': history.user_id, + 'type': history.type, + 'symbol': history.symbol, + 'timeframe': history.timeframe, + 'content': history.content, + 'create_time': history.create_time + }) + + return history_list + + except Exception as e: + logger.error(f"获取用户分析历史记录失败: {e}") + return [] \ No newline at end of file diff --git a/cryptoai/models/analysis_result.py b/cryptoai/models/analysis_result.py new file mode 100644 index 0000000..b522e85 --- /dev/null +++ b/cryptoai/models/analysis_result.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Dict, Any, List, Optional +from datetime import datetime + +from sqlalchemy import Column, Integer, String, DateTime, Index +from sqlalchemy.dialects.mysql import JSON + +from cryptoai.models.base import Base, logger + +# 定义分析结果模型 +class AnalysisResult(Base): + """分析结果表模型""" + __tablename__ = 'analysis_results' + + id = Column(Integer, primary_key=True, autoincrement=True) + agent = Column(String(50), nullable=False, comment='智能体类型(crypto, gold)') + symbol = Column(String(50), nullable=False, comment='交易对符号') + time_interval = Column(String(20), nullable=False, comment='时间间隔') + completion_result = Column(JSON, nullable=False, comment='分析结果JSON') + created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') + updated_at = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间') + + # 索引 + __table_args__ = ( + Index('idx_agent', 'agent'), + Index('idx_symbol', 'symbol'), + Index('idx_time_interval', 'time_interval'), + Index('idx_created_at', 'created_at'), + ) + +class AnalysisResultManager: + """分析结果管理类""" + + def __init__(self, db_session): + self.session = db_session + + def save_analysis_result(self, agent: str, symbol: str, time_interval: str, + analysis_result: Dict[str, Any]) -> bool: + """ + 保存分析结果到数据库 + + Args: + agent: 智能体类型,例如 'crypto' 或 'gold' + symbol: 交易对符号,例如 'BTCUSDT' + time_interval: 时间间隔,例如 '1h', '4h', '1d' + analysis_result: 分析结果数据 + + Returns: + 保存是否成功 + """ + try: + # 创建新记录 + new_result = AnalysisResult( + agent=agent, + symbol=symbol, + time_interval=time_interval, + completion_result=analysis_result, + created_at=datetime.now(), + updated_at=datetime.now() + ) + + # 添加并提交 + self.session.add(new_result) + self.session.commit() + + logger.info(f"成功保存 {agent} 分析结果,交易对: {symbol}, 时间间隔: {time_interval}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"保存分析结果失败: {e}") + return False + + def get_latest_result(self, agent: str, symbol: str, time_interval: str) -> Optional[Dict[str, Any]]: + """ + 获取最新的分析结果 + + Args: + agent: 智能体类型,例如 'crypto' 或 'gold' + symbol: 交易对符号,例如 'BTCUSDT' + time_interval: 时间间隔,例如 '1h', '4h', '1d' + + Returns: + 最新分析结果,如果查询失败则返回None + """ + try: + # 查询最新的结果 + result = self.session.query(AnalysisResult).filter( + AnalysisResult.agent == agent, + AnalysisResult.symbol == symbol, + AnalysisResult.time_interval == time_interval + ).order_by(AnalysisResult.created_at.desc()).first() + + if result: + # 转换为字典 + return { + 'id': result.id, + 'agent': result.agent, + 'symbol': result.symbol, + 'time_interval': result.time_interval, + 'completion_result': result.completion_result, + 'created_at': result.created_at + } + else: + return None + + except Exception as e: + logger.error(f"获取最新分析结果失败: {e}") + return None \ No newline at end of file diff --git a/cryptoai/models/astock.py b/cryptoai/models/astock.py new file mode 100644 index 0000000..90e179e --- /dev/null +++ b/cryptoai/models/astock.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Dict, Any, List, Optional, Union +from datetime import datetime + +from sqlalchemy import Column, Integer, String, DateTime, Index + +from cryptoai.models.base import Base, logger +from cryptoai.utils.db_utils import convert_timestamp_to_datetime + +# 定义 A 股数据模型 +class AStock(Base): + """A股股票基本信息表模型""" + __tablename__ = 'astock' + + stock_code = Column(String(10), primary_key=True, comment='股票代码') + short_name = Column(String(50), nullable=False, comment='股票简称') + exchange = Column(String(20), nullable=True, comment='交易所') + list_date = Column(DateTime, nullable=True, comment='上市日期') + created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') + + # 索引 + __table_args__ = ( + Index('idx_stock_code', 'stock_code', unique=True), + {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} + ) + +class AStockManager: + """A股管理类""" + + def __init__(self, db_session): + self.session = db_session + + def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool: + """ + 批量创建股票信息 + + Args: + stocks: 股票信息列表,每个元素为包含以下键的字典: + - stock_code: 股票代码 + - short_name: 股票简称 + - exchange: 交易所(可选) + - list_date: 上市日期(可选,Unix时间戳,毫秒级) + + Returns: + 创建是否成功 + """ + try: + for stock_data in stocks: + # 检查股票代码是否已存在 + existing_stock = self.session.query(AStock).filter( + AStock.stock_code == stock_data['stock_code'] + ).first() + + if existing_stock: + logger.warning(f"股票代码 {stock_data['stock_code']} 已存在,跳过") + continue + + # 创建新股票记录 + new_stock = AStock( + stock_code=stock_data['stock_code'], + short_name=stock_data['short_name'], + exchange=stock_data.get('exchange') + ) + + # 处理上市日期 + if 'list_date' in stock_data and stock_data['list_date']: + list_date = convert_timestamp_to_datetime(stock_data['list_date']) + if list_date: + new_stock.list_date = list_date + + self.session.add(new_stock) + + # 批量提交 + self.session.commit() + logger.info(f"成功批量创建股票信息,共 {len(stocks)} 条记录") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"批量创建股票信息失败: {e}") + return False + + def create_stock(self, stock_code: str, short_name: str, + exchange: Optional[str] = None, + list_date: Optional[Union[int, float, str]] = None) -> bool: + """ + 创建新的股票信息 + + Args: + stock_code: 股票代码 + short_name: 股票简称 + exchange: 交易所(可选) + list_date: 上市日期(可选,Unix时间戳,毫秒级) + + Returns: + 创建是否成功 + """ + try: + # 检查股票代码是否已存在 + existing_stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first() + if existing_stock: + logger.warning(f"股票代码 {stock_code} 已存在") + return False + + # 创建新股票记录 + new_stock = AStock( + stock_code=stock_code, + short_name=short_name, + exchange=exchange + ) + + # 处理上市日期 + if list_date: + converted_date = convert_timestamp_to_datetime(list_date) + if converted_date: + new_stock.list_date = converted_date + + self.session.add(new_stock) + self.session.commit() + + logger.info(f"成功创建股票信息: {stock_code} - {short_name}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"创建股票信息失败: {e}") + return False + + def update_stock(self, stock_code: str, short_name: Optional[str] = None, + exchange: Optional[str] = None, + list_date: Optional[Union[int, float, str, datetime]] = None) -> bool: + """ + 更新股票信息 + + Args: + stock_code: 股票代码 + short_name: 股票简称 + exchange: 交易所 + list_date: 上市日期(可以是datetime对象或时间戳) + + Returns: + 更新是否成功 + """ + try: + # 查询股票 + stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first() + if not stock: + logger.warning(f"股票代码 {stock_code} 不存在") + return False + + # 更新字段 + if short_name is not None: + stock.short_name = short_name + if exchange is not None: + stock.exchange = exchange + if list_date is not None: + if isinstance(list_date, datetime): + stock.list_date = list_date + else: + converted_date = convert_timestamp_to_datetime(list_date) + if converted_date: + stock.list_date = converted_date + + self.session.commit() + + logger.info(f"成功更新股票信息: {stock_code}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"更新股票信息失败: {e}") + return False + + def delete_stock(self, stock_code: str) -> bool: + """ + 删除股票信息 + + Args: + stock_code: 股票代码 + + Returns: + 删除是否成功 + """ + try: + # 查询股票 + stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first() + if not stock: + logger.warning(f"股票代码 {stock_code} 不存在") + return False + + # 删除股票 + self.session.delete(stock) + self.session.commit() + + logger.info(f"成功删除股票信息: {stock_code}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"删除股票信息失败: {e}") + return False + + def get_stock_by_code(self, stock_code: str) -> Optional[Dict[str, Any]]: + """ + 通过股票代码获取股票信息 + + Args: + stock_code: 股票代码 + + Returns: + 股票信息,如果不存在则返回None + """ + try: + # 查询股票 + stock = self.session.query(AStock).filter(AStock.stock_code == stock_code).first() + + if stock: + return { + 'stock_code': stock.stock_code, + 'short_name': stock.short_name, + 'exchange': stock.exchange, + 'list_date': stock.list_date, + 'created_at': stock.created_at + } + else: + return None + + except Exception as e: + logger.error(f"获取股票信息失败: {e}") + return None + + def search_stock(self, key: str, limit: int = 10) -> List[Dict[str, Any]]: + """ + 搜索股票 + + Args: + key: 搜索关键词 + limit: 最大返回数量 + + Returns: + 股票信息列表 + """ + try: + # 查询股票 + stocks = self.session.query(AStock).filter( + AStock.short_name.like(f"{key}%") | AStock.stock_code.like(f"{key}%") + ).limit(limit).all() + + return [{ + 'stock_code': stock.stock_code, + 'short_name': stock.short_name, + 'exchange': stock.exchange, + 'list_date': stock.list_date, + 'created_at': stock.created_at + } for stock in stocks] + + except Exception as e: + logger.error(f"获取股票信息失败: {e}") + return [] + + def get_stock_by_name(self, short_name: str) -> List[Dict[str, Any]]: + """ + 通过股票简称获取股票信息(可能有多个结果) + + Args: + short_name: 股票简称 + + Returns: + 股票信息列表 + """ + try: + # 查询股票 + stocks = self.session.query(AStock).filter(AStock.short_name == short_name).all() + + return [{ + 'stock_code': stock.stock_code, + 'short_name': stock.short_name, + 'exchange': stock.exchange, + 'list_date': stock.list_date, + 'created_at': stock.created_at + } for stock in stocks] + + except Exception as e: + logger.error(f"获取股票信息失败: {e}") + return [] + + def list_stocks(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]: + """ + 获取股票列表 + + Args: + limit: 返回的最大数量 + skip: 跳过的数量 + + Returns: + 股票信息列表 + """ + try: + # 查询股票列表 + stocks = self.session.query(AStock).order_by(AStock.stock_code).offset(skip).limit(limit).all() + + return [{ + 'stock_code': stock.stock_code, + 'short_name': stock.short_name, + 'exchange': stock.exchange, + 'list_date': stock.list_date, + 'created_at': stock.created_at + } for stock in stocks] + + except Exception as e: + logger.error(f"获取股票列表失败: {e}") + return [] \ No newline at end of file diff --git a/cryptoai/models/base.py b/cryptoai/models/base.py new file mode 100644 index 0000000..2cd33a1 --- /dev/null +++ b/cryptoai/models/base.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +from datetime import datetime + +from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Index, text, ForeignKey +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker, relationship +from sqlalchemy.dialects.mysql import JSON + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger('db_models') + +# 创建模型基类 +Base = declarative_base() \ No newline at end of file diff --git a/cryptoai/models/token.py b/cryptoai/models/token.py new file mode 100644 index 0000000..45df755 --- /dev/null +++ b/cryptoai/models/token.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Dict, Any, List, Optional +from datetime import datetime + +from sqlalchemy import Column, Integer, String, DateTime, Index + +from cryptoai.models.base import Base, logger + +# 定义Token模型 +class Token(Base): + """Token信息表模型""" + __tablename__ = 'tokens' + + id = Column(Integer, primary_key=True, autoincrement=True) + symbol = Column(String(50), nullable=False, unique=True, comment='交易对符号') + base_asset = Column(String(20), nullable=False, comment='基础资产') + quote_asset = Column(String(20), nullable=False, comment='计价资产') + created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') + + # 索引 + __table_args__ = ( + Index('idx_symbol', 'symbol'), + Index('idx_base_asset', 'base_asset'), + Index('idx_quote_asset', 'quote_asset'), + {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} + ) + +class TokenManager: + """Token管理类""" + + def __init__(self, db_session): + self.session = db_session + + def create_token(self, symbol: str, base_asset: str, quote_asset: str) -> bool: + """ + 创建新的Token信息 + + Args: + symbol: 交易对符号 + base_asset: 基础资产 + quote_asset: 计价资产 + + Returns: + 创建是否成功 + """ + try: + # 检查交易对是否已存在 + existing_token = self.session.query(Token).filter(Token.symbol == symbol).first() + if existing_token: + logger.warning(f"交易对 {symbol} 已存在") + return False + + # 创建新Token记录 + new_token = Token( + symbol=symbol, + base_asset=base_asset, + quote_asset=quote_asset + ) + + self.session.add(new_token) + self.session.commit() + + logger.info(f"成功创建Token信息: {symbol}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"创建Token信息失败: {e}") + return False + + def delete_token(self, symbol: str) -> bool: + """ + 删除Token信息 + + Args: + symbol: 交易对符号 + + Returns: + 删除是否成功 + """ + try: + # 查询Token + token = self.session.query(Token).filter(Token.symbol == symbol).first() + if not token: + logger.warning(f"交易对 {symbol} 不存在") + return False + + # 删除Token + self.session.delete(token) + self.session.commit() + + logger.info(f"成功删除Token信息: {symbol}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"删除Token信息失败: {e}") + return False + + def search_token(self, key: str, limit: int = 10) -> List[Dict[str, Any]]: + """ + 搜索Token + + Args: + key: 搜索关键词 + limit: 最大返回数量 + + Returns: + Token信息列表 + """ + try: + # 使用 SQLAlchemy 的 ORM 查询 + tokens = self.session.query(Token).filter( + Token.symbol.like(f"{key}%") | Token.base_asset.like(f"{key}%") + ).limit(limit).all() + + return [{ + 'symbol': token.symbol, + 'base_asset': token.base_asset, + 'quote_asset': token.quote_asset + } for token in tokens] + + except Exception as e: + logger.error(f"搜索Token信息失败: {e}") + return [] \ No newline at end of file diff --git a/cryptoai/models/user.py b/cryptoai/models/user.py new file mode 100644 index 0000000..e03ad85 --- /dev/null +++ b/cryptoai/models/user.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Dict, Any, List, Optional +from datetime import datetime + +from sqlalchemy import Column, Integer, String, DateTime, Index +from sqlalchemy.orm import relationship + +from cryptoai.models.base import Base, logger + +# 定义用户数据模型 +class User(Base): + """用户数据表模型""" + __tablename__ = 'users' + + id = Column(Integer, primary_key=True, autoincrement=True) + mail = Column(String(100), nullable=False, unique=True, comment='邮箱') + nickname = Column(String(50), nullable=False, comment='昵称') + password = Column(String(100), nullable=False, comment='密码') + level = Column(Integer, nullable=False, default=0, comment='用户级别(0=普通用户,1=VIP,2=SVIP)') + points = Column(Integer, nullable=False, default=0, comment='用户积分') + create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') + + # 关系 + questions = relationship("UserQuestion", back_populates="user") + analysis_histories = relationship("AnalysisHistory", back_populates="user") + + # 索引和表属性 + __table_args__ = ( + Index('idx_mail', 'mail'), + Index('idx_level', 'level'), + Index('idx_create_time', 'create_time'), + {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} + ) + +class UserManager: + """用户管理类""" + + def __init__(self, db_session): + self.session = db_session + + def register_user(self, mail: str, nickname: str, password: str, level: int = 0, points: int = 0) -> bool: + """ + 注册新用户 + + Args: + mail: 邮箱 + nickname: 昵称 + password: 密码 + level: 用户级别,默认为0(普通用户) + points: 初始积分,默认为0 + + Returns: + 注册是否成功 + """ + try: + # 检查邮箱是否已存在 + existing_user = self.session.query(User).filter(User.mail == mail).first() + if existing_user: + logger.warning(f"邮箱 {mail} 已被注册") + return False + + # 创建新用户 + new_user = User( + mail=mail, + nickname=nickname, + password=password, # 实际应用中应该对密码进行哈希处理 + level=level, + points=points, + create_time=datetime.now() + ) + + # 添加并提交 + self.session.add(new_user) + self.session.commit() + + logger.info(f"成功注册用户: {mail},初始积分: {points}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"注册用户失败: {e}") + return False + + def get_user_count(self) -> int: + """ + 获取用户数量 + """ + try: + # 查询用户数量 + user_count = self.session.query(User).count() + + return user_count + except Exception as e: + logger.error(f"获取用户数量失败: {e}") + return 0 + + def get_user_by_id(self, user_id: int) -> Optional[Dict[str, Any]]: + """ + 通过ID获取用户信息 + + Args: + user_id: 用户ID + + Returns: + 用户信息,如果用户不存在则返回None + """ + try: + # 查询用户 + user = self.session.query(User).filter(User.id == user_id).first() + + if user: + # 转换为字典 + return { + 'id': user.id, + 'mail': user.mail, + 'nickname': user.nickname, + 'level': user.level, + 'points': user.points, + 'create_time': user.create_time + } + else: + return None + + except Exception as e: + logger.error(f"获取用户信息失败: {e}") + return None + + def login(self, mail: str, password: str) -> Optional[Dict[str, Any]]: + """ + 登录 + """ + + user = self.session.query(User).filter(User.mail == mail).first() + + if not user: + return None + + if user.password != password: + return None + + return {'id': user.id, 'mail': user.mail, 'nickname': user.nickname, 'level': user.level, 'points': user.points, 'create_time': user.create_time} + + def get_user_by_mail(self, mail: str) -> Optional[Dict[str, Any]]: + """ + 通过邮箱获取用户信息 + + Args: + mail: 邮箱 + + Returns: + 用户信息,如果用户不存在则返回None + """ + try: + # 查询用户 + user = self.session.query(User).filter(User.mail == mail).first() + + if user: + # 转换为字典 + return { + 'id': user.id, + 'mail': user.mail, + 'nickname': user.nickname, + 'level': user.level, + 'points': user.points, + 'create_time': user.create_time + } + else: + return None + + except Exception as e: + logger.error(f"获取用户信息失败: {e}") + return None + + def update_user_level(self, user_id: int, level: int) -> bool: + """ + 更新用户级别 + + Args: + user_id: 用户ID + level: 新的用户级别 + + Returns: + 更新是否成功 + """ + try: + # 查询用户 + user = self.session.query(User).filter(User.id == user_id).first() + + if not user: + logger.warning(f"用户ID {user_id} 不存在") + return False + + # 更新级别 + user.level = level + self.session.commit() + + logger.info(f"成功更新用户 {user.mail} 的级别为 {level}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"更新用户级别失败: {e}") + return False + + def add_user_points(self, user_id: int, points: int) -> bool: + """ + 为用户增加积分 + + Args: + user_id: 用户ID + points: 增加的积分数量(正数) + + Returns: + 操作是否成功 + """ + if points <= 0: + logger.warning(f"增加的积分必须是正数: {points}") + return False + + try: + # 查询用户 + user = self.session.query(User).filter(User.id == user_id).first() + + if not user: + logger.warning(f"用户ID {user_id} 不存在") + return False + + # 增加积分 + user.points += points + self.session.commit() + + logger.info(f"成功为用户 {user.mail} 增加 {points} 积分,当前积分: {user.points}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"增加用户积分失败: {e}") + return False + + def consume_user_points(self, user_id: int, points: int) -> bool: + """ + 用户消费积分 + + Args: + user_id: 用户ID + points: 消费的积分数量(正数) + + Returns: + 操作是否成功 + """ + if points <= 0: + logger.warning(f"消费的积分必须是正数: {points}") + return False + + try: + # 查询用户 + user = self.session.query(User).filter(User.id == user_id).first() + + if not user: + logger.warning(f"用户ID {user_id} 不存在") + return False + + # 检查积分是否足够 + if user.points < points: + logger.warning(f"用户 {user.mail} 积分不足,当前积分: {user.points},需要消费: {points}") + return False + + # 消费积分 + user.points -= points + self.session.commit() + + logger.info(f"成功从用户 {user.mail} 消费 {points} 积分,剩余积分: {user.points}") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"消费用户积分失败: {e}") + return False \ No newline at end of file diff --git a/cryptoai/models/user_question.py b/cryptoai/models/user_question.py new file mode 100644 index 0000000..02c6b12 --- /dev/null +++ b/cryptoai/models/user_question.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from typing import Dict, Any, List, Optional +from datetime import datetime + +from sqlalchemy import Column, Integer, String, Text, DateTime, Index, ForeignKey +from sqlalchemy.orm import relationship + +from cryptoai.models.base import Base, logger +from cryptoai.models.user import User + +# 定义用户提问数据模型 +class UserQuestion(Base): + """用户提问数据表模型""" + __tablename__ = 'user_questions' + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, ForeignKey('users.id'), nullable=False, comment='用户ID') + agent_id = Column(String(50), nullable=False, comment='AI Agent ID') + question = Column(Text, nullable=False, comment='提问内容') + create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') + + # 关系 + user = relationship("User", back_populates="questions") + + # 索引和表属性 + __table_args__ = ( + Index('idx_user_id', 'user_id'), + Index('idx_agent_id', 'agent_id'), + Index('idx_create_time', 'create_time'), + {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} + ) + +class UserQuestionManager: + """用户提问管理类""" + + def __init__(self, db_session): + self.session = db_session + + def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool: + """ + 保存用户提问数据 + + Args: + user_id: 用户ID + agent_id: AI Agent ID + question: 提问内容 + + Returns: + 保存是否成功 + """ + try: + # 创建新记录 + new_question = UserQuestion( + user_id=user_id, + agent_id=agent_id, + question=question, + create_time=datetime.now() + ) + + # 添加并提交 + self.session.add(new_question) + self.session.commit() + + logger.info(f"成功保存用户 {user_id} 对 Agent {agent_id} 的提问") + return True + + except Exception as e: + self.session.rollback() + logger.error(f"保存用户提问失败: {e}") + return False + + def get_user_question_count(self) -> int: + """ + 获取用户提问数量 + + Returns: + 提问总数 + """ + try: + # 查询用户提问数量 + question_count = self.session.query(UserQuestion).count() + + return question_count + except Exception as e: + logger.error(f"获取用户提问数量失败: {e}") + return 0 + + def get_user_questions(self, user_id: Optional[int] = None, agent_id: Optional[str] = None, + limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]: + """ + 获取用户提问数据 + + Args: + user_id: 可选,指定获取特定用户的提问 + agent_id: 可选,指定获取特定Agent的提问 + limit: 返回的最大记录数,默认20条 + skip: 跳过的记录数,默认0条 + + Returns: + 提问数据列表,如果查询失败则返回空列表 + """ + try: + # 构建查询 + query = self.session.query(UserQuestion) + + # 如果指定了user_id,则筛选 + if user_id: + query = query.filter(UserQuestion.user_id == user_id) + + # 如果指定了agent_id,则筛选 + if agent_id: + query = query.filter(UserQuestion.agent_id == agent_id) + + # 按创建时间降序排序并限制数量 + results = query.order_by(UserQuestion.create_time.desc()).offset(skip).limit(limit).all() + + # 转换为字典列表 + questions = [] + for result in results: + questions.append({ + 'id': result.id, + 'user_id': result.user_id, + 'agent_id': result.agent_id, + 'question': result.question, + 'create_time': result.create_time + }) + + return questions + + except Exception as e: + logger.error(f"获取用户提问失败: {e}") + return [] + + def get_user_question_by_id(self, question_id: int) -> Optional[Dict[str, Any]]: + """ + 通过ID获取用户提问数据 + + Args: + question_id: 提问ID + + Returns: + 提问数据,如果不存在则返回None + """ + try: + # 查询提问 + result = self.session.query(UserQuestion).filter(UserQuestion.id == question_id).first() + + if result: + # 转换为字典 + return { + 'id': result.id, + 'user_id': result.user_id, + 'agent_id': result.agent_id, + 'question': result.question, + 'create_time': result.create_time + } + else: + return None + + except Exception as e: + logger.error(f"获取用户提问失败: {e}") + return None \ No newline at end of file diff --git a/cryptoai/routes/adata.py b/cryptoai/routes/adata.py index 61cd48d..4cba3b2 100644 --- a/cryptoai/routes/adata.py +++ b/cryptoai/routes/adata.py @@ -22,7 +22,7 @@ logger.setLevel(logging.DEBUG) @router.get("/stock/search") async def search_stock(key: str, limit: int = 10): manager = get_db_manager() - result = manager.search_stock(key, limit) + result = manager.astock_manager.search_stock(key, limit) return result @@ -123,16 +123,9 @@ async def get_stock_data_all(stock_code: str): return result - @router.post('/{stock_code}/analysis', summary="获取股票分析数据") async def get_stock_analysis(stock_code: str, current_user: Dict[str, Any] = Depends(get_current_user)): - # 检查stock_code是否存在 - # codes = get_db_manager().search_stock(stock_code) - # if not codes or len(codes) == 0: - # raise HTTPException(status_code=400, detail="您输入的股票代码不存在,请检查后重新输入。") - - # stock_code = codes[0]["stock_code"] url = 'https://mate.aimateplus.com/v1/workflows/run' token = 'app-nWuCOa0YfQVtAosTY3Jr5vFV' @@ -150,7 +143,7 @@ async def get_stock_analysis(stock_code: str, current_user: Dict[str, Any] = Dep } # 保存用户提问 - get_db_manager().save_user_question(current_user["id"], stock_code, "请分析以下股票:" + stock_code + ",并给出分析报告。") + get_db_manager().user_question_manager.save_user_question(current_user["id"], stock_code, "请分析以下股票:" + stock_code + ",并给出分析报告。") response = requests.post(url, headers=headers, json=data, stream=True) diff --git a/cryptoai/routes/agent.py b/cryptoai/routes/agent.py index 3e2b72c..451d704 100644 --- a/cryptoai/routes/agent.py +++ b/cryptoai/routes/agent.py @@ -49,7 +49,7 @@ async def create_agent(agent: AgentCreate): 创建新的AI Agent """ - agent = get_db_manager().create_agent(agent.name, agent.description, agent.hello_prompt, agent.dify_token, agent.inputs) + agent = get_db_manager().agent_manager.create_agent(agent.name, agent.description, agent.hello_prompt, agent.dify_token, agent.inputs) return agent @@ -64,7 +64,7 @@ async def get_agents( """ # 从数据库获取Agent列表 - agents = get_db_manager().list_agents(limit=limit, skip=skip) + agents = get_db_manager().agent_manager.list_agents(limit=limit, skip=skip) return agents @@ -76,7 +76,7 @@ async def chat(request: ChatRequest, current_user: Dict[str, Any] = Depends(get_ # 尝试从数据库获取Agent try: agent_id = int(request.agent_id) - agent = get_db_manager().get_agent_by_id(agent_id) + agent = get_db_manager().agent_manager.get_agent_by_id(agent_id) if not agent: raise HTTPException(status_code=400, detail="Invalid agent ID") @@ -106,7 +106,7 @@ async def chat(request: ChatRequest, current_user: Dict[str, Any] = Depends(get_ logging.info(f"Chat request data: {data}") # 保存用户提问 - get_db_manager().save_user_question(current_user["id"], request.agent_id, request.user_prompt) + get_db_manager().user_question_manager.save_user_question(current_user["id"], request.agent_id, request.user_prompt) response = requests.post(url, headers=headers, json=data, stream=True) diff --git a/cryptoai/routes/analysis.py b/cryptoai/routes/analysis.py new file mode 100644 index 0000000..748cd05 --- /dev/null +++ b/cryptoai/routes/analysis.py @@ -0,0 +1,30 @@ +from fastapi import APIRouter +from typing import Optional +from cryptoai.utils.db_manager import get_db_manager +from fastapi import Depends +from pydantic import BaseModel +from cryptoai.routes.user import get_current_user + + +class AnalysisHistoryRequest(BaseModel): + symbol: str + content: str + timeframe: Optional[str] = None + type: str + +router = APIRouter() + +@router.post("/analysis_history") +async def analysis_history(request: AnalysisHistoryRequest, + current_user: dict = Depends(get_current_user)): + + get_db_manager().analysis_history_manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe) + + return {"message": "ok"} + +@router.get("/analysis_histories") +async def get_analysis_histories(current_user: dict = Depends(get_current_user), + limit: int = 10, + offset: int = 0): + history = get_db_manager().analysis_history_manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset) + return history \ No newline at end of file diff --git a/cryptoai/routes/crypto.py b/cryptoai/routes/crypto.py index 22437ee..53d9e2d 100644 --- a/cryptoai/routes/crypto.py +++ b/cryptoai/routes/crypto.py @@ -34,7 +34,7 @@ class CryptoAnalysisRequest(BaseModel): @router.get("/kline/{symbol}") async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: Optional[int] = 100): # 检查symbol是否存在 - tokens = get_db_manager().search_token(symbol) + tokens = get_db_manager().token_manager.search_token(symbol) if not tokens or len(tokens) == 0: raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。") @@ -57,7 +57,7 @@ async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit: async def analysis_crypto_v2(request: CryptoAnalysisRequest, current_user: dict = Depends(get_current_user)): # 检查symbol是否存在 - tokens = get_db_manager().search_token(request.symbol) + tokens = get_db_manager().token_manager.search_token(request.symbol) if not tokens or len(tokens) == 0: raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。") @@ -80,7 +80,7 @@ async def analysis_crypto_v2(request: CryptoAnalysisRequest, } # 保存用户提问 - get_db_manager().save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。") + get_db_manager().user_question_manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。") response = requests.post(url, headers=headers, json=data, stream=True) @@ -111,7 +111,7 @@ async def analysis_crypto(request: CryptoAnalysisRequest, else: user_prompt = f"请分析以下加密货币:{request.symbol},并给出分析报告。" - agent = get_db_manager().get_agent_by_id(agent_id) + agent = get_db_manager().agent_manager.get_agent_by_id(agent_id) if not agent: raise HTTPException(status_code=400, detail="Invalid agent ID") @@ -138,7 +138,7 @@ async def analysis_crypto(request: CryptoAnalysisRequest, logging.info(f"Chat request data: {data}") # 保存用户提问 - get_db_manager().save_user_question(current_user["id"], agent_id, user_prompt) + get_db_manager().user_question_manager.save_user_question(current_user["id"], agent_id, user_prompt) response = requests.post(url, headers=headers, json=data, stream=True) diff --git a/cryptoai/routes/fastapi_app.py b/cryptoai/routes/fastapi_app.py index e653e72..2e93107 100644 --- a/cryptoai/routes/fastapi_app.py +++ b/cryptoai/routes/fastapi_app.py @@ -22,6 +22,7 @@ from cryptoai.routes.adata import router as adata_router from cryptoai.routes.question import router as question_router from cryptoai.routes.crypto import router as crypto_router from cryptoai.routes.platform import router as platform_router +from cryptoai.routes.analysis import router as analysis_router # 配置日志 logging.basicConfig( @@ -58,6 +59,8 @@ app.include_router(user_router, prefix="/user", tags=["用户管理"]) app.include_router(question_router, prefix="/question", tags=["用户提问"]) app.include_router(adata_router, prefix="/adata", tags=["A股数据"]) app.include_router(crypto_router, prefix="/crypto", tags=["加密货币数据"]) +app.include_router(analysis_router, prefix="/analysis", tags=["分析历史"]) + # 请求计时中间件 @app.middleware("http") async def add_process_time_header(request: Request, call_next): diff --git a/cryptoai/routes/feed.py b/cryptoai/routes/feed.py index 0754e19..e5c6793 100644 --- a/cryptoai/routes/feed.py +++ b/cryptoai/routes/feed.py @@ -52,7 +52,7 @@ async def create_feed(feed: AgentFeedCreate) -> Dict[str, Any]: db_manager = get_db_manager() # 保存信息流 - success = db_manager.save_agent_feed( + success = db_manager.agent_feed_manager.save_agent_feed( agent_name=feed.agent_name, content=feed.content, avatar_url=feed.avatar_url @@ -98,7 +98,7 @@ async def get_feeds( db_manager = get_db_manager() # 获取信息流 - feeds = db_manager.get_agent_feeds(agent_name=agent_name, limit=limit, skip=skip) + feeds = db_manager.agent_feed_manager.get_agent_feeds(agent_name=agent_name, limit=limit, skip=skip) return feeds diff --git a/cryptoai/routes/platform.py b/cryptoai/routes/platform.py index a2e3c0c..516648b 100644 --- a/cryptoai/routes/platform.py +++ b/cryptoai/routes/platform.py @@ -14,8 +14,8 @@ async def get_platform_info(): result = {} try: - result["user_count"] = db_manager.get_user_count() - result["question_count"] = db_manager.get_user_question_count() + result["user_count"] = db_manager.user_manager.get_user_count() + result["question_count"] = db_manager.user_question_manager.get_user_question_count() return result except Exception as e: diff --git a/cryptoai/routes/question.py b/cryptoai/routes/question.py index 1a0be14..bbcddde 100644 --- a/cryptoai/routes/question.py +++ b/cryptoai/routes/question.py @@ -55,7 +55,7 @@ async def create_question( db_manager = get_db_manager() # 保存提问 - success = db_manager.save_user_question( + success = db_manager.user_question_manager.save_user_question( user_id=current_user["id"], agent_id=question.agent_id, question=question.question @@ -103,7 +103,7 @@ async def get_questions( db_manager = get_db_manager() # 获取提问记录 - questions = db_manager.get_user_questions( + questions = db_manager.user_question_manager.get_user_questions( user_id=current_user["id"], agent_id=agent_id, limit=limit, @@ -139,7 +139,7 @@ async def get_question( db_manager = get_db_manager() # 获取提问记录 - question = db_manager.get_user_question_by_id(question_id) + question = db_manager.user_question_manager.get_user_question_by_id(question_id) if not question: raise HTTPException( diff --git a/cryptoai/routes/user.py b/cryptoai/routes/user.py index 9ace636..688b37a 100644 --- a/cryptoai/routes/user.py +++ b/cryptoai/routes/user.py @@ -104,7 +104,7 @@ async def get_current_user(request: Request) -> Dict[str, Any]: raise credentials_exception db_manager = get_db_manager() - user = db_manager.get_user_by_mail(mail) + user = db_manager.user_manager.get_user_by_mail(mail) if user is None: raise credentials_exception return user @@ -125,7 +125,7 @@ async def send_verification_code(request: SendVerificationCodeRequest) -> Dict[s db_manager = get_db_manager() # 检查邮箱是否已被注册 - user = db_manager.get_user_by_mail(request.mail) + user = db_manager.user_manager.get_user_by_mail(request.mail) if user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -231,7 +231,7 @@ async def login(loginData: UserLogin) -> TokenResponse: db_manager = get_db_manager() # 获取用户信息 - user = db_manager.get_user_by_mail(loginData.mail) + user = db_manager.user_manager.get_user_by_mail(loginData.mail) if not user: raise HTTPException( @@ -246,9 +246,8 @@ async def login(loginData: UserLogin) -> TokenResponse: # 查询用户的密码哈希 session = db_manager.Session() try: - from cryptoai.utils.db_manager import User - db_user = session.query(User).filter(User.mail == loginData.mail).first() - if not db_user or db_user.password != hashed_password: + user = db_manager.user_manager.login(loginData.mail, hashed_password) + if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="邮箱或密码错误", @@ -334,7 +333,7 @@ async def update_user_level( db_manager = get_db_manager() # 更新用户级别 - success = db_manager.update_user_level(user_id, level) + success = db_manager.user_manager.update_user_level(user_id, level) if not success: raise HTTPException( @@ -373,7 +372,7 @@ async def get_user_points( db_manager = get_db_manager() # 获取用户信息 - user = db_manager.get_user_by_id(user_id) + user = db_manager.user_manager.get_user_by_id(user_id) if not user: raise HTTPException( @@ -416,7 +415,7 @@ async def add_user_points( db_manager = get_db_manager() # 添加积分 - success = db_manager.add_user_points(user_id, points) + success = db_manager.user_manager.add_user_points(user_id, points) if not success: raise HTTPException( @@ -425,7 +424,7 @@ async def add_user_points( ) # 获取更新后的用户信息 - user = db_manager.get_user_by_id(user_id) + user = db_manager.user_manager.get_user_by_id(user_id) return { "status": "success", @@ -461,7 +460,7 @@ async def consume_user_points( db_manager = get_db_manager() # 消费积分 - success = db_manager.consume_user_points(user_id, points) + success = db_manager.user_manager.consume_user_points(user_id, points) if not success: raise HTTPException( @@ -470,7 +469,7 @@ async def consume_user_points( ) # 获取更新后的用户信息 - user = db_manager.get_user_by_id(user_id) + user = db_manager.user_manager.get_user_by_id(user_id) return { "status": "success", diff --git a/cryptoai/utils/db_manager.py b/cryptoai/utils/db_manager.py index 8e3c156..0abff47 100644 --- a/cryptoai/utils/db_manager.py +++ b/cryptoai/utils/db_manager.py @@ -2,19 +2,24 @@ # -*- coding: utf-8 -*- import os -import json import logging from typing import Dict, Any, List, Optional, Union from datetime import datetime -from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Index, text, ForeignKey -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker, relationship -from sqlalchemy.dialects.mysql import JSON +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import QueuePool -from sqlalchemy import or_ from cryptoai.utils.config_loader import ConfigLoader +from cryptoai.models.base import Base +from cryptoai.models.token import TokenManager +from cryptoai.models.analysis_result import AnalysisResultManager +from cryptoai.models.agent_feed import AgentFeedManager +from cryptoai.models.user import UserManager +from cryptoai.models.user_question import UserQuestionManager +from cryptoai.models.agent import AgentManager +from cryptoai.models.astock import AStockManager +from cryptoai.models.analysis_history import AnalysisHistoryManager # 配置日志 logging.basicConfig( @@ -23,180 +28,15 @@ logging.basicConfig( ) logger = logging.getLogger('db_manager') -# 创建模型基类 -Base = declarative_base() - -# 定义Token模型 -class Token(Base): - """Token信息表模型""" - __tablename__ = 'tokens' - - id = Column(Integer, primary_key=True, autoincrement=True) - symbol = Column(String(50), nullable=False, unique=True, comment='交易对符号') - base_asset = Column(String(20), nullable=False, comment='基础资产') - quote_asset = Column(String(20), nullable=False, comment='计价资产') - created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') - - # 索引 - __table_args__ = ( - Index('idx_symbol', 'symbol'), - Index('idx_base_asset', 'base_asset'), - Index('idx_quote_asset', 'quote_asset'), - {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} - ) - -# 定义分析结果模型 -class AnalysisResult(Base): - """分析结果表模型""" - __tablename__ = 'analysis_results' - - id = Column(Integer, primary_key=True, autoincrement=True) - agent = Column(String(50), nullable=False, comment='智能体类型(crypto, gold)') - symbol = Column(String(50), nullable=False, comment='交易对符号') - time_interval = Column(String(20), nullable=False, comment='时间间隔') - completion_result = Column(JSON, nullable=False, comment='分析结果JSON') - created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') - updated_at = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间') - - # 索引 - __table_args__ = ( - Index('idx_agent', 'agent'), - Index('idx_symbol', 'symbol'), - Index('idx_time_interval', 'time_interval'), - Index('idx_created_at', 'created_at'), - ) - -# 定义AI Agent信息流模型 -class AgentFeed(Base): - """AI Agent信息流表模型""" - __tablename__ = 'agent_feeds' - - id = Column(Integer, primary_key=True, autoincrement=True) - agent_name = Column(String(50), nullable=False, comment='AI Agent名称') - avatar_url = Column(String(255), nullable=True, comment='头像URL') - content = Column(Text, nullable=False, comment='内容') - create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') - - # 索引和表属性 - __table_args__ = ( - Index('idx_agent_name', 'agent_name'), - Index('idx_create_time', 'create_time'), - {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} - ) - -# 定义用户数据模型 -class User(Base): - """用户数据表模型""" - __tablename__ = 'users' - - id = Column(Integer, primary_key=True, autoincrement=True) - mail = Column(String(100), nullable=False, unique=True, comment='邮箱') - nickname = Column(String(50), nullable=False, comment='昵称') - password = Column(String(100), nullable=False, comment='密码') - level = Column(Integer, nullable=False, default=0, comment='用户级别(0=普通用户,1=VIP,2=SVIP)') - points = Column(Integer, nullable=False, default=0, comment='用户积分') - create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') - - # 关系 - questions = relationship("UserQuestion", back_populates="user") - - # 索引和表属性 - __table_args__ = ( - Index('idx_mail', 'mail'), - Index('idx_level', 'level'), - Index('idx_create_time', 'create_time'), - {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} - ) - -# 定义用户提问数据模型 -class UserQuestion(Base): - """用户提问数据表模型""" - __tablename__ = 'user_questions' - - id = Column(Integer, primary_key=True, autoincrement=True) - user_id = Column(Integer, ForeignKey('users.id'), nullable=False, comment='用户ID') - agent_id = Column(String(50), nullable=False, comment='AI Agent ID') - question = Column(Text, nullable=False, comment='提问内容') - create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') - - # 关系 - user = relationship("User", back_populates="questions") - - # 索引和表属性 - __table_args__ = ( - Index('idx_user_id', 'user_id'), - Index('idx_agent_id', 'agent_id'), - Index('idx_create_time', 'create_time'), - {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} - ) - -# 定义Agent数据模型 -class Agent(Base): - """Agent数据表模型""" - __tablename__ = 'agents' - - id = Column(Integer, primary_key=True, autoincrement=True) - name = Column(String(100), nullable=False, unique=True, comment='Agent名称') - hello_prompt = Column(Text, nullable=True, comment='欢迎提示语') - description = Column(Text, nullable=True, comment='Agent描述') - dify_token = Column(String(255), nullable=True, comment='Dify API令牌') - inputs = Column(JSON, nullable=True, comment='输入参数JSON') - create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') - update_time = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间') - - # 索引和表属性 - __table_args__ = ( - Index('idx_name', 'name'), - Index('idx_create_time', 'create_time'), - {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} - ) - -# 定义 A 股数据模型 -class AStock(Base): - """A股股票基本信息表模型""" - __tablename__ = 'astock' - - stock_code = Column(String(10), primary_key=True, comment='股票代码') - short_name = Column(String(50), nullable=False, comment='股票简称') - exchange = Column(String(20), nullable=True, comment='交易所') - list_date = Column(DateTime, nullable=True, comment='上市日期') - created_at = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') - - # 索引 - __table_args__ = ( - Index('idx_stock_code', 'stock_code', unique=True), - {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} - ) - -def _convert_timestamp_to_datetime(timestamp: Union[int, float, str, None]) -> Optional[datetime]: - """ - 将时间戳转换为datetime对象 - - Args: - timestamp: Unix时间戳(毫秒级) - - Returns: - datetime对象或None - """ - if timestamp is None: - return None - - try: - # 转换为整数 - if isinstance(timestamp, str): - timestamp = int(timestamp) - - # 如果是毫秒级时间戳,转换为秒级 - if timestamp > 1e11: # 判断是否为毫秒级时间戳 - timestamp = timestamp / 1000 - - return datetime.fromtimestamp(timestamp) - except (ValueError, TypeError) as e: - logger.error(f"时间戳转换失败: {timestamp}, 错误: {e}") - return None - class DBManager: - """数据库管理工具,用于连接MySQL数据库并保存智能体分析结果""" + """ + 数据库管理工具,用于连接MySQL数据库并提供各个模型的管理器 + + 使用方法: + - 调用 get_db_manager() 获取数据库管理器实例 + - 使用 db_manager.token_manager.xxx() 直接访问各个模型管理器的方法 + - 使用 db_manager.get_session() 可以获取一个新的数据库会话 + """ def __init__(self, host: str, port: int, user: str, password: str, db_name: str): """ @@ -219,6 +59,9 @@ class DBManager: # 初始化数据库连接 self._init_db() + + # 初始化各个管理器 + self._init_managers() def _init_db(self) -> None: """初始化数据库连接和表""" @@ -250,762 +93,70 @@ class DBManager: logger.error(f"数据库初始化失败: {e}") self.engine = None - def save_analysis_result(self, agent: str, symbol: str, time_interval: str, - analysis_result: Dict[str, Any]) -> bool: - """ - 保存分析结果到数据库 - - Args: - agent: 智能体类型,例如 'crypto' 或 'gold' - symbol: 交易对符号,例如 'BTCUSDT' - time_interval: 时间间隔,例如 '1h', '4h', '1d' - analysis_result: 分析结果数据 - - Returns: - 保存是否成功 - """ + def _init_managers(self) -> None: + """初始化各个模型的管理器""" if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return False - + logger.error("引擎未初始化,无法创建管理器") + return + try: - # 创建会话 session = self.Session() - try: - # 创建新记录 - new_result = AnalysisResult( - agent=agent, - symbol=symbol, - time_interval=time_interval, - completion_result=analysis_result, - created_at=datetime.now(), - updated_at=datetime.now() - ) - - # 添加并提交 - session.add(new_result) - session.commit() - - logger.info(f"成功保存 {agent} 分析结果,交易对: {symbol}, 时间间隔: {time_interval}") - return True - - except Exception as e: - session.rollback() - logger.error(f"保存分析结果失败: {e}") - return False - - finally: - session.close() - + # 初始化各个模型的管理器 + self.token_manager = TokenManager(session) + self.analysis_result_manager = AnalysisResultManager(session) + self.agent_feed_manager = AgentFeedManager(session) + self.user_manager = UserManager(session) + self.user_question_manager = UserQuestionManager(session) + self.agent_manager = AgentManager(session) + self.astock_manager = AStockManager(session) + self.analysis_history_manager = AnalysisHistoryManager(session) + + logger.info("成功初始化所有模型管理器") + except Exception as e: - logger.error(f"创建数据库会话失败: {e}") - # 如果是连接错误,尝试重新初始化 - try: - self._init_db() - except: - pass - return False + logger.error(f"管理器初始化失败: {e}") + if session: + session.close() - def save_agent_feed(self, agent_name: str, content: str, avatar_url: Optional[str] = None) -> bool: + def get_session(self): """ - 保存AI Agent信息流到数据库 + 获取新的数据库会话 - Args: - agent_name: AI Agent名称 - content: 内容 - avatar_url: 头像URL,可选 - Returns: - 保存是否成功 + SQLAlchemy session对象,如果初始化失败则返回None """ - if not self.engine: + if not self.Session: try: self._init_db() except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return False - - try: - # 创建会话 - session = self.Session() - - try: - # 创建新记录 - new_feed = AgentFeed( - agent_name=agent_name, - content=content, - avatar_url=avatar_url - ) - - # 添加并提交 - session.add(new_feed) - session.commit() - - logger.info(f"成功保存 {agent_name} 的信息流") - return True - - except Exception as e: - session.rollback() - logger.error(f"保存信息流失败: {e}") - return False - - finally: - session.close() - - except Exception as e: - logger.error(f"创建数据库会话失败: {e}") - # 如果是连接错误,尝试重新初始化 - try: - self._init_db() - except: - pass - return False - - def register_user(self, mail: str, nickname: str, password: str, level: int = 0, points: int = 0) -> bool: - """ - 注册新用户 - - Args: - mail: 邮箱 - nickname: 昵称 - password: 密码 - level: 用户级别,默认为0(普通用户) - points: 初始积分,默认为0 - - Returns: - 注册是否成功 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return False - - try: - # 创建会话 - session = self.Session() - - try: - # 检查邮箱是否已存在 - existing_user = session.query(User).filter(User.mail == mail).first() - if existing_user: - logger.warning(f"邮箱 {mail} 已被注册") - return False - - # 创建新用户 - new_user = User( - mail=mail, - nickname=nickname, - password=password, # 实际应用中应该对密码进行哈希处理 - level=level, - points=points, - create_time=datetime.now() - ) - - # 添加并提交 - session.add(new_user) - session.commit() - - logger.info(f"成功注册用户: {mail},初始积分: {points}") - return True - - except Exception as e: - session.rollback() - logger.error(f"注册用户失败: {e}") - return False - - finally: - session.close() - - except Exception as e: - logger.error(f"创建数据库会话失败: {e}") - # 如果是连接错误,尝试重新初始化 - try: - self._init_db() - except: - pass - return False - - def get_user_count(self) -> int: - """ - 获取用户数量 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return 0 - - try: - # 创建会话 - session = self.Session() - - try: - # 查询用户数量 - user_count = session.query(User).count() - - return user_count - except Exception as e: - logger.error(f"获取用户数量失败: {e}") - return 0 - finally: - session.close() - - except Exception as e: - logger.error(f"获取用户数量失败: {e}") - return 0 - - def get_user_by_id(self, user_id: int) -> Optional[Dict[str, Any]]: - """ - 通过ID获取用户信息 - - Args: - user_id: 用户ID - - Returns: - 用户信息,如果用户不存在则返回None - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") + logger.error(f"重新初始化数据库失败: {e}") return None - try: - # 创建会话 - session = self.Session() - - try: - # 查询用户 - user = session.query(User).filter(User.id == user_id).first() - - if user: - # 转换为字典 - return { - 'id': user.id, - 'mail': user.mail, - 'nickname': user.nickname, - 'level': user.level, - 'points': user.points, - 'create_time': user.create_time - } - else: - return None - - finally: - session.close() - - except Exception as e: - logger.error(f"获取用户信息失败: {e}") - return None + return self.Session() - def get_user_by_mail(self, mail: str) -> Optional[Dict[str, Any]]: + def refresh_managers(self) -> bool: """ - 通过邮箱获取用户信息 + 刷新所有管理器,重新建立会话 - Args: - mail: 邮箱 - Returns: - 用户信息,如果用户不存在则返回None + 刷新是否成功 """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return None - try: - # 创建会话 - session = self.Session() + # 关闭旧会话(如果有) + self.close() - try: - # 查询用户 - user = session.query(User).filter(User.mail == mail).first() - - if user: - # 转换为字典 - return { - 'id': user.id, - 'mail': user.mail, - 'nickname': user.nickname, - 'level': user.level, - 'points': user.points, - 'create_time': user.create_time - } - else: - return None - - finally: - session.close() - + # 重新初始化数据库连接 + self._init_db() + + # 重新初始化管理器 + self._init_managers() + + return self.engine is not None except Exception as e: - logger.error(f"获取用户信息失败: {e}") - return None - - def update_user_level(self, user_id: int, level: int) -> bool: - """ - 更新用户级别 - - Args: - user_id: 用户ID - level: 新的用户级别 - - Returns: - 更新是否成功 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return False - - try: - # 创建会话 - session = self.Session() - - try: - # 查询用户 - user = session.query(User).filter(User.id == user_id).first() - - if not user: - logger.warning(f"用户ID {user_id} 不存在") - return False - - # 更新级别 - user.level = level - session.commit() - - logger.info(f"成功更新用户 {user.mail} 的级别为 {level}") - return True - - except Exception as e: - session.rollback() - logger.error(f"更新用户级别失败: {e}") - return False - - finally: - session.close() - - except Exception as e: - logger.error(f"创建数据库会话失败: {e}") + logger.error(f"刷新管理器失败: {e}") return False - def add_user_points(self, user_id: int, points: int) -> bool: - """ - 为用户增加积分 - - Args: - user_id: 用户ID - points: 增加的积分数量(正数) - - Returns: - 操作是否成功 - """ - if points <= 0: - logger.warning(f"增加的积分必须是正数: {points}") - return False - - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return False - - try: - # 创建会话 - session = self.Session() - - try: - # 查询用户 - user = session.query(User).filter(User.id == user_id).first() - - if not user: - logger.warning(f"用户ID {user_id} 不存在") - return False - - # 增加积分 - user.points += points - session.commit() - - logger.info(f"成功为用户 {user.mail} 增加 {points} 积分,当前积分: {user.points}") - return True - - except Exception as e: - session.rollback() - logger.error(f"增加用户积分失败: {e}") - return False - - finally: - session.close() - - except Exception as e: - logger.error(f"创建数据库会话失败: {e}") - return False - - def consume_user_points(self, user_id: int, points: int) -> bool: - """ - 用户消费积分 - - Args: - user_id: 用户ID - points: 消费的积分数量(正数) - - Returns: - 操作是否成功 - """ - if points <= 0: - logger.warning(f"消费的积分必须是正数: {points}") - return False - - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return False - - try: - # 创建会话 - session = self.Session() - - try: - # 查询用户 - user = session.query(User).filter(User.id == user_id).first() - - if not user: - logger.warning(f"用户ID {user_id} 不存在") - return False - - # 检查积分是否足够 - if user.points < points: - logger.warning(f"用户 {user.mail} 积分不足,当前积分: {user.points},需要消费: {points}") - return False - - # 消费积分 - user.points -= points - session.commit() - - logger.info(f"成功从用户 {user.mail} 消费 {points} 积分,剩余积分: {user.points}") - return True - - except Exception as e: - session.rollback() - logger.error(f"消费用户积分失败: {e}") - return False - - finally: - session.close() - - except Exception as e: - logger.error(f"创建数据库会话失败: {e}") - return False - - def get_agent_feeds(self, agent_name: Optional[str] = None, limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]: - """ - 获取AI Agent信息流 - - Args: - agent_name: 可选,指定获取特定Agent的信息流 - limit: 返回的最大记录数,默认20条 - - Returns: - 信息流列表,如果查询失败则返回空列表 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return [] - - try: - # 创建会话 - session = self.Session() - - try: - # 构建查询 - query = session.query(AgentFeed) - - # 如果指定了agent_name,则筛选 - if agent_name: - query = query.filter(AgentFeed.agent_name == agent_name) - - # 按创建时间降序排序并限制数量 - results = query.order_by(AgentFeed.create_time.desc()).offset(skip).limit(limit).all() - - # 转换为字典列表 - feeds = [] - for result in results: - feeds.append({ - 'id': result.id, - 'agent_name': result.agent_name, - 'avatar_url': result.avatar_url, - 'content': result.content, - 'create_time': result.create_time - }) - - return feeds - - finally: - session.close() - - except Exception as e: - logger.error(f"获取信息流失败: {e}") - return [] - - def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool: - """ - 保存用户提问数据 - - Args: - user_id: 用户ID - agent_id: AI Agent ID - question: 提问内容 - - Returns: - 保存是否成功 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return False - - try: - # 创建会话 - session = self.Session() - - try: - # 创建新记录 - new_question = UserQuestion( - user_id=user_id, - agent_id=agent_id, - question=question, - create_time=datetime.now() - ) - - # 添加并提交 - session.add(new_question) - session.commit() - - logger.info(f"成功保存用户 {user_id} 对 Agent {agent_id} 的提问") - return True - - except Exception as e: - session.rollback() - logger.error(f"保存用户提问失败: {e}") - return False - - finally: - session.close() - - except Exception as e: - logger.error(f"创建数据库会话失败: {e}") - # 如果是连接错误,尝试重新初始化 - try: - self._init_db() - except: - pass - return False - - def get_user_question_count(self) -> int: - """ - 获取用户提问数量 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return 0 - - try: - # 创建会话 - session = self.Session() - - try: - # 查询用户提问数量 - question_count = session.query(UserQuestion).count() - - return question_count - except Exception as e: - logger.error(f"获取用户提问数量失败: {e}") - return 0 - finally: - session.close() - - except Exception as e: - logger.error(f"获取用户提问数量失败: {e}") - return 0 - - - def get_user_questions(self, user_id: Optional[int] = None, agent_id: Optional[str] = None, - limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]: - """ - 获取用户提问数据 - - Args: - user_id: 可选,指定获取特定用户的提问 - agent_id: 可选,指定获取特定Agent的提问 - limit: 返回的最大记录数,默认20条 - skip: 跳过的记录数,默认0条 - - Returns: - 提问数据列表,如果查询失败则返回空列表 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return [] - - try: - # 创建会话 - session = self.Session() - - try: - # 构建查询 - query = session.query(UserQuestion) - - # 如果指定了user_id,则筛选 - if user_id: - query = query.filter(UserQuestion.user_id == user_id) - - # 如果指定了agent_id,则筛选 - if agent_id: - query = query.filter(UserQuestion.agent_id == agent_id) - - # 按创建时间降序排序并限制数量 - results = query.order_by(UserQuestion.create_time.desc()).offset(skip).limit(limit).all() - - # 转换为字典列表 - questions = [] - for result in results: - questions.append({ - 'id': result.id, - 'user_id': result.user_id, - 'agent_id': result.agent_id, - 'question': result.question, - 'create_time': result.create_time - }) - - return questions - - finally: - session.close() - - except Exception as e: - logger.error(f"获取用户提问失败: {e}") - return [] - - def get_user_question_by_id(self, question_id: int) -> Optional[Dict[str, Any]]: - """ - 通过ID获取用户提问数据 - - Args: - question_id: 提问ID - - Returns: - 提问数据,如果不存在则返回None - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return None - - try: - # 创建会话 - session = self.Session() - - try: - # 查询提问 - result = session.query(UserQuestion).filter(UserQuestion.id == question_id).first() - - if result: - # 转换为字典 - return { - 'id': result.id, - 'user_id': result.user_id, - 'agent_id': result.agent_id, - 'question': result.question, - 'create_time': result.create_time - } - else: - return None - - finally: - session.close() - - except Exception as e: - logger.error(f"获取用户提问失败: {e}") - return None - - def get_latest_result(self, agent: str, symbol: str, time_interval: str) -> Optional[Dict[str, Any]]: - """ - 获取最新的分析结果 - - Args: - agent: 智能体类型,例如 'crypto' 或 'gold' - symbol: 交易对符号,例如 'BTCUSDT' - time_interval: 时间间隔,例如 '1h', '4h', '1d' - - Returns: - 最新分析结果,如果查询失败则返回None - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return None - - try: - # 创建会话 - session = self.Session() - - try: - # 查询最新的结果 - result = session.query(AnalysisResult).filter( - AnalysisResult.agent == agent, - AnalysisResult.symbol == symbol, - AnalysisResult.time_interval == time_interval - ).order_by(AnalysisResult.created_at.desc()).first() - - if result: - # 转换为字典 - return { - 'id': result.id, - 'agent': result.agent, - 'symbol': result.symbol, - 'time_interval': result.time_interval, - 'completion_result': result.completion_result, - 'created_at': result.created_at - } - else: - return None - - finally: - session.close() - - except Exception as e: - logger.error(f"获取最新分析结果失败: {e}") - return None - def close(self) -> None: """关闭数据库连接(如果存在)""" if self.engine: @@ -1013,852 +164,6 @@ class DBManager: logger.info("数据库连接已关闭") self.engine = None self.Session = None - - def create_agent(self, name: str, description: Optional[str] = None, - hello_prompt: Optional[str] = None, dify_token: Optional[str] = None, - inputs: Optional[Dict[str, Any]] = None) -> Optional[int]: - """ - 创建新的Agent - - Args: - name: Agent名称 - description: Agent描述 - hello_prompt: 欢迎提示语 - dify_token: Dify API令牌 - inputs: 输入参数JSON - - Returns: - 新Agent的ID,如果创建失败则返回None - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return None - - try: - # 创建会话 - session = self.Session() - - try: - # 检查名称是否已存在 - existing_agent = session.query(Agent).filter(Agent.name == name).first() - if existing_agent: - logger.warning(f"Agent名称 {name} 已存在") - return None - - # 创建新Agent - new_agent = Agent( - name=name, - description=description, - hello_prompt=hello_prompt, - dify_token=dify_token, - inputs=inputs, - create_time=datetime.now(), - update_time=datetime.now() - ) - - # 添加并提交 - session.add(new_agent) - session.commit() - - logger.info(f"成功创建Agent: {name}") - return new_agent.id - - except Exception as e: - session.rollback() - logger.error(f"创建Agent失败: {e}") - return None - - finally: - session.close() - - except Exception as e: - logger.error(f"创建数据库会话失败: {e}") - return None - - def update_agent(self, agent_id: int, name: Optional[str] = None, - description: Optional[str] = None, hello_prompt: Optional[str] = None, - dify_token: Optional[str] = None, inputs: Optional[Dict[str, Any]] = None) -> bool: - """ - 更新Agent信息 - - Args: - agent_id: Agent ID - name: Agent名称 - description: Agent描述 - hello_prompt: 欢迎提示语 - dify_token: Dify API令牌 - inputs: 输入参数JSON - - Returns: - 更新是否成功 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return False - - try: - # 创建会话 - session = self.Session() - - try: - # 查询Agent - agent = session.query(Agent).filter(Agent.id == agent_id).first() - if not agent: - logger.warning(f"Agent ID {agent_id} 不存在") - return False - - # 更新字段 - if name is not None: - # 检查名称是否与其他Agent冲突 - if name != agent.name: - existing = session.query(Agent).filter(Agent.name == name).first() - if existing: - logger.warning(f"Agent名称 {name} 已被其他Agent使用") - return False - agent.name = name - - if description is not None: - agent.description = description - - if hello_prompt is not None: - agent.hello_prompt = hello_prompt - - if dify_token is not None: - agent.dify_token = dify_token - - if inputs is not None: - agent.inputs = inputs - - agent.update_time = datetime.now() - - # 提交更改 - session.commit() - - logger.info(f"成功更新Agent ID: {agent_id}") - return True - - except Exception as e: - session.rollback() - logger.error(f"更新Agent失败: {e}") - return False - - finally: - session.close() - - except Exception as e: - logger.error(f"创建数据库会话失败: {e}") - return False - - def get_agent_by_id(self, agent_id: int) -> Optional[Dict[str, Any]]: - """ - 通过ID获取Agent信息 - - Args: - agent_id: Agent ID - - Returns: - Agent信息,如果不存在则返回None - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return None - - try: - # 创建会话 - session = self.Session() - - try: - # 查询Agent - agent = session.query(Agent).filter(Agent.id == agent_id).first() - - if agent: - # 转换为字典 - return { - 'id': agent.id, - 'name': agent.name, - 'description': agent.description, - 'hello_prompt': agent.hello_prompt, - 'dify_token': agent.dify_token, - 'inputs': agent.inputs, - 'create_time': agent.create_time, - 'update_time': agent.update_time - } - else: - return None - - finally: - session.close() - - except Exception as e: - logger.error(f"获取Agent信息失败: {e}") - return None - - def get_agent_by_name(self, name: str) -> Optional[Dict[str, Any]]: - """ - 通过名称获取Agent信息 - - Args: - name: Agent名称 - - Returns: - Agent信息,如果不存在则返回None - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return None - - try: - # 创建会话 - session = self.Session() - - try: - # 查询Agent - agent = session.query(Agent).filter(Agent.name == name).first() - - if agent: - # 转换为字典 - return { - 'id': agent.id, - 'name': agent.name, - 'description': agent.description, - 'hello_prompt': agent.hello_prompt, - 'dify_token': agent.dify_token, - 'inputs': agent.inputs, - 'create_time': agent.create_time, - 'update_time': agent.update_time - } - else: - return None - - finally: - session.close() - - except Exception as e: - logger.error(f"获取Agent信息失败: {e}") - return None - - def list_agents(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]: - """ - 获取Agent列表 - - Args: - limit: 返回的最大数量 - skip: 跳过的数量 - - Returns: - Agent列表 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return [] - - try: - # 创建会话 - session = self.Session() - - try: - # 查询Agent列表 - agents = session.query(Agent).order_by(Agent.create_time.asc()).offset(skip).limit(limit).all() - - # 转换为字典列表 - return [{ - 'id': agent.id, - 'name': agent.name, - 'description': agent.description, - 'hello_prompt': agent.hello_prompt, - 'dify_token': agent.dify_token, - 'inputs': agent.inputs, - 'create_time': agent.create_time, - 'update_time': agent.update_time - } for agent in agents] - - finally: - session.close() - - except Exception as e: - logger.error(f"获取Agent列表失败: {e}") - return [] - - def delete_agent(self, agent_id: int) -> bool: - """ - 删除Agent - - Args: - agent_id: Agent ID - - Returns: - 删除是否成功 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return False - - try: - # 创建会话 - session = self.Session() - - try: - # 查询Agent - agent = session.query(Agent).filter(Agent.id == agent_id).first() - if not agent: - logger.warning(f"Agent ID {agent_id} 不存在") - return False - - # 删除Agent - session.delete(agent) - session.commit() - - logger.info(f"成功删除Agent ID: {agent_id}") - return True - - except Exception as e: - session.rollback() - logger.error(f"删除Agent失败: {e}") - return False - - finally: - session.close() - - except Exception as e: - logger.error(f"创建数据库会话失败: {e}") - return False - - def create_stocks(self, stocks: List[Dict[str, Any]]) -> bool: - """ - 批量创建股票信息 - - Args: - stocks: 股票信息列表,每个元素为包含以下键的字典: - - stock_code: 股票代码 - - short_name: 股票简称 - - exchange: 交易所(可选) - - list_date: 上市日期(可选,Unix时间戳,毫秒级) - - Returns: - 创建是否成功 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return False - - try: - session = self.Session() - - try: - for stock_data in stocks: - # 检查股票代码是否已存在 - existing_stock = session.query(AStock).filter( - AStock.stock_code == stock_data['stock_code'] - ).first() - - if existing_stock: - logger.warning(f"股票代码 {stock_data['stock_code']} 已存在,跳过") - continue - - # 创建新股票记录 - new_stock = AStock( - stock_code=stock_data['stock_code'], - short_name=stock_data['short_name'], - exchange=stock_data.get('exchange') - ) - - # 处理上市日期 - if 'list_date' in stock_data and stock_data['list_date']: - list_date = _convert_timestamp_to_datetime(stock_data['list_date']) - if list_date: - new_stock.list_date = list_date - - session.add(new_stock) - - # 批量提交 - session.commit() - logger.info(f"成功批量创建股票信息,共 {len(stocks)} 条记录") - return True - - except Exception as e: - session.rollback() - logger.error(f"批量创建股票信息失败: {e}") - return False - - finally: - session.close() - - except Exception as e: - logger.error(f"创建数据库会话失败: {e}") - return False - - def create_stock(self, stock_code: str, short_name: str, - exchange: Optional[str] = None, - list_date: Optional[Union[int, float, str]] = None) -> bool: - """ - 创建新的股票信息 - - Args: - stock_code: 股票代码 - short_name: 股票简称 - exchange: 交易所(可选) - list_date: 上市日期(可选,Unix时间戳,毫秒级) - - Returns: - 创建是否成功 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return False - - try: - session = self.Session() - - try: - # 检查股票代码是否已存在 - existing_stock = session.query(AStock).filter(AStock.stock_code == stock_code).first() - if existing_stock: - logger.warning(f"股票代码 {stock_code} 已存在") - return False - - # 创建新股票记录 - new_stock = AStock( - stock_code=stock_code, - short_name=short_name, - exchange=exchange - ) - - # 处理上市日期 - if list_date: - converted_date = _convert_timestamp_to_datetime(list_date) - if converted_date: - new_stock.list_date = converted_date - - session.add(new_stock) - session.commit() - - logger.info(f"成功创建股票信息: {stock_code} - {short_name}") - return True - - except Exception as e: - session.rollback() - logger.error(f"创建股票信息失败: {e}") - return False - - finally: - session.close() - - except Exception as e: - logger.error(f"创建数据库会话失败: {e}") - return False - - def update_stock(self, stock_code: str, short_name: Optional[str] = None, - exchange: Optional[str] = None, list_date: Optional[datetime] = None) -> bool: - """ - 更新股票信息 - - Args: - stock_code: 股票代码 - short_name: 股票简称 - exchange: 交易所 - list_date: 上市日期 - - Returns: - 更新是否成功 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return False - - try: - session = self.Session() - - try: - # 查询股票 - stock = session.query(AStock).filter(AStock.stock_code == stock_code).first() - if not stock: - logger.warning(f"股票代码 {stock_code} 不存在") - return False - - # 更新字段 - if short_name is not None: - stock.short_name = short_name - if exchange is not None: - stock.exchange = exchange - if list_date is not None: - stock.list_date = list_date - - stock.updated_at = datetime.now() - - session.commit() - - logger.info(f"成功更新股票信息: {stock_code}") - return True - - except Exception as e: - session.rollback() - logger.error(f"更新股票信息失败: {e}") - return False - - finally: - session.close() - - except Exception as e: - logger.error(f"创建数据库会话失败: {e}") - return False - - def delete_stock(self, stock_code: str) -> bool: - """ - 删除股票信息 - - Args: - stock_code: 股票代码 - - Returns: - 删除是否成功 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return False - - try: - session = self.Session() - - try: - # 查询股票 - stock = session.query(AStock).filter(AStock.stock_code == stock_code).first() - if not stock: - logger.warning(f"股票代码 {stock_code} 不存在") - return False - - # 删除股票 - session.delete(stock) - session.commit() - - logger.info(f"成功删除股票信息: {stock_code}") - return True - - except Exception as e: - session.rollback() - logger.error(f"删除股票信息失败: {e}") - return False - - finally: - session.close() - - except Exception as e: - logger.error(f"创建数据库会话失败: {e}") - return False - - def get_stock_by_code(self, stock_code: str) -> Optional[Dict[str, Any]]: - """ - 通过股票代码获取股票信息 - - Args: - stock_code: 股票代码 - - Returns: - 股票信息,如果不存在则返回None - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return None - - try: - session = self.Session() - - try: - # 查询股票 - stock = session.query(AStock).filter(AStock.stock_code == stock_code).first() - - if stock: - return { - 'stock_code': stock.stock_code, - 'short_name': stock.short_name, - 'exchange': stock.exchange, - 'list_date': stock.list_date, - 'created_at': stock.created_at, - 'updated_at': stock.updated_at - } - else: - return None - - finally: - session.close() - - except Exception as e: - logger.error(f"获取股票信息失败: {e}") - return None - - def search_stock(self, key: str, limit: int = 10) -> List[Dict[str, Any]]: - """ - 搜索股票 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return [] - - try: - session = self.Session() - - try: - # 查询股票 - stocks = session.query(AStock).filter(AStock.short_name.like(f"{key}%") | AStock.stock_code.like(f"{key}%")).limit(limit).all() - - return [{ - 'stock_code': stock.stock_code, - 'short_name': stock.short_name, - 'exchange': stock.exchange, - 'list_date': stock.list_date, - 'created_at': stock.created_at - } for stock in stocks] - - finally: - session.close() - - except Exception as e: - logger.error(f"获取股票信息失败: {e}") - return [] - - def get_stock_by_name(self, short_name: str) -> List[Dict[str, Any]]: - """ - 通过股票简称获取股票信息(可能有多个结果) - - Args: - short_name: 股票简称 - - Returns: - 股票信息列表 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return [] - - try: - session = self.Session() - - try: - # 查询股票 - stocks = session.query(AStock).filter(AStock.short_name == short_name).all() - - return [{ - 'stock_code': stock.stock_code, - 'short_name': stock.short_name, - 'exchange': stock.exchange, - 'list_date': stock.list_date, - 'created_at': stock.created_at - } for stock in stocks] - - finally: - session.close() - - except Exception as e: - logger.error(f"获取股票信息失败: {e}") - return [] - - def list_stocks(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]: - """ - 获取股票列表 - - Args: - limit: 返回的最大数量 - skip: 跳过的数量 - - Returns: - 股票信息列表 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return [] - - try: - session = self.Session() - - try: - # 查询股票列表 - stocks = session.query(AStock).order_by(AStock.stock_code).offset(skip).limit(limit).all() - - return [{ - 'stock_code': stock.stock_code, - 'short_name': stock.short_name, - 'exchange': stock.exchange, - 'list_date': stock.list_date, - 'created_at': stock.created_at - } for stock in stocks] - - finally: - session.close() - - except Exception as e: - logger.error(f"获取股票列表失败: {e}") - return [] - - def create_token(self, symbol: str, base_asset: str, quote_asset: str) -> bool: - """ - 创建新的Token信息 - - Args: - symbol: 交易对符号 - base_asset: 基础资产 - quote_asset: 计价资产 - - Returns: - 创建是否成功 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return False - - try: - session = self.Session() - - try: - # 检查交易对是否已存在 - existing_token = session.query(Token).filter(Token.symbol == symbol).first() - if existing_token: - logger.warning(f"交易对 {symbol} 已存在") - return False - - # 创建新Token记录 - new_token = Token( - symbol=symbol, - base_asset=base_asset, - quote_asset=quote_asset - ) - - session.add(new_token) - session.commit() - - logger.info(f"成功创建Token信息: {symbol}") - return True - - except Exception as e: - session.rollback() - logger.error(f"创建Token信息失败: {e}") - return False - - finally: - session.close() - - except Exception as e: - logger.error(f"创建数据库会话失败: {e}") - return False - - def delete_token(self, symbol: str) -> bool: - """ - 删除Token信息 - - Args: - symbol: 交易对符号 - - Returns: - 删除是否成功 - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return False - - try: - session = self.Session() - - try: - # 查询Token - token = session.query(Token).filter(Token.symbol == symbol).first() - if not token: - logger.warning(f"交易对 {symbol} 不存在") - return False - - # 删除Token - session.delete(token) - session.commit() - - logger.info(f"成功删除Token信息: {symbol}") - return True - - except Exception as e: - session.rollback() - logger.error(f"删除Token信息失败: {e}") - return False - - finally: - session.close() - - except Exception as e: - logger.error(f"创建数据库会话失败: {e}") - return False - - def search_token(self, key: str, limit: int = 10) -> List[Dict[str, Any]]: - """ - 搜索Token - """ - if not self.engine: - try: - self._init_db() - except Exception as e: - logger.error(f"重新连接数据库失败: {e}") - return [] - - try: - session = self.Session() - - # 使用 SQLAlchemy 的 ORM 查询 - tokens = session.query(Token).filter(Token.symbol.like(f"{key}%") | Token.base_asset.like(f"{key}%")).limit(limit).all() - - return [{ - 'symbol': token.symbol, - 'base_asset': token.base_asset, - 'quote_asset': token.quote_asset - } for token in tokens] - - except Exception as e: - logger.error(f"获取Token信息失败: {e}") - return [] - - finally: - session.close() # 单例模式 _db_instance = None @@ -1880,6 +185,18 @@ def get_db_manager(host: Optional[str] = None, Returns: 数据库管理器实例 + + 使用示例: + ```python + # 获取数据库管理器 + db_manager = get_db_manager() + + # 使用Token管理器 + tokens = db_manager.token_manager.search_token("BTC") + + # 使用用户管理器 + user = db_manager.user_manager.get_user_by_mail("example@test.com") + ``` """ global _db_instance diff --git a/cryptoai/utils/db_utils.py b/cryptoai/utils/db_utils.py new file mode 100644 index 0000000..9a0b73d --- /dev/null +++ b/cryptoai/utils/db_utils.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import logging +from typing import Optional, Union +from datetime import datetime + +# 配置日志 +logger = logging.getLogger('db_utils') + +def convert_timestamp_to_datetime(timestamp: Union[int, float, str, None]) -> Optional[datetime]: + """ + 将时间戳转换为datetime对象 + + Args: + timestamp: Unix时间戳(毫秒级) + + Returns: + datetime对象或None + """ + if timestamp is None: + return None + + try: + # 转换为整数 + if isinstance(timestamp, str): + timestamp = int(timestamp) + + # 如果是毫秒级时间戳,转换为秒级 + if timestamp > 1e11: # 判断是否为毫秒级时间戳 + timestamp = timestamp / 1000 + + return datetime.fromtimestamp(timestamp) + except (ValueError, TypeError) as e: + logger.error(f"时间戳转换失败: {timestamp}, 错误: {e}") + return None \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 0ee86f8..46b77f3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -29,7 +29,7 @@ services: cryptoai-api: build: . container_name: cryptoai-api - image: cryptoai-api:0.1.25 + image: cryptoai-api:0.1.26 restart: always ports: - "8000:8000" diff --git a/test.py b/test.py index 94819a7..05ed715 100644 --- a/test.py +++ b/test.py @@ -17,4 +17,4 @@ if __name__ == "__main__": base_asset = symbol.split('USDT')[0] quote_asset = 'USDT' - manager.create_token(symbol,base_asset, quote_asset) + manager.token_manager.create_token(symbol,base_asset, quote_asset)