update
This commit is contained in:
parent
ef415d4ec9
commit
03cb7fe83a
@ -17,11 +17,14 @@ from cryptoai.utils.config_loader import ConfigLoader
|
||||
from fastapi.responses import StreamingResponse
|
||||
from cryptoai.routes.user import get_current_user
|
||||
import requests
|
||||
from cryptoai.utils.db_manager import get_db_manager
|
||||
|
||||
# 创建路由
|
||||
router = APIRouter()
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
user_prompt: str
|
||||
agent_id: str
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
@ -57,6 +60,10 @@ async def chat(request: ChatRequest,current_user: Dict[str, Any] = Depends(get_c
|
||||
"response_mode" : "streaming",
|
||||
"user" : current_user["mail"]
|
||||
}
|
||||
|
||||
# 保存用户提问
|
||||
get_db_manager().save_user_question(current_user["id"], request.agent_id, request.user_prompt)
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
|
||||
#获取response 的 stream
|
||||
|
||||
@ -18,6 +18,7 @@ from typing import Dict, Any
|
||||
from cryptoai.routes.agent import router as agent_router
|
||||
from cryptoai.routes.feed import router as feed_router
|
||||
from cryptoai.routes.user import router as user_router
|
||||
from cryptoai.routes.question import router as question_router
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
@ -50,6 +51,7 @@ app.add_middleware(
|
||||
app.include_router(agent_router, prefix="/agent")
|
||||
app.include_router(feed_router, prefix="/feed", tags=["AI Agent信息流"])
|
||||
app.include_router(user_router, prefix="/user", tags=["用户管理"])
|
||||
app.include_router(question_router, prefix="/question", tags=["用户提问"])
|
||||
|
||||
# 请求计时中间件
|
||||
@app.middleware("http")
|
||||
|
||||
166
cryptoai/routes/question.py
Normal file
166
cryptoai/routes/question.py
Normal file
@ -0,0 +1,166 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
用户提问API路由模块,提供用户提问数据的增删改查功能
|
||||
"""
|
||||
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException, status, Depends, Query
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from cryptoai.utils.db_manager import get_db_manager
|
||||
from cryptoai.routes.user import get_current_user
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger("question_router")
|
||||
|
||||
# 创建路由
|
||||
router = APIRouter()
|
||||
|
||||
# 请求模型
|
||||
class QuestionCreate(BaseModel):
|
||||
"""创建提问请求模型"""
|
||||
agent_id: str
|
||||
question: str
|
||||
|
||||
# 响应模型
|
||||
class QuestionResponse(BaseModel):
|
||||
"""提问响应模型"""
|
||||
id: int
|
||||
user_id: int
|
||||
agent_id: str
|
||||
question: str
|
||||
create_time: datetime
|
||||
|
||||
@router.post("/", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
|
||||
async def create_question(
|
||||
question: QuestionCreate,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
创建新的用户提问记录
|
||||
|
||||
Args:
|
||||
question: 提问创建请求
|
||||
current_user: 当前用户信息,由依赖项提供
|
||||
|
||||
Returns:
|
||||
创建成功的状态信息
|
||||
"""
|
||||
try:
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 保存提问
|
||||
success = db_manager.save_user_question(
|
||||
user_id=current_user["id"],
|
||||
agent_id=question.agent_id,
|
||||
question=question.question
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="保存提问失败"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "提问记录创建成功"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建提问记录失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"创建提问记录失败: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/", response_model=List[QuestionResponse])
|
||||
async def get_questions(
|
||||
agent_id: Optional[str] = Query(None, description="AI Agent ID,可选"),
|
||||
limit: int = Query(20, description="返回的最大记录数,默认20条"),
|
||||
skip: int = Query(0, description="跳过的记录数,默认0条"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> List[QuestionResponse]:
|
||||
"""
|
||||
获取用户提问记录列表
|
||||
|
||||
Args:
|
||||
agent_id: 可选,指定获取特定Agent的提问
|
||||
limit: 返回的最大记录数,默认20条
|
||||
skip: 跳过的记录数,默认0条
|
||||
current_user: 当前用户信息,由依赖项提供
|
||||
|
||||
Returns:
|
||||
提问记录列表
|
||||
"""
|
||||
try:
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 获取提问记录
|
||||
questions = db_manager.get_user_questions(
|
||||
user_id=current_user["id"],
|
||||
agent_id=agent_id,
|
||||
limit=limit,
|
||||
skip=skip
|
||||
)
|
||||
|
||||
return questions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取提问记录失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取提问记录失败: {str(e)}"
|
||||
)
|
||||
|
||||
@router.get("/{question_id}", response_model=QuestionResponse)
|
||||
async def get_question(
|
||||
question_id: int,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> QuestionResponse:
|
||||
"""
|
||||
获取特定提问记录
|
||||
|
||||
Args:
|
||||
question_id: 提问ID
|
||||
current_user: 当前用户信息,由依赖项提供
|
||||
|
||||
Returns:
|
||||
提问记录
|
||||
"""
|
||||
try:
|
||||
# 获取数据库管理器
|
||||
db_manager = get_db_manager()
|
||||
|
||||
# 获取提问记录
|
||||
question = db_manager.get_user_question_by_id(question_id)
|
||||
|
||||
if not question:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"提问记录 {question_id} 不存在"
|
||||
)
|
||||
|
||||
# 检查是否是当前用户的提问
|
||||
if question["user_id"] != current_user["id"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="没有权限查看此提问记录"
|
||||
)
|
||||
|
||||
return question
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取提问记录失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"获取提问记录失败: {str(e)}"
|
||||
)
|
||||
@ -7,9 +7,9 @@ import logging
|
||||
from typing import Dict, Any, List, Optional, Union
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Index, text
|
||||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Index, text, ForeignKey
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker, relationship
|
||||
from sqlalchemy.dialects.mysql import JSON
|
||||
from sqlalchemy.pool import QueuePool
|
||||
|
||||
@ -76,6 +76,9 @@ class User(Base):
|
||||
level = Column(Integer, nullable=False, default=0, comment='用户级别(0=普通用户,1=VIP,2=SVIP)')
|
||||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||||
|
||||
# 关系
|
||||
questions = relationship("UserQuestion", back_populates="user")
|
||||
|
||||
# 索引和表属性
|
||||
__table_args__ = (
|
||||
Index('idx_mail', 'mail'),
|
||||
@ -84,6 +87,28 @@ class User(Base):
|
||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||||
)
|
||||
|
||||
# 定义用户提问数据模型
|
||||
class UserQuestion(Base):
|
||||
"""用户提问数据表模型"""
|
||||
__tablename__ = 'user_questions'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False, comment='用户ID')
|
||||
agent_id = Column(String(50), nullable=False, comment='AI Agent ID')
|
||||
question = Column(Text, nullable=False, comment='提问内容')
|
||||
create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间')
|
||||
|
||||
# 关系
|
||||
user = relationship("User", back_populates="questions")
|
||||
|
||||
# 索引和表属性
|
||||
__table_args__ = (
|
||||
Index('idx_user_id', 'user_id'),
|
||||
Index('idx_agent_id', 'agent_id'),
|
||||
Index('idx_create_time', 'create_time'),
|
||||
{'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'}
|
||||
)
|
||||
|
||||
class DBManager:
|
||||
"""数据库管理工具,用于连接MySQL数据库并保存智能体分析结果"""
|
||||
|
||||
@ -508,6 +533,166 @@ class DBManager:
|
||||
logger.error(f"获取信息流失败: {e}")
|
||||
return []
|
||||
|
||||
def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool:
|
||||
"""
|
||||
保存用户提问数据
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
agent_id: AI Agent ID
|
||||
question: 提问内容
|
||||
|
||||
Returns:
|
||||
保存是否成功
|
||||
"""
|
||||
if not self.engine:
|
||||
try:
|
||||
self._init_db()
|
||||
except Exception as e:
|
||||
logger.error(f"重新连接数据库失败: {e}")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 创建会话
|
||||
session = self.Session()
|
||||
|
||||
try:
|
||||
# 创建新记录
|
||||
new_question = UserQuestion(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
question=question,
|
||||
create_time=datetime.now()
|
||||
)
|
||||
|
||||
# 添加并提交
|
||||
session.add(new_question)
|
||||
session.commit()
|
||||
|
||||
logger.info(f"成功保存用户 {user_id} 对 Agent {agent_id} 的提问")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.error(f"保存用户提问失败: {e}")
|
||||
return False
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库会话失败: {e}")
|
||||
# 如果是连接错误,尝试重新初始化
|
||||
try:
|
||||
self._init_db()
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
def get_user_questions(self, user_id: Optional[int] = None, agent_id: Optional[str] = None,
|
||||
limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取用户提问数据
|
||||
|
||||
Args:
|
||||
user_id: 可选,指定获取特定用户的提问
|
||||
agent_id: 可选,指定获取特定Agent的提问
|
||||
limit: 返回的最大记录数,默认20条
|
||||
skip: 跳过的记录数,默认0条
|
||||
|
||||
Returns:
|
||||
提问数据列表,如果查询失败则返回空列表
|
||||
"""
|
||||
if not self.engine:
|
||||
try:
|
||||
self._init_db()
|
||||
except Exception as e:
|
||||
logger.error(f"重新连接数据库失败: {e}")
|
||||
return []
|
||||
|
||||
try:
|
||||
# 创建会话
|
||||
session = self.Session()
|
||||
|
||||
try:
|
||||
# 构建查询
|
||||
query = session.query(UserQuestion)
|
||||
|
||||
# 如果指定了user_id,则筛选
|
||||
if user_id:
|
||||
query = query.filter(UserQuestion.user_id == user_id)
|
||||
|
||||
# 如果指定了agent_id,则筛选
|
||||
if agent_id:
|
||||
query = query.filter(UserQuestion.agent_id == agent_id)
|
||||
|
||||
# 按创建时间降序排序并限制数量
|
||||
results = query.order_by(UserQuestion.create_time.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
# 转换为字典列表
|
||||
questions = []
|
||||
for result in results:
|
||||
questions.append({
|
||||
'id': result.id,
|
||||
'user_id': result.user_id,
|
||||
'agent_id': result.agent_id,
|
||||
'question': result.question,
|
||||
'create_time': result.create_time
|
||||
})
|
||||
|
||||
return questions
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户提问失败: {e}")
|
||||
return []
|
||||
|
||||
def get_user_question_by_id(self, question_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
通过ID获取用户提问数据
|
||||
|
||||
Args:
|
||||
question_id: 提问ID
|
||||
|
||||
Returns:
|
||||
提问数据,如果不存在则返回None
|
||||
"""
|
||||
if not self.engine:
|
||||
try:
|
||||
self._init_db()
|
||||
except Exception as e:
|
||||
logger.error(f"重新连接数据库失败: {e}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# 创建会话
|
||||
session = self.Session()
|
||||
|
||||
try:
|
||||
# 查询提问
|
||||
result = session.query(UserQuestion).filter(UserQuestion.id == question_id).first()
|
||||
|
||||
if result:
|
||||
# 转换为字典
|
||||
return {
|
||||
'id': result.id,
|
||||
'user_id': result.user_id,
|
||||
'agent_id': result.agent_id,
|
||||
'question': result.question,
|
||||
'create_time': result.create_time
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户提问失败: {e}")
|
||||
return None
|
||||
|
||||
def get_latest_result(self, agent: str, symbol: str, time_interval: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取最新的分析结果
|
||||
|
||||
@ -75,6 +75,31 @@ def update_table_charset():
|
||||
|
||||
logger.info("成功更新users表字符集为utf8mb4")
|
||||
|
||||
# 检查user_questions表是否存在
|
||||
result = session.execute(text("""
|
||||
SELECT COUNT(*)
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = 'user_questions';
|
||||
"""))
|
||||
|
||||
table_exists = result.scalar() > 0
|
||||
|
||||
# 如果user_questions表存在,更新其字符集
|
||||
if table_exists:
|
||||
session.execute(text("""
|
||||
ALTER TABLE user_questions
|
||||
CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
"""))
|
||||
|
||||
# 特别更新question列的字符集
|
||||
session.execute(text("""
|
||||
ALTER TABLE user_questions
|
||||
MODIFY question TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci;
|
||||
"""))
|
||||
|
||||
logger.info("成功更新user_questions表字符集为utf8mb4")
|
||||
|
||||
session.commit()
|
||||
logger.info("成功更新数据库表字符集为utf8mb4")
|
||||
return True
|
||||
|
||||
@ -29,7 +29,7 @@ services:
|
||||
cryptoai-api:
|
||||
build: .
|
||||
container_name: cryptoai-api
|
||||
image: cryptoai-api:0.0.6
|
||||
image: cryptoai-api:0.0.7
|
||||
restart: always
|
||||
ports:
|
||||
- "8000:8000"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user