up
This commit is contained in:
parent
ea94081617
commit
59860c6191
@ -360,15 +360,6 @@ class CryptoAgent:
|
||||
if self.discord_bot:
|
||||
print(f"发送交易建议到Discord...")
|
||||
self.discord_bot.send_message(content=message)
|
||||
|
||||
# 保存交易建议到数据库
|
||||
try:
|
||||
saved = self.db_manager.agent_feed_manager.save_agent_feed(
|
||||
agent_name="Crypto Agent",
|
||||
content=message
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"保存交易建议到数据库时出错: {e}")
|
||||
|
||||
# 导出 DeepSeek API token 使用情况
|
||||
self._export_token_usage()
|
||||
|
||||
@ -6,9 +6,7 @@
|
||||
from cryptoai.models.base import Base
|
||||
from cryptoai.models.token import Token, TokenManager
|
||||
from cryptoai.models.analysis_result import AnalysisResult, AnalysisResultManager
|
||||
from cryptoai.models.agent_feed import AgentFeed, AgentFeedManager
|
||||
from cryptoai.models.user import User, UserManager
|
||||
from cryptoai.models.user_question import UserQuestion, UserQuestionManager
|
||||
from cryptoai.models.agent import Agent, AgentManager
|
||||
from cryptoai.models.astock import AStock, AStockManager
|
||||
from cryptoai.models.analysis_history import AnalysisHistory, AnalysisHistoryManager
|
||||
Binary file not shown.
@ -1,268 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Index
|
||||
from sqlalchemy.dialects.mysql import JSON
|
||||
|
||||
from cryptoai.models.base import Base, logger
|
||||
|
||||
# 定义Agent数据模型
|
||||
class Agent(Base):
|
||||
"""Agent数据表模型"""
|
||||
__tablename__ = 'agents'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
name = Column(String(100), nullable=False, unique=True, comment='Agent名称')
|
||||
hello_prompt = Column(Text, nullable=True, comment='欢迎提示语')
|
||||
description = Column(Text, nullable=True, comment='Agent描述')
|
||||
dify_token = Column(String(255), nullable=True, comment='Dify API令牌')
|
||||
inputs = Column(JSON, nullable=True, comment='输入参数JSON')
|
||||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||||
update_time = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
|
||||
|
||||
# 索引和表属性
|
||||
__table_args__ = (
|
||||
Index('idx_name', 'name'),
|
||||
Index('idx_create_time', 'create_time'),
|
||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||||
)
|
||||
|
||||
class AgentManager:
|
||||
"""Agent管理类"""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.session = db_session
|
||||
|
||||
def create_agent(self, name: str, description: Optional[str] = None,
|
||||
hello_prompt: Optional[str] = None, dify_token: Optional[str] = None,
|
||||
inputs: Optional[Dict[str, Any]] = None) -> Optional[int]:
|
||||
"""
|
||||
创建新的Agent
|
||||
|
||||
Args:
|
||||
name: Agent名称
|
||||
description: Agent描述
|
||||
hello_prompt: 欢迎提示语
|
||||
dify_token: Dify API令牌
|
||||
inputs: 输入参数JSON
|
||||
|
||||
Returns:
|
||||
新Agent的ID,如果创建失败则返回None
|
||||
"""
|
||||
try:
|
||||
# 检查名称是否已存在
|
||||
existing_agent = self.session.query(Agent).filter(Agent.name == name).first()
|
||||
if existing_agent:
|
||||
logger.warning(f"Agent名称 {name} 已存在")
|
||||
return None
|
||||
|
||||
# 创建新Agent
|
||||
new_agent = Agent(
|
||||
name=name,
|
||||
description=description,
|
||||
hello_prompt=hello_prompt,
|
||||
dify_token=dify_token,
|
||||
inputs=inputs,
|
||||
create_time=datetime.now(),
|
||||
update_time=datetime.now()
|
||||
)
|
||||
|
||||
# 添加并提交
|
||||
self.session.add(new_agent)
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功创建Agent: {name}")
|
||||
return new_agent.id
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"创建Agent失败: {e}")
|
||||
return None
|
||||
|
||||
def update_agent(self, agent_id: int, name: Optional[str] = None,
|
||||
description: Optional[str] = None, hello_prompt: Optional[str] = None,
|
||||
dify_token: Optional[str] = None, inputs: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
更新Agent信息
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
name: Agent名称
|
||||
description: Agent描述
|
||||
hello_prompt: 欢迎提示语
|
||||
dify_token: Dify API令牌
|
||||
inputs: 输入参数JSON
|
||||
|
||||
Returns:
|
||||
更新是否成功
|
||||
"""
|
||||
try:
|
||||
# 查询Agent
|
||||
agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
|
||||
if not agent:
|
||||
logger.warning(f"Agent ID {agent_id} 不存在")
|
||||
return False
|
||||
|
||||
# 更新字段
|
||||
if name is not None:
|
||||
# 检查名称是否与其他Agent冲突
|
||||
if name != agent.name:
|
||||
existing = self.session.query(Agent).filter(Agent.name == name).first()
|
||||
if existing:
|
||||
logger.warning(f"Agent名称 {name} 已被其他Agent使用")
|
||||
return False
|
||||
agent.name = name
|
||||
|
||||
if description is not None:
|
||||
agent.description = description
|
||||
|
||||
if hello_prompt is not None:
|
||||
agent.hello_prompt = hello_prompt
|
||||
|
||||
if dify_token is not None:
|
||||
agent.dify_token = dify_token
|
||||
|
||||
if inputs is not None:
|
||||
agent.inputs = inputs
|
||||
|
||||
agent.update_time = datetime.now()
|
||||
|
||||
# 提交更改
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功更新Agent ID: {agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"更新Agent失败: {e}")
|
||||
return False
|
||||
|
||||
def get_agent_by_id(self, agent_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
通过ID获取Agent信息
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
Agent信息,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
# 查询Agent
|
||||
agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
|
||||
|
||||
if agent:
|
||||
# 转换为字典
|
||||
return {
|
||||
'id': agent.id,
|
||||
'name': agent.name,
|
||||
'description': agent.description,
|
||||
'hello_prompt': agent.hello_prompt,
|
||||
'dify_token': agent.dify_token,
|
||||
'inputs': agent.inputs,
|
||||
'create_time': agent.create_time,
|
||||
'update_time': agent.update_time
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Agent信息失败: {e}")
|
||||
return None
|
||||
|
||||
def get_agent_by_name(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
通过名称获取Agent信息
|
||||
|
||||
Args:
|
||||
name: Agent名称
|
||||
|
||||
Returns:
|
||||
Agent信息,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
# 查询Agent
|
||||
agent = self.session.query(Agent).filter(Agent.name == name).first()
|
||||
|
||||
if agent:
|
||||
# 转换为字典
|
||||
return {
|
||||
'id': agent.id,
|
||||
'name': agent.name,
|
||||
'description': agent.description,
|
||||
'hello_prompt': agent.hello_prompt,
|
||||
'dify_token': agent.dify_token,
|
||||
'inputs': agent.inputs,
|
||||
'create_time': agent.create_time,
|
||||
'update_time': agent.update_time
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Agent信息失败: {e}")
|
||||
return None
|
||||
|
||||
def list_agents(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取Agent列表
|
||||
|
||||
Args:
|
||||
limit: 返回的最大数量
|
||||
skip: 跳过的数量
|
||||
|
||||
Returns:
|
||||
Agent列表
|
||||
"""
|
||||
try:
|
||||
# 查询Agent列表
|
||||
agents = self.session.query(Agent).order_by(Agent.create_time.asc()).offset(skip).limit(limit).all()
|
||||
|
||||
# 转换为字典列表
|
||||
return [{
|
||||
'id': agent.id,
|
||||
'name': agent.name,
|
||||
'description': agent.description,
|
||||
'hello_prompt': agent.hello_prompt,
|
||||
'dify_token': agent.dify_token,
|
||||
'inputs': agent.inputs,
|
||||
'create_time': agent.create_time,
|
||||
'update_time': agent.update_time
|
||||
} for agent in agents]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Agent列表失败: {e}")
|
||||
return []
|
||||
|
||||
def delete_agent(self, agent_id: int) -> bool:
|
||||
"""
|
||||
删除Agent
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
|
||||
Returns:
|
||||
删除是否成功
|
||||
"""
|
||||
try:
|
||||
# 查询Agent
|
||||
agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
|
||||
if not agent:
|
||||
logger.warning(f"Agent ID {agent_id} 不存在")
|
||||
return False
|
||||
|
||||
# 删除Agent
|
||||
self.session.delete(agent)
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功删除Agent ID: {agent_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"删除Agent失败: {e}")
|
||||
return False
|
||||
@ -1,105 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Index
|
||||
|
||||
from cryptoai.models.base import Base, logger
|
||||
|
||||
# 定义AI Agent信息流模型
|
||||
class AgentFeed(Base):
|
||||
"""AI Agent信息流表模型"""
|
||||
__tablename__ = 'agent_feeds'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
agent_name = Column(String(50), nullable=False, comment='AI Agent名称')
|
||||
avatar_url = Column(String(255), nullable=True, comment='头像URL')
|
||||
content = Column(Text, nullable=False, comment='内容')
|
||||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||||
|
||||
# 索引和表属性
|
||||
__table_args__ = (
|
||||
Index('idx_agent_name', 'agent_name'),
|
||||
Index('idx_create_time', 'create_time'),
|
||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||||
)
|
||||
|
||||
class AgentFeedManager:
|
||||
"""Agent信息流管理类"""
|
||||
|
||||
def __init__(self, db_session):
|
||||
self.session = db_session
|
||||
|
||||
def save_agent_feed(self, agent_name: str, content: str, avatar_url: Optional[str] = None) -> bool:
|
||||
"""
|
||||
保存AI Agent信息流到数据库
|
||||
|
||||
Args:
|
||||
agent_name: AI Agent名称
|
||||
content: 内容
|
||||
avatar_url: 头像URL,可选
|
||||
|
||||
Returns:
|
||||
保存是否成功
|
||||
"""
|
||||
try:
|
||||
# 创建新记录
|
||||
new_feed = AgentFeed(
|
||||
agent_name=agent_name,
|
||||
content=content,
|
||||
avatar_url=avatar_url
|
||||
)
|
||||
|
||||
# 添加并提交
|
||||
self.session.add(new_feed)
|
||||
self.session.commit()
|
||||
|
||||
logger.info(f"成功保存 {agent_name} 的信息流")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.session.rollback()
|
||||
logger.error(f"保存信息流失败: {e}")
|
||||
return False
|
||||
|
||||
def get_agent_feeds(self, agent_name: Optional[str] = None, limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取AI Agent信息流
|
||||
|
||||
Args:
|
||||
agent_name: 可选,指定获取特定Agent的信息流
|
||||
limit: 返回的最大记录数,默认20条
|
||||
skip: 跳过的记录数,默认0条
|
||||
|
||||
Returns:
|
||||
信息流列表,如果查询失败则返回空列表
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = self.session.query(AgentFeed)
|
||||
|
||||
# 如果指定了agent_name,则筛选
|
||||
if agent_name:
|
||||
query = query.filter(AgentFeed.agent_name == agent_name)
|
||||
|
||||
# 按创建时间降序排序并限制数量
|
||||
results = query.order_by(AgentFeed.create_time.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
# 转换为字典列表
|
||||
feeds = []
|
||||
for result in results:
|
||||
feeds.append({
|
||||
'id': result.id,
|
||||
'agent_name': result.agent_name,
|
||||
'avatar_url': result.avatar_url,
|
||||
'content': result.content,
|
||||
'create_time': result.create_time
|
||||
})
|
||||
|
||||
return feeds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取信息流失败: {e}")
|
||||
return []
|
||||
@ -81,89 +81,4 @@ async def get_stock_data(stock_code: str, start_date: Optional[str] = None, end_
|
||||
logger.error(f"获取股票数据失败: {e}")
|
||||
return {}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/stock/data/all", summary="获取所有股票数据")
|
||||
async def get_stock_data_all(stock_code: str):
|
||||
|
||||
result = {}
|
||||
|
||||
|
||||
try:
|
||||
api = AStockAPI()
|
||||
|
||||
#获取股本信息
|
||||
stock_shares = api.get_stock_shares(stock_code)
|
||||
result["stock_shares"] = json.loads(stock_shares.to_json(orient="records"))
|
||||
|
||||
# 获取概念板块
|
||||
concept_east = api.get_concept_east(stock_code)
|
||||
result["concept_east"] = json.loads(concept_east.to_json(orient="records"))
|
||||
|
||||
# 获取板块
|
||||
plate_east = api.get_plate_east(stock_code)
|
||||
result["plate_east"] = json.loads(plate_east.to_json(orient="records"))
|
||||
|
||||
# 获取市场数据
|
||||
market_data = api.get_market_data(stock_code)
|
||||
result["market_data"] = json.loads(market_data.to_json(orient="records"))
|
||||
|
||||
# 获取分钟线数据
|
||||
min_data = api.get_market_min_data(stock_code)
|
||||
result["min_data"] = json.loads(min_data.to_json(orient="records"))
|
||||
|
||||
# 获取资金流向数据
|
||||
flow_data = api.get_capital_flow(stock_code)
|
||||
result["flow_data"] = json.loads(flow_data.to_json(orient="records"))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取股票数据失败: {e}")
|
||||
return {}
|
||||
|
||||
return result
|
||||
|
||||
@router.post('/{stock_code}/analysis', summary="获取股票分析数据")
|
||||
async def get_stock_analysis(stock_code: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
|
||||
|
||||
url = 'https://mate.aimateplus.com/v1/workflows/run'
|
||||
token = 'app-nWuCOa0YfQVtAosTY3Jr5vFV'
|
||||
headers = {
|
||||
'Authorization': f'Bearer {token}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
data = {
|
||||
"inputs" : {
|
||||
"stock_code" : stock_code
|
||||
},
|
||||
"response_mode": "streaming",
|
||||
"user": current_user["mail"]
|
||||
}
|
||||
|
||||
# 保存用户提问
|
||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], stock_code, "请分析以下股票:" + stock_code + ",并给出分析报告。")
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
|
||||
# 如果响应不成功,返回错误
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Failed to get response from Dify API: {response.text}"
|
||||
)
|
||||
|
||||
# 获取response的stream
|
||||
def stream_response():
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(stream_response(), media_type="text/plain")
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
return result
|
||||
@ -1,126 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
API路由模块,为前端提供REST API接口
|
||||
"""
|
||||
|
||||
import os
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
import logging
|
||||
|
||||
from cryptoai.api.deepseek_api import DeepSeekAPI
|
||||
from cryptoai.utils.config_loader import ConfigLoader
|
||||
from fastapi.responses import StreamingResponse
|
||||
from cryptoai.routes.user import get_current_user
|
||||
import requests
|
||||
from datetime import datetime
|
||||
from cryptoai.utils.db_manager import get_db_manager
|
||||
|
||||
# 创建路由
|
||||
router = APIRouter()
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
user_prompt: str
|
||||
agent_id: int
|
||||
conversation_id: Optional[str] = None
|
||||
|
||||
class AgentCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
hello_prompt: Optional[str] = None
|
||||
dify_token: Optional[str] = None
|
||||
inputs: Optional[Dict[str, Any]] = None
|
||||
|
||||
class AgentUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
hello_prompt: Optional[str] = None
|
||||
dify_token: Optional[str] = None
|
||||
inputs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@router.post("/create")
|
||||
async def create_agent(agent: AgentCreate):
|
||||
"""
|
||||
创建新的AI Agent
|
||||
"""
|
||||
|
||||
agent = get_db_manager().agent_manager.create_agent(agent.name, agent.description, agent.hello_prompt, agent.dify_token, agent.inputs)
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
@router.get("/list", response_model=List[Dict[str, Any]])
|
||||
async def get_agents(
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000)
|
||||
):
|
||||
"""
|
||||
获取所有代理
|
||||
"""
|
||||
|
||||
# 从数据库获取Agent列表
|
||||
agents = get_db_manager().agent_manager.list_agents(limit=limit, skip=skip)
|
||||
return agents
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
async def chat(request: ChatRequest, current_user: Dict[str, Any] = Depends(get_current_user)):
|
||||
"""
|
||||
聊天接口
|
||||
"""
|
||||
# 尝试从数据库获取Agent
|
||||
try:
|
||||
agent_id = int(request.agent_id)
|
||||
agent = get_db_manager().agent_manager.get_agent_by_id(agent_id)
|
||||
|
||||
if not agent:
|
||||
raise HTTPException(status_code=400, detail="Invalid agent ID")
|
||||
|
||||
token = agent.get("dify_token")
|
||||
inputs = agent.get("inputs") or {}
|
||||
inputs["current_date"] = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid agent ID format")
|
||||
|
||||
url = "https://mate.aimateplus.com/v1/chat-messages"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"inputs": inputs,
|
||||
"query": request.user_prompt,
|
||||
"response_mode": "streaming",
|
||||
"user": current_user["mail"]
|
||||
}
|
||||
|
||||
if request.conversation_id:
|
||||
data["conversation_id"] = request.conversation_id
|
||||
|
||||
logging.info(f"Chat request data: {data}")
|
||||
|
||||
# 保存用户提问
|
||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], request.agent_id, request.user_prompt)
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
|
||||
# 如果响应不成功,返回错误
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Failed to get response from Dify API: {response.text}"
|
||||
)
|
||||
|
||||
# 获取response的stream
|
||||
def stream_response():
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(stream_response(), media_type="text/plain")
|
||||
@ -4,6 +4,9 @@ from cryptoai.utils.db_manager import get_db_manager
|
||||
from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
from cryptoai.routes.user import get_current_user
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
import requests
|
||||
|
||||
|
||||
class AnalysisHistoryRequest(BaseModel):
|
||||
@ -27,4 +30,73 @@ async def get_analysis_histories(current_user: dict = Depends(get_current_user),
|
||||
limit: int = 10,
|
||||
offset: int = 0):
|
||||
history = get_db_manager().analysis_history_manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset)
|
||||
return history
|
||||
return history
|
||||
|
||||
|
||||
class AnalysisRequest(BaseModel):
|
||||
symbol: Optional[str] = None
|
||||
timeframe: Optional[str] = None
|
||||
stock_code: Optional[str] = None
|
||||
type: str
|
||||
|
||||
@router.post("/analysis")
|
||||
async def analysis(request: AnalysisRequest,
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
|
||||
if request.type == 'crypto':
|
||||
# 检查symbol是否存在
|
||||
tokens = get_db_manager().token_manager.search_token(request.symbol)
|
||||
if not tokens or len(tokens) == 0:
|
||||
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
||||
|
||||
symbol = tokens[0]["symbol"]
|
||||
token = 'app-BbaqIAMPi0ktgaV9IizMlc2N'
|
||||
|
||||
payload = {
|
||||
"inputs" : {
|
||||
"symbol" : symbol,
|
||||
"timeframe" : request.timeframe
|
||||
},
|
||||
"response_mode": "streaming",
|
||||
"user": current_user["mail"]
|
||||
}
|
||||
|
||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
|
||||
|
||||
|
||||
else:
|
||||
stock_code = request.stock_code
|
||||
token = 'app-nWuCOa0YfQVtAosTY3Jr5vFV'
|
||||
|
||||
payload = {
|
||||
"inputs" : {
|
||||
"stock_code": stock_code
|
||||
},
|
||||
"response_mode": "streaming",
|
||||
"user": current_user["mail"]
|
||||
}
|
||||
|
||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], stock_code, "请分析以下股票:" + stock_code + ",并给出分析报告。")
|
||||
|
||||
url = 'https://mate.aimateplus.com/v1/workflows/run'
|
||||
headers = {
|
||||
'Authorization': f'Bearer {token}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload, stream=True)
|
||||
|
||||
# 如果响应不成功,返回错误
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Failed to get response from Dify API: {response.text}"
|
||||
)
|
||||
|
||||
# 获取response的stream
|
||||
def stream_response():
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(stream_response(), media_type="text/plain")
|
||||
@ -51,110 +51,4 @@ async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit:
|
||||
else:
|
||||
result[timeframe] = binance_api.get_historical_klines(symbol=symbol, interval=timeframe, limit=limit).to_dict(orient="records")
|
||||
|
||||
return result
|
||||
|
||||
@router.post("/analysis_v2")
|
||||
async def analysis_crypto_v2(request: CryptoAnalysisRequest,
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
# 检查symbol是否存在
|
||||
tokens = get_db_manager().token_manager.search_token(request.symbol)
|
||||
if not tokens or len(tokens) == 0:
|
||||
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
||||
|
||||
symbol = tokens[0]["symbol"]
|
||||
|
||||
url = 'https://mate.aimateplus.com/v1/workflows/run'
|
||||
token = 'app-BbaqIAMPi0ktgaV9IizMlc2N'
|
||||
headers = {
|
||||
'Authorization': f'Bearer {token}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
data = {
|
||||
"inputs" : {
|
||||
"symbol" : symbol,
|
||||
"timeframe" : request.timeframe
|
||||
},
|
||||
"response_mode": "streaming",
|
||||
"user": current_user["mail"]
|
||||
}
|
||||
|
||||
# 保存用户提问
|
||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
|
||||
# 如果响应不成功,返回错误
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Failed to get response from Dify API: {response.text}"
|
||||
)
|
||||
|
||||
# 获取response的stream
|
||||
def stream_response():
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(stream_response(), media_type="text/plain")
|
||||
|
||||
@router.post("/analysis")
|
||||
async def analysis_crypto(request: CryptoAnalysisRequest,
|
||||
current_user: dict = Depends(get_current_user)):
|
||||
|
||||
# 尝试从数据库获取Agent
|
||||
try:
|
||||
agent_id = 1
|
||||
if request.timeframe:
|
||||
user_prompt = f"请分析以下加密货币:{request.symbol},并给出 {request.timeframe} 级别的分析报告。"
|
||||
else:
|
||||
user_prompt = f"请分析以下加密货币:{request.symbol},并给出分析报告。"
|
||||
|
||||
agent = get_db_manager().agent_manager.get_agent_by_id(agent_id)
|
||||
|
||||
if not agent:
|
||||
raise HTTPException(status_code=400, detail="Invalid agent ID")
|
||||
|
||||
token = agent.get("dify_token")
|
||||
inputs = agent.get("inputs") or {}
|
||||
inputs["current_date"] = datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail="Invalid agent ID format")
|
||||
|
||||
url = "https://mate.aimateplus.com/v1/chat-messages"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"inputs": inputs,
|
||||
"query": user_prompt,
|
||||
"response_mode": "streaming",
|
||||
"user": current_user["mail"]
|
||||
}
|
||||
|
||||
logging.info(f"Chat request data: {data}")
|
||||
|
||||
# 保存用户提问
|
||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], agent_id, user_prompt)
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
|
||||
# 如果响应不成功,返回错误
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Failed to get response from Dify API: {response.text}"
|
||||
)
|
||||
|
||||
# 获取response的stream
|
||||
def stream_response():
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(stream_response(), media_type="text/plain")
|
||||
|
||||
|
||||
return result
|
||||
@ -15,11 +15,8 @@ from fastapi.responses import JSONResponse
|
||||
import time
|
||||
from typing import Dict, Any
|
||||
|
||||
from cryptoai.routes.agent import router as agent_router
|
||||
from cryptoai.routes.feed import router as feed_router
|
||||
from cryptoai.routes.user import router as user_router
|
||||
from cryptoai.routes.adata import router as adata_router
|
||||
from cryptoai.routes.question import router as question_router
|
||||
from cryptoai.routes.crypto import router as crypto_router
|
||||
from cryptoai.routes.platform import router as platform_router
|
||||
from cryptoai.routes.analysis import router as analysis_router
|
||||
@ -52,11 +49,8 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
# 添加API路由
|
||||
app.include_router(agent_router, prefix="/agent")
|
||||
app.include_router(platform_router, prefix="/platform", tags=["平台信息"])
|
||||
app.include_router(feed_router, prefix="/feed", tags=["AI Agent信息流"])
|
||||
app.include_router(user_router, prefix="/user", tags=["用户管理"])
|
||||
app.include_router(question_router, prefix="/question", tags=["用户提问"])
|
||||
app.include_router(adata_router, prefix="/adata", tags=["A股数据"])
|
||||
app.include_router(crypto_router, prefix="/crypto", tags=["加密货币数据"])
|
||||
app.include_router(analysis_router, prefix="/analysis", tags=["分析历史"])
|
||||
|
||||
@ -1,110 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
AI Agent信息流API路由模块,提供信息流的增删改查功能
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException, status, Query, Depends
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from cryptoai.routes.user import get_current_user
|
||||
from cryptoai.utils.db_manager import get_db_manager
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger("feed_router")
|
||||
|
||||
# 创建路由
|
||||
router = APIRouter()
|
||||
|
||||
# 请求模型
|
||||
class AgentFeedCreate(BaseModel):
|
||||
"""创建信息流请求模型"""
|
||||
agent_name: str
|
||||
content: str
|
||||
avatar_url: Optional[str] = None
|
||||
|
||||
# 响应模型
|
||||
class AgentFeedResponse(BaseModel):
|
||||
"""信息流响应模型"""
|
||||
id: int
|
||||
agent_name: str
|
||||
avatar_url: Optional[str] = None
|
||||
content: str
|
||||
create_time: datetime
|
||||
|
||||
@router.post("", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
|
||||
async def create_feed(feed: AgentFeedCreate) -> Dict[str, Any]:
|
||||
"""
|
||||
创建新的AI Agent信息流
|
||||
|
||||
Args:
|
||||
feed: 信息流创建请求
|
||||
|
||||
Returns:
|
||||
创建成功的状态信息
|
||||
"""
|
||||
try:
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 保存信息流
|
||||
success = db_manager.agent_feed_manager.save_agent_feed(
|
||||
agent_name=feed.agent_name,
|
||||
content=feed.content,
|
||||
avatar_url=feed.avatar_url
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="保存信息流失败"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "信息流创建成功"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建信息流失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"创建信息流失败: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("", response_model=List[AgentFeedResponse])
|
||||
async def get_feeds(
|
||||
agent_name: Optional[str] = Query(None, description="AI Agent名称,可选"),
|
||||
limit: int = Query(20, description="返回的最大记录数,默认20条"),
|
||||
skip: int = Query(0, description="跳过的记录数,默认0条"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> List[AgentFeedResponse]:
|
||||
"""
|
||||
获取AI Agent信息流列表
|
||||
|
||||
Args:
|
||||
agent_name: 可选,指定获取特定Agent的信息流
|
||||
limit: 返回的最大记录数,默认20条
|
||||
|
||||
Returns:
|
||||
信息流列表
|
||||
"""
|
||||
try:
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 获取信息流
|
||||
feeds = db_manager.agent_feed_manager.get_agent_feeds(agent_name=agent_name, limit=limit, skip=skip)
|
||||
|
||||
return feeds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取信息流失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取信息流失败: {str(e)}"
|
||||
)
|
||||
@ -1,166 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
用户提问API路由模块,提供用户提问数据的增删改查功能
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException, status, Depends, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from cryptoai.utils.db_manager import get_db_manager
|
||||
from cryptoai.routes.user import get_current_user
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger("question_router")
|
||||
|
||||
# 创建路由
|
||||
router = APIRouter()
|
||||
|
||||
# 请求模型
|
||||
class QuestionCreate(BaseModel):
|
||||
"""创建提问请求模型"""
|
||||
agent_id: str
|
||||
question: str
|
||||
|
||||
# 响应模型
|
||||
class QuestionResponse(BaseModel):
|
||||
"""提问响应模型"""
|
||||
id: int
|
||||
user_id: int
|
||||
agent_id: str
|
||||
question: str
|
||||
create_time: datetime
|
||||
|
||||
@router.post("/", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
|
||||
async def create_question(
|
||||
question: QuestionCreate,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
创建新的用户提问记录
|
||||
|
||||
Args:
|
||||
question: 提问创建请求
|
||||
current_user: 当前用户信息,由依赖项提供
|
||||
|
||||
Returns:
|
||||
创建成功的状态信息
|
||||
"""
|
||||
try:
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 保存提问
|
||||
success = db_manager.user_question_manager.save_user_question(
|
||||
user_id=current_user["id"],
|
||||
agent_id=question.agent_id,
|
||||
question=question.question
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="保存提问失败"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "提问记录创建成功"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建提问记录失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"创建提问记录失败: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/", response_model=List[QuestionResponse])
|
||||
async def get_questions(
|
||||
agent_id: Optional[str] = Query(None, description="AI Agent ID,可选"),
|
||||
limit: int = Query(20, description="返回的最大记录数,默认20条"),
|
||||
skip: int = Query(0, description="跳过的记录数,默认0条"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> List[QuestionResponse]:
|
||||
"""
|
||||
获取用户提问记录列表
|
||||
|
||||
Args:
|
||||
agent_id: 可选,指定获取特定Agent的提问
|
||||
limit: 返回的最大记录数,默认20条
|
||||
skip: 跳过的记录数,默认0条
|
||||
current_user: 当前用户信息,由依赖项提供
|
||||
|
||||
Returns:
|
||||
提问记录列表
|
||||
"""
|
||||
try:
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 获取提问记录
|
||||
questions = db_manager.user_question_manager.get_user_questions(
|
||||
user_id=current_user["id"],
|
||||
agent_id=agent_id,
|
||||
limit=limit,
|
||||
skip=skip
|
||||
)
|
||||
|
||||
return questions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取提问记录失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取提问记录失败: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/{question_id}", response_model=QuestionResponse)
|
||||
async def get_question(
|
||||
question_id: int,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> QuestionResponse:
|
||||
"""
|
||||
获取特定提问记录
|
||||
|
||||
Args:
|
||||
question_id: 提问ID
|
||||
current_user: 当前用户信息,由依赖项提供
|
||||
|
||||
Returns:
|
||||
提问记录
|
||||
"""
|
||||
try:
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 获取提问记录
|
||||
question = db_manager.user_question_manager.get_user_question_by_id(question_id)
|
||||
|
||||
if not question:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"提问记录 {question_id} 不存在"
|
||||
)
|
||||
|
||||
# 检查是否是当前用户的提问
|
||||
if question["user_id"] != current_user["id"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="没有权限查看此提问记录"
|
||||
)
|
||||
|
||||
return question
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取提问记录失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取提问记录失败: {str(e)}"
|
||||
)
|
||||
@ -29,7 +29,7 @@ router = APIRouter()
|
||||
# JWT配置
|
||||
JWT_SECRET_KEY = "FX9Rf7YpvTRdXWmj0Osx8P9smSkUh6fW" # 实际应用中应该从环境变量或配置文件中获取
|
||||
JWT_ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 0 # 用户登录后不过期
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 180 * 24 * 60 * 60 # 180天
|
||||
|
||||
# 请求模型
|
||||
class UserRegister(BaseModel):
|
||||
@ -82,7 +82,7 @@ def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None) -
|
||||
if expires_delta:
|
||||
expire = datetime.now() + expires_delta
|
||||
else:
|
||||
expire = datetime.now() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
expire = datetime.now() + timedelta(days=180)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||
return encoded_jwt
|
||||
@ -291,15 +291,12 @@ async def login(loginData: UserLogin) -> TokenResponse:
|
||||
session.close()
|
||||
|
||||
# 创建访问令牌,不过期
|
||||
access_token_expires = None
|
||||
access_token = create_access_token(
|
||||
data={"sub": user["mail"]}, expires_delta=access_token_expires
|
||||
)
|
||||
access_token = create_access_token(data={"sub": user["mail"]})
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
token_type="bearer",
|
||||
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||||
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||
user_info=UserResponse(
|
||||
id=user["id"],
|
||||
mail=user["mail"],
|
||||
|
||||
@ -14,10 +14,8 @@ from cryptoai.utils.config_loader import ConfigLoader
|
||||
from cryptoai.models.base import Base
|
||||
from cryptoai.models.token import TokenManager
|
||||
from cryptoai.models.analysis_result import AnalysisResultManager
|
||||
from cryptoai.models.agent_feed import AgentFeedManager
|
||||
from cryptoai.models.user import UserManager
|
||||
from cryptoai.models.user_question import UserQuestionManager
|
||||
from cryptoai.models.agent import AgentManager
|
||||
from cryptoai.models.astock import AStockManager
|
||||
from cryptoai.models.analysis_history import AnalysisHistoryManager
|
||||
|
||||
@ -105,10 +103,8 @@ class DBManager:
|
||||
# 初始化各个模型的管理器
|
||||
self.token_manager = TokenManager(session)
|
||||
self.analysis_result_manager = AnalysisResultManager(session)
|
||||
self.agent_feed_manager = AgentFeedManager(session)
|
||||
self.user_manager = UserManager(session)
|
||||
self.user_question_manager = UserQuestionManager(session)
|
||||
self.agent_manager = AgentManager(session)
|
||||
self.astock_manager = AStockManager(session)
|
||||
self.analysis_history_manager = AnalysisHistoryManager(session)
|
||||
|
||||
|
||||
@ -29,7 +29,7 @@ services:
|
||||
cryptoai-api:
|
||||
build: .
|
||||
container_name: cryptoai-api
|
||||
image: cryptoai-api:0.1.27
|
||||
image: cryptoai-api:0.1.28
|
||||
restart: always
|
||||
ports:
|
||||
- "8000:8000"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user