stock-ai-agent/backend/app/agent/enhanced_agent.py
2026-02-03 10:08:15 +08:00

378 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
增强版Agent - 集成LLM智能分析
"""
import re
import json
from typing import Dict, Any, Optional
from app.config import get_settings
from app.agent.context import ContextManager
from app.agent.skill_manager import skill_manager
from app.skills.market_data import MarketDataSkill
from app.skills.technical_analysis import TechnicalAnalysisSkill
from app.skills.fundamental import FundamentalSkill
from app.skills.visualization import VisualizationSkill
from app.services.llm_service import llm_service
from app.utils.logger import logger
from app.utils.stock_names import search_stock_by_name, get_stock_name
class EnhancedStockAgent:
"""增强版股票分析Agent集成LLM"""
def __init__(self):
"""初始化Agent"""
self.context_manager = ContextManager()
self.settings = get_settings()
# 注册技能
self._register_skills()
# 检查LLM是否可用
self.use_llm = bool(self.settings.zhipuai_api_key) and llm_service.client is not None
if self.use_llm:
logger.info("Enhanced Agent初始化完成LLM模式")
else:
logger.info("Enhanced Agent初始化完成规则模式")
def _register_skills(self):
"""注册所有技能"""
skill_manager.register(MarketDataSkill())
skill_manager.register(TechnicalAnalysisSkill())
skill_manager.register(FundamentalSkill())
skill_manager.register(VisualizationSkill())
logger.info("技能注册完成")
async def process_message(
self,
message: str,
session_id: str,
user_id: Optional[str] = None
) -> Dict[str, Any]:
"""
处理用户消息(增强版)
Args:
message: 用户消息
session_id: 会话ID
user_id: 用户ID
Returns:
响应结果
"""
logger.info(f"处理消息: {message[:50]}...")
# 保存用户消息
self.context_manager.add_message(session_id, "user", message)
# 提取股票代码
stock_code = self._extract_stock_code(message)
# 使用LLM或规则识别意图
if self.use_llm:
intent = await self._recognize_intent_with_llm(message, stock_code)
else:
intent = self._recognize_intent_with_rules(message, stock_code)
logger.info(f"识别意图: {intent}")
# 执行技能
result = await self._execute_intent(intent, message)
# 生成响应使用LLM增强
response = await self._generate_response(intent, result, stock_code)
# 保存助手响应
self.context_manager.add_message(
session_id,
"assistant",
response["message"],
metadata=response.get("metadata")
)
return response
async def _recognize_intent_with_llm(
self,
message: str,
stock_code: Optional[str]
) -> Dict[str, Any]:
"""使用LLM识别意图"""
try:
llm_result = llm_service.analyze_intent(message)
intent_type = llm_result.get("type", "unknown")
confidence = llm_result.get("confidence", 0)
# 如果置信度太低,回退到规则模式
if confidence < 0.5:
logger.info("LLM置信度低回退到规则模式")
return self._recognize_intent_with_rules(message, stock_code)
# 构建意图
intent = {
"type": intent_type,
"confidence": confidence,
"skill": self._map_intent_to_skill(intent_type),
"params": {"stock_code": stock_code} if stock_code else {}
}
return intent
except Exception as e:
logger.error(f"LLM意图识别失败: {e}")
return self._recognize_intent_with_rules(message, stock_code)
def _recognize_intent_with_rules(
self,
message: str,
stock_code: Optional[str]
) -> Dict[str, Any]:
"""使用规则识别意图(原有逻辑)"""
message_lower = message.lower()
# 行情查询
if any(keyword in message_lower for keyword in ["行情", "价格", "涨跌", "实时", "quote"]):
return {
"type": "market_data",
"skill": "market_data",
"params": {
"stock_code": stock_code,
"data_type": "quote"
}
}
# K线查询
if any(keyword in message_lower for keyword in ["k线", "kline", "走势", "图表"]):
return {
"type": "visualization",
"skill": "visualization",
"params": {
"stock_code": stock_code,
"chart_type": "candlestick"
}
}
# 技术分析
if any(keyword in message_lower for keyword in ["技术", "指标", "macd", "rsi", "kdj", "均线", "ma"]):
return {
"type": "technical_analysis",
"skill": "technical_analysis",
"params": {
"stock_code": stock_code,
"indicators": ["ma", "macd", "rsi"]
}
}
# 基本面
if any(keyword in message_lower for keyword in ["基本面", "公司", "行业", "信息"]):
return {
"type": "fundamental",
"skill": "fundamental",
"params": {
"stock_code": stock_code
}
}
# 默认:行情查询
if stock_code:
return {
"type": "market_data",
"skill": "market_data",
"params": {
"stock_code": stock_code,
"data_type": "quote"
}
}
# 无法识别
return {
"type": "unknown",
"skill": None,
"params": {}
}
def _map_intent_to_skill(self, intent_type: str) -> Optional[str]:
"""将意图类型映射到技能名称"""
mapping = {
"market_data": "market_data",
"technical_analysis": "technical_analysis",
"fundamental": "fundamental",
"visualization": "visualization"
}
return mapping.get(intent_type)
def _extract_stock_code(self, message: str) -> Optional[str]:
"""从消息中提取股票代码"""
# 匹配6位数字
pattern = r'\b\d{6}\b'
matches = re.findall(pattern, message)
if matches:
return matches[0]
# 使用股票名称数据库搜索
chinese_pattern = r'[\u4e00-\u9fa5]{2,6}'
chinese_words = re.findall(chinese_pattern, message)
for word in chinese_words:
code = search_stock_by_name(word)
if code:
logger.info(f"识别股票名称: {word} -> {code}")
return code
return None
async def _execute_intent(self, intent: Dict[str, Any], message: str) -> Dict[str, Any]:
"""执行意图对应的技能"""
if intent["type"] == "unknown":
return {
"success": False,
"error": "无法理解您的问题,请提供股票代码或明确的查询意图"
}
skill_name = intent["skill"]
params = intent["params"]
if not params.get("stock_code"):
return {
"success": False,
"error": "请提供股票代码或股票名称"
}
# 执行技能
result = await skill_manager.execute_skill(skill_name, **params)
return result
async def _generate_response(
self,
intent: Dict[str, Any],
result: Dict[str, Any],
stock_code: Optional[str]
) -> Dict[str, Any]:
"""生成响应消息使用LLM增强"""
if not result.get("success", True):
return {
"message": f"抱歉,{result.get('error', '处理失败')}",
"metadata": {"type": "error"}
}
data = result.get("data", result)
# 基础格式化
base_response = self._format_response_basic(intent, data)
# 如果启用LLM添加智能分析
if self.use_llm and stock_code and intent["type"] == "technical_analysis":
try:
stock_name = get_stock_name(stock_code) or stock_code
llm_summary = llm_service.generate_analysis_summary(
stock_code, stock_name, data
)
base_response["message"] += f"\n\n【AI分析】\n{llm_summary}"
except Exception as e:
logger.error(f"LLM分析生成失败: {e}")
return base_response
def _format_response_basic(self, intent: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
"""基础响应格式化(原有逻辑)"""
if "error" in data:
return {
"message": f"查询失败:{data['error']}",
"metadata": {"type": "error"}
}
intent_type = intent["type"]
if intent_type == "market_data":
return self._format_market_data(data)
elif intent_type == "technical_analysis":
return self._format_technical(data)
elif intent_type == "fundamental":
return self._format_fundamental(data)
elif intent_type == "visualization":
return self._format_visualization(data)
else:
return {
"message": "查询完成",
"metadata": {"type": "data", "data": data}
}
def _format_market_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""格式化行情数据"""
if "kline_data" in data:
kline_data = data["kline_data"]
message = f"已获取K线数据{len(kline_data)}条记录"
return {
"message": message,
"metadata": {"type": "kline", "data": kline_data}
}
message = f"""
{data.get('name', '股票')}】({data.get('ts_code', '')})
交易日期:{data.get('trade_date', '')}
最新价:{data.get('close', 0):.2f}
涨跌额:{data.get('change', 0):.2f}
涨跌幅:{data.get('pct_chg', 0):.2f}%
开盘价:{data.get('open', 0):.2f}
最高价:{data.get('high', 0):.2f}
最低价:{data.get('low', 0):.2f}
成交量:{data.get('vol', 0):.0f}
成交额:{data.get('amount', 0):.0f}千元
""".strip()
return {
"message": message,
"metadata": {"type": "quote", "data": data}
}
def _format_technical(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""格式化技术分析"""
indicators = data.get("indicators", {})
message_parts = [f"{data.get('stock_code', '')}】技术指标:\n"]
if "ma" in indicators:
ma = indicators["ma"]
message_parts.append(f"均线MA5={ma.get('ma5')}, MA10={ma.get('ma10')}, MA20={ma.get('ma20')}")
if "macd" in indicators:
macd = indicators["macd"]
message_parts.append(f"MACDDIF={macd.get('dif')}, DEA={macd.get('dea')}, MACD={macd.get('macd')}")
if "rsi" in indicators:
rsi = indicators["rsi"]
message_parts.append(f"RSIRSI6={rsi.get('rsi6')}, RSI12={rsi.get('rsi12')}, RSI24={rsi.get('rsi24')}")
return {
"message": "\n".join(message_parts),
"metadata": {"type": "technical", "data": data}
}
def _format_fundamental(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""格式化基本面"""
message = f"""
{data.get('name', '股票')}】基本信息
股票代码:{data.get('ts_code', '')}
所属地域:{data.get('area', '')}
所属行业:{data.get('industry', '')}
上市市场:{data.get('market', '')}
上市日期:{data.get('list_date', '')}
""".strip()
return {
"message": message,
"metadata": {"type": "fundamental", "data": data}
}
def _format_visualization(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""格式化可视化"""
return {
"message": f"已生成{data.get('stock_code', '')}的K线图",
"metadata": {"type": "chart", "data": data}
}
# 创建全局Agent实例
enhanced_agent = EnhancedStockAgent()