134 lines
3.8 KiB
Python
134 lines
3.8 KiB
Python
"""
|
||
对话API路由
|
||
"""
|
||
from fastapi import APIRouter, HTTPException, Depends
|
||
from fastapi.responses import StreamingResponse
|
||
from typing import Optional
|
||
import uuid
|
||
import json
|
||
import asyncio
|
||
from app.models.chat import ChatRequest, ChatResponse
|
||
from app.models.database import User
|
||
from app.agent.smart_agent import smart_agent # 使用智能Agent
|
||
from app.middleware.auth_middleware import get_current_user
|
||
from app.utils.logger import logger
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
@router.post("/message", response_model=ChatResponse)
|
||
async def send_message(request: ChatRequest):
|
||
"""
|
||
发送消息给Agent
|
||
|
||
Args:
|
||
request: 聊天请求
|
||
|
||
Returns:
|
||
Agent响应
|
||
"""
|
||
try:
|
||
# 生成或使用现有session_id
|
||
session_id = request.session_id or str(uuid.uuid4())
|
||
|
||
# 处理消息(使用智能Agent)
|
||
response = await smart_agent.process_message(
|
||
message=request.message,
|
||
session_id=session_id,
|
||
user_id=request.user_id
|
||
)
|
||
|
||
return ChatResponse(
|
||
message=response["message"],
|
||
session_id=session_id,
|
||
metadata=response.get("metadata")
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理消息失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/history/{session_id}")
|
||
async def get_history(session_id: str, limit: int = 50):
|
||
"""
|
||
获取对话历史
|
||
|
||
Args:
|
||
session_id: 会话ID
|
||
limit: 最大消息数
|
||
|
||
Returns:
|
||
对话历史
|
||
"""
|
||
try:
|
||
context = smart_agent.context_manager.get_context(session_id)
|
||
return {
|
||
"session_id": session_id,
|
||
"messages": context
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"获取历史失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.post("/message/stream")
|
||
async def send_message_stream(
|
||
request: ChatRequest,
|
||
current_user: User = Depends(get_current_user)
|
||
):
|
||
"""
|
||
流式发送消息给Agent
|
||
|
||
Args:
|
||
request: 聊天请求
|
||
current_user: 当前登录用户
|
||
|
||
Returns:
|
||
Server-Sent Events 流式响应
|
||
"""
|
||
try:
|
||
# 生成或使用现有session_id
|
||
session_id = request.session_id or str(uuid.uuid4())
|
||
|
||
async def event_generator():
|
||
"""生成SSE事件流"""
|
||
try:
|
||
# 发送session_id
|
||
yield f"data: {json.dumps({'type': 'session_id', 'session_id': session_id})}\n\n"
|
||
|
||
# 添加小延迟确保数据被发送
|
||
await asyncio.sleep(0.01)
|
||
|
||
# 处理消息并流式返回(使用真实用户ID)
|
||
async for chunk in smart_agent.process_message_stream(
|
||
message=request.message,
|
||
session_id=session_id,
|
||
user_id=str(current_user.id)
|
||
):
|
||
yield f"data: {json.dumps({'type': 'content', 'content': chunk})}\n\n"
|
||
# 添加小延迟,让浏览器有机会接收数据
|
||
await asyncio.sleep(0.001)
|
||
|
||
# 发送完成信号
|
||
yield f"data: {json.dumps({'type': 'done'})}\n\n"
|
||
|
||
except Exception as e:
|
||
logger.error(f"流式处理消息失败: {e}")
|
||
yield f"data: {json.dumps({'type': 'error', 'error': str(e)})}\n\n"
|
||
|
||
return StreamingResponse(
|
||
event_generator(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no",
|
||
"Transfer-Encoding": "chunked"
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"创建流式响应失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|