up
This commit is contained in:
parent
ea94081617
commit
59860c6191
@ -360,15 +360,6 @@ class CryptoAgent:
|
|||||||
if self.discord_bot:
|
if self.discord_bot:
|
||||||
print(f"发送交易建议到Discord...")
|
print(f"发送交易建议到Discord...")
|
||||||
self.discord_bot.send_message(content=message)
|
self.discord_bot.send_message(content=message)
|
||||||
|
|
||||||
# 保存交易建议到数据库
|
|
||||||
try:
|
|
||||||
saved = self.db_manager.agent_feed_manager.save_agent_feed(
|
|
||||||
agent_name="Crypto Agent",
|
|
||||||
content=message
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"保存交易建议到数据库时出错: {e}")
|
|
||||||
|
|
||||||
# 导出 DeepSeek API token 使用情况
|
# 导出 DeepSeek API token 使用情况
|
||||||
self._export_token_usage()
|
self._export_token_usage()
|
||||||
|
|||||||
@ -6,9 +6,7 @@
|
|||||||
from cryptoai.models.base import Base
|
from cryptoai.models.base import Base
|
||||||
from cryptoai.models.token import Token, TokenManager
|
from cryptoai.models.token import Token, TokenManager
|
||||||
from cryptoai.models.analysis_result import AnalysisResult, AnalysisResultManager
|
from cryptoai.models.analysis_result import AnalysisResult, AnalysisResultManager
|
||||||
from cryptoai.models.agent_feed import AgentFeed, AgentFeedManager
|
|
||||||
from cryptoai.models.user import User, UserManager
|
from cryptoai.models.user import User, UserManager
|
||||||
from cryptoai.models.user_question import UserQuestion, UserQuestionManager
|
from cryptoai.models.user_question import UserQuestion, UserQuestionManager
|
||||||
from cryptoai.models.agent import Agent, AgentManager
|
|
||||||
from cryptoai.models.astock import AStock, AStockManager
|
from cryptoai.models.astock import AStock, AStockManager
|
||||||
from cryptoai.models.analysis_history import AnalysisHistory, AnalysisHistoryManager
|
from cryptoai.models.analysis_history import AnalysisHistory, AnalysisHistoryManager
|
||||||
Binary file not shown.
@ -1,268 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
from typing import Dict, Any, List, Optional
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Index
|
|
||||||
from sqlalchemy.dialects.mysql import JSON
|
|
||||||
|
|
||||||
from cryptoai.models.base import Base, logger
|
|
||||||
|
|
||||||
# 定义Agent数据模型
|
|
||||||
class Agent(Base):
|
|
||||||
"""Agent数据表模型"""
|
|
||||||
__tablename__ = 'agents'
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
name = Column(String(100), nullable=False, unique=True, comment='Agent名称')
|
|
||||||
hello_prompt = Column(Text, nullable=True, comment='欢迎提示语')
|
|
||||||
description = Column(Text, nullable=True, comment='Agent描述')
|
|
||||||
dify_token = Column(String(255), nullable=True, comment='Dify API令牌')
|
|
||||||
inputs = Column(JSON, nullable=True, comment='输入参数JSON')
|
|
||||||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
|
||||||
update_time = Column(DateTime, nullable=False, default=datetime.now, onupdate=datetime.now, comment='更新时间')
|
|
||||||
|
|
||||||
# 索引和表属性
|
|
||||||
__table_args__ = (
|
|
||||||
Index('idx_name', 'name'),
|
|
||||||
Index('idx_create_time', 'create_time'),
|
|
||||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
|
||||||
)
|
|
||||||
|
|
||||||
class AgentManager:
|
|
||||||
"""Agent管理类"""
|
|
||||||
|
|
||||||
def __init__(self, db_session):
|
|
||||||
self.session = db_session
|
|
||||||
|
|
||||||
def create_agent(self, name: str, description: Optional[str] = None,
|
|
||||||
hello_prompt: Optional[str] = None, dify_token: Optional[str] = None,
|
|
||||||
inputs: Optional[Dict[str, Any]] = None) -> Optional[int]:
|
|
||||||
"""
|
|
||||||
创建新的Agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Agent名称
|
|
||||||
description: Agent描述
|
|
||||||
hello_prompt: 欢迎提示语
|
|
||||||
dify_token: Dify API令牌
|
|
||||||
inputs: 输入参数JSON
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
新Agent的ID,如果创建失败则返回None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 检查名称是否已存在
|
|
||||||
existing_agent = self.session.query(Agent).filter(Agent.name == name).first()
|
|
||||||
if existing_agent:
|
|
||||||
logger.warning(f"Agent名称 {name} 已存在")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 创建新Agent
|
|
||||||
new_agent = Agent(
|
|
||||||
name=name,
|
|
||||||
description=description,
|
|
||||||
hello_prompt=hello_prompt,
|
|
||||||
dify_token=dify_token,
|
|
||||||
inputs=inputs,
|
|
||||||
create_time=datetime.now(),
|
|
||||||
update_time=datetime.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加并提交
|
|
||||||
self.session.add(new_agent)
|
|
||||||
self.session.commit()
|
|
||||||
|
|
||||||
logger.info(f"成功创建Agent: {name}")
|
|
||||||
return new_agent.id
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.session.rollback()
|
|
||||||
logger.error(f"创建Agent失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def update_agent(self, agent_id: int, name: Optional[str] = None,
|
|
||||||
description: Optional[str] = None, hello_prompt: Optional[str] = None,
|
|
||||||
dify_token: Optional[str] = None, inputs: Optional[Dict[str, Any]] = None) -> bool:
|
|
||||||
"""
|
|
||||||
更新Agent信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_id: Agent ID
|
|
||||||
name: Agent名称
|
|
||||||
description: Agent描述
|
|
||||||
hello_prompt: 欢迎提示语
|
|
||||||
dify_token: Dify API令牌
|
|
||||||
inputs: 输入参数JSON
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
更新是否成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 查询Agent
|
|
||||||
agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
|
|
||||||
if not agent:
|
|
||||||
logger.warning(f"Agent ID {agent_id} 不存在")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 更新字段
|
|
||||||
if name is not None:
|
|
||||||
# 检查名称是否与其他Agent冲突
|
|
||||||
if name != agent.name:
|
|
||||||
existing = self.session.query(Agent).filter(Agent.name == name).first()
|
|
||||||
if existing:
|
|
||||||
logger.warning(f"Agent名称 {name} 已被其他Agent使用")
|
|
||||||
return False
|
|
||||||
agent.name = name
|
|
||||||
|
|
||||||
if description is not None:
|
|
||||||
agent.description = description
|
|
||||||
|
|
||||||
if hello_prompt is not None:
|
|
||||||
agent.hello_prompt = hello_prompt
|
|
||||||
|
|
||||||
if dify_token is not None:
|
|
||||||
agent.dify_token = dify_token
|
|
||||||
|
|
||||||
if inputs is not None:
|
|
||||||
agent.inputs = inputs
|
|
||||||
|
|
||||||
agent.update_time = datetime.now()
|
|
||||||
|
|
||||||
# 提交更改
|
|
||||||
self.session.commit()
|
|
||||||
|
|
||||||
logger.info(f"成功更新Agent ID: {agent_id}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.session.rollback()
|
|
||||||
logger.error(f"更新Agent失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_agent_by_id(self, agent_id: int) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
通过ID获取Agent信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_id: Agent ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Agent信息,如果不存在则返回None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 查询Agent
|
|
||||||
agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
|
|
||||||
|
|
||||||
if agent:
|
|
||||||
# 转换为字典
|
|
||||||
return {
|
|
||||||
'id': agent.id,
|
|
||||||
'name': agent.name,
|
|
||||||
'description': agent.description,
|
|
||||||
'hello_prompt': agent.hello_prompt,
|
|
||||||
'dify_token': agent.dify_token,
|
|
||||||
'inputs': agent.inputs,
|
|
||||||
'create_time': agent.create_time,
|
|
||||||
'update_time': agent.update_time
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取Agent信息失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_agent_by_name(self, name: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
通过名称获取Agent信息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Agent名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Agent信息,如果不存在则返回None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 查询Agent
|
|
||||||
agent = self.session.query(Agent).filter(Agent.name == name).first()
|
|
||||||
|
|
||||||
if agent:
|
|
||||||
# 转换为字典
|
|
||||||
return {
|
|
||||||
'id': agent.id,
|
|
||||||
'name': agent.name,
|
|
||||||
'description': agent.description,
|
|
||||||
'hello_prompt': agent.hello_prompt,
|
|
||||||
'dify_token': agent.dify_token,
|
|
||||||
'inputs': agent.inputs,
|
|
||||||
'create_time': agent.create_time,
|
|
||||||
'update_time': agent.update_time
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取Agent信息失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def list_agents(self, limit: int = 100, skip: int = 0) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
获取Agent列表
|
|
||||||
|
|
||||||
Args:
|
|
||||||
limit: 返回的最大数量
|
|
||||||
skip: 跳过的数量
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Agent列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 查询Agent列表
|
|
||||||
agents = self.session.query(Agent).order_by(Agent.create_time.asc()).offset(skip).limit(limit).all()
|
|
||||||
|
|
||||||
# 转换为字典列表
|
|
||||||
return [{
|
|
||||||
'id': agent.id,
|
|
||||||
'name': agent.name,
|
|
||||||
'description': agent.description,
|
|
||||||
'hello_prompt': agent.hello_prompt,
|
|
||||||
'dify_token': agent.dify_token,
|
|
||||||
'inputs': agent.inputs,
|
|
||||||
'create_time': agent.create_time,
|
|
||||||
'update_time': agent.update_time
|
|
||||||
} for agent in agents]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取Agent列表失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def delete_agent(self, agent_id: int) -> bool:
|
|
||||||
"""
|
|
||||||
删除Agent
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_id: Agent ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
删除是否成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 查询Agent
|
|
||||||
agent = self.session.query(Agent).filter(Agent.id == agent_id).first()
|
|
||||||
if not agent:
|
|
||||||
logger.warning(f"Agent ID {agent_id} 不存在")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 删除Agent
|
|
||||||
self.session.delete(agent)
|
|
||||||
self.session.commit()
|
|
||||||
|
|
||||||
logger.info(f"成功删除Agent ID: {agent_id}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.session.rollback()
|
|
||||||
logger.error(f"删除Agent失败: {e}")
|
|
||||||
return False
|
|
||||||
@ -1,105 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
from typing import Dict, Any, List, Optional
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from sqlalchemy import Column, Integer, String, Text, DateTime, Index
|
|
||||||
|
|
||||||
from cryptoai.models.base import Base, logger
|
|
||||||
|
|
||||||
# 定义AI Agent信息流模型
|
|
||||||
class AgentFeed(Base):
|
|
||||||
"""AI Agent信息流表模型"""
|
|
||||||
__tablename__ = 'agent_feeds'
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
agent_name = Column(String(50), nullable=False, comment='AI Agent名称')
|
|
||||||
avatar_url = Column(String(255), nullable=True, comment='头像URL')
|
|
||||||
content = Column(Text, nullable=False, comment='内容')
|
|
||||||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
|
||||||
|
|
||||||
# 索引和表属性
|
|
||||||
__table_args__ = (
|
|
||||||
Index('idx_agent_name', 'agent_name'),
|
|
||||||
Index('idx_create_time', 'create_time'),
|
|
||||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
|
||||||
)
|
|
||||||
|
|
||||||
class AgentFeedManager:
|
|
||||||
"""Agent信息流管理类"""
|
|
||||||
|
|
||||||
def __init__(self, db_session):
|
|
||||||
self.session = db_session
|
|
||||||
|
|
||||||
def save_agent_feed(self, agent_name: str, content: str, avatar_url: Optional[str] = None) -> bool:
|
|
||||||
"""
|
|
||||||
保存AI Agent信息流到数据库
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_name: AI Agent名称
|
|
||||||
content: 内容
|
|
||||||
avatar_url: 头像URL,可选
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
保存是否成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 创建新记录
|
|
||||||
new_feed = AgentFeed(
|
|
||||||
agent_name=agent_name,
|
|
||||||
content=content,
|
|
||||||
avatar_url=avatar_url
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加并提交
|
|
||||||
self.session.add(new_feed)
|
|
||||||
self.session.commit()
|
|
||||||
|
|
||||||
logger.info(f"成功保存 {agent_name} 的信息流")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.session.rollback()
|
|
||||||
logger.error(f"保存信息流失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_agent_feeds(self, agent_name: Optional[str] = None, limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
获取AI Agent信息流
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_name: 可选,指定获取特定Agent的信息流
|
|
||||||
limit: 返回的最大记录数,默认20条
|
|
||||||
skip: 跳过的记录数,默认0条
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
信息流列表,如果查询失败则返回空列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 构建查询
|
|
||||||
query = self.session.query(AgentFeed)
|
|
||||||
|
|
||||||
# 如果指定了agent_name,则筛选
|
|
||||||
if agent_name:
|
|
||||||
query = query.filter(AgentFeed.agent_name == agent_name)
|
|
||||||
|
|
||||||
# 按创建时间降序排序并限制数量
|
|
||||||
results = query.order_by(AgentFeed.create_time.desc()).offset(skip).limit(limit).all()
|
|
||||||
|
|
||||||
# 转换为字典列表
|
|
||||||
feeds = []
|
|
||||||
for result in results:
|
|
||||||
feeds.append({
|
|
||||||
'id': result.id,
|
|
||||||
'agent_name': result.agent_name,
|
|
||||||
'avatar_url': result.avatar_url,
|
|
||||||
'content': result.content,
|
|
||||||
'create_time': result.create_time
|
|
||||||
})
|
|
||||||
|
|
||||||
return feeds
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取信息流失败: {e}")
|
|
||||||
return []
|
|
||||||
@ -81,89 +81,4 @@ async def get_stock_data(stock_code: str, start_date: Optional[str] = None, end_
|
|||||||
logger.error(f"获取股票数据失败: {e}")
|
logger.error(f"获取股票数据失败: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stock/data/all", summary="获取所有股票数据")
|
|
||||||
async def get_stock_data_all(stock_code: str):
|
|
||||||
|
|
||||||
result = {}
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
api = AStockAPI()
|
|
||||||
|
|
||||||
#获取股本信息
|
|
||||||
stock_shares = api.get_stock_shares(stock_code)
|
|
||||||
result["stock_shares"] = json.loads(stock_shares.to_json(orient="records"))
|
|
||||||
|
|
||||||
# 获取概念板块
|
|
||||||
concept_east = api.get_concept_east(stock_code)
|
|
||||||
result["concept_east"] = json.loads(concept_east.to_json(orient="records"))
|
|
||||||
|
|
||||||
# 获取板块
|
|
||||||
plate_east = api.get_plate_east(stock_code)
|
|
||||||
result["plate_east"] = json.loads(plate_east.to_json(orient="records"))
|
|
||||||
|
|
||||||
# 获取市场数据
|
|
||||||
market_data = api.get_market_data(stock_code)
|
|
||||||
result["market_data"] = json.loads(market_data.to_json(orient="records"))
|
|
||||||
|
|
||||||
# 获取分钟线数据
|
|
||||||
min_data = api.get_market_min_data(stock_code)
|
|
||||||
result["min_data"] = json.loads(min_data.to_json(orient="records"))
|
|
||||||
|
|
||||||
# 获取资金流向数据
|
|
||||||
flow_data = api.get_capital_flow(stock_code)
|
|
||||||
result["flow_data"] = json.loads(flow_data.to_json(orient="records"))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取股票数据失败: {e}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
@router.post('/{stock_code}/analysis', summary="获取股票分析数据")
|
|
||||||
async def get_stock_analysis(stock_code: str, current_user: Dict[str, Any] = Depends(get_current_user)):
|
|
||||||
|
|
||||||
|
|
||||||
url = 'https://mate.aimateplus.com/v1/workflows/run'
|
|
||||||
token = 'app-nWuCOa0YfQVtAosTY3Jr5vFV'
|
|
||||||
headers = {
|
|
||||||
'Authorization': f'Bearer {token}',
|
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"inputs" : {
|
|
||||||
"stock_code" : stock_code
|
|
||||||
},
|
|
||||||
"response_mode": "streaming",
|
|
||||||
"user": current_user["mail"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# 保存用户提问
|
|
||||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], stock_code, "请分析以下股票:" + stock_code + ",并给出分析报告。")
|
|
||||||
|
|
||||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
|
||||||
|
|
||||||
# 如果响应不成功,返回错误
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=response.status_code,
|
|
||||||
detail=f"Failed to get response from Dify API: {response.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取response的stream
|
|
||||||
def stream_response():
|
|
||||||
for chunk in response.iter_content(chunk_size=1024):
|
|
||||||
if chunk:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return StreamingResponse(stream_response(), media_type="text/plain")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -1,126 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
"""
|
|
||||||
API路由模块,为前端提供REST API接口
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Body, Query, Path
|
|
||||||
from typing import Dict, Any, List, Optional
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from cryptoai.api.deepseek_api import DeepSeekAPI
|
|
||||||
from cryptoai.utils.config_loader import ConfigLoader
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from cryptoai.routes.user import get_current_user
|
|
||||||
import requests
|
|
||||||
from datetime import datetime
|
|
||||||
from cryptoai.utils.db_manager import get_db_manager
|
|
||||||
|
|
||||||
# 创建路由
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
|
||||||
user_prompt: str
|
|
||||||
agent_id: int
|
|
||||||
conversation_id: Optional[str] = None
|
|
||||||
|
|
||||||
class AgentCreate(BaseModel):
|
|
||||||
name: str
|
|
||||||
description: Optional[str] = None
|
|
||||||
hello_prompt: Optional[str] = None
|
|
||||||
dify_token: Optional[str] = None
|
|
||||||
inputs: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
class AgentUpdate(BaseModel):
|
|
||||||
name: Optional[str] = None
|
|
||||||
description: Optional[str] = None
|
|
||||||
hello_prompt: Optional[str] = None
|
|
||||||
dify_token: Optional[str] = None
|
|
||||||
inputs: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/create")
|
|
||||||
async def create_agent(agent: AgentCreate):
|
|
||||||
"""
|
|
||||||
创建新的AI Agent
|
|
||||||
"""
|
|
||||||
|
|
||||||
agent = get_db_manager().agent_manager.create_agent(agent.name, agent.description, agent.hello_prompt, agent.dify_token, agent.inputs)
|
|
||||||
|
|
||||||
return agent
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/list", response_model=List[Dict[str, Any]])
|
|
||||||
async def get_agents(
|
|
||||||
skip: int = Query(0, ge=0),
|
|
||||||
limit: int = Query(100, ge=1, le=1000)
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
获取所有代理
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 从数据库获取Agent列表
|
|
||||||
agents = get_db_manager().agent_manager.list_agents(limit=limit, skip=skip)
|
|
||||||
return agents
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/chat")
|
|
||||||
async def chat(request: ChatRequest, current_user: Dict[str, Any] = Depends(get_current_user)):
|
|
||||||
"""
|
|
||||||
聊天接口
|
|
||||||
"""
|
|
||||||
# 尝试从数据库获取Agent
|
|
||||||
try:
|
|
||||||
agent_id = int(request.agent_id)
|
|
||||||
agent = get_db_manager().agent_manager.get_agent_by_id(agent_id)
|
|
||||||
|
|
||||||
if not agent:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid agent ID")
|
|
||||||
|
|
||||||
token = agent.get("dify_token")
|
|
||||||
inputs = agent.get("inputs") or {}
|
|
||||||
inputs["current_date"] = datetime.now().strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid agent ID format")
|
|
||||||
|
|
||||||
url = "https://mate.aimateplus.com/v1/chat-messages"
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {token}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
data = {
|
|
||||||
"inputs": inputs,
|
|
||||||
"query": request.user_prompt,
|
|
||||||
"response_mode": "streaming",
|
|
||||||
"user": current_user["mail"]
|
|
||||||
}
|
|
||||||
|
|
||||||
if request.conversation_id:
|
|
||||||
data["conversation_id"] = request.conversation_id
|
|
||||||
|
|
||||||
logging.info(f"Chat request data: {data}")
|
|
||||||
|
|
||||||
# 保存用户提问
|
|
||||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], request.agent_id, request.user_prompt)
|
|
||||||
|
|
||||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
|
||||||
|
|
||||||
# 如果响应不成功,返回错误
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=response.status_code,
|
|
||||||
detail=f"Failed to get response from Dify API: {response.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取response的stream
|
|
||||||
def stream_response():
|
|
||||||
for chunk in response.iter_content(chunk_size=1024):
|
|
||||||
if chunk:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return StreamingResponse(stream_response(), media_type="text/plain")
|
|
||||||
@ -4,6 +4,9 @@ from cryptoai.utils.db_manager import get_db_manager
|
|||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from cryptoai.routes.user import get_current_user
|
from cryptoai.routes.user import get_current_user
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
class AnalysisHistoryRequest(BaseModel):
|
class AnalysisHistoryRequest(BaseModel):
|
||||||
@ -27,4 +30,73 @@ async def get_analysis_histories(current_user: dict = Depends(get_current_user),
|
|||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
offset: int = 0):
|
offset: int = 0):
|
||||||
history = get_db_manager().analysis_history_manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset)
|
history = get_db_manager().analysis_history_manager.get_user_analysis_history(current_user["id"], limit=limit, offset=offset)
|
||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
|
class AnalysisRequest(BaseModel):
|
||||||
|
symbol: Optional[str] = None
|
||||||
|
timeframe: Optional[str] = None
|
||||||
|
stock_code: Optional[str] = None
|
||||||
|
type: str
|
||||||
|
|
||||||
|
@router.post("/analysis")
|
||||||
|
async def analysis(request: AnalysisRequest,
|
||||||
|
current_user: dict = Depends(get_current_user)):
|
||||||
|
|
||||||
|
if request.type == 'crypto':
|
||||||
|
# 检查symbol是否存在
|
||||||
|
tokens = get_db_manager().token_manager.search_token(request.symbol)
|
||||||
|
if not tokens or len(tokens) == 0:
|
||||||
|
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
||||||
|
|
||||||
|
symbol = tokens[0]["symbol"]
|
||||||
|
token = 'app-BbaqIAMPi0ktgaV9IizMlc2N'
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"inputs" : {
|
||||||
|
"symbol" : symbol,
|
||||||
|
"timeframe" : request.timeframe
|
||||||
|
},
|
||||||
|
"response_mode": "streaming",
|
||||||
|
"user": current_user["mail"]
|
||||||
|
}
|
||||||
|
|
||||||
|
get_db_manager().user_question_manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
stock_code = request.stock_code
|
||||||
|
token = 'app-nWuCOa0YfQVtAosTY3Jr5vFV'
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"inputs" : {
|
||||||
|
"stock_code": stock_code
|
||||||
|
},
|
||||||
|
"response_mode": "streaming",
|
||||||
|
"user": current_user["mail"]
|
||||||
|
}
|
||||||
|
|
||||||
|
get_db_manager().user_question_manager.save_user_question(current_user["id"], stock_code, "请分析以下股票:" + stock_code + ",并给出分析报告。")
|
||||||
|
|
||||||
|
url = 'https://mate.aimateplus.com/v1/workflows/run'
|
||||||
|
headers = {
|
||||||
|
'Authorization': f'Bearer {token}',
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.post(url, headers=headers, json=payload, stream=True)
|
||||||
|
|
||||||
|
# 如果响应不成功,返回错误
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=response.status_code,
|
||||||
|
detail=f"Failed to get response from Dify API: {response.text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取response的stream
|
||||||
|
def stream_response():
|
||||||
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
|
if chunk:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return StreamingResponse(stream_response(), media_type="text/plain")
|
||||||
@ -51,110 +51,4 @@ async def get_crypto_kline(symbol: str, timeframe: Optional[str] = None, limit:
|
|||||||
else:
|
else:
|
||||||
result[timeframe] = binance_api.get_historical_klines(symbol=symbol, interval=timeframe, limit=limit).to_dict(orient="records")
|
result[timeframe] = binance_api.get_historical_klines(symbol=symbol, interval=timeframe, limit=limit).to_dict(orient="records")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@router.post("/analysis_v2")
|
|
||||||
async def analysis_crypto_v2(request: CryptoAnalysisRequest,
|
|
||||||
current_user: dict = Depends(get_current_user)):
|
|
||||||
# 检查symbol是否存在
|
|
||||||
tokens = get_db_manager().token_manager.search_token(request.symbol)
|
|
||||||
if not tokens or len(tokens) == 0:
|
|
||||||
raise HTTPException(status_code=400, detail="您输入的币种在币安不存在,请检查后重新输入。")
|
|
||||||
|
|
||||||
symbol = tokens[0]["symbol"]
|
|
||||||
|
|
||||||
url = 'https://mate.aimateplus.com/v1/workflows/run'
|
|
||||||
token = 'app-BbaqIAMPi0ktgaV9IizMlc2N'
|
|
||||||
headers = {
|
|
||||||
'Authorization': f'Bearer {token}',
|
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"inputs" : {
|
|
||||||
"symbol" : symbol,
|
|
||||||
"timeframe" : request.timeframe
|
|
||||||
},
|
|
||||||
"response_mode": "streaming",
|
|
||||||
"user": current_user["mail"]
|
|
||||||
}
|
|
||||||
|
|
||||||
# 保存用户提问
|
|
||||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], symbol, "请分析以下加密货币:" + symbol + ",并给出分析报告。")
|
|
||||||
|
|
||||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
|
||||||
|
|
||||||
# 如果响应不成功,返回错误
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=response.status_code,
|
|
||||||
detail=f"Failed to get response from Dify API: {response.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取response的stream
|
|
||||||
def stream_response():
|
|
||||||
for chunk in response.iter_content(chunk_size=1024):
|
|
||||||
if chunk:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return StreamingResponse(stream_response(), media_type="text/plain")
|
|
||||||
|
|
||||||
@router.post("/analysis")
|
|
||||||
async def analysis_crypto(request: CryptoAnalysisRequest,
|
|
||||||
current_user: dict = Depends(get_current_user)):
|
|
||||||
|
|
||||||
# 尝试从数据库获取Agent
|
|
||||||
try:
|
|
||||||
agent_id = 1
|
|
||||||
if request.timeframe:
|
|
||||||
user_prompt = f"请分析以下加密货币:{request.symbol},并给出 {request.timeframe} 级别的分析报告。"
|
|
||||||
else:
|
|
||||||
user_prompt = f"请分析以下加密货币:{request.symbol},并给出分析报告。"
|
|
||||||
|
|
||||||
agent = get_db_manager().agent_manager.get_agent_by_id(agent_id)
|
|
||||||
|
|
||||||
if not agent:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid agent ID")
|
|
||||||
|
|
||||||
token = agent.get("dify_token")
|
|
||||||
inputs = agent.get("inputs") or {}
|
|
||||||
inputs["current_date"] = datetime.now().strftime("%Y-%m-%d")
|
|
||||||
|
|
||||||
except ValueError:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid agent ID format")
|
|
||||||
|
|
||||||
url = "https://mate.aimateplus.com/v1/chat-messages"
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {token}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
data = {
|
|
||||||
"inputs": inputs,
|
|
||||||
"query": user_prompt,
|
|
||||||
"response_mode": "streaming",
|
|
||||||
"user": current_user["mail"]
|
|
||||||
}
|
|
||||||
|
|
||||||
logging.info(f"Chat request data: {data}")
|
|
||||||
|
|
||||||
# 保存用户提问
|
|
||||||
get_db_manager().user_question_manager.save_user_question(current_user["id"], agent_id, user_prompt)
|
|
||||||
|
|
||||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
|
||||||
|
|
||||||
# 如果响应不成功,返回错误
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=response.status_code,
|
|
||||||
detail=f"Failed to get response from Dify API: {response.text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取response的stream
|
|
||||||
def stream_response():
|
|
||||||
for chunk in response.iter_content(chunk_size=1024):
|
|
||||||
if chunk:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return StreamingResponse(stream_response(), media_type="text/plain")
|
|
||||||
|
|
||||||
|
|
||||||
@ -15,11 +15,8 @@ from fastapi.responses import JSONResponse
|
|||||||
import time
|
import time
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
from cryptoai.routes.agent import router as agent_router
|
|
||||||
from cryptoai.routes.feed import router as feed_router
|
|
||||||
from cryptoai.routes.user import router as user_router
|
from cryptoai.routes.user import router as user_router
|
||||||
from cryptoai.routes.adata import router as adata_router
|
from cryptoai.routes.adata import router as adata_router
|
||||||
from cryptoai.routes.question import router as question_router
|
|
||||||
from cryptoai.routes.crypto import router as crypto_router
|
from cryptoai.routes.crypto import router as crypto_router
|
||||||
from cryptoai.routes.platform import router as platform_router
|
from cryptoai.routes.platform import router as platform_router
|
||||||
from cryptoai.routes.analysis import router as analysis_router
|
from cryptoai.routes.analysis import router as analysis_router
|
||||||
@ -52,11 +49,8 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 添加API路由
|
# 添加API路由
|
||||||
app.include_router(agent_router, prefix="/agent")
|
|
||||||
app.include_router(platform_router, prefix="/platform", tags=["平台信息"])
|
app.include_router(platform_router, prefix="/platform", tags=["平台信息"])
|
||||||
app.include_router(feed_router, prefix="/feed", tags=["AI Agent信息流"])
|
|
||||||
app.include_router(user_router, prefix="/user", tags=["用户管理"])
|
app.include_router(user_router, prefix="/user", tags=["用户管理"])
|
||||||
app.include_router(question_router, prefix="/question", tags=["用户提问"])
|
|
||||||
app.include_router(adata_router, prefix="/adata", tags=["A股数据"])
|
app.include_router(adata_router, prefix="/adata", tags=["A股数据"])
|
||||||
app.include_router(crypto_router, prefix="/crypto", tags=["加密货币数据"])
|
app.include_router(crypto_router, prefix="/crypto", tags=["加密货币数据"])
|
||||||
app.include_router(analysis_router, prefix="/analysis", tags=["分析历史"])
|
app.include_router(analysis_router, prefix="/analysis", tags=["分析历史"])
|
||||||
|
|||||||
@ -1,110 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
"""
|
|
||||||
AI Agent信息流API路由模块,提供信息流的增删改查功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from fastapi import APIRouter, HTTPException, status, Query, Depends
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import Dict, Any, List, Optional
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from cryptoai.routes.user import get_current_user
|
|
||||||
from cryptoai.utils.db_manager import get_db_manager
|
|
||||||
|
|
||||||
# 配置日志
|
|
||||||
logger = logging.getLogger("feed_router")
|
|
||||||
|
|
||||||
# 创建路由
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
# 请求模型
|
|
||||||
class AgentFeedCreate(BaseModel):
|
|
||||||
"""创建信息流请求模型"""
|
|
||||||
agent_name: str
|
|
||||||
content: str
|
|
||||||
avatar_url: Optional[str] = None
|
|
||||||
|
|
||||||
# 响应模型
|
|
||||||
class AgentFeedResponse(BaseModel):
|
|
||||||
"""信息流响应模型"""
|
|
||||||
id: int
|
|
||||||
agent_name: str
|
|
||||||
avatar_url: Optional[str] = None
|
|
||||||
content: str
|
|
||||||
create_time: datetime
|
|
||||||
|
|
||||||
@router.post("", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
|
|
||||||
async def create_feed(feed: AgentFeedCreate) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
创建新的AI Agent信息流
|
|
||||||
|
|
||||||
Args:
|
|
||||||
feed: 信息流创建请求
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
创建成功的状态信息
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 获取数据库管理器
|
|
||||||
db_manager = get_db_manager()
|
|
||||||
|
|
||||||
# 保存信息流
|
|
||||||
success = db_manager.agent_feed_manager.save_agent_feed(
|
|
||||||
agent_name=feed.agent_name,
|
|
||||||
content=feed.content,
|
|
||||||
avatar_url=feed.avatar_url
|
|
||||||
)
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="保存信息流失败"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"message": "信息流创建成功"
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"创建信息流失败: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"创建信息流失败: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("", response_model=List[AgentFeedResponse])
|
|
||||||
async def get_feeds(
|
|
||||||
agent_name: Optional[str] = Query(None, description="AI Agent名称,可选"),
|
|
||||||
limit: int = Query(20, description="返回的最大记录数,默认20条"),
|
|
||||||
skip: int = Query(0, description="跳过的记录数,默认0条"),
|
|
||||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
|
||||||
) -> List[AgentFeedResponse]:
|
|
||||||
"""
|
|
||||||
获取AI Agent信息流列表
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_name: 可选,指定获取特定Agent的信息流
|
|
||||||
limit: 返回的最大记录数,默认20条
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
信息流列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 获取数据库管理器
|
|
||||||
db_manager = get_db_manager()
|
|
||||||
|
|
||||||
# 获取信息流
|
|
||||||
feeds = db_manager.agent_feed_manager.get_agent_feeds(agent_name=agent_name, limit=limit, skip=skip)
|
|
||||||
|
|
||||||
return feeds
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取信息流失败: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"获取信息流失败: {str(e)}"
|
|
||||||
)
|
|
||||||
@ -1,166 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
"""
|
|
||||||
用户提问API路由模块,提供用户提问数据的增删改查功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from fastapi import APIRouter, HTTPException, status, Depends, Query
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from typing import Dict, Any, List, Optional
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from cryptoai.utils.db_manager import get_db_manager
|
|
||||||
from cryptoai.routes.user import get_current_user
|
|
||||||
|
|
||||||
# 配置日志
|
|
||||||
logger = logging.getLogger("question_router")
|
|
||||||
|
|
||||||
# 创建路由
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
# 请求模型
|
|
||||||
class QuestionCreate(BaseModel):
|
|
||||||
"""创建提问请求模型"""
|
|
||||||
agent_id: str
|
|
||||||
question: str
|
|
||||||
|
|
||||||
# 响应模型
|
|
||||||
class QuestionResponse(BaseModel):
|
|
||||||
"""提问响应模型"""
|
|
||||||
id: int
|
|
||||||
user_id: int
|
|
||||||
agent_id: str
|
|
||||||
question: str
|
|
||||||
create_time: datetime
|
|
||||||
|
|
||||||
@router.post("/", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
|
|
||||||
async def create_question(
|
|
||||||
question: QuestionCreate,
|
|
||||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
创建新的用户提问记录
|
|
||||||
|
|
||||||
Args:
|
|
||||||
question: 提问创建请求
|
|
||||||
current_user: 当前用户信息,由依赖项提供
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
创建成功的状态信息
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 获取数据库管理器
|
|
||||||
db_manager = get_db_manager()
|
|
||||||
|
|
||||||
# 保存提问
|
|
||||||
success = db_manager.user_question_manager.save_user_question(
|
|
||||||
user_id=current_user["id"],
|
|
||||||
agent_id=question.agent_id,
|
|
||||||
question=question.question
|
|
||||||
)
|
|
||||||
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="保存提问失败"
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "success",
|
|
||||||
"message": "提问记录创建成功"
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"创建提问记录失败: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"创建提问记录失败: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/", response_model=List[QuestionResponse])
|
|
||||||
async def get_questions(
|
|
||||||
agent_id: Optional[str] = Query(None, description="AI Agent ID,可选"),
|
|
||||||
limit: int = Query(20, description="返回的最大记录数,默认20条"),
|
|
||||||
skip: int = Query(0, description="跳过的记录数,默认0条"),
|
|
||||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
|
||||||
) -> List[QuestionResponse]:
|
|
||||||
"""
|
|
||||||
获取用户提问记录列表
|
|
||||||
|
|
||||||
Args:
|
|
||||||
agent_id: 可选,指定获取特定Agent的提问
|
|
||||||
limit: 返回的最大记录数,默认20条
|
|
||||||
skip: 跳过的记录数,默认0条
|
|
||||||
current_user: 当前用户信息,由依赖项提供
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
提问记录列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 获取数据库管理器
|
|
||||||
db_manager = get_db_manager()
|
|
||||||
|
|
||||||
# 获取提问记录
|
|
||||||
questions = db_manager.user_question_manager.get_user_questions(
|
|
||||||
user_id=current_user["id"],
|
|
||||||
agent_id=agent_id,
|
|
||||||
limit=limit,
|
|
||||||
skip=skip
|
|
||||||
)
|
|
||||||
|
|
||||||
return questions
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取提问记录失败: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"获取提问记录失败: {str(e)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/{question_id}", response_model=QuestionResponse)
|
|
||||||
async def get_question(
|
|
||||||
question_id: int,
|
|
||||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
|
||||||
) -> QuestionResponse:
|
|
||||||
"""
|
|
||||||
获取特定提问记录
|
|
||||||
|
|
||||||
Args:
|
|
||||||
question_id: 提问ID
|
|
||||||
current_user: 当前用户信息,由依赖项提供
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
提问记录
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 获取数据库管理器
|
|
||||||
db_manager = get_db_manager()
|
|
||||||
|
|
||||||
# 获取提问记录
|
|
||||||
question = db_manager.user_question_manager.get_user_question_by_id(question_id)
|
|
||||||
|
|
||||||
if not question:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"提问记录 {question_id} 不存在"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查是否是当前用户的提问
|
|
||||||
if question["user_id"] != current_user["id"]:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="没有权限查看此提问记录"
|
|
||||||
)
|
|
||||||
|
|
||||||
return question
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取提问记录失败: {str(e)}")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail=f"获取提问记录失败: {str(e)}"
|
|
||||||
)
|
|
||||||
@ -29,7 +29,7 @@ router = APIRouter()
|
|||||||
# JWT配置
|
# JWT配置
|
||||||
JWT_SECRET_KEY = "FX9Rf7YpvTRdXWmj0Osx8P9smSkUh6fW" # 实际应用中应该从环境变量或配置文件中获取
|
JWT_SECRET_KEY = "FX9Rf7YpvTRdXWmj0Osx8P9smSkUh6fW" # 实际应用中应该从环境变量或配置文件中获取
|
||||||
JWT_ALGORITHM = "HS256"
|
JWT_ALGORITHM = "HS256"
|
||||||
ACCESS_TOKEN_EXPIRE_MINUTES = 0 # 用户登录后不过期
|
ACCESS_TOKEN_EXPIRE_MINUTES = 180 * 24 * 60 * 60 # 180天
|
||||||
|
|
||||||
# 请求模型
|
# 请求模型
|
||||||
class UserRegister(BaseModel):
|
class UserRegister(BaseModel):
|
||||||
@ -82,7 +82,7 @@ def create_access_token(data: Dict[str, Any], expires_delta: timedelta = None) -
|
|||||||
if expires_delta:
|
if expires_delta:
|
||||||
expire = datetime.now() + expires_delta
|
expire = datetime.now() + expires_delta
|
||||||
else:
|
else:
|
||||||
expire = datetime.now() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
expire = datetime.now() + timedelta(days=180)
|
||||||
to_encode.update({"exp": expire})
|
to_encode.update({"exp": expire})
|
||||||
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm=JWT_ALGORITHM)
|
||||||
return encoded_jwt
|
return encoded_jwt
|
||||||
@ -291,15 +291,12 @@ async def login(loginData: UserLogin) -> TokenResponse:
|
|||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
# 创建访问令牌,不过期
|
# 创建访问令牌,不过期
|
||||||
access_token_expires = None
|
access_token = create_access_token(data={"sub": user["mail"]})
|
||||||
access_token = create_access_token(
|
|
||||||
data={"sub": user["mail"]}, expires_delta=access_token_expires
|
|
||||||
)
|
|
||||||
|
|
||||||
return TokenResponse(
|
return TokenResponse(
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
token_type="bearer",
|
token_type="bearer",
|
||||||
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES,
|
||||||
user_info=UserResponse(
|
user_info=UserResponse(
|
||||||
id=user["id"],
|
id=user["id"],
|
||||||
mail=user["mail"],
|
mail=user["mail"],
|
||||||
|
|||||||
@ -14,10 +14,8 @@ from cryptoai.utils.config_loader import ConfigLoader
|
|||||||
from cryptoai.models.base import Base
|
from cryptoai.models.base import Base
|
||||||
from cryptoai.models.token import TokenManager
|
from cryptoai.models.token import TokenManager
|
||||||
from cryptoai.models.analysis_result import AnalysisResultManager
|
from cryptoai.models.analysis_result import AnalysisResultManager
|
||||||
from cryptoai.models.agent_feed import AgentFeedManager
|
|
||||||
from cryptoai.models.user import UserManager
|
from cryptoai.models.user import UserManager
|
||||||
from cryptoai.models.user_question import UserQuestionManager
|
from cryptoai.models.user_question import UserQuestionManager
|
||||||
from cryptoai.models.agent import AgentManager
|
|
||||||
from cryptoai.models.astock import AStockManager
|
from cryptoai.models.astock import AStockManager
|
||||||
from cryptoai.models.analysis_history import AnalysisHistoryManager
|
from cryptoai.models.analysis_history import AnalysisHistoryManager
|
||||||
|
|
||||||
@ -105,10 +103,8 @@ class DBManager:
|
|||||||
# 初始化各个模型的管理器
|
# 初始化各个模型的管理器
|
||||||
self.token_manager = TokenManager(session)
|
self.token_manager = TokenManager(session)
|
||||||
self.analysis_result_manager = AnalysisResultManager(session)
|
self.analysis_result_manager = AnalysisResultManager(session)
|
||||||
self.agent_feed_manager = AgentFeedManager(session)
|
|
||||||
self.user_manager = UserManager(session)
|
self.user_manager = UserManager(session)
|
||||||
self.user_question_manager = UserQuestionManager(session)
|
self.user_question_manager = UserQuestionManager(session)
|
||||||
self.agent_manager = AgentManager(session)
|
|
||||||
self.astock_manager = AStockManager(session)
|
self.astock_manager = AStockManager(session)
|
||||||
self.analysis_history_manager = AnalysisHistoryManager(session)
|
self.analysis_history_manager = AnalysisHistoryManager(session)
|
||||||
|
|
||||||
|
|||||||
@ -29,7 +29,7 @@ services:
|
|||||||
cryptoai-api:
|
cryptoai-api:
|
||||||
build: .
|
build: .
|
||||||
container_name: cryptoai-api
|
container_name: cryptoai-api
|
||||||
image: cryptoai-api:0.1.27
|
image: cryptoai-api:0.1.28
|
||||||
restart: always
|
restart: always
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user