""" AI Agent核心 基于LangChain的股票分析Agent """ import re import json from typing import Dict, Any, Optional from app.config import get_settings from app.agent.context import ContextManager from app.agent.skill_manager import skill_manager from app.skills.market_data import MarketDataSkill from app.skills.technical_analysis import TechnicalAnalysisSkill from app.skills.fundamental import FundamentalSkill from app.skills.visualization import VisualizationSkill from app.utils.logger import logger class StockAnalysisAgent: """股票分析Agent""" def __init__(self): """初始化Agent""" self.context_manager = ContextManager() self.settings = get_settings() # 注册技能 self._register_skills() # 初始化LLM(简化版,使用规则匹配) # 在实际部署时,这里应该集成智谱AI GLM-4 self.use_llm = bool(self.settings.zhipuai_api_key) logger.info("Stock Analysis Agent初始化完成") def _register_skills(self): """注册所有技能""" skill_manager.register(MarketDataSkill()) skill_manager.register(TechnicalAnalysisSkill()) skill_manager.register(FundamentalSkill()) skill_manager.register(VisualizationSkill()) logger.info("技能注册完成") async def process_message( self, message: str, session_id: str, user_id: Optional[str] = None ) -> Dict[str, Any]: """ 处理用户消息 Args: message: 用户消息 session_id: 会话ID user_id: 用户ID Returns: 响应结果 """ logger.info(f"处理消息: {message[:50]}...") # 保存用户消息 self.context_manager.add_message(session_id, "user", message) # 意图识别和技能调用 intent = self._recognize_intent(message) logger.info(f"识别意图: {intent}") # 执行技能 result = await self._execute_intent(intent, message) # 生成响应 response = self._generate_response(intent, result) # 保存助手响应 self.context_manager.add_message( session_id, "assistant", response["message"], metadata=response.get("metadata") ) return response def _recognize_intent(self, message: str) -> Dict[str, Any]: """ 识别用户意图(简化版规则匹配) Args: message: 用户消息 Returns: 意图字典 """ message_lower = message.lower() # 提取股票代码 stock_code = self._extract_stock_code(message) # 行情查询 if any(keyword in message_lower for keyword in ["行情", "价格", "涨跌", "实时", "quote"]): return { "type": "market_data", "skill": "market_data", "params": { "stock_code": stock_code, "data_type": "quote" } } # K线查询 if any(keyword in message_lower for keyword in ["k线", "kline", "走势", "图表"]): return { "type": "visualization", "skill": "visualization", "params": { "stock_code": stock_code, "chart_type": "candlestick" } } # 技术分析 if any(keyword in message_lower for keyword in ["技术", "指标", "macd", "rsi", "kdj", "均线", "ma"]): return { "type": "technical_analysis", "skill": "technical_analysis", "params": { "stock_code": stock_code, "indicators": ["ma", "macd", "rsi"] } } # 基本面 if any(keyword in message_lower for keyword in ["基本面", "公司", "行业", "信息"]): return { "type": "fundamental", "skill": "fundamental", "params": { "stock_code": stock_code } } # 默认:行情查询 if stock_code: return { "type": "market_data", "skill": "market_data", "params": { "stock_code": stock_code, "data_type": "quote" } } # 无法识别 return { "type": "unknown", "skill": None, "params": {} } def _extract_stock_code(self, message: str) -> Optional[str]: """ 从消息中提取股票代码 Args: message: 用户消息 Returns: 股票代码或None """ from app.utils.stock_names import search_stock_by_name # 匹配6位数字 pattern = r'\b\d{6}\b' matches = re.findall(pattern, message) if matches: return matches[0] # 使用股票名称数据库搜索 # 提取可能的股票名称(2-6个汉字) chinese_pattern = r'[\u4e00-\u9fa5]{2,6}' chinese_words = re.findall(chinese_pattern, message) for word in chinese_words: code = search_stock_by_name(word) if code: logger.info(f"识别股票名称: {word} -> {code}") return code return None async def _execute_intent(self, intent: Dict[str, Any], message: str) -> Dict[str, Any]: """ 执行意图对应的技能 Args: intent: 意图字典 message: 原始消息 Returns: 执行结果 """ if intent["type"] == "unknown": return { "success": False, "error": "无法理解您的问题,请提供股票代码或明确的查询意图" } skill_name = intent["skill"] params = intent["params"] if not params.get("stock_code"): return { "success": False, "error": "请提供股票代码(6位数字)" } # 执行技能 result = await skill_manager.execute_skill(skill_name, **params) return result def _generate_response(self, intent: Dict[str, Any], result: Dict[str, Any]) -> Dict[str, Any]: """ 生成响应消息 Args: intent: 意图 result: 执行结果 Returns: 响应字典 """ if not result.get("success", True): return { "message": f"抱歉,{result.get('error', '处理失败')}", "metadata": { "type": "error" } } data = result.get("data", result) # 根据意图类型生成不同响应 if intent["type"] == "market_data": return self._format_market_data_response(data) elif intent["type"] == "technical_analysis": return self._format_technical_response(data) elif intent["type"] == "fundamental": return self._format_fundamental_response(data) elif intent["type"] == "visualization": return self._format_visualization_response(data) else: return { "message": "查询完成", "metadata": { "type": "data", "data": data } } def _format_market_data_response(self, data: Dict[str, Any]) -> Dict[str, Any]: """格式化行情数据响应""" if "error" in data: return { "message": f"查询失败:{data['error']}", "metadata": {"type": "error"} } if "kline_data" in data: kline_data = data["kline_data"] message = f"已获取K线数据,共{len(kline_data)}条记录" return { "message": message, "metadata": { "type": "kline", "data": kline_data } } # 实时行情 message = f""" 【{data.get('name', '股票')}】({data.get('ts_code', '')}) 交易日期:{data.get('trade_date', '')} 最新价:{data.get('close', 0):.2f} 涨跌额:{data.get('change', 0):.2f} 涨跌幅:{data.get('pct_chg', 0):.2f}% 开盘价:{data.get('open', 0):.2f} 最高价:{data.get('high', 0):.2f} 最低价:{data.get('low', 0):.2f} 成交量:{data.get('vol', 0):.0f}手 成交额:{data.get('amount', 0):.0f}千元 """.strip() return { "message": message, "metadata": { "type": "quote", "data": data } } def _format_technical_response(self, data: Dict[str, Any]) -> Dict[str, Any]: """格式化技术分析响应""" if "error" in data: return { "message": f"分析失败:{data['error']}", "metadata": {"type": "error"} } indicators = data.get("indicators", {}) message_parts = [f"【{data.get('stock_code', '')}】技术指标:\n"] if "ma" in indicators: ma = indicators["ma"] message_parts.append(f"均线:MA5={ma.get('ma5')}, MA10={ma.get('ma10')}, MA20={ma.get('ma20')}") if "macd" in indicators: macd = indicators["macd"] message_parts.append(f"MACD:DIF={macd.get('dif')}, DEA={macd.get('dea')}, MACD={macd.get('macd')}") if "rsi" in indicators: rsi = indicators["rsi"] message_parts.append(f"RSI:RSI6={rsi.get('rsi6')}, RSI12={rsi.get('rsi12')}, RSI24={rsi.get('rsi24')}") return { "message": "\n".join(message_parts), "metadata": { "type": "technical", "data": data } } def _format_fundamental_response(self, data: Dict[str, Any]) -> Dict[str, Any]: """格式化基本面响应""" if "error" in data: return { "message": f"查询失败:{data['error']}", "metadata": {"type": "error"} } message = f""" 【{data.get('name', '股票')}】基本信息 股票代码:{data.get('ts_code', '')} 所属地域:{data.get('area', '')} 所属行业:{data.get('industry', '')} 上市市场:{data.get('market', '')} 上市日期:{data.get('list_date', '')} """.strip() return { "message": message, "metadata": { "type": "fundamental", "data": data } } def _format_visualization_response(self, data: Dict[str, Any]) -> Dict[str, Any]: """格式化可视化响应""" if "error" in data: return { "message": f"生成图表失败:{data['error']}", "metadata": {"type": "error"} } return { "message": f"已生成{data.get('stock_code', '')}的K线图", "metadata": { "type": "chart", "data": data } } # 创建全局Agent实例 stock_agent = StockAnalysisAgent()