crypto.ai/cryptoai/routes/question.py
2025-05-24 12:08:57 +08:00

166 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
用户提问API路由模块提供用户提问数据的增删改查功能
"""
import logging
from fastapi import APIRouter, HTTPException, status, Depends, Query
from pydantic import BaseModel
from typing import Dict, Any, List, Optional
from datetime import datetime
from cryptoai.utils.db_manager import get_db_manager
from cryptoai.routes.user import get_current_user
# 配置日志
logger = logging.getLogger("question_router")
# 创建路由
router = APIRouter()
# 请求模型
class QuestionCreate(BaseModel):
"""创建提问请求模型"""
agent_id: str
question: str
# 响应模型
class QuestionResponse(BaseModel):
"""提问响应模型"""
id: int
user_id: int
agent_id: str
question: str
create_time: datetime
@router.post("/", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
async def create_question(
question: QuestionCreate,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
创建新的用户提问记录
Args:
question: 提问创建请求
current_user: 当前用户信息,由依赖项提供
Returns:
创建成功的状态信息
"""
try:
# 获取数据库管理器
db_manager = get_db_manager()
# 保存提问
success = db_manager.user_question_manager.save_user_question(
user_id=current_user["id"],
agent_id=question.agent_id,
question=question.question
)
if not success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="保存提问失败"
)
return {
"status": "success",
"message": "提问记录创建成功"
}
except Exception as e:
logger.error(f"创建提问记录失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"创建提问记录失败: {str(e)}"
)
@router.get("/", response_model=List[QuestionResponse])
async def get_questions(
agent_id: Optional[str] = Query(None, description="AI Agent ID可选"),
limit: int = Query(20, description="返回的最大记录数默认20条"),
skip: int = Query(0, description="跳过的记录数默认0条"),
current_user: Dict[str, Any] = Depends(get_current_user)
) -> List[QuestionResponse]:
"""
获取用户提问记录列表
Args:
agent_id: 可选指定获取特定Agent的提问
limit: 返回的最大记录数默认20条
skip: 跳过的记录数默认0条
current_user: 当前用户信息,由依赖项提供
Returns:
提问记录列表
"""
try:
# 获取数据库管理器
db_manager = get_db_manager()
# 获取提问记录
questions = db_manager.user_question_manager.get_user_questions(
user_id=current_user["id"],
agent_id=agent_id,
limit=limit,
skip=skip
)
return questions
except Exception as e:
logger.error(f"获取提问记录失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取提问记录失败: {str(e)}"
)
@router.get("/{question_id}", response_model=QuestionResponse)
async def get_question(
question_id: int,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> QuestionResponse:
"""
获取特定提问记录
Args:
question_id: 提问ID
current_user: 当前用户信息,由依赖项提供
Returns:
提问记录
"""
try:
# 获取数据库管理器
db_manager = get_db_manager()
# 获取提问记录
question = db_manager.user_question_manager.get_user_question_by_id(question_id)
if not question:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"提问记录 {question_id} 不存在"
)
# 检查是否是当前用户的提问
if question["user_id"] != current_user["id"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="没有权限查看此提问记录"
)
return question
except HTTPException:
raise
except Exception as e:
logger.error(f"获取提问记录失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取提问记录失败: {str(e)}"
)