astock-agent/backend/app/llm/chat_agent.py
2026-04-30 23:29:52 +08:00

114 lines
3.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""对话 Agent
带 function calling 循环的对话 Agent支持流式和非流式模式。
工具调用阶段会 yield 状态消息,保持 SSE 连接活跃。
"""
import json
import logging
from typing import AsyncGenerator
from app.llm.client import chat_completion, stream_chat_completion, get_client
from app.llm.prompts import CHAT_SYSTEM_PROMPT
from app.llm.tools import CHAT_TOOLS
from app.llm.tool_executor import execute_tool, set_chat_user_context
from app.config import settings
logger = logging.getLogger(__name__)
MAX_TOOL_ROUNDS = 5
# 工具名称映射(用于状态提示)
TOOL_LABELS = {
"get_strategy_board": "读取今日作战结论",
"get_market_temperature": "查询市场温度",
"get_hot_sectors": "查询热门板块",
"get_latest_recommendations": "查询推荐列表",
"get_user_watchlist_snapshot": "读取自选股作战池",
"get_stock_kline": "查询K线数据",
"get_stock_capital_flow": "查询资金流向",
"diagnose_stock": "生成个股会诊",
"search_stock": "搜索股票",
}
async def chat_stream(messages: list[dict], current_user: dict | None = None) -> AsyncGenerator[dict, None]:
"""流式对话,支持 tool use 循环
Yields:
dict: {"type": "status", "content": "..."} 或 {"type": "content", "content": "..."}
"""
client = get_client()
if not client:
yield {"type": "content", "content": "LLM 未配置,请在 .env 中设置 ASTOCK_DEEPSEEK_API_KEY"}
return
set_chat_user_context(current_user)
# 构建完整消息列表
full_messages = [{"role": "system", "content": CHAT_SYSTEM_PROMPT}]
full_messages.extend(messages)
try:
# Tool use 循环(非流式,直到没有 tool_calls
for round_num in range(MAX_TOOL_ROUNDS):
if round_num == 0:
yield {"type": "status", "content": "整理今日作战上下文..."}
resp = await chat_completion(full_messages, tools=CHAT_TOOLS)
if not resp:
yield {"type": "content", "content": "AI 服务暂时不可用,请稍后重试"}
return
# 检查是否有 tool_calls
if not resp.tool_calls:
break
# 将 assistant 消息(含 tool_calls加入历史
full_messages.append({
"role": "assistant",
"content": resp.content or "",
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in resp.tool_calls
],
})
# 执行每个工具调用
for tc in resp.tool_calls:
try:
args = json.loads(tc.function.arguments)
except json.JSONDecodeError:
args = {}
tool_label = TOOL_LABELS.get(tc.function.name, tc.function.name)
yield {"type": "status", "content": f"正在{tool_label}..."}
logger.info(f"Chat Agent 调用工具: {tc.function.name}({args})")
result = await execute_tool(tc.function.name, args)
full_messages.append({
"role": "tool",
"tool_call_id": tc.id,
"content": result,
})
yield {"type": "status", "content": "整理作战结论中..."}
else:
pass
# 最终回复:流式输出
yield {"type": "status", "content": ""} # 清除状态
async for delta in stream_chat_completion(full_messages):
if delta.content:
yield {"type": "content", "content": delta.content}
finally:
set_chat_user_context(None)