268 lines
8.9 KiB
Python
268 lines
8.9 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
from typing import Dict, Any, List, Optional
|
||
from datetime import datetime
|
||
|
||
from sqlalchemy import Column, Integer, String, Text, DateTime, Index
|
||
from sqlalchemy.dialects.mysql import JSON
|
||
|
||
from cryptoai.models.base import Base, logger
|
||
|
||
# 定义Agent数据模型
|
||
class Agent(Base):
|
||
"""Agent数据表模型"""
|
||
__tablename__ = 'agents'
|
||
|
||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||
name = Column(String(100), nullable=False, unique=True, comment='Agent名称')
|
||
hello_prompt = Column(Text, nullable=True, comment='欢迎提示语')
|
||
description = Column(Text, nullable=True, comment='Agent描述')
|
||
dify_token = Column(String(255), nullable=True, comment='Dify API令牌')
|
||
inputs = Column(JSON, nullable=True, comment='输入参数JSON')
|
||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||
update_time = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
|
||
|
||
# 索引和表属性
|
||
__table_args__ = (
|
||
Index('idx_name', 'name'),
|
||
Index('idx_create_time', 'create_time'),
|
||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||
)
|
||
|
||
class AgentManager:
|
||
"""Agent管理类"""
|
||
|
||
def __init__(self, db_session):
|
||
self.session = db_session
|
||
|
||
def create_agent(self, name: str, description: Optional[str] = None,
|
||
hello_prompt: Optional[str] = None, dify_token: Optional[str] = None,
|
||
inputs: Optional[Dict[str, Any]] = None) -> Optional[int]:
|
||
"""
|
||
创建新的Agent
|
||
|
||
Args:
|
||
name: Agent名称
|
||
description: Agent描述
|
||
hello_prompt: 欢迎提示语
|
||
dify_token: Dify API令牌
|
||
inputs: 输入参数JSON
|
||
|
||
Returns:
|
||
新Agent的ID,如果创建失败则返回None
|
||
"""
|
||
try:
|
||
# 检查名称是否已存在
|
||
existing_agent = self.session.query(Agent).filter(Agent.name == name).first()
|
||
if existing_agent:
|
||
logger.warning(f"Agent名称 {name} 已存在")
|
||
return None
|
||
|
||
# 创建新Agent
|
||
new_agent = Agent(
|
||
name=name,
|
||
description=description,
|
||
hello_prompt=hello_prompt,
|
||
dify_token=dify_token,
|
||
inputs=inputs,
|
||
create_time=datetime.now(),
|
||
update_time=datetime.now()
|
||
)
|
||
|
||
# 添加并提交
|
||
self.session.add(new_agent)
|
||
self.session.commit()
|
||
|
||
logger.info(f"成功创建Agent: {name}")
|
||
return new_agent.id
|
||
|
||
except Exception as e:
|
||
self.session.rollback()
|
||
logger.error(f"创建Agent失败: {e}")
|
||
return None
|
||
|
||
def update_agent(self, agent_id: int, name: Optional[str] = None,
|
||
description: Optional[str] = None, hello_prompt: Optional[str] = None,
|
||
dify_token: Optional[str] = None, inputs: Optional[Dict[str, Any]] = None) -> bool:
|
||
"""
|
||
更新Agent信息
|
||
|
||
Args:
|
||
agent_id: Agent ID
|
||
name: Agent名称
|
||
description: Agent描述
|
||
hello_prompt: 欢迎提示语
|
||
dify_token: Dify API令牌
|
||
inputs: 输入参数JSON
|
||
|
||
Returns:
|
||
更新是否成功
|
||
"""
|
||
try:
|
||
# 查询Agent
|
||
agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
|
||
if not agent:
|
||
logger.warning(f"Agent ID {agent_id} 不存在")
|
||
return False
|
||
|
||
# 更新字段
|
||
if name is not None:
|
||
# 检查名称是否与其他Agent冲突
|
||
if name != agent.name:
|
||
existing = self.session.query(Agent).filter(Agent.name == name).first()
|
||
if existing:
|
||
logger.warning(f"Agent名称 {name} 已被其他Agent使用")
|
||
return False
|
||
agent.name = name
|
||
|
||
if description is not None:
|
||
agent.description = description
|
||
|
||
if hello_prompt is not None:
|
||
agent.hello_prompt = hello_prompt
|
||
|
||
if dify_token is not None:
|
||
agent.dify_token = dify_token
|
||
|
||
if inputs is not None:
|
||
agent.inputs = inputs
|
||
|
||
agent.update_time = datetime.now()
|
||
|
||
# 提交更改
|
||
self.session.commit()
|
||
|
||
logger.info(f"成功更新Agent ID: {agent_id}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.session.rollback()
|
||
logger.error(f"更新Agent失败: {e}")
|
||
return False
|
||
|
||
def get_agent_by_id(self, agent_id: int) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
通过ID获取Agent信息
|
||
|
||
Args:
|
||
agent_id: Agent ID
|
||
|
||
Returns:
|
||
Agent信息,如果不存在则返回None
|
||
"""
|
||
try:
|
||
# 查询Agent
|
||
agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
|
||
|
||
if agent:
|
||
# 转换为字典
|
||
return {
|
||
'id': agent.id,
|
||
'name': agent.name,
|
||
'description': agent.description,
|
||
'hello_prompt': agent.hello_prompt,
|
||
'dify_token': agent.dify_token,
|
||
'inputs': agent.inputs,
|
||
'create_time': agent.create_time,
|
||
'update_time': agent.update_time
|
||
}
|
||
else:
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取Agent信息失败: {e}")
|
||
return None
|
||
|
||
def get_agent_by_name(self, name: str) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
通过名称获取Agent信息
|
||
|
||
Args:
|
||
name: Agent名称
|
||
|
||
Returns:
|
||
Agent信息,如果不存在则返回None
|
||
"""
|
||
try:
|
||
# 查询Agent
|
||
agent = self.session.query(Agent).filter(Agent.name == name).first()
|
||
|
||
if agent:
|
||
# 转换为字典
|
||
return {
|
||
'id': agent.id,
|
||
'name': agent.name,
|
||
'description': agent.description,
|
||
'hello_prompt': agent.hello_prompt,
|
||
'dify_token': agent.dify_token,
|
||
'inputs': agent.inputs,
|
||
'create_time': agent.create_time,
|
||
'update_time': agent.update_time
|
||
}
|
||
else:
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取Agent信息失败: {e}")
|
||
return None
|
||
|
||
def list_agents(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
|
||
"""
|
||
获取Agent列表
|
||
|
||
Args:
|
||
limit: 返回的最大数量
|
||
skip: 跳过的数量
|
||
|
||
Returns:
|
||
Agent列表
|
||
"""
|
||
try:
|
||
# 查询Agent列表
|
||
agents = self.session.query(Agent).order_by(Agent.create_time.asc()).offset(skip).limit(limit).all()
|
||
|
||
# 转换为字典列表
|
||
return [{
|
||
'id': agent.id,
|
||
'name': agent.name,
|
||
'description': agent.description,
|
||
'hello_prompt': agent.hello_prompt,
|
||
'dify_token': agent.dify_token,
|
||
'inputs': agent.inputs,
|
||
'create_time': agent.create_time,
|
||
'update_time': agent.update_time
|
||
} for agent in agents]
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取Agent列表失败: {e}")
|
||
return []
|
||
|
||
def delete_agent(self, agent_id: int) -> bool:
|
||
"""
|
||
删除Agent
|
||
|
||
Args:
|
||
agent_id: Agent ID
|
||
|
||
Returns:
|
||
删除是否成功
|
||
"""
|
||
try:
|
||
# 查询Agent
|
||
agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
|
||
if not agent:
|
||
logger.warning(f"Agent ID {agent_id} 不存在")
|
||
return False
|
||
|
||
# 删除Agent
|
||
self.session.delete(agent)
|
||
self.session.commit()
|
||
|
||
logger.info(f"成功删除Agent ID: {agent_id}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.session.rollback()
|
||
logger.error(f"删除Agent失败: {e}")
|
||
return False |