66 lines
1.8 KiB
Python
66 lines
1.8 KiB
Python
"""对话 API
|
||
|
||
POST /api/chat/stream - SSE 流式对话
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
import traceback
|
||
|
||
from fastapi import APIRouter, Depends
|
||
from fastapi.responses import StreamingResponse
|
||
from pydantic import BaseModel
|
||
|
||
from app.core.deps import get_current_user
|
||
from app.db.error_logger import log_error
|
||
from app.llm.chat_agent import chat_stream
|
||
|
||
logger = logging.getLogger(__name__)
|
||
router = APIRouter(prefix="/api/chat", tags=["chat"])
|
||
|
||
|
||
class ChatMessage(BaseModel):
|
||
role: str
|
||
content: str
|
||
|
||
|
||
class ChatRequest(BaseModel):
|
||
messages: list[ChatMessage]
|
||
|
||
|
||
@router.post("/stream")
|
||
async def chat_stream_endpoint(req: ChatRequest, current_user: dict = Depends(get_current_user)):
|
||
"""流式对话接口(SSE)"""
|
||
messages = [{"role": m.role, "content": m.content} for m in req.messages]
|
||
|
||
async def event_generator():
|
||
try:
|
||
async for msg in chat_stream(messages, current_user=current_user):
|
||
data = json.dumps(msg, ensure_ascii=False)
|
||
yield f"data: {data}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
except Exception as e:
|
||
logger.error(f"Chat stream error: {e}")
|
||
await log_error(
|
||
"chat",
|
||
f"Chat stream error: {e}",
|
||
detail=traceback.format_exc(),
|
||
context={"method": "POST", "path": "/api/chat/stream"},
|
||
)
|
||
error_data = json.dumps(
|
||
{"type": "content", "content": f"出错了: {e}"},
|
||
ensure_ascii=False,
|
||
)
|
||
yield f"data: {error_data}\n\n"
|
||
yield "data: [DONE]\n\n"
|
||
|
||
return StreamingResponse(
|
||
event_generator(),
|
||
media_type="text/event-stream",
|
||
headers={
|
||
"Cache-Control": "no-cache",
|
||
"Connection": "keep-alive",
|
||
"X-Accel-Buffering": "no",
|
||
},
|
||
)
|