347 lines
12 KiB
Python
347 lines
12 KiB
Python
from fastapi import APIRouter
|
||
from typing import Optional
|
||
from fastapi import Depends
|
||
from pydantic import BaseModel
|
||
from cryptoai.routes.user import get_current_user
|
||
from fastapi import HTTPException
|
||
from fastapi.responses import StreamingResponse
|
||
import requests
|
||
from datetime import date, timedelta
|
||
from cryptoai.models.analysis_history import AnalysisHistoryManager
|
||
from cryptoai.models.user_question import UserQuestionManager
|
||
from cryptoai.models.token import TokenManager
|
||
from sqlalchemy.orm import Session
|
||
from cryptoai.utils.db_manager import get_db
|
||
from cryptoai.models.user import UserManager
|
||
from cryptoai.models.user_subscription import UserSubscriptionManager
|
||
from datetime import datetime
|
||
|
||
class AnalysisHistoryRequest(BaseModel):
|
||
symbol: str
|
||
content: str
|
||
timeframe: Optional[str] = None
|
||
type: str
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
# api_url = 'http://ai.meida.fit/v1'
|
||
# agent_token = 'app-wHONyQlk26htYWXbVxtBHI2Y'
|
||
|
||
api_url = 'https://mate.aimateplus.com/v1'
|
||
agent_token = 'app-pPtva2AdJ8hJzkBKu12ThWjD'
|
||
|
||
@router.post("/analysis_history")
|
||
async def analysis_history(request: AnalysisHistoryRequest,
|
||
current_user: dict = Depends(get_current_user),
|
||
session: Session = Depends(get_db)):
|
||
|
||
manager = AnalysisHistoryManager(session)
|
||
manager.add_analysis_history(current_user["id"], request.type, request.symbol, request.content, request.timeframe)
|
||
|
||
return {"message": "ok"}
|
||
|
||
@router.get("/analysis_histories")
|
||
async def get_analysis_histories(current_user: dict = Depends(get_current_user),
|
||
limit: int = 10,
|
||
offset: int = 0,
|
||
session: Session = Depends(get_db)):
|
||
manager = AnalysisHistoryManager(session)
|
||
history = manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset)
|
||
return history
|
||
|
||
|
||
class AnalysisRequest(BaseModel):
|
||
symbol: Optional[str] = None
|
||
timeframe: Optional[str] = None
|
||
stock_code: Optional[str] = None
|
||
type: str
|
||
|
||
class ChatRequest(BaseModel):
|
||
message: str
|
||
conversation_id: Optional[str] = None
|
||
|
||
class StopStreamingRequest(BaseModel):
|
||
task_id: str
|
||
|
||
@router.post("/stop_streaming")
|
||
async def stop_streaming(request: StopStreamingRequest,
|
||
current_user: dict = Depends(get_current_user),
|
||
session: Session = Depends(get_db)):
|
||
|
||
url = f'{api_url}/chat-messages/{request.task_id}/stop'
|
||
headers = {
|
||
'Authorization': f'Bearer {agent_token}',
|
||
'Content-Type': 'application/json'
|
||
}
|
||
payload = {
|
||
"user": current_user["mail"]
|
||
}
|
||
response = requests.post(url, headers=headers, json=payload)
|
||
|
||
if response.status_code != 200:
|
||
raise HTTPException(
|
||
status_code=response.status_code,
|
||
detail=f"Failed to stop streaming: {response.text}"
|
||
)
|
||
|
||
if response.json()["result"] == "success":
|
||
return {"message": "ok"}
|
||
else:
|
||
raise HTTPException(
|
||
status_code=response.status_code,
|
||
detail=f"Failed to stop streaming: {response.text}"
|
||
)
|
||
|
||
class RenameConversationRequest(BaseModel):
|
||
name: str
|
||
|
||
@router.post('/conversations/{conversation_id}/name')
|
||
async def rename_conversation(conversation_id: str,
|
||
request: RenameConversationRequest,
|
||
current_user: dict = Depends(get_current_user),
|
||
session: Session = Depends(get_db)):
|
||
|
||
url = f'{api_url}/conversations/{conversation_id}/name'
|
||
headers = {
|
||
'Authorization': f'Bearer {agent_token}',
|
||
'Content-Type': 'application/json'
|
||
}
|
||
payload = {
|
||
"name": request.name,
|
||
"user": current_user["mail"]
|
||
}
|
||
response = requests.post(url, headers=headers, json=payload)
|
||
if response.status_code != 200:
|
||
raise HTTPException(
|
||
status_code=response.status_code,
|
||
detail=f"Failed to rename conversation: {response.text}"
|
||
)
|
||
|
||
return response.json()
|
||
|
||
@router.delete('/conversations/{conversation_id}')
|
||
async def delete_conversation(conversation_id: str,
|
||
current_user: dict = Depends(get_current_user),
|
||
session: Session = Depends(get_db)):
|
||
|
||
url = f'{api_url}/conversations/{conversation_id}'
|
||
headers = {
|
||
'Authorization': f'Bearer {agent_token}',
|
||
'Content-Type': 'application/json'
|
||
}
|
||
|
||
payload = {
|
||
"user": current_user["mail"]
|
||
}
|
||
response = requests.delete(url, headers=headers, json=payload)
|
||
if response.status_code != 200:
|
||
raise HTTPException(
|
||
status_code=response.status_code,
|
||
detail=f"Failed to delete conversation: {response.text}"
|
||
)
|
||
|
||
return {"message" : "ok"}
|
||
|
||
@router.get('/conversations')
|
||
async def get_conversations(current_user: dict = Depends(get_current_user),
|
||
session: Session = Depends(get_db)):
|
||
|
||
url = f'{api_url}/conversations'
|
||
url = f'{url}?user={current_user["mail"]}&limit=5'
|
||
headers = {
|
||
'Authorization': f'Bearer {agent_token}',
|
||
'Content-Type': 'application/json'
|
||
}
|
||
|
||
response = requests.get(url, headers=headers)
|
||
|
||
if response.status_code != 200:
|
||
raise HTTPException(
|
||
status_code=response.status_code,
|
||
detail=f"Failed to get conversations from Dify API: {response.text}"
|
||
)
|
||
|
||
conversations = []
|
||
for conversation in response.json()["data"]:
|
||
conversations.append({
|
||
"id": conversation["id"],
|
||
"name": conversation["name"],
|
||
"created_at": conversation["created_at"]
|
||
})
|
||
|
||
return conversations
|
||
|
||
@router.get('/conversation_messages/{conversation_id}')
|
||
async def get_conversation_messages(conversation_id: str,
|
||
current_user: dict = Depends(get_current_user),
|
||
session: Session = Depends(get_db)):
|
||
|
||
url = f'{api_url}/messages'
|
||
headers = {
|
||
'Authorization': f'Bearer {agent_token}',
|
||
'Content-Type': 'application/json'
|
||
}
|
||
|
||
url = f'{url}?conversation_id={conversation_id}&user={current_user["mail"]}'
|
||
|
||
response = requests.get(url, headers=headers)
|
||
|
||
if response.status_code != 200:
|
||
raise HTTPException(
|
||
status_code=response.status_code,
|
||
detail=f"Failed to get messages from Dify API: {response.text}"
|
||
)
|
||
|
||
messages = []
|
||
for message in response.json()["data"]:
|
||
messages.append({
|
||
"id": message["id"],
|
||
"conversation_id": message["conversation_id"],
|
||
"query": message["query"],
|
||
"answer": message["answer"],
|
||
"created_at": message["created_at"]
|
||
})
|
||
|
||
return messages
|
||
|
||
@router.post("/chat-messages")
|
||
async def chat(request: ChatRequest,
|
||
current_user: dict = Depends(get_current_user),
|
||
session: Session = Depends(get_db)):
|
||
|
||
# 检查用户是否订阅
|
||
user_subscription_manager = UserSubscriptionManager(session)
|
||
user_subscription = user_subscription_manager.get_subscription_by_user_id(current_user["id"])
|
||
is_member = user_subscription and user_subscription["expire_time"] > datetime.now()
|
||
user_points = current_user["points"]
|
||
|
||
if not is_member and user_points < 1:
|
||
raise HTTPException(status_code=999, detail="你的免费次数不足,你可以订阅会员。")
|
||
|
||
payload = {
|
||
"inputs" : {},
|
||
"query": request.message,
|
||
"response_mode": "streaming",
|
||
"user": current_user["mail"],
|
||
}
|
||
|
||
if request.conversation_id:
|
||
payload["conversation_id"] = request.conversation_id
|
||
|
||
url = f'{api_url}/chat-messages'
|
||
headers = {
|
||
'Authorization': f'Bearer {agent_token}',
|
||
'Content-Type': 'application/json'
|
||
}
|
||
|
||
manager = UserQuestionManager(session)
|
||
manager.save_user_question(current_user["id"],"chat-messages", request.message)
|
||
|
||
response = requests.post(url, headers=headers, json=payload, stream=True)
|
||
|
||
# 如果响应不成功,返回错误
|
||
if response.status_code != 200:
|
||
raise HTTPException(
|
||
status_code=response.status_code,
|
||
detail=f"Failed to get response from Dify API: {response.text}"
|
||
)
|
||
|
||
# 扣除用户积分
|
||
if not is_member:
|
||
manager = UserManager(session)
|
||
manager.consume_user_points(current_user["id"], 1)
|
||
|
||
# 获取response的stream
|
||
def stream_response():
|
||
for chunk in response.iter_content(chunk_size=1024):
|
||
if chunk:
|
||
yield chunk
|
||
|
||
return StreamingResponse(stream_response(), media_type="text/plain")
|
||
|
||
@router.post("/analysis")
|
||
async def analysis(request: AnalysisRequest,
|
||
current_user: dict = Depends(get_current_user),
|
||
session: Session = Depends(get_db)):
|
||
|
||
if request.type == 'crypto':
|
||
# 检查symbol是否存在
|
||
manager = TokenManager(session)
|
||
tokens = manager.search_token(request.symbol)
|
||
if not tokens or len(tokens) == 0:
|
||
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
||
|
||
symbol = tokens[0]["symbol"]
|
||
token = 'app-BbaqIAMPi0ktgaV9IizMlc2N'
|
||
|
||
payload = {
|
||
"inputs" : {
|
||
"symbol" : symbol,
|
||
"timeframe" : request.timeframe
|
||
},
|
||
"response_mode": "streaming",
|
||
"user": current_user["mail"]
|
||
}
|
||
|
||
manager = UserQuestionManager(session)
|
||
manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
|
||
|
||
|
||
elif request.type == 'astock':
|
||
stock_code = request.stock_code
|
||
token = 'app-nWuCOa0YfQVtAosTY3Jr5vFV'
|
||
|
||
payload = {
|
||
"inputs" : {
|
||
"stock_code": stock_code
|
||
},
|
||
"response_mode": "streaming",
|
||
"user": current_user["mail"]
|
||
}
|
||
|
||
manager = UserQuestionManager(session)
|
||
manager.save_user_question(current_user["id"], stock_code, "请分析以下A股股票:" + stock_code + ",并给出分析报告。")
|
||
|
||
elif request.type == 'usstock':
|
||
stock_code = request.stock_code
|
||
token = 'app-gFjHuqwMEFzu7oNAMWAlZXBG'
|
||
|
||
payload = {
|
||
"inputs" : {
|
||
"stock": stock_code,
|
||
"start_date": (date.today() - timedelta(days=180)).strftime("%Y-%m-%d"),
|
||
"end_date": date.today().strftime("%Y-%m-%d")
|
||
},
|
||
"response_mode": "streaming",
|
||
"user": current_user["mail"]
|
||
}
|
||
manager = UserQuestionManager(session)
|
||
manager.save_user_question(current_user["id"], stock_code, "请分析以下美股股票:" + stock_code + ",并给出分析报告。")
|
||
|
||
else:
|
||
raise HTTPException(status_code=400, detail="不支持的类型")
|
||
|
||
|
||
url = f'{api_url}/workflows/run'
|
||
headers = {
|
||
'Authorization': f'Bearer {token}',
|
||
'Content-Type': 'application/json'
|
||
}
|
||
|
||
response = requests.post(url, headers=headers, json=payload, stream=True)
|
||
|
||
# 如果响应不成功,返回错误
|
||
if response.status_code != 200:
|
||
raise HTTPException(
|
||
status_code=response.status_code,
|
||
detail=f"Failed to get response from Dify API: {response.text}"
|
||
)
|
||
|
||
# 获取response的stream
|
||
def stream_response():
|
||
for chunk in response.iter_content(chunk_size=1024):
|
||
if chunk:
|
||
yield chunk
|
||
|
||
return StreamingResponse(stream_response(), media_type="text/plain") |