diff --git a/cryptoai/routes/agent.py b/cryptoai/routes/agent.py index 1422479..0cabdf3 100644 --- a/cryptoai/routes/agent.py +++ b/cryptoai/routes/agent.py @@ -6,7 +6,7 @@ API路由模块,为前端提供REST API接口 """ import os -from fastapi import APIRouter, Depends, HTTPException, status, Body +from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path from typing import Dict, Any, List, Optional from pydantic import BaseModel import json @@ -25,53 +25,85 @@ router = APIRouter() class ChatRequest(BaseModel): user_prompt: str - agent_id: str + agent_id: int + +class AgentCreate(BaseModel): + name: str + description: Optional[str] = None + hello_prompt: Optional[str] = None + dify_token: Optional[str] = None + inputs: Optional[Dict[str, Any]] = None + +class AgentUpdate(BaseModel): + name: Optional[str] = None + description: Optional[str] = None + hello_prompt: Optional[str] = None + dify_token: Optional[str] = None + inputs: Optional[Dict[str, Any]] = None -@router.get("/list") -async def get_agents(current_user: Dict[str, Any] = Depends(get_current_user)): +@router.post("/create") +async def create_agent(agent: AgentCreate): + """ + 创建新的AI Agent + """ + + agent = get_db_manager().create_agent(agent.name, agent.description, agent.hello_prompt, agent.dify_token, agent.inputs) + + return agent + + +@router.get("/list", response_model=List[Dict[str, Any]]) +async def get_agents( + skip: int = Query(0, ge=0), + limit: int = Query(100, ge=1, le=1000) +): """ 获取所有代理 """ - return [ - { - "id": "1", - "name": "加密货币交易助手", - "hello_prompt": "您好,我是加密货币交易助手,为您提供专业的数字货币交易分析和建议", - "description": "帮你分析做加密货币技术分析", - }, - { - "id": "2", - "name": "美股交易助手", - "hello_prompt": "您好,我是美股交易助手,您可以直接输入股票名称,比如AAPL,然后我会为您提供专业的股票交易分析和建议", - "description": "帮你分析做美股股票技术分析", - }, - # { - # "id": "3", - # "name": "期货交易助手", - # "hello_prompt": "您好,我是期货交易助手,为您提供专业的期货交易分析和建议", - # "description": "帮你分析做期货技术分析", - # } - ] + # # 检查用户权限 + # if current_user.get("level", 0) < 2: # 假设需要SVIP权限才能查看全部Agent + # # 使用硬编码的默认Agent + # return [ + # { + # "id": "1", + # "name": "Crypto Assistant", + # "hello_prompt": "您好,我是加密货币交易助手,为您提供专业的数字货币交易分析和建议", + # "description": "帮你分析做加密货币技术分析", + # }, + # { + # "id": "2", + # "name": "US Stock Assistant", + # "hello_prompt": "您好,我是美股交易助手,您可以直接输入股票名称,比如AAPL,然后我会为您提供专业的股票交易分析和建议", + # "description": "帮你分析做美股股票技术分析", + # }, + # ] + # else: + # 从数据库获取Agent列表 + agents = get_db_manager().list_agents(limit=limit, skip=skip) + return agents @router.post("/chat") -async def chat(request: ChatRequest,current_user: Dict[str, Any] = Depends(get_current_user)): +async def chat(request: ChatRequest, current_user: Dict[str, Any] = Depends(get_current_user)): """ 聊天接口 """ - if request.agent_id == "1": - token = "app-vhJecqbcLukf72g0uxAb9tcz" - elif request.agent_id == "2": - token = "app-FLIYXrCbbQIkwgXx02Y1Mxjg" - else: - raise HTTPException(status_code=400, detail="Invalid agent ID") - - inputs = {} - if request.agent_id == "2": - inputs = { - "current_date": datetime.now().strftime("%Y-%m-%d") - } + # 尝试从数据库获取Agent + try: + agent_id = int(request.agent_id) + agent = get_db_manager().get_agent_by_id(agent_id) + + if not agent: + raise HTTPException(status_code=400, detail="Invalid agent ID") + + token = agent.get("dify_token") + inputs = agent.get("inputs") or {} + if agent.get("id") == 2: + inputs["current_date"] = datetime.now().strftime("%Y-%m-%d") + + except ValueError: + raise HTTPException(status_code=400, detail="Invalid agent ID format") url = "https://mate.aimateplus.com/v1/chat-messages" headers = { @@ -79,18 +111,27 @@ async def chat(request: ChatRequest,current_user: Dict[str, Any] = Depends(get_c "Content-Type": "application/json" } data = { - "inputs" : inputs, - "query" : request.user_prompt, - "response_mode" : "streaming", - "user" : current_user["mail"] + "inputs": inputs, + "query": request.user_prompt, + "response_mode": "streaming", + "user": current_user["mail"] } + logging.info(f"Chat request data: {data}") + # 保存用户提问 get_db_manager().save_user_question(current_user["id"], request.agent_id, request.user_prompt) response = requests.post(url, headers=headers, json=data, stream=True) - #获取response 的 stream + # 如果响应不成功,返回错误 + if response.status_code != 200: + raise HTTPException( + status_code=response.status_code, + detail=f"Failed to get response from Dify API: {response.text}" + ) + + # 获取response的stream def stream_response(): for chunk in response.iter_content(chunk_size=1024): if chunk: diff --git a/cryptoai/utils/db_manager.py b/cryptoai/utils/db_manager.py index 7be84bb..a7e6ab8 100644 --- a/cryptoai/utils/db_manager.py +++ b/cryptoai/utils/db_manager.py @@ -109,6 +109,27 @@ class UserQuestion(Base): {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} ) +# 定义Agent数据模型 +class Agent(Base): + """Agent数据表模型""" + __tablename__ = 'agents' + + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(100), nullable=False, unique=True, comment='Agent名称') + hello_prompt = Column(Text, nullable=True, comment='欢迎提示语') + description = Column(Text, nullable=True, comment='Agent描述') + dify_token = Column(String(255), nullable=True, comment='Dify API令牌') + inputs = Column(JSON, nullable=True, comment='输入参数JSON') + create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') + update_time = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间') + + # 索引和表属性 + __table_args__ = ( + Index('idx_name', 'name'), + Index('idx_create_time', 'create_time'), + {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} + ) + class DBManager: """数据库管理工具,用于连接MySQL数据库并保存智能体分析结果""" @@ -745,11 +766,339 @@ class DBManager: return None def close(self) -> None: - """关闭数据库连接""" + """关闭数据库连接(如果存在)""" if self.engine: self.engine.dispose() - self.engine = None logger.info("数据库连接已关闭") + self.engine = None + self.Session = None + + def create_agent(self, name: str, description: Optional[str] = None, + hello_prompt: Optional[str] = None, dify_token: Optional[str] = None, + inputs: Optional[Dict[str, Any]] = None) -> Optional[int]: + """ + 创建新的Agent + + Args: + name: Agent名称 + description: Agent描述 + hello_prompt: 欢迎提示语 + dify_token: Dify API令牌 + inputs: 输入参数JSON + + Returns: + 新Agent的ID,如果创建失败则返回None + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return None + + try: + # 创建会话 + session = self.Session() + + try: + # 检查名称是否已存在 + existing_agent = session.query(Agent).filter(Agent.name == name).first() + if existing_agent: + logger.warning(f"Agent名称 {name} 已存在") + return None + + # 创建新Agent + new_agent = Agent( + name=name, + description=description, + hello_prompt=hello_prompt, + dify_token=dify_token, + inputs=inputs, + create_time=datetime.now(), + update_time=datetime.now() + ) + + # 添加并提交 + session.add(new_agent) + session.commit() + + logger.info(f"成功创建Agent: {name}") + return new_agent.id + + except Exception as e: + session.rollback() + logger.error(f"创建Agent失败: {e}") + return None + + finally: + session.close() + + except Exception as e: + logger.error(f"创建数据库会话失败: {e}") + return None + + def update_agent(self, agent_id: int, name: Optional[str] = None, + description: Optional[str] = None, hello_prompt: Optional[str] = None, + dify_token: Optional[str] = None, inputs: Optional[Dict[str, Any]] = None) -> bool: + """ + 更新Agent信息 + + Args: + agent_id: Agent ID + name: Agent名称 + description: Agent描述 + hello_prompt: 欢迎提示语 + dify_token: Dify API令牌 + inputs: 输入参数JSON + + Returns: + 更新是否成功 + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return False + + try: + # 创建会话 + session = self.Session() + + try: + # 查询Agent + agent = session.query(Agent).filter(Agent.id == agent_id).first() + if not agent: + logger.warning(f"Agent ID {agent_id} 不存在") + return False + + # 更新字段 + if name is not None: + # 检查名称是否与其他Agent冲突 + if name != agent.name: + existing = session.query(Agent).filter(Agent.name == name).first() + if existing: + logger.warning(f"Agent名称 {name} 已被其他Agent使用") + return False + agent.name = name + + if description is not None: + agent.description = description + + if hello_prompt is not None: + agent.hello_prompt = hello_prompt + + if dify_token is not None: + agent.dify_token = dify_token + + if inputs is not None: + agent.inputs = inputs + + agent.update_time = datetime.now() + + # 提交更改 + session.commit() + + logger.info(f"成功更新Agent ID: {agent_id}") + return True + + except Exception as e: + session.rollback() + logger.error(f"更新Agent失败: {e}") + return False + + finally: + session.close() + + except Exception as e: + logger.error(f"创建数据库会话失败: {e}") + return False + + def get_agent_by_id(self, agent_id: int) -> Optional[Dict[str, Any]]: + """ + 通过ID获取Agent信息 + + Args: + agent_id: Agent ID + + Returns: + Agent信息,如果不存在则返回None + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return None + + try: + # 创建会话 + session = self.Session() + + try: + # 查询Agent + agent = session.query(Agent).filter(Agent.id == agent_id).first() + + if agent: + # 转换为字典 + return { + 'id': agent.id, + 'name': agent.name, + 'description': agent.description, + 'hello_prompt': agent.hello_prompt, + 'dify_token': agent.dify_token, + 'inputs': agent.inputs, + 'create_time': agent.create_time, + 'update_time': agent.update_time + } + else: + return None + + finally: + session.close() + + except Exception as e: + logger.error(f"获取Agent信息失败: {e}") + return None + + def get_agent_by_name(self, name: str) -> Optional[Dict[str, Any]]: + """ + 通过名称获取Agent信息 + + Args: + name: Agent名称 + + Returns: + Agent信息,如果不存在则返回None + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return None + + try: + # 创建会话 + session = self.Session() + + try: + # 查询Agent + agent = session.query(Agent).filter(Agent.name == name).first() + + if agent: + # 转换为字典 + return { + 'id': agent.id, + 'name': agent.name, + 'description': agent.description, + 'hello_prompt': agent.hello_prompt, + 'dify_token': agent.dify_token, + 'inputs': agent.inputs, + 'create_time': agent.create_time, + 'update_time': agent.update_time + } + else: + return None + + finally: + session.close() + + except Exception as e: + logger.error(f"获取Agent信息失败: {e}") + return None + + def list_agents(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]: + """ + 获取Agent列表 + + Args: + limit: 返回的最大数量 + skip: 跳过的数量 + + Returns: + Agent列表 + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return [] + + try: + # 创建会话 + session = self.Session() + + try: + # 查询Agent列表 + agents = session.query(Agent).order_by(Agent.create_time.asc()).offset(skip).limit(limit).all() + + # 转换为字典列表 + return [{ + 'id': agent.id, + 'name': agent.name, + 'description': agent.description, + 'hello_prompt': agent.hello_prompt, + 'dify_token': agent.dify_token, + 'inputs': agent.inputs, + 'create_time': agent.create_time, + 'update_time': agent.update_time + } for agent in agents] + + finally: + session.close() + + except Exception as e: + logger.error(f"获取Agent列表失败: {e}") + return [] + + def delete_agent(self, agent_id: int) -> bool: + """ + 删除Agent + + Args: + agent_id: Agent ID + + Returns: + 删除是否成功 + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return False + + try: + # 创建会话 + session = self.Session() + + try: + # 查询Agent + agent = session.query(Agent).filter(Agent.id == agent_id).first() + if not agent: + logger.warning(f"Agent ID {agent_id} 不存在") + return False + + # 删除Agent + session.delete(agent) + session.commit() + + logger.info(f"成功删除Agent ID: {agent_id}") + return True + + except Exception as e: + session.rollback() + logger.error(f"删除Agent失败: {e}") + return False + + finally: + session.close() + + except Exception as e: + logger.error(f"创建数据库会话失败: {e}") + return False # 单例模式 diff --git a/docker-compose.yml b/docker-compose.yml index 1af7139..44ab693 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -29,7 +29,7 @@ services: cryptoai-api: build: . container_name: cryptoai-api - image: cryptoai-api:0.0.10 + image: cryptoai-api:0.0.11 restart: always ports: - "8000:8000"