update
This commit is contained in:
parent
560597a892
commit
2654788e8f
@ -6,7 +6,7 @@ API路由模块,为前端提供REST API接口
|
||||
"""
|
||||
|
||||
import os
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Body
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
@ -25,53 +25,85 @@ router = APIRouter()
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
user_prompt: str
|
||||
agent_id: str
|
||||
agent_id: int
|
||||
|
||||
class AgentCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
hello_prompt: Optional[str] = None
|
||||
dify_token: Optional[str] = None
|
||||
inputs: Optional[Dict[str, Any]] = None
|
||||
|
||||
class AgentUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
hello_prompt: Optional[str] = None
|
||||
dify_token: Optional[str] = None
|
||||
inputs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
async def get_agents(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
@router.post("/create")
|
||||
async def create_agent(agent: AgentCreate):
|
||||
"""
|
||||
创建新的AI Agent
|
||||
"""
|
||||
|
||||
agent = get_db_manager().create_agent(agent.name, agent.description, agent.hello_prompt, agent.dify_token, agent.inputs)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@router.get("/list", response_model=List[Dict[str, Any]])
|
||||
async def get_agents(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000)
|
||||
):
|
||||
"""
|
||||
获取所有代理
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"id": "1",
|
||||
"name": "加密货币交易助手",
|
||||
"hello_prompt": "您好,我是加密货币交易助手,为您提供专业的数字货币交易分析和建议",
|
||||
"description": "帮你分析做加密货币技术分析",
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"name": "美股交易助手",
|
||||
"hello_prompt": "您好,我是美股交易助手,您可以直接输入股票名称,比如AAPL,然后我会为您提供专业的股票交易分析和建议",
|
||||
"description": "帮你分析做美股股票技术分析",
|
||||
},
|
||||
# # 检查用户权限
|
||||
# if current_user.get("level", 0) < 2: # 假设需要SVIP权限才能查看全部Agent
|
||||
# # 使用硬编码的默认Agent
|
||||
# return [
|
||||
# {
|
||||
# "id": "3",
|
||||
# "name": "期货交易助手",
|
||||
# "hello_prompt": "您好,我是期货交易助手,为您提供专业的期货交易分析和建议",
|
||||
# "description": "帮你分析做期货技术分析",
|
||||
# }
|
||||
]
|
||||
# "id": "1",
|
||||
# "name": "Crypto Assistant",
|
||||
# "hello_prompt": "您好,我是加密货币交易助手,为您提供专业的数字货币交易分析和建议",
|
||||
# "description": "帮你分析做加密货币技术分析",
|
||||
# },
|
||||
# {
|
||||
# "id": "2",
|
||||
# "name": "US Stock Assistant",
|
||||
# "hello_prompt": "您好,我是美股交易助手,您可以直接输入股票名称,比如AAPL,然后我会为您提供专业的股票交易分析和建议",
|
||||
# "description": "帮你分析做美股股票技术分析",
|
||||
# },
|
||||
# ]
|
||||
# else:
|
||||
# 从数据库获取Agent列表
|
||||
agents = get_db_manager().list_agents(limit=limit, skip=skip)
|
||||
return agents
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def chat(request: ChatRequest,current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
async def chat(request: ChatRequest, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""
|
||||
聊天接口
|
||||
"""
|
||||
if request.agent_id == "1":
|
||||
token = "app-vhJecqbcLukf72g0uxAb9tcz"
|
||||
elif request.agent_id == "2":
|
||||
token = "app-FLIYXrCbbQIkwgXx02Y1Mxjg"
|
||||
else:
|
||||
# 尝试从数据库获取Agent
|
||||
try:
|
||||
agent_id = int(request.agent_id)
|
||||
agent = get_db_manager().get_agent_by_id(agent_id)
|
||||
|
||||
if not agent:
|
||||
raise HTTPException(status_code=400, detail="Invalid agent ID")
|
||||
|
||||
inputs = {}
|
||||
if request.agent_id == "2":
|
||||
inputs = {
|
||||
"current_date": datetime.now().strftime("%Y-%m-%d")
|
||||
}
|
||||
token = agent.get("dify_token")
|
||||
inputs = agent.get("inputs") or {}
|
||||
if agent.get("id") == 2:
|
||||
inputs["current_date"] = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid agent ID format")
|
||||
|
||||
url = "https://mate.aimateplus.com/v1/chat-messages"
|
||||
headers = {
|
||||
@ -79,18 +111,27 @@ async def chat(request: ChatRequest,current_user: Dict[str, Any] = Depends(get_c
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"inputs" : inputs,
|
||||
"query" : request.user_prompt,
|
||||
"response_mode" : "streaming",
|
||||
"user" : current_user["mail"]
|
||||
"inputs": inputs,
|
||||
"query": request.user_prompt,
|
||||
"response_mode": "streaming",
|
||||
"user": current_user["mail"]
|
||||
}
|
||||
|
||||
logging.info(f"Chat request data: {data}")
|
||||
|
||||
# 保存用户提问
|
||||
get_db_manager().save_user_question(current_user["id"], request.agent_id, request.user_prompt)
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
|
||||
#获取response 的 stream
|
||||
# 如果响应不成功,返回错误
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Failed to get response from Dify API: {response.text}"
|
||||
)
|
||||
|
||||
# 获取response的stream
|
||||
def stream_response():
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
|
||||
@ -109,6 +109,27 @@ class UserQuestion(Base):
|
||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||||
)
|
||||
|
||||
# 定义Agent数据模型
|
||||
class Agent(Base):
|
||||
"""Agent数据表模型"""
|
||||
__tablename__ = 'agents'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(String(100), nullable=False, unique=True, comment='Agent名称')
|
||||
hello_prompt = Column(Text, nullable=True, comment='欢迎提示语')
|
||||
description = Column(Text, nullable=True, comment='Agent描述')
|
||||
dify_token = Column(String(255), nullable=True, comment='Dify API令牌')
|
||||
inputs = Column(JSON, nullable=True, comment='输入参数JSON')
|
||||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||||
update_time = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
|
||||
|
||||
# 索引和表属性
|
||||
__table_args__ = (
|
||||
Index('idx_name', 'name'),
|
||||
Index('idx_create_time', 'create_time'),
|
||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||||
)
|
||||
|
||||
class DBManager:
|
||||
"""数据库管理工具,用于连接MySQL数据库并保存智能体分析结果"""
|
||||
|
||||
@ -745,11 +766,339 @@ class DBManager:
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭数据库连接"""
|
||||
"""关闭数据库连接(如果存在)"""
|
||||
if self.engine:
|
||||
self.engine.dispose()
|
||||
self.engine = None
|
||||
logger.info("数据库连接已关闭")
|
||||
self.engine = None
|
||||
self.Session = None
|
||||
|
||||
def create_agent(self, name: str, description: Optional[str] = None,
|
||||
hello_prompt: Optional[str] = None, dify_token: Optional[str] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None) -> Optional[int]:
|
||||
"""
|
||||
创建新的Agent
|
||||
|
||||
Args:
|
||||
name: Agent名称
|
||||
description: Agent描述
|
||||
hello_prompt: 欢迎提示语
|
||||
dify_token: Dify API令牌
|
||||
inputs: 输入参数JSON
|
||||
|
||||
Returns:
|
||||
新Agent的ID,如果创建失败则返回None
|
||||
"""
|
||||
if not self.engine:
|
||||
try:
|
||||
self._init_db()
|
||||
except Exception as e:
|
||||
logger.error(f"重新连接数据库失败: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 创建会话
|
||||
session = self.Session()
|
||||
|
||||
try:
|
||||
# 检查名称是否已存在
|
||||
existing_agent = session.query(Agent).filter(Agent.name == name).first()
|
||||
if existing_agent:
|
||||
logger.warning(f"Agent名称 {name} 已存在")
|
||||
return None
|
||||
|
||||
# 创建新Agent
|
||||
new_agent = Agent(
|
||||
name=name,
|
||||
description=description,
|
||||
hello_prompt=hello_prompt,
|
||||
dify_token=dify_token,
|
||||
inputs=inputs,
|
||||
create_time=datetime.now(),
|
||||
update_time=datetime.now()
|
||||
)
|
||||
|
||||
# 添加并提交
|
||||
session.add(new_agent)
|
||||
session.commit()
|
||||
|
||||
logger.info(f"成功创建Agent: {name}")
|
||||
return new_agent.id
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"创建Agent失败: {e}")
|
||||
return None
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库会话失败: {e}")
|
||||
return None
|
||||
|
||||
def update_agent(self, agent_id: int, name: Optional[str] = None,
|
||||
description: Optional[str] = None, hello_prompt: Optional[str] = None,
|
||||
dify_token: Optional[str] = None, inputs: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
更新Agent信息
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
name: Agent名称
|
||||
description: Agent描述
|
||||
hello_prompt: 欢迎提示语
|
||||
dify_token: Dify API令牌
|
||||
inputs: 输入参数JSON
|
||||
|
||||
Returns:
|
||||
更新是否成功
|
||||
"""
|
||||
if not self.engine:
|
||||
try:
|
||||
self._init_db()
|
||||
except Exception as e:
|
||||
logger.error(f"重新连接数据库失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 创建会话
|
||||
session = self.Session()
|
||||
|
||||
try:
|
||||
# 查询Agent
|
||||
agent = session.query(Agent).filter(Agent.id == agent_id).first()
|
||||
if not agent:
|
||||
logger.warning(f"Agent ID {agent_id} 不存在")
|
||||
return False
|
||||
|
||||
# 更新字段
|
||||
if name is not None:
|
||||
# 检查名称是否与其他Agent冲突
|
||||
if name != agent.name:
|
||||
existing = session.query(Agent).filter(Agent.name == name).first()
|
||||
if existing:
|
||||
logger.warning(f"Agent名称 {name} 已被其他Agent使用")
|
||||
return False
|
||||
agent.name = name
|
||||
|
||||
if description is not None:
|
||||
agent.description = description
|
||||
|
||||
if hello_prompt is not None:
|
||||
agent.hello_prompt = hello_prompt
|
||||
|
||||
if dify_token is not None:
|
||||
agent.dify_token = dify_token
|
||||
|
||||
if inputs is not None:
|
||||
agent.inputs = inputs
|
||||
|
||||
agent.update_time = datetime.now()
|
||||
|
||||
# 提交更改
|
||||
session.commit()
|
||||
|
||||
logger.info(f"成功更新Agent ID: {agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"更新Agent失败: {e}")
|
||||
return False
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库会话失败: {e}")
|
||||
return False
|
||||
|
||||
def get_agent_by_id(self, agent_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
通过ID获取Agent信息
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
Agent信息,如果不存在则返回None
|
||||
"""
|
||||
if not self.engine:
|
||||
try:
|
||||
self._init_db()
|
||||
except Exception as e:
|
||||
logger.error(f"重新连接数据库失败: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 创建会话
|
||||
session = self.Session()
|
||||
|
||||
try:
|
||||
# 查询Agent
|
||||
agent = session.query(Agent).filter(Agent.id == agent_id).first()
|
||||
|
||||
if agent:
|
||||
# 转换为字典
|
||||
return {
|
||||
'id': agent.id,
|
||||
'name': agent.name,
|
||||
'description': agent.description,
|
||||
'hello_prompt': agent.hello_prompt,
|
||||
'dify_token': agent.dify_token,
|
||||
'inputs': agent.inputs,
|
||||
'create_time': agent.create_time,
|
||||
'update_time': agent.update_time
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Agent信息失败: {e}")
|
||||
return None
|
||||
|
||||
def get_agent_by_name(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
通过名称获取Agent信息
|
||||
|
||||
Args:
|
||||
name: Agent名称
|
||||
|
||||
Returns:
|
||||
Agent信息,如果不存在则返回None
|
||||
"""
|
||||
if not self.engine:
|
||||
try:
|
||||
self._init_db()
|
||||
except Exception as e:
|
||||
logger.error(f"重新连接数据库失败: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 创建会话
|
||||
session = self.Session()
|
||||
|
||||
try:
|
||||
# 查询Agent
|
||||
agent = session.query(Agent).filter(Agent.name == name).first()
|
||||
|
||||
if agent:
|
||||
# 转换为字典
|
||||
return {
|
||||
'id': agent.id,
|
||||
'name': agent.name,
|
||||
'description': agent.description,
|
||||
'hello_prompt': agent.hello_prompt,
|
||||
'dify_token': agent.dify_token,
|
||||
'inputs': agent.inputs,
|
||||
'create_time': agent.create_time,
|
||||
'update_time': agent.update_time
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Agent信息失败: {e}")
|
||||
return None
|
||||
|
||||
def list_agents(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取Agent列表
|
||||
|
||||
Args:
|
||||
limit: 返回的最大数量
|
||||
skip: 跳过的数量
|
||||
|
||||
Returns:
|
||||
Agent列表
|
||||
"""
|
||||
if not self.engine:
|
||||
try:
|
||||
self._init_db()
|
||||
except Exception as e:
|
||||
logger.error(f"重新连接数据库失败: {e}")
|
||||
return []
|
||||
|
||||
try:
|
||||
# 创建会话
|
||||
session = self.Session()
|
||||
|
||||
try:
|
||||
# 查询Agent列表
|
||||
agents = session.query(Agent).order_by(Agent.create_time.asc()).offset(skip).limit(limit).all()
|
||||
|
||||
# 转换为字典列表
|
||||
return [{
|
||||
'id': agent.id,
|
||||
'name': agent.name,
|
||||
'description': agent.description,
|
||||
'hello_prompt': agent.hello_prompt,
|
||||
'dify_token': agent.dify_token,
|
||||
'inputs': agent.inputs,
|
||||
'create_time': agent.create_time,
|
||||
'update_time': agent.update_time
|
||||
} for agent in agents]
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Agent列表失败: {e}")
|
||||
return []
|
||||
|
||||
def delete_agent(self, agent_id: int) -> bool:
|
||||
"""
|
||||
删除Agent
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
删除是否成功
|
||||
"""
|
||||
if not self.engine:
|
||||
try:
|
||||
self._init_db()
|
||||
except Exception as e:
|
||||
logger.error(f"重新连接数据库失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 创建会话
|
||||
session = self.Session()
|
||||
|
||||
try:
|
||||
# 查询Agent
|
||||
agent = session.query(Agent).filter(Agent.id == agent_id).first()
|
||||
if not agent:
|
||||
logger.warning(f"Agent ID {agent_id} 不存在")
|
||||
return False
|
||||
|
||||
# 删除Agent
|
||||
session.delete(agent)
|
||||
session.commit()
|
||||
|
||||
logger.info(f"成功删除Agent ID: {agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"删除Agent失败: {e}")
|
||||
return False
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库会话失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 单例模式
|
||||
|
||||
@ -29,7 +29,7 @@ services:
|
||||
cryptoai-api:
|
||||
build: .
|
||||
container_name: cryptoai-api
|
||||
image: cryptoai-api:0.0.10
|
||||
image: cryptoai-api:0.0.11
|
||||
restart: always
|
||||
ports:
|
||||
- "8000:8000"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user