"""对话 Agent 带 function calling 循环的对话 Agent,支持流式和非流式模式。 工具调用阶段会 yield 状态消息,保持 SSE 连接活跃。 """ import json import logging from typing import AsyncGenerator from app.llm.client import chat_completion, stream_chat_completion, get_client from app.llm.prompts import CHAT_SYSTEM_PROMPT from app.llm.tools import CHAT_TOOLS from app.llm.tool_executor import execute_tool from app.config import settings logger = logging.getLogger(__name__) MAX_TOOL_ROUNDS = 5 # 工具名称映射(用于状态提示) TOOL_LABELS = { "get_market_temperature": "查询市场温度", "get_hot_sectors": "查询热门板块", "get_latest_recommendations": "查询推荐列表", "get_stock_kline": "查询K线数据", "get_stock_capital_flow": "查询资金流向", "search_stock": "搜索股票", } async def chat_stream(messages: list[dict]) -> AsyncGenerator[dict, None]: """流式对话,支持 tool use 循环 Yields: dict: {"type": "status", "content": "..."} 或 {"type": "content", "content": "..."} """ client = get_client() if not client: yield {"type": "content", "content": "LLM 未配置,请在 .env 中设置 ASTOCK_DEEPSEEK_API_KEY"} return # 构建完整消息列表 full_messages = [{"role": "system", "content": CHAT_SYSTEM_PROMPT}] full_messages.extend(messages) # Tool use 循环(非流式,直到没有 tool_calls) for round_num in range(MAX_TOOL_ROUNDS): if round_num == 0: yield {"type": "status", "content": "思考中..."} resp = await chat_completion(full_messages, tools=CHAT_TOOLS) if not resp: yield {"type": "content", "content": "AI 服务暂时不可用,请稍后重试"} return # 检查是否有 tool_calls if not resp.tool_calls: break # 将 assistant 消息(含 tool_calls)加入历史 full_messages.append({ "role": "assistant", "content": resp.content or "", "tool_calls": [ { "id": tc.id, "type": "function", "function": { "name": tc.function.name, "arguments": tc.function.arguments, }, } for tc in resp.tool_calls ], }) # 执行每个工具调用 for tc in resp.tool_calls: try: args = json.loads(tc.function.arguments) except json.JSONDecodeError: args = {} tool_label = TOOL_LABELS.get(tc.function.name, tc.function.name) yield {"type": "status", "content": f"正在{tool_label}..."} logger.info(f"Chat Agent 调用工具: {tc.function.name}({args})") result = await execute_tool(tc.function.name, args) full_messages.append({ "role": "tool", "tool_call_id": tc.id, "content": result, }) yield {"type": "status", "content": "分析数据中..."} else: # 超过最大轮次,用最后的消息生成回复 pass # 最终回复:流式输出 yield {"type": "status", "content": ""} # 清除状态 async for delta in stream_chat_completion(full_messages): if delta.content: yield {"type": "content", "content": delta.content}