This commit is contained in:
aaron 2025-05-09 23:37:28 +08:00
parent 560597a892
commit 2654788e8f
3 changed files with 435 additions and 45 deletions

View File

@ -6,7 +6,7 @@ API路由模块为前端提供REST API接口
""" """
import os import os
from fastapi import APIRouter, Depends, HTTPException, status, Body from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from pydantic import BaseModel from pydantic import BaseModel
import json import json
@ -25,34 +25,63 @@ router = APIRouter()
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
user_prompt: str user_prompt: str
agent_id: str agent_id: int
class AgentCreate(BaseModel):
name: str
description: Optional[str] = None
hello_prompt: Optional[str] = None
dify_token: Optional[str] = None
inputs: Optional[Dict[str, Any]] = None
class AgentUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
hello_prompt: Optional[str] = None
dify_token: Optional[str] = None
inputs: Optional[Dict[str, Any]] = None
@router.get("/list") @router.post("/create")
async def get_agents(current_user: Dict[str, Any] = Depends(get_current_user)): async def create_agent(agent: AgentCreate):
"""
创建新的AI Agent
"""
agent = get_db_manager().create_agent(agent.name, agent.description, agent.hello_prompt, agent.dify_token, agent.inputs)
return agent
@router.get("/list", response_model=List[Dict[str, Any]])
async def get_agents(
skip: int = Query(0, ge=0),
limit: int = Query(100, ge=1, le=1000)
):
""" """
获取所有代理 获取所有代理
""" """
return [ # # 检查用户权限
{ # if current_user.get("level", 0) < 2: # 假设需要SVIP权限才能查看全部Agent
"id": "1", # # 使用硬编码的默认Agent
"name": "加密货币交易助手", # return [
"hello_prompt": "您好,我是加密货币交易助手,为您提供专业的数字货币交易分析和建议",
"description": "帮你分析做加密货币技术分析",
},
{
"id": "2",
"name": "美股交易助手",
"hello_prompt": "您好我是美股交易助手您可以直接输入股票名称比如AAPL然后我会为您提供专业的股票交易分析和建议",
"description": "帮你分析做美股股票技术分析",
},
# { # {
# "id": "3", # "id": "1",
# "name": "期货交易助手", # "name": "Crypto Assistant",
# "hello_prompt": "您好,我是期货交易助手,为您提供专业的期货交易分析和建议", # "hello_prompt": "您好,我是加密货币交易助手,为您提供专业的数字货币交易分析和建议",
# "description": "帮你分析做期货技术分析", # "description": "帮你分析做加密货币技术分析",
# } # },
] # {
# "id": "2",
# "name": "US Stock Assistant",
# "hello_prompt": "您好我是美股交易助手您可以直接输入股票名称比如AAPL然后我会为您提供专业的股票交易分析和建议",
# "description": "帮你分析做美股股票技术分析",
# },
# ]
# else:
# 从数据库获取Agent列表
agents = get_db_manager().list_agents(limit=limit, skip=skip)
return agents
@router.post("/chat") @router.post("/chat")
@ -60,18 +89,21 @@ async def chat(request: ChatRequest,current_user: Dict[str, Any] = Depends(get_c
""" """
聊天接口 聊天接口
""" """
if request.agent_id == "1": # 尝试从数据库获取Agent
token = "app-vhJecqbcLukf72g0uxAb9tcz" try:
elif request.agent_id == "2": agent_id = int(request.agent_id)
token = "app-FLIYXrCbbQIkwgXx02Y1Mxjg" agent = get_db_manager().get_agent_by_id(agent_id)
else:
if not agent:
raise HTTPException(status_code=400, detail="Invalid agent ID") raise HTTPException(status_code=400, detail="Invalid agent ID")
inputs = {} token = agent.get("dify_token")
if request.agent_id == "2": inputs = agent.get("inputs") or {}
inputs = { if agent.get("id") == 2:
"current_date": datetime.now().strftime("%Y-%m-%d") inputs["current_date"] = datetime.now().strftime("%Y-%m-%d")
}
except ValueError:
raise HTTPException(status_code=400, detail="Invalid agent ID format")
url = "https://mate.aimateplus.com/v1/chat-messages" url = "https://mate.aimateplus.com/v1/chat-messages"
headers = { headers = {
@ -85,11 +117,20 @@ async def chat(request: ChatRequest,current_user: Dict[str, Any] = Depends(get_c
"user": current_user["mail"] "user": current_user["mail"]
} }
logging.info(f"Chat request data: {data}")
# 保存用户提问 # 保存用户提问
get_db_manager().save_user_question(current_user["id"], request.agent_id, request.user_prompt) get_db_manager().save_user_question(current_user["id"], request.agent_id, request.user_prompt)
response = requests.post(url, headers=headers, json=data, stream=True) response = requests.post(url, headers=headers, json=data, stream=True)
# 如果响应不成功,返回错误
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail=f"Failed to get response from Dify API: {response.text}"
)
# 获取response的stream # 获取response的stream
def stream_response(): def stream_response():
for chunk in response.iter_content(chunk_size=1024): for chunk in response.iter_content(chunk_size=1024):

View File

@ -109,6 +109,27 @@ class UserQuestion(Base):
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
) )
# 定义Agent数据模型
class Agent(Base):
"""Agent数据表模型"""
__tablename__ = 'agents'
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(String(100), nullable=False, unique=True, comment='Agent名称')
hello_prompt = Column(Text, nullable=True, comment='欢迎提示语')
description = Column(Text, nullable=True, comment='Agent描述')
dify_token = Column(String(255), nullable=True, comment='Dify API令牌')
inputs = Column(JSON, nullable=True, comment='输入参数JSON')
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
update_time = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
# 索引和表属性
__table_args__ = (
Index('idx_name', 'name'),
Index('idx_create_time', 'create_time'),
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
)
class DBManager: class DBManager:
"""数据库管理工具用于连接MySQL数据库并保存智能体分析结果""" """数据库管理工具用于连接MySQL数据库并保存智能体分析结果"""
@ -745,11 +766,339 @@ class DBManager:
return None return None
def close(self) -> None: def close(self) -> None:
"""关闭数据库连接""" """关闭数据库连接(如果存在)"""
if self.engine: if self.engine:
self.engine.dispose() self.engine.dispose()
self.engine = None
logger.info("数据库连接已关闭") logger.info("数据库连接已关闭")
self.engine = None
self.Session = None
def create_agent(self, name: str, description: Optional[str] = None,
hello_prompt: Optional[str] = None, dify_token: Optional[str] = None,
inputs: Optional[Dict[str, Any]] = None) -> Optional[int]:
"""
创建新的Agent
Args:
name: Agent名称
description: Agent描述
hello_prompt: 欢迎提示语
dify_token: Dify API令牌
inputs: 输入参数JSON
Returns:
新Agent的ID如果创建失败则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 检查名称是否已存在
existing_agent = session.query(Agent).filter(Agent.name == name).first()
if existing_agent:
logger.warning(f"Agent名称 {name} 已存在")
return None
# 创建新Agent
new_agent = Agent(
name=name,
description=description,
hello_prompt=hello_prompt,
dify_token=dify_token,
inputs=inputs,
create_time=datetime.now(),
update_time=datetime.now()
)
# 添加并提交
session.add(new_agent)
session.commit()
logger.info(f"成功创建Agent: {name}")
return new_agent.id
except Exception as e:
session.rollback()
logger.error(f"创建Agent失败: {e}")
return None
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return None
def update_agent(self, agent_id: int, name: Optional[str] = None,
description: Optional[str] = None, hello_prompt: Optional[str] = None,
dify_token: Optional[str] = None, inputs: Optional[Dict[str, Any]] = None) -> bool:
"""
更新Agent信息
Args:
agent_id: Agent ID
name: Agent名称
description: Agent描述
hello_prompt: 欢迎提示语
dify_token: Dify API令牌
inputs: 输入参数JSON
Returns:
更新是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 查询Agent
agent = session.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
logger.warning(f"Agent ID {agent_id} 不存在")
return False
# 更新字段
if name is not None:
# 检查名称是否与其他Agent冲突
if name != agent.name:
existing = session.query(Agent).filter(Agent.name == name).first()
if existing:
logger.warning(f"Agent名称 {name} 已被其他Agent使用")
return False
agent.name = name
if description is not None:
agent.description = description
if hello_prompt is not None:
agent.hello_prompt = hello_prompt
if dify_token is not None:
agent.dify_token = dify_token
if inputs is not None:
agent.inputs = inputs
agent.update_time = datetime.now()
# 提交更改
session.commit()
logger.info(f"成功更新Agent ID: {agent_id}")
return True
except Exception as e:
session.rollback()
logger.error(f"更新Agent失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
def get_agent_by_id(self, agent_id: int) -> Optional[Dict[str, Any]]:
"""
通过ID获取Agent信息
Args:
agent_id: Agent ID
Returns:
Agent信息如果不存在则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 查询Agent
agent = session.query(Agent).filter(Agent.id == agent_id).first()
if agent:
# 转换为字典
return {
'id': agent.id,
'name': agent.name,
'description': agent.description,
'hello_prompt': agent.hello_prompt,
'dify_token': agent.dify_token,
'inputs': agent.inputs,
'create_time': agent.create_time,
'update_time': agent.update_time
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取Agent信息失败: {e}")
return None
def get_agent_by_name(self, name: str) -> Optional[Dict[str, Any]]:
"""
通过名称获取Agent信息
Args:
name: Agent名称
Returns:
Agent信息如果不存在则返回None
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return None
try:
# 创建会话
session = self.Session()
try:
# 查询Agent
agent = session.query(Agent).filter(Agent.name == name).first()
if agent:
# 转换为字典
return {
'id': agent.id,
'name': agent.name,
'description': agent.description,
'hello_prompt': agent.hello_prompt,
'dify_token': agent.dify_token,
'inputs': agent.inputs,
'create_time': agent.create_time,
'update_time': agent.update_time
}
else:
return None
finally:
session.close()
except Exception as e:
logger.error(f"获取Agent信息失败: {e}")
return None
def list_agents(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
"""
获取Agent列表
Args:
limit: 返回的最大数量
skip: 跳过的数量
Returns:
Agent列表
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return []
try:
# 创建会话
session = self.Session()
try:
# 查询Agent列表
agents = session.query(Agent).order_by(Agent.create_time.asc()).offset(skip).limit(limit).all()
# 转换为字典列表
return [{
'id': agent.id,
'name': agent.name,
'description': agent.description,
'hello_prompt': agent.hello_prompt,
'dify_token': agent.dify_token,
'inputs': agent.inputs,
'create_time': agent.create_time,
'update_time': agent.update_time
} for agent in agents]
finally:
session.close()
except Exception as e:
logger.error(f"获取Agent列表失败: {e}")
return []
def delete_agent(self, agent_id: int) -> bool:
"""
删除Agent
Args:
agent_id: Agent ID
Returns:
删除是否成功
"""
if not self.engine:
try:
self._init_db()
except Exception as e:
logger.error(f"重新连接数据库失败: {e}")
return False
try:
# 创建会话
session = self.Session()
try:
# 查询Agent
agent = session.query(Agent).filter(Agent.id == agent_id).first()
if not agent:
logger.warning(f"Agent ID {agent_id} 不存在")
return False
# 删除Agent
session.delete(agent)
session.commit()
logger.info(f"成功删除Agent ID: {agent_id}")
return True
except Exception as e:
session.rollback()
logger.error(f"删除Agent失败: {e}")
return False
finally:
session.close()
except Exception as e:
logger.error(f"创建数据库会话失败: {e}")
return False
# 单例模式 # 单例模式

View File

@ -29,7 +29,7 @@ services:
cryptoai-api: cryptoai-api:
build: . build: .
container_name: cryptoai-api container_name: cryptoai-api
image: cryptoai-api:0.0.10 image: cryptoai-api:0.0.11
restart: always restart: always
ports: ports:
- "8000:8000" - "8000:8000"