114 lines
3.9 KiB
Python
114 lines
3.9 KiB
Python
"""对话 Agent
|
||
|
||
带 function calling 循环的对话 Agent,支持流式和非流式模式。
|
||
工具调用阶段会 yield 状态消息,保持 SSE 连接活跃。
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
from typing import AsyncGenerator
|
||
|
||
from app.llm.client import chat_completion, stream_chat_completion, get_client
|
||
from app.llm.prompts import CHAT_SYSTEM_PROMPT
|
||
from app.llm.tools import CHAT_TOOLS
|
||
from app.llm.tool_executor import execute_tool, set_chat_user_context
|
||
from app.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
MAX_TOOL_ROUNDS = 5
|
||
|
||
# 工具名称映射(用于状态提示)
|
||
TOOL_LABELS = {
|
||
"get_strategy_board": "读取今日作战结论",
|
||
"get_market_temperature": "查询市场温度",
|
||
"get_hot_sectors": "查询热门板块",
|
||
"get_latest_recommendations": "查询推荐列表",
|
||
"get_user_watchlist_snapshot": "读取自选股作战池",
|
||
"get_stock_kline": "查询K线数据",
|
||
"get_stock_capital_flow": "查询资金流向",
|
||
"diagnose_stock": "生成个股会诊",
|
||
"search_stock": "搜索股票",
|
||
}
|
||
|
||
|
||
async def chat_stream(messages: list[dict], current_user: dict | None = None) -> AsyncGenerator[dict, None]:
|
||
"""流式对话,支持 tool use 循环
|
||
|
||
Yields:
|
||
dict: {"type": "status", "content": "..."} 或 {"type": "content", "content": "..."}
|
||
"""
|
||
client = get_client()
|
||
if not client:
|
||
yield {"type": "content", "content": "LLM 未配置,请在 .env 中设置 ASTOCK_DEEPSEEK_API_KEY"}
|
||
return
|
||
|
||
set_chat_user_context(current_user)
|
||
|
||
# 构建完整消息列表
|
||
full_messages = [{"role": "system", "content": CHAT_SYSTEM_PROMPT}]
|
||
full_messages.extend(messages)
|
||
|
||
try:
|
||
# Tool use 循环(非流式,直到没有 tool_calls)
|
||
for round_num in range(MAX_TOOL_ROUNDS):
|
||
if round_num == 0:
|
||
yield {"type": "status", "content": "整理今日作战上下文..."}
|
||
|
||
resp = await chat_completion(full_messages, tools=CHAT_TOOLS)
|
||
if not resp:
|
||
yield {"type": "content", "content": "AI 服务暂时不可用,请稍后重试"}
|
||
return
|
||
|
||
# 检查是否有 tool_calls
|
||
if not resp.tool_calls:
|
||
break
|
||
|
||
# 将 assistant 消息(含 tool_calls)加入历史
|
||
full_messages.append({
|
||
"role": "assistant",
|
||
"content": resp.content or "",
|
||
"tool_calls": [
|
||
{
|
||
"id": tc.id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tc.function.name,
|
||
"arguments": tc.function.arguments,
|
||
},
|
||
}
|
||
for tc in resp.tool_calls
|
||
],
|
||
})
|
||
|
||
# 执行每个工具调用
|
||
for tc in resp.tool_calls:
|
||
try:
|
||
args = json.loads(tc.function.arguments)
|
||
except json.JSONDecodeError:
|
||
args = {}
|
||
|
||
tool_label = TOOL_LABELS.get(tc.function.name, tc.function.name)
|
||
yield {"type": "status", "content": f"正在{tool_label}..."}
|
||
|
||
logger.info(f"Chat Agent 调用工具: {tc.function.name}({args})")
|
||
result = await execute_tool(tc.function.name, args)
|
||
|
||
full_messages.append({
|
||
"role": "tool",
|
||
"tool_call_id": tc.id,
|
||
"content": result,
|
||
})
|
||
|
||
yield {"type": "status", "content": "整理作战结论中..."}
|
||
else:
|
||
pass
|
||
|
||
# 最终回复:流式输出
|
||
yield {"type": "status", "content": ""} # 清除状态
|
||
async for delta in stream_chat_completion(full_messages):
|
||
if delta.content:
|
||
yield {"type": "content", "content": delta.content}
|
||
finally:
|
||
set_chat_user_context(None)
|