100 lines
3.0 KiB
Python
100 lines
3.0 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
API路由模块,为前端提供REST API接口
|
||
"""
|
||
|
||
import os
|
||
from fastapi import APIRouter, Depends, HTTPException, status, Body
|
||
from typing import Dict, Any, List, Optional
|
||
from pydantic import BaseModel
|
||
import json
|
||
import logging
|
||
|
||
from cryptoai.api.deepseek_api import DeepSeekAPI
|
||
from cryptoai.utils.config_loader import ConfigLoader
|
||
from fastapi.responses import StreamingResponse
|
||
from cryptoai.routes.user import get_current_user
|
||
import requests
|
||
from datetime import datetime
|
||
from cryptoai.utils.db_manager import get_db_manager
|
||
|
||
# 创建路由
|
||
router = APIRouter()
|
||
|
||
class ChatRequest(BaseModel):
|
||
user_prompt: str
|
||
agent_id: str
|
||
|
||
|
||
@router.get("/list")
|
||
async def get_agents(current_user: Dict[str, Any] = Depends(get_current_user)):
|
||
"""
|
||
获取所有代理
|
||
"""
|
||
return [
|
||
{
|
||
"id": "1",
|
||
"name": "加密货币交易助手",
|
||
"hello_prompt": "您好,我是加密货币交易助手,为您提供专业的数字货币交易分析和建议",
|
||
"description": "帮你分析做加密货币技术分析",
|
||
},
|
||
{
|
||
"id": "2",
|
||
"name": "美股交易助手",
|
||
"hello_prompt": "您好,我是美股交易助手,您可以直接输入股票名称,比如AAPL,然后我会为您提供专业的股票交易分析和建议",
|
||
"description": "帮你分析做美股股票技术分析",
|
||
},
|
||
# {
|
||
# "id": "3",
|
||
# "name": "期货交易助手",
|
||
# "hello_prompt": "您好,我是期货交易助手,为您提供专业的期货交易分析和建议",
|
||
# "description": "帮你分析做期货技术分析",
|
||
# }
|
||
]
|
||
|
||
|
||
@router.post("/chat")
|
||
async def chat(request: ChatRequest,current_user: Dict[str, Any] = Depends(get_current_user)):
|
||
"""
|
||
聊天接口
|
||
"""
|
||
if request.agent_id == "1":
|
||
token = "app-vhJecqbcLukf72g0uxAb9tcz"
|
||
elif request.agent_id == "2":
|
||
token = "app-FLIYXrCbbQIkwgXx02Y1Mxjg"
|
||
else:
|
||
raise HTTPException(status_code=400, detail="Invalid agent ID")
|
||
|
||
inputs = {}
|
||
if request.agent_id == "2":
|
||
inputs = {
|
||
"current_date": datetime.now().strftime("%Y-%m-%d")
|
||
}
|
||
|
||
url = "https://mate.aimateplus.com/v1/chat-messages"
|
||
headers = {
|
||
"Authorization": f"Bearer {token}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
data = {
|
||
"inputs" : inputs,
|
||
"query" : request.user_prompt,
|
||
"response_mode" : "streaming",
|
||
"user" : current_user["mail"]
|
||
}
|
||
|
||
# 保存用户提问
|
||
get_db_manager().save_user_question(current_user["id"], request.agent_id, request.user_prompt)
|
||
|
||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||
|
||
#获取response 的 stream
|
||
def stream_response():
|
||
for chunk in response.iter_content(chunk_size=1024):
|
||
if chunk:
|
||
yield chunk
|
||
|
||
return StreamingResponse(stream_response(), media_type="text/plain")
|