diff --git a/cryptoai/routes/agent.py b/cryptoai/routes/agent.py index 27fdf54..7f840dc 100644 --- a/cryptoai/routes/agent.py +++ b/cryptoai/routes/agent.py @@ -17,11 +17,14 @@ from cryptoai.utils.config_loader import ConfigLoader from fastapi.responses import StreamingResponse from cryptoai.routes.user import get_current_user import requests +from cryptoai.utils.db_manager import get_db_manager + # 创建路由 router = APIRouter() class ChatRequest(BaseModel): user_prompt: str + agent_id: str @router.get("/list") @@ -57,6 +60,10 @@ async def chat(request: ChatRequest,current_user: Dict[str, Any] = Depends(get_c "response_mode" : "streaming", "user" : current_user["mail"] } + + # 保存用户提问 + get_db_manager().save_user_question(current_user["id"], request.agent_id, request.user_prompt) + response = requests.post(url, headers=headers, json=data, stream=True) #获取response 的 stream diff --git a/cryptoai/routes/fastapi_app.py b/cryptoai/routes/fastapi_app.py index a6532ed..6d5a461 100644 --- a/cryptoai/routes/fastapi_app.py +++ b/cryptoai/routes/fastapi_app.py @@ -18,6 +18,7 @@ from typing import Dict, Any from cryptoai.routes.agent import router as agent_router from cryptoai.routes.feed import router as feed_router from cryptoai.routes.user import router as user_router +from cryptoai.routes.question import router as question_router # 配置日志 logging.basicConfig( @@ -50,6 +51,7 @@ app.add_middleware( app.include_router(agent_router, prefix="/agent") app.include_router(feed_router, prefix="/feed", tags=["AI Agent信息流"]) app.include_router(user_router, prefix="/user", tags=["用户管理"]) +app.include_router(question_router, prefix="/question", tags=["用户提问"]) # 请求计时中间件 @app.middleware("http") diff --git a/cryptoai/routes/question.py b/cryptoai/routes/question.py new file mode 100644 index 0000000..1a0be14 --- /dev/null +++ b/cryptoai/routes/question.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +用户提问API路由模块,提供用户提问数据的增删改查功能 +""" + +import logging +from fastapi import APIRouter, HTTPException, status, Depends, Query +from pydantic import BaseModel +from typing import Dict, Any, List, Optional +from datetime import datetime + +from cryptoai.utils.db_manager import get_db_manager +from cryptoai.routes.user import get_current_user + +# 配置日志 +logger = logging.getLogger("question_router") + +# 创建路由 +router = APIRouter() + +# 请求模型 +class QuestionCreate(BaseModel): + """创建提问请求模型""" + agent_id: str + question: str + +# 响应模型 +class QuestionResponse(BaseModel): + """提问响应模型""" + id: int + user_id: int + agent_id: str + question: str + create_time: datetime + +@router.post("/", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED) +async def create_question( + question: QuestionCreate, + current_user: Dict[str, Any] = Depends(get_current_user) +) -> Dict[str, Any]: + """ + 创建新的用户提问记录 + + Args: + question: 提问创建请求 + current_user: 当前用户信息,由依赖项提供 + + Returns: + 创建成功的状态信息 + """ + try: + # 获取数据库管理器 + db_manager = get_db_manager() + + # 保存提问 + success = db_manager.save_user_question( + user_id=current_user["id"], + agent_id=question.agent_id, + question=question.question + ) + + if not success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="保存提问失败" + ) + + return { + "status": "success", + "message": "提问记录创建成功" + } + + except Exception as e: + logger.error(f"创建提问记录失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"创建提问记录失败: {str(e)}" + ) + +@router.get("/", response_model=List[QuestionResponse]) +async def get_questions( + agent_id: Optional[str] = Query(None, description="AI Agent ID,可选"), + limit: int = Query(20, description="返回的最大记录数,默认20条"), + skip: int = Query(0, description="跳过的记录数,默认0条"), + current_user: Dict[str, Any] = Depends(get_current_user) +) -> List[QuestionResponse]: + """ + 获取用户提问记录列表 + + Args: + agent_id: 可选,指定获取特定Agent的提问 + limit: 返回的最大记录数,默认20条 + skip: 跳过的记录数,默认0条 + current_user: 当前用户信息,由依赖项提供 + + Returns: + 提问记录列表 + """ + try: + # 获取数据库管理器 + db_manager = get_db_manager() + + # 获取提问记录 + questions = db_manager.get_user_questions( + user_id=current_user["id"], + agent_id=agent_id, + limit=limit, + skip=skip + ) + + return questions + + except Exception as e: + logger.error(f"获取提问记录失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"获取提问记录失败: {str(e)}" + ) + +@router.get("/{question_id}", response_model=QuestionResponse) +async def get_question( + question_id: int, + current_user: Dict[str, Any] = Depends(get_current_user) +) -> QuestionResponse: + """ + 获取特定提问记录 + + Args: + question_id: 提问ID + current_user: 当前用户信息,由依赖项提供 + + Returns: + 提问记录 + """ + try: + # 获取数据库管理器 + db_manager = get_db_manager() + + # 获取提问记录 + question = db_manager.get_user_question_by_id(question_id) + + if not question: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"提问记录 {question_id} 不存在" + ) + + # 检查是否是当前用户的提问 + if question["user_id"] != current_user["id"]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="没有权限查看此提问记录" + ) + + return question + + except HTTPException: + raise + except Exception as e: + logger.error(f"获取提问记录失败: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"获取提问记录失败: {str(e)}" + ) \ No newline at end of file diff --git a/cryptoai/utils/db_manager.py b/cryptoai/utils/db_manager.py index 4ff21c2..7be84bb 100644 --- a/cryptoai/utils/db_manager.py +++ b/cryptoai/utils/db_manager.py @@ -7,9 +7,9 @@ import logging from typing import Dict, Any, List, Optional, Union from datetime import datetime -from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Index, text +from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, Index, text, ForeignKey from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, relationship from sqlalchemy.dialects.mysql import JSON from sqlalchemy.pool import QueuePool @@ -76,6 +76,9 @@ class User(Base): level = Column(Integer, nullable=False, default=0, comment='用户级别(0=普通用户,1=VIP,2=SVIP)') create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') + # 关系 + questions = relationship("UserQuestion", back_populates="user") + # 索引和表属性 __table_args__ = ( Index('idx_mail', 'mail'), @@ -84,6 +87,28 @@ class User(Base): {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} ) +# 定义用户提问数据模型 +class UserQuestion(Base): + """用户提问数据表模型""" + __tablename__ = 'user_questions' + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(Integer, ForeignKey('users.id'), nullable=False, comment='用户ID') + agent_id = Column(String(50), nullable=False, comment='AI Agent ID') + question = Column(Text, nullable=False, comment='提问内容') + create_time = Column(DateTime, nullable=False, default=datetime.now, comment='创建时间') + + # 关系 + user = relationship("User", back_populates="questions") + + # 索引和表属性 + __table_args__ = ( + Index('idx_user_id', 'user_id'), + Index('idx_agent_id', 'agent_id'), + Index('idx_create_time', 'create_time'), + {'mysql_charset': 'utf8mb4', 'mysql_collate': 'utf8mb4_unicode_ci'} + ) + class DBManager: """数据库管理工具,用于连接MySQL数据库并保存智能体分析结果""" @@ -508,6 +533,166 @@ class DBManager: logger.error(f"获取信息流失败: {e}") return [] + def save_user_question(self, user_id: int, agent_id: str, question: str) -> bool: + """ + 保存用户提问数据 + + Args: + user_id: 用户ID + agent_id: AI Agent ID + question: 提问内容 + + Returns: + 保存是否成功 + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return False + + try: + # 创建会话 + session = self.Session() + + try: + # 创建新记录 + new_question = UserQuestion( + user_id=user_id, + agent_id=agent_id, + question=question, + create_time=datetime.now() + ) + + # 添加并提交 + session.add(new_question) + session.commit() + + logger.info(f"成功保存用户 {user_id} 对 Agent {agent_id} 的提问") + return True + + except Exception as e: + session.rollback() + logger.error(f"保存用户提问失败: {e}") + return False + + finally: + session.close() + + except Exception as e: + logger.error(f"创建数据库会话失败: {e}") + # 如果是连接错误,尝试重新初始化 + try: + self._init_db() + except: + pass + return False + + def get_user_questions(self, user_id: Optional[int] = None, agent_id: Optional[str] = None, + limit: int = 20, skip: int = 0) -> List[Dict[str, Any]]: + """ + 获取用户提问数据 + + Args: + user_id: 可选,指定获取特定用户的提问 + agent_id: 可选,指定获取特定Agent的提问 + limit: 返回的最大记录数,默认20条 + skip: 跳过的记录数,默认0条 + + Returns: + 提问数据列表,如果查询失败则返回空列表 + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return [] + + try: + # 创建会话 + session = self.Session() + + try: + # 构建查询 + query = session.query(UserQuestion) + + # 如果指定了user_id,则筛选 + if user_id: + query = query.filter(UserQuestion.user_id == user_id) + + # 如果指定了agent_id,则筛选 + if agent_id: + query = query.filter(UserQuestion.agent_id == agent_id) + + # 按创建时间降序排序并限制数量 + results = query.order_by(UserQuestion.create_time.desc()).offset(skip).limit(limit).all() + + # 转换为字典列表 + questions = [] + for result in results: + questions.append({ + 'id': result.id, + 'user_id': result.user_id, + 'agent_id': result.agent_id, + 'question': result.question, + 'create_time': result.create_time + }) + + return questions + + finally: + session.close() + + except Exception as e: + logger.error(f"获取用户提问失败: {e}") + return [] + + def get_user_question_by_id(self, question_id: int) -> Optional[Dict[str, Any]]: + """ + 通过ID获取用户提问数据 + + Args: + question_id: 提问ID + + Returns: + 提问数据,如果不存在则返回None + """ + if not self.engine: + try: + self._init_db() + except Exception as e: + logger.error(f"重新连接数据库失败: {e}") + return None + + try: + # 创建会话 + session = self.Session() + + try: + # 查询提问 + result = session.query(UserQuestion).filter(UserQuestion.id == question_id).first() + + if result: + # 转换为字典 + return { + 'id': result.id, + 'user_id': result.user_id, + 'agent_id': result.agent_id, + 'question': result.question, + 'create_time': result.create_time + } + else: + return None + + finally: + session.close() + + except Exception as e: + logger.error(f"获取用户提问失败: {e}") + return None + def get_latest_result(self, agent: str, symbol: str, time_interval: str) -> Optional[Dict[str, Any]]: """ 获取最新的分析结果 diff --git a/cryptoai/utils/update_db_charset.py b/cryptoai/utils/update_db_charset.py index 08d2818..2397686 100644 --- a/cryptoai/utils/update_db_charset.py +++ b/cryptoai/utils/update_db_charset.py @@ -75,6 +75,31 @@ def update_table_charset(): logger.info("成功更新users表字符集为utf8mb4") + # 检查user_questions表是否存在 + result = session.execute(text(""" + SELECT COUNT(*) + FROM information_schema.tables + WHERE table_schema = DATABASE() + AND table_name = 'user_questions'; + """)) + + table_exists = result.scalar() > 0 + + # 如果user_questions表存在,更新其字符集 + if table_exists: + session.execute(text(""" + ALTER TABLE user_questions + CONVERT TO CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + """)) + + # 特别更新question列的字符集 + session.execute(text(""" + ALTER TABLE user_questions + MODIFY question TEXT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; + """)) + + logger.info("成功更新user_questions表字符集为utf8mb4") + session.commit() logger.info("成功更新数据库表字符集为utf8mb4") return True diff --git a/docker-compose.yml b/docker-compose.yml index 948a8cf..c3850bb 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -29,7 +29,7 @@ services: cryptoai-api: build: . container_name: cryptoai-api - image: cryptoai-api:0.0.6 + image: cryptoai-api:0.0.7 restart: always ports: - "8000:8000"