"""对话 Agent 带 function calling 循环的对话 Agent,支持流式和非流式模式。 工具调用阶段会 yield 状态消息,保持 SSE 连接活跃。 """ import json import logging from typing import AsyncGenerator from app.llm.client import chat_completion, stream_chat_completion, get_client from app.llm.prompts import CHAT_SYSTEM_PROMPT from app.llm.tools import CHAT_TOOLS from app.llm.tool_executor import execute_tool, set_chat_user_context from app.config import settings logger = logging.getLogger(__name__) MAX_TOOL_ROUNDS = 5 # 工具名称映射(用于状态提示) TOOL_LABELS = { "get_strategy_board": "读取今日作战结论", "get_market_temperature": "查询市场温度", "get_hot_sectors": "查询热门板块", "get_latest_recommendations": "查询推荐列表", "get_user_watchlist_snapshot": "读取自选股作战池", "get_stock_kline": "查询K线数据", "get_stock_capital_flow": "查询资金流向", "diagnose_stock": "生成个股会诊", "search_stock": "搜索股票", } async def chat_stream(messages: list[dict], current_user: dict | None = None) -> AsyncGenerator[dict, None]: """流式对话,支持 tool use 循环 Yields: dict: {"type": "status", "content": "..."} 或 {"type": "content", "content": "..."} """ client = get_client() if not client: yield {"type": "content", "content": "LLM 未配置,请在 .env 中设置 ASTOCK_DEEPSEEK_API_KEY"} return set_chat_user_context(current_user) # 构建完整消息列表 full_messages = [{"role": "system", "content": CHAT_SYSTEM_PROMPT}] full_messages.extend(messages) try: # Tool use 循环(非流式,直到没有 tool_calls) for round_num in range(MAX_TOOL_ROUNDS): if round_num == 0: yield {"type": "status", "content": "整理今日作战上下文..."} resp = await chat_completion(full_messages, tools=CHAT_TOOLS) if not resp: yield {"type": "content", "content": "AI 服务暂时不可用,请稍后重试"} return # 检查是否有 tool_calls if not resp.tool_calls: break # 将 assistant 消息(含 tool_calls)加入历史 full_messages.append({ "role": "assistant", "content": resp.content or "", "tool_calls": [ { "id": tc.id, "type": "function", "function": { "name": tc.function.name, "arguments": tc.function.arguments, }, } for tc in resp.tool_calls ], }) # 执行每个工具调用 for tc in resp.tool_calls: try: args = json.loads(tc.function.arguments) except json.JSONDecodeError: args = {} tool_label = TOOL_LABELS.get(tc.function.name, tc.function.name) yield {"type": "status", "content": f"正在{tool_label}..."} logger.info(f"Chat Agent 调用工具: {tc.function.name}({args})") result = await execute_tool(tc.function.name, args) full_messages.append({ "role": "tool", "tool_call_id": tc.id, "content": result, }) yield {"type": "status", "content": "整理作战结论中..."} else: pass # 最终回复:流式输出 yield {"type": "status", "content": ""} # 清除状态 async for delta in stream_chat_completion(full_messages): if delta.content: yield {"type": "content", "content": delta.content} finally: set_chat_user_context(None)