astock-agent/backend/app/llm/chat_agent.py
2026-04-07 20:51:00 +08:00

107 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""对话 Agent
带 function calling 循环的对话 Agent支持流式和非流式模式。
工具调用阶段会 yield 状态消息,保持 SSE 连接活跃。
"""
import json
import logging
from typing import AsyncGenerator
from app.llm.client import chat_completion, stream_chat_completion, get_client
from app.llm.prompts import CHAT_SYSTEM_PROMPT
from app.llm.tools import CHAT_TOOLS
from app.llm.tool_executor import execute_tool
from app.config import settings
logger = logging.getLogger(__name__)
MAX_TOOL_ROUNDS = 5
# 工具名称映射(用于状态提示)
TOOL_LABELS = {
"get_market_temperature": "查询市场温度",
"get_hot_sectors": "查询热门板块",
"get_latest_recommendations": "查询推荐列表",
"get_stock_kline": "查询K线数据",
"get_stock_capital_flow": "查询资金流向",
"search_stock": "搜索股票",
}
async def chat_stream(messages: list[dict]) -> AsyncGenerator[dict, None]:
"""流式对话,支持 tool use 循环
Yields:
dict: {"type": "status", "content": "..."} 或 {"type": "content", "content": "..."}
"""
client = get_client()
if not client:
yield {"type": "content", "content": "LLM 未配置,请在 .env 中设置 ASTOCK_DEEPSEEK_API_KEY"}
return
# 构建完整消息列表
full_messages = [{"role": "system", "content": CHAT_SYSTEM_PROMPT}]
full_messages.extend(messages)
# Tool use 循环(非流式,直到没有 tool_calls
for round_num in range(MAX_TOOL_ROUNDS):
if round_num == 0:
yield {"type": "status", "content": "思考中..."}
resp = await chat_completion(full_messages, tools=CHAT_TOOLS)
if not resp:
yield {"type": "content", "content": "AI 服务暂时不可用,请稍后重试"}
return
# 检查是否有 tool_calls
if not resp.tool_calls:
break
# 将 assistant 消息(含 tool_calls加入历史
full_messages.append({
"role": "assistant",
"content": resp.content or "",
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in resp.tool_calls
],
})
# 执行每个工具调用
for tc in resp.tool_calls:
try:
args = json.loads(tc.function.arguments)
except json.JSONDecodeError:
args = {}
tool_label = TOOL_LABELS.get(tc.function.name, tc.function.name)
yield {"type": "status", "content": f"正在{tool_label}..."}
logger.info(f"Chat Agent 调用工具: {tc.function.name}({args})")
result = await execute_tool(tc.function.name, args)
full_messages.append({
"role": "tool",
"tool_call_id": tc.id,
"content": result,
})
yield {"type": "status", "content": "分析数据中..."}
else:
# 超过最大轮次,用最后的消息生成回复
pass
# 最终回复:流式输出
yield {"type": "status", "content": ""} # 清除状态
async for delta in stream_chat_completion(full_messages):
if delta.content:
yield {"type": "content", "content": delta.content}