378 lines
12 KiB
Python
378 lines
12 KiB
Python
"""
|
||
增强版Agent - 集成LLM智能分析
|
||
"""
|
||
import re
|
||
import json
|
||
from typing import Dict, Any, Optional
|
||
from app.config import get_settings
|
||
from app.agent.context import ContextManager
|
||
from app.agent.skill_manager import skill_manager
|
||
from app.skills.market_data import MarketDataSkill
|
||
from app.skills.technical_analysis import TechnicalAnalysisSkill
|
||
from app.skills.fundamental import FundamentalSkill
|
||
from app.skills.visualization import VisualizationSkill
|
||
from app.services.llm_service import llm_service
|
||
from app.utils.logger import logger
|
||
from app.utils.stock_names import search_stock_by_name, get_stock_name
|
||
|
||
|
||
class EnhancedStockAgent:
|
||
"""增强版股票分析Agent(集成LLM)"""
|
||
|
||
def __init__(self):
|
||
"""初始化Agent"""
|
||
self.context_manager = ContextManager()
|
||
self.settings = get_settings()
|
||
|
||
# 注册技能
|
||
self._register_skills()
|
||
|
||
# 检查LLM是否可用
|
||
self.use_llm = bool(self.settings.zhipuai_api_key) and llm_service.client is not None
|
||
|
||
if self.use_llm:
|
||
logger.info("Enhanced Agent初始化完成(LLM模式)")
|
||
else:
|
||
logger.info("Enhanced Agent初始化完成(规则模式)")
|
||
|
||
def _register_skills(self):
|
||
"""注册所有技能"""
|
||
skill_manager.register(MarketDataSkill())
|
||
skill_manager.register(TechnicalAnalysisSkill())
|
||
skill_manager.register(FundamentalSkill())
|
||
skill_manager.register(VisualizationSkill())
|
||
logger.info("技能注册完成")
|
||
|
||
async def process_message(
|
||
self,
|
||
message: str,
|
||
session_id: str,
|
||
user_id: Optional[str] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
处理用户消息(增强版)
|
||
|
||
Args:
|
||
message: 用户消息
|
||
session_id: 会话ID
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
响应结果
|
||
"""
|
||
logger.info(f"处理消息: {message[:50]}...")
|
||
|
||
# 保存用户消息
|
||
self.context_manager.add_message(session_id, "user", message)
|
||
|
||
# 提取股票代码
|
||
stock_code = self._extract_stock_code(message)
|
||
|
||
# 使用LLM或规则识别意图
|
||
if self.use_llm:
|
||
intent = await self._recognize_intent_with_llm(message, stock_code)
|
||
else:
|
||
intent = self._recognize_intent_with_rules(message, stock_code)
|
||
|
||
logger.info(f"识别意图: {intent}")
|
||
|
||
# 执行技能
|
||
result = await self._execute_intent(intent, message)
|
||
|
||
# 生成响应(使用LLM增强)
|
||
response = await self._generate_response(intent, result, stock_code)
|
||
|
||
# 保存助手响应
|
||
self.context_manager.add_message(
|
||
session_id,
|
||
"assistant",
|
||
response["message"],
|
||
metadata=response.get("metadata")
|
||
)
|
||
|
||
return response
|
||
|
||
async def _recognize_intent_with_llm(
|
||
self,
|
||
message: str,
|
||
stock_code: Optional[str]
|
||
) -> Dict[str, Any]:
|
||
"""使用LLM识别意图"""
|
||
try:
|
||
llm_result = llm_service.analyze_intent(message)
|
||
|
||
intent_type = llm_result.get("type", "unknown")
|
||
confidence = llm_result.get("confidence", 0)
|
||
|
||
# 如果置信度太低,回退到规则模式
|
||
if confidence < 0.5:
|
||
logger.info("LLM置信度低,回退到规则模式")
|
||
return self._recognize_intent_with_rules(message, stock_code)
|
||
|
||
# 构建意图
|
||
intent = {
|
||
"type": intent_type,
|
||
"confidence": confidence,
|
||
"skill": self._map_intent_to_skill(intent_type),
|
||
"params": {"stock_code": stock_code} if stock_code else {}
|
||
}
|
||
|
||
return intent
|
||
|
||
except Exception as e:
|
||
logger.error(f"LLM意图识别失败: {e}")
|
||
return self._recognize_intent_with_rules(message, stock_code)
|
||
|
||
def _recognize_intent_with_rules(
|
||
self,
|
||
message: str,
|
||
stock_code: Optional[str]
|
||
) -> Dict[str, Any]:
|
||
"""使用规则识别意图(原有逻辑)"""
|
||
message_lower = message.lower()
|
||
|
||
# 行情查询
|
||
if any(keyword in message_lower for keyword in ["行情", "价格", "涨跌", "实时", "quote"]):
|
||
return {
|
||
"type": "market_data",
|
||
"skill": "market_data",
|
||
"params": {
|
||
"stock_code": stock_code,
|
||
"data_type": "quote"
|
||
}
|
||
}
|
||
|
||
# K线查询
|
||
if any(keyword in message_lower for keyword in ["k线", "kline", "走势", "图表"]):
|
||
return {
|
||
"type": "visualization",
|
||
"skill": "visualization",
|
||
"params": {
|
||
"stock_code": stock_code,
|
||
"chart_type": "candlestick"
|
||
}
|
||
}
|
||
|
||
# 技术分析
|
||
if any(keyword in message_lower for keyword in ["技术", "指标", "macd", "rsi", "kdj", "均线", "ma"]):
|
||
return {
|
||
"type": "technical_analysis",
|
||
"skill": "technical_analysis",
|
||
"params": {
|
||
"stock_code": stock_code,
|
||
"indicators": ["ma", "macd", "rsi"]
|
||
}
|
||
}
|
||
|
||
# 基本面
|
||
if any(keyword in message_lower for keyword in ["基本面", "公司", "行业", "信息"]):
|
||
return {
|
||
"type": "fundamental",
|
||
"skill": "fundamental",
|
||
"params": {
|
||
"stock_code": stock_code
|
||
}
|
||
}
|
||
|
||
# 默认:行情查询
|
||
if stock_code:
|
||
return {
|
||
"type": "market_data",
|
||
"skill": "market_data",
|
||
"params": {
|
||
"stock_code": stock_code,
|
||
"data_type": "quote"
|
||
}
|
||
}
|
||
|
||
# 无法识别
|
||
return {
|
||
"type": "unknown",
|
||
"skill": None,
|
||
"params": {}
|
||
}
|
||
|
||
def _map_intent_to_skill(self, intent_type: str) -> Optional[str]:
|
||
"""将意图类型映射到技能名称"""
|
||
mapping = {
|
||
"market_data": "market_data",
|
||
"technical_analysis": "technical_analysis",
|
||
"fundamental": "fundamental",
|
||
"visualization": "visualization"
|
||
}
|
||
return mapping.get(intent_type)
|
||
|
||
def _extract_stock_code(self, message: str) -> Optional[str]:
|
||
"""从消息中提取股票代码"""
|
||
# 匹配6位数字
|
||
pattern = r'\b\d{6}\b'
|
||
matches = re.findall(pattern, message)
|
||
|
||
if matches:
|
||
return matches[0]
|
||
|
||
# 使用股票名称数据库搜索
|
||
chinese_pattern = r'[\u4e00-\u9fa5]{2,6}'
|
||
chinese_words = re.findall(chinese_pattern, message)
|
||
|
||
for word in chinese_words:
|
||
code = search_stock_by_name(word)
|
||
if code:
|
||
logger.info(f"识别股票名称: {word} -> {code}")
|
||
return code
|
||
|
||
return None
|
||
|
||
async def _execute_intent(self, intent: Dict[str, Any], message: str) -> Dict[str, Any]:
|
||
"""执行意图对应的技能"""
|
||
if intent["type"] == "unknown":
|
||
return {
|
||
"success": False,
|
||
"error": "无法理解您的问题,请提供股票代码或明确的查询意图"
|
||
}
|
||
|
||
skill_name = intent["skill"]
|
||
params = intent["params"]
|
||
|
||
if not params.get("stock_code"):
|
||
return {
|
||
"success": False,
|
||
"error": "请提供股票代码或股票名称"
|
||
}
|
||
|
||
# 执行技能
|
||
result = await skill_manager.execute_skill(skill_name, **params)
|
||
return result
|
||
|
||
async def _generate_response(
|
||
self,
|
||
intent: Dict[str, Any],
|
||
result: Dict[str, Any],
|
||
stock_code: Optional[str]
|
||
) -> Dict[str, Any]:
|
||
"""生成响应消息(使用LLM增强)"""
|
||
if not result.get("success", True):
|
||
return {
|
||
"message": f"抱歉,{result.get('error', '处理失败')}",
|
||
"metadata": {"type": "error"}
|
||
}
|
||
|
||
data = result.get("data", result)
|
||
|
||
# 基础格式化
|
||
base_response = self._format_response_basic(intent, data)
|
||
|
||
# 如果启用LLM,添加智能分析
|
||
if self.use_llm and stock_code and intent["type"] == "technical_analysis":
|
||
try:
|
||
stock_name = get_stock_name(stock_code) or stock_code
|
||
llm_summary = llm_service.generate_analysis_summary(
|
||
stock_code, stock_name, data
|
||
)
|
||
base_response["message"] += f"\n\n【AI分析】\n{llm_summary}"
|
||
except Exception as e:
|
||
logger.error(f"LLM分析生成失败: {e}")
|
||
|
||
return base_response
|
||
|
||
def _format_response_basic(self, intent: Dict[str, Any], data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""基础响应格式化(原有逻辑)"""
|
||
if "error" in data:
|
||
return {
|
||
"message": f"查询失败:{data['error']}",
|
||
"metadata": {"type": "error"}
|
||
}
|
||
|
||
intent_type = intent["type"]
|
||
|
||
if intent_type == "market_data":
|
||
return self._format_market_data(data)
|
||
elif intent_type == "technical_analysis":
|
||
return self._format_technical(data)
|
||
elif intent_type == "fundamental":
|
||
return self._format_fundamental(data)
|
||
elif intent_type == "visualization":
|
||
return self._format_visualization(data)
|
||
else:
|
||
return {
|
||
"message": "查询完成",
|
||
"metadata": {"type": "data", "data": data}
|
||
}
|
||
|
||
def _format_market_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""格式化行情数据"""
|
||
if "kline_data" in data:
|
||
kline_data = data["kline_data"]
|
||
message = f"已获取K线数据,共{len(kline_data)}条记录"
|
||
return {
|
||
"message": message,
|
||
"metadata": {"type": "kline", "data": kline_data}
|
||
}
|
||
|
||
message = f"""
|
||
【{data.get('name', '股票')}】({data.get('ts_code', '')})
|
||
交易日期:{data.get('trade_date', '')}
|
||
最新价:{data.get('close', 0):.2f}
|
||
涨跌额:{data.get('change', 0):.2f}
|
||
涨跌幅:{data.get('pct_chg', 0):.2f}%
|
||
开盘价:{data.get('open', 0):.2f}
|
||
最高价:{data.get('high', 0):.2f}
|
||
最低价:{data.get('low', 0):.2f}
|
||
成交量:{data.get('vol', 0):.0f}手
|
||
成交额:{data.get('amount', 0):.0f}千元
|
||
""".strip()
|
||
|
||
return {
|
||
"message": message,
|
||
"metadata": {"type": "quote", "data": data}
|
||
}
|
||
|
||
def _format_technical(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""格式化技术分析"""
|
||
indicators = data.get("indicators", {})
|
||
message_parts = [f"【{data.get('stock_code', '')}】技术指标:\n"]
|
||
|
||
if "ma" in indicators:
|
||
ma = indicators["ma"]
|
||
message_parts.append(f"均线:MA5={ma.get('ma5')}, MA10={ma.get('ma10')}, MA20={ma.get('ma20')}")
|
||
|
||
if "macd" in indicators:
|
||
macd = indicators["macd"]
|
||
message_parts.append(f"MACD:DIF={macd.get('dif')}, DEA={macd.get('dea')}, MACD={macd.get('macd')}")
|
||
|
||
if "rsi" in indicators:
|
||
rsi = indicators["rsi"]
|
||
message_parts.append(f"RSI:RSI6={rsi.get('rsi6')}, RSI12={rsi.get('rsi12')}, RSI24={rsi.get('rsi24')}")
|
||
|
||
return {
|
||
"message": "\n".join(message_parts),
|
||
"metadata": {"type": "technical", "data": data}
|
||
}
|
||
|
||
def _format_fundamental(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""格式化基本面"""
|
||
message = f"""
|
||
【{data.get('name', '股票')}】基本信息
|
||
股票代码:{data.get('ts_code', '')}
|
||
所属地域:{data.get('area', '')}
|
||
所属行业:{data.get('industry', '')}
|
||
上市市场:{data.get('market', '')}
|
||
上市日期:{data.get('list_date', '')}
|
||
""".strip()
|
||
|
||
return {
|
||
"message": message,
|
||
"metadata": {"type": "fundamental", "data": data}
|
||
}
|
||
|
||
def _format_visualization(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""格式化可视化"""
|
||
return {
|
||
"message": f"已生成{data.get('stock_code', '')}的K线图",
|
||
"metadata": {"type": "chart", "data": data}
|
||
}
|
||
|
||
|
||
# 创建全局Agent实例
|
||
enhanced_agent = EnhancedStockAgent()
|