""" 增强版Agent - 集成LLM智能分析 """ import re import json from typing import Dict, Any, Optional from app.config import get_settings from app.agent.context import ContextManager from app.agent.skill_manager import skill_manager from app.skills.market_data import MarketDataSkill from app.skills.technical_analysis import TechnicalAnalysisSkill from app.skills.fundamental import FundamentalSkill from app.skills.visualization import VisualizationSkill from app.services.llm_service import llm_service from app.utils.logger import logger from app.utils.stock_names import search_stock_by_name, get_stock_name class EnhancedStockAgent: """增强版股票分析Agent(集成LLM)""" def __init__(self): """初始化Agent""" self.context_manager = ContextManager() self.settings = get_settings() # 注册技能 self._register_skills() # 检查LLM是否可用 self.use_llm = bool(self.settings.zhipuai_api_key) and llm_service.client is not None if self.use_llm: logger.info("Enhanced Agent初始化完成(LLM模式)") else: logger.info("Enhanced Agent初始化完成(规则模式)") def _register_skills(self): """注册所有技能""" skill_manager.register(MarketDataSkill()) skill_manager.register(TechnicalAnalysisSkill()) skill_manager.register(FundamentalSkill()) skill_manager.register(VisualizationSkill()) logger.info("技能注册完成") async def process_message( self, message: str, session_id: str, user_id: Optional[str] = None ) -> Dict[str, Any]: """ 处理用户消息(增强版) Args: message: 用户消息 session_id: 会话ID user_id: 用户ID Returns: 响应结果 """ logger.info(f"处理消息: {message[:50]}...") # 保存用户消息 self.context_manager.add_message(session_id, "user", message) # 提取股票代码 stock_code = self._extract_stock_code(message) # 使用LLM或规则识别意图 if self.use_llm: intent = await self._recognize_intent_with_llm(message, stock_code) else: intent = self._recognize_intent_with_rules(message, stock_code) logger.info(f"识别意图: {intent}") # 执行技能 result = await self._execute_intent(intent, message) # 生成响应(使用LLM增强) response = await self._generate_response(intent, result, stock_code) # 保存助手响应 self.context_manager.add_message( session_id, "assistant", response["message"], metadata=response.get("metadata") ) return response async def _recognize_intent_with_llm( self, message: str, stock_code: Optional[str] ) -> Dict[str, Any]: """使用LLM识别意图""" try: llm_result = llm_service.analyze_intent(message) intent_type = llm_result.get("type", "unknown") confidence = llm_result.get("confidence", 0) # 如果置信度太低,回退到规则模式 if confidence < 0.5: logger.info("LLM置信度低,回退到规则模式") return self._recognize_intent_with_rules(message, stock_code) # 构建意图 intent = { "type": intent_type, "confidence": confidence, "skill": self._map_intent_to_skill(intent_type), "params": {"stock_code": stock_code} if stock_code else {} } return intent except Exception as e: logger.error(f"LLM意图识别失败: {e}") return self._recognize_intent_with_rules(message, stock_code) def _recognize_intent_with_rules( self, message: str, stock_code: Optional[str] ) -> Dict[str, Any]: """使用规则识别意图(原有逻辑)""" message_lower = message.lower() # 行情查询 if any(keyword in message_lower for keyword in ["行情", "价格", "涨跌", "实时", "quote"]): return { "type": "market_data", "skill": "market_data", "params": { "stock_code": stock_code, "data_type": "quote" } } # K线查询 if any(keyword in message_lower for keyword in ["k线", "kline", "走势", "图表"]): return { "type": "visualization", "skill": "visualization", "params": { "stock_code": stock_code, "chart_type": "candlestick" } } # 技术分析 if any(keyword in message_lower for keyword in ["技术", "指标", "macd", "rsi", "kdj", "均线", "ma"]): return { "type": "technical_analysis", "skill": "technical_analysis", "params": { "stock_code": stock_code, "indicators": ["ma", "macd", "rsi"] } } # 基本面 if any(keyword in message_lower for keyword in ["基本面", "公司", "行业", "信息"]): return { "type": "fundamental", "skill": "fundamental", "params": { "stock_code": stock_code } } # 默认:行情查询 if stock_code: return { "type": "market_data", "skill": "market_data", "params": { "stock_code": stock_code, "data_type": "quote" } } # 无法识别 return { "type": "unknown", "skill": None, "params": {} } def _map_intent_to_skill(self, intent_type: str) -> Optional[str]: """将意图类型映射到技能名称""" mapping = { "market_data": "market_data", "technical_analysis": "technical_analysis", "fundamental": "fundamental", "visualization": "visualization" } return mapping.get(intent_type) def _extract_stock_code(self, message: str) -> Optional[str]: """从消息中提取股票代码""" # 匹配6位数字 pattern = r'\b\d{6}\b' matches = re.findall(pattern, message) if matches: return matches[0] # 使用股票名称数据库搜索 chinese_pattern = r'[\u4e00-\u9fa5]{2,6}' chinese_words = re.findall(chinese_pattern, message) for word in chinese_words: code = search_stock_by_name(word) if code: logger.info(f"识别股票名称: {word} -> {code}") return code return None async def _execute_intent(self, intent: Dict[str, Any], message: str) -> Dict[str, Any]: """执行意图对应的技能""" if intent["type"] == "unknown": return { "success": False, "error": "无法理解您的问题,请提供股票代码或明确的查询意图" } skill_name = intent["skill"] params = intent["params"] if not params.get("stock_code"): return { "success": False, "error": "请提供股票代码或股票名称" } # 执行技能 result = await skill_manager.execute_skill(skill_name, **params) return result async def _generate_response( self, intent: Dict[str, Any], result: Dict[str, Any], stock_code: Optional[str] ) -> Dict[str, Any]: """生成响应消息(使用LLM增强)""" if not result.get("success", True): return { "message": f"抱歉,{result.get('error', '处理失败')}", "metadata": {"type": "error"} } data = result.get("data", result) # 基础格式化 base_response = self._format_response_basic(intent, data) # 如果启用LLM,添加智能分析 if self.use_llm and stock_code and intent["type"] == "technical_analysis": try: stock_name = get_stock_name(stock_code) or stock_code llm_summary = llm_service.generate_analysis_summary( stock_code, stock_name, data ) base_response["message"] += f"\n\n【AI分析】\n{llm_summary}" except Exception as e: logger.error(f"LLM分析生成失败: {e}") return base_response def _format_response_basic(self, intent: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]: """基础响应格式化(原有逻辑)""" if "error" in data: return { "message": f"查询失败:{data['error']}", "metadata": {"type": "error"} } intent_type = intent["type"] if intent_type == "market_data": return self._format_market_data(data) elif intent_type == "technical_analysis": return self._format_technical(data) elif intent_type == "fundamental": return self._format_fundamental(data) elif intent_type == "visualization": return self._format_visualization(data) else: return { "message": "查询完成", "metadata": {"type": "data", "data": data} } def _format_market_data(self, data: Dict[str, Any]) -> Dict[str, Any]: """格式化行情数据""" if "kline_data" in data: kline_data = data["kline_data"] message = f"已获取K线数据,共{len(kline_data)}条记录" return { "message": message, "metadata": {"type": "kline", "data": kline_data} } message = f""" 【{data.get('name', '股票')}】({data.get('ts_code', '')}) 交易日期:{data.get('trade_date', '')} 最新价:{data.get('close', 0):.2f} 涨跌额:{data.get('change', 0):.2f} 涨跌幅:{data.get('pct_chg', 0):.2f}% 开盘价:{data.get('open', 0):.2f} 最高价:{data.get('high', 0):.2f} 最低价:{data.get('low', 0):.2f} 成交量:{data.get('vol', 0):.0f}手 成交额:{data.get('amount', 0):.0f}千元 """.strip() return { "message": message, "metadata": {"type": "quote", "data": data} } def _format_technical(self, data: Dict[str, Any]) -> Dict[str, Any]: """格式化技术分析""" indicators = data.get("indicators", {}) message_parts = [f"【{data.get('stock_code', '')}】技术指标:\n"] if "ma" in indicators: ma = indicators["ma"] message_parts.append(f"均线:MA5={ma.get('ma5')}, MA10={ma.get('ma10')}, MA20={ma.get('ma20')}") if "macd" in indicators: macd = indicators["macd"] message_parts.append(f"MACD:DIF={macd.get('dif')}, DEA={macd.get('dea')}, MACD={macd.get('macd')}") if "rsi" in indicators: rsi = indicators["rsi"] message_parts.append(f"RSI:RSI6={rsi.get('rsi6')}, RSI12={rsi.get('rsi12')}, RSI24={rsi.get('rsi24')}") return { "message": "\n".join(message_parts), "metadata": {"type": "technical", "data": data} } def _format_fundamental(self, data: Dict[str, Any]) -> Dict[str, Any]: """格式化基本面""" message = f""" 【{data.get('name', '股票')}】基本信息 股票代码:{data.get('ts_code', '')} 所属地域:{data.get('area', '')} 所属行业:{data.get('industry', '')} 上市市场:{data.get('market', '')} 上市日期:{data.get('list_date', '')} """.strip() return { "message": message, "metadata": {"type": "fundamental", "data": data} } def _format_visualization(self, data: Dict[str, Any]) -> Dict[str, Any]: """格式化可视化""" return { "message": f"已生成{data.get('stock_code', '')}的K线图", "metadata": {"type": "chart", "data": data} } # 创建全局Agent实例 enhanced_agent = EnhancedStockAgent()