astock-agent/backend/app/api/chat.py
2026-04-23 17:24:55 +08:00

66 lines
1.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""对话 API
POST /api/chat/stream - SSE 流式对话
"""
import json
import logging
import traceback
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.core.deps import get_current_user
from app.db.error_logger import log_error
from app.llm.chat_agent import chat_stream
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/chat", tags=["chat"])
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: list[ChatMessage]
@router.post("/stream")
async def chat_stream_endpoint(req: ChatRequest, current_user: dict = Depends(get_current_user)):
"""流式对话接口SSE"""
messages = [{"role": m.role, "content": m.content} for m in req.messages]
async def event_generator():
try:
async for msg in chat_stream(messages, current_user=current_user):
data = json.dumps(msg, ensure_ascii=False)
yield f"data: {data}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Chat stream error: {e}")
await log_error(
"chat",
f"Chat stream error: {e}",
detail=traceback.format_exc(),
context={"method": "POST", "path": "/api/chat/stream"},
)
error_data = json.dumps(
{"type": "content", "content": f"出错了: {e}"},
ensure_ascii=False,
)
yield f"data: {error_data}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)