127 lines
3.5 KiB
Python
127 lines
3.5 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
API路由模块,为前端提供REST API接口
|
||
"""
|
||
|
||
import os
|
||
from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path
|
||
from typing import Dict, Any, List, Optional
|
||
from pydantic import BaseModel
|
||
import json
|
||
import logging
|
||
|
||
from cryptoai.api.deepseek_api import DeepSeekAPI
|
||
from cryptoai.utils.config_loader import ConfigLoader
|
||
from fastapi.responses import StreamingResponse
|
||
from cryptoai.routes.user import get_current_user
|
||
import requests
|
||
from datetime import datetime
|
||
from cryptoai.utils.db_manager import get_db_manager
|
||
|
||
# 创建路由
|
||
router = APIRouter()
|
||
|
||
class ChatRequest(BaseModel):
|
||
user_prompt: str
|
||
agent_id: int
|
||
conversation_id: Optional[str] = None
|
||
|
||
class AgentCreate(BaseModel):
|
||
name: str
|
||
description: Optional[str] = None
|
||
hello_prompt: Optional[str] = None
|
||
dify_token: Optional[str] = None
|
||
inputs: Optional[Dict[str, Any]] = None
|
||
|
||
class AgentUpdate(BaseModel):
|
||
name: Optional[str] = None
|
||
description: Optional[str] = None
|
||
hello_prompt: Optional[str] = None
|
||
dify_token: Optional[str] = None
|
||
inputs: Optional[Dict[str, Any]] = None
|
||
|
||
|
||
@router.post("/create")
|
||
async def create_agent(agent: AgentCreate):
|
||
"""
|
||
创建新的AI Agent
|
||
"""
|
||
|
||
agent = get_db_manager().create_agent(agent.name, agent.description, agent.hello_prompt, agent.dify_token, agent.inputs)
|
||
|
||
return agent
|
||
|
||
|
||
@router.get("/list", response_model=List[Dict[str, Any]])
|
||
async def get_agents(
|
||
skip: int = Query(0, ge=0),
|
||
limit: int = Query(100, ge=1, le=1000)
|
||
):
|
||
"""
|
||
获取所有代理
|
||
"""
|
||
|
||
# 从数据库获取Agent列表
|
||
agents = get_db_manager().list_agents(limit=limit, skip=skip)
|
||
return agents
|
||
|
||
|
||
@router.post("/chat")
|
||
async def chat(request: ChatRequest, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||
"""
|
||
聊天接口
|
||
"""
|
||
# 尝试从数据库获取Agent
|
||
try:
|
||
agent_id = int(request.agent_id)
|
||
agent = get_db_manager().get_agent_by_id(agent_id)
|
||
|
||
if not agent:
|
||
raise HTTPException(status_code=400, detail="Invalid agent ID")
|
||
|
||
token = agent.get("dify_token")
|
||
inputs = agent.get("inputs") or {}
|
||
inputs["current_date"] = datetime.now().strftime("%Y-%m-%d")
|
||
|
||
except ValueError:
|
||
raise HTTPException(status_code=400, detail="Invalid agent ID format")
|
||
|
||
url = "https://mate.aimateplus.com/v1/chat-messages"
|
||
headers = {
|
||
"Authorization": f"Bearer {token}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
data = {
|
||
"inputs": inputs,
|
||
"query": request.user_prompt,
|
||
"response_mode": "streaming",
|
||
"user": current_user["mail"]
|
||
}
|
||
|
||
if request.conversation_id:
|
||
data["conversation_id"] = request.conversation_id
|
||
|
||
logging.info(f"Chat request data: {data}")
|
||
|
||
# 保存用户提问
|
||
get_db_manager().save_user_question(current_user["id"], request.agent_id, request.user_prompt)
|
||
|
||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||
|
||
# 如果响应不成功,返回错误
|
||
if response.status_code != 200:
|
||
raise HTTPException(
|
||
status_code=response.status_code,
|
||
detail=f"Failed to get response from Dify API: {response.text}"
|
||
)
|
||
|
||
# 获取response的stream
|
||
def stream_response():
|
||
for chunk in response.iter_content(chunk_size=1024):
|
||
if chunk:
|
||
yield chunk
|
||
|
||
return StreamingResponse(stream_response(), media_type="text/plain")
|