166 lines
4.7 KiB
Python
166 lines
4.7 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
用户提问API路由模块,提供用户提问数据的增删改查功能
|
||
"""
|
||
|
||
import logging
|
||
from fastapi import APIRouter, HTTPException, status, Depends, Query
|
||
from pydantic import BaseModel
|
||
from typing import Dict, Any, List, Optional
|
||
from datetime import datetime
|
||
|
||
from cryptoai.utils.db_manager import get_db_manager
|
||
from cryptoai.routes.user import get_current_user
|
||
|
||
# 配置日志
|
||
logger = logging.getLogger("question_router")
|
||
|
||
# 创建路由
|
||
router = APIRouter()
|
||
|
||
# 请求模型
|
||
class QuestionCreate(BaseModel):
|
||
"""创建提问请求模型"""
|
||
agent_id: str
|
||
question: str
|
||
|
||
# 响应模型
|
||
class QuestionResponse(BaseModel):
|
||
"""提问响应模型"""
|
||
id: int
|
||
user_id: int
|
||
agent_id: str
|
||
question: str
|
||
create_time: datetime
|
||
|
||
@router.post("/", response_model=Dict[str, Any], status_code=status.HTTP_201_CREATED)
|
||
async def create_question(
|
||
question: QuestionCreate,
|
||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
创建新的用户提问记录
|
||
|
||
Args:
|
||
question: 提问创建请求
|
||
current_user: 当前用户信息,由依赖项提供
|
||
|
||
Returns:
|
||
创建成功的状态信息
|
||
"""
|
||
try:
|
||
# 获取数据库管理器
|
||
db_manager = get_db_manager()
|
||
|
||
# 保存提问
|
||
success = db_manager.save_user_question(
|
||
user_id=current_user["id"],
|
||
agent_id=question.agent_id,
|
||
question=question.question
|
||
)
|
||
|
||
if not success:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail="保存提问失败"
|
||
)
|
||
|
||
return {
|
||
"status": "success",
|
||
"message": "提问记录创建成功"
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建提问记录失败: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"创建提问记录失败: {str(e)}"
|
||
)
|
||
|
||
@router.get("/", response_model=List[QuestionResponse])
|
||
async def get_questions(
|
||
agent_id: Optional[str] = Query(None, description="AI Agent ID,可选"),
|
||
limit: int = Query(20, description="返回的最大记录数,默认20条"),
|
||
skip: int = Query(0, description="跳过的记录数,默认0条"),
|
||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||
) -> List[QuestionResponse]:
|
||
"""
|
||
获取用户提问记录列表
|
||
|
||
Args:
|
||
agent_id: 可选,指定获取特定Agent的提问
|
||
limit: 返回的最大记录数,默认20条
|
||
skip: 跳过的记录数,默认0条
|
||
current_user: 当前用户信息,由依赖项提供
|
||
|
||
Returns:
|
||
提问记录列表
|
||
"""
|
||
try:
|
||
# 获取数据库管理器
|
||
db_manager = get_db_manager()
|
||
|
||
# 获取提问记录
|
||
questions = db_manager.get_user_questions(
|
||
user_id=current_user["id"],
|
||
agent_id=agent_id,
|
||
limit=limit,
|
||
skip=skip
|
||
)
|
||
|
||
return questions
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取提问记录失败: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"获取提问记录失败: {str(e)}"
|
||
)
|
||
|
||
@router.get("/{question_id}", response_model=QuestionResponse)
|
||
async def get_question(
|
||
question_id: int,
|
||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||
) -> QuestionResponse:
|
||
"""
|
||
获取特定提问记录
|
||
|
||
Args:
|
||
question_id: 提问ID
|
||
current_user: 当前用户信息,由依赖项提供
|
||
|
||
Returns:
|
||
提问记录
|
||
"""
|
||
try:
|
||
# 获取数据库管理器
|
||
db_manager = get_db_manager()
|
||
|
||
# 获取提问记录
|
||
question = db_manager.get_user_question_by_id(question_id)
|
||
|
||
if not question:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"提问记录 {question_id} 不存在"
|
||
)
|
||
|
||
# 检查是否是当前用户的提问
|
||
if question["user_id"] != current_user["id"]:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_403_FORBIDDEN,
|
||
detail="没有权限查看此提问记录"
|
||
)
|
||
|
||
return question
|
||
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"获取提问记录失败: {str(e)}")
|
||
raise HTTPException(
|
||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
detail=f"获取提问记录失败: {str(e)}"
|
||
) |