379 lines
11 KiB
Python
379 lines
11 KiB
Python
"""
|
||
AI Agent核心
|
||
基于LangChain的股票分析Agent
|
||
"""
|
||
import re
|
||
import json
|
||
from typing import Dict, Any, Optional
|
||
from app.config import get_settings
|
||
from app.agent.context import ContextManager
|
||
from app.agent.skill_manager import skill_manager
|
||
from app.skills.market_data import MarketDataSkill
|
||
from app.skills.technical_analysis import TechnicalAnalysisSkill
|
||
from app.skills.fundamental import FundamentalSkill
|
||
from app.skills.visualization import VisualizationSkill
|
||
from app.utils.logger import logger
|
||
|
||
|
||
class StockAnalysisAgent:
|
||
"""股票分析Agent"""
|
||
|
||
def __init__(self):
|
||
"""初始化Agent"""
|
||
self.context_manager = ContextManager()
|
||
self.settings = get_settings()
|
||
|
||
# 注册技能
|
||
self._register_skills()
|
||
|
||
# 初始化LLM(简化版,使用规则匹配)
|
||
# 在实际部署时,这里应该集成智谱AI GLM-4
|
||
self.use_llm = bool(self.settings.zhipuai_api_key)
|
||
|
||
logger.info("Stock Analysis Agent初始化完成")
|
||
|
||
def _register_skills(self):
|
||
"""注册所有技能"""
|
||
skill_manager.register(MarketDataSkill())
|
||
skill_manager.register(TechnicalAnalysisSkill())
|
||
skill_manager.register(FundamentalSkill())
|
||
skill_manager.register(VisualizationSkill())
|
||
logger.info("技能注册完成")
|
||
|
||
async def process_message(
|
||
self,
|
||
message: str,
|
||
session_id: str,
|
||
user_id: Optional[str] = None
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
处理用户消息
|
||
|
||
Args:
|
||
message: 用户消息
|
||
session_id: 会话ID
|
||
user_id: 用户ID
|
||
|
||
Returns:
|
||
响应结果
|
||
"""
|
||
logger.info(f"处理消息: {message[:50]}...")
|
||
|
||
# 保存用户消息
|
||
self.context_manager.add_message(session_id, "user", message)
|
||
|
||
# 意图识别和技能调用
|
||
intent = self._recognize_intent(message)
|
||
logger.info(f"识别意图: {intent}")
|
||
|
||
# 执行技能
|
||
result = await self._execute_intent(intent, message)
|
||
|
||
# 生成响应
|
||
response = self._generate_response(intent, result)
|
||
|
||
# 保存助手响应
|
||
self.context_manager.add_message(
|
||
session_id,
|
||
"assistant",
|
||
response["message"],
|
||
metadata=response.get("metadata")
|
||
)
|
||
|
||
return response
|
||
|
||
def _recognize_intent(self, message: str) -> Dict[str, Any]:
|
||
"""
|
||
识别用户意图(简化版规则匹配)
|
||
|
||
Args:
|
||
message: 用户消息
|
||
|
||
Returns:
|
||
意图字典
|
||
"""
|
||
message_lower = message.lower()
|
||
|
||
# 提取股票代码
|
||
stock_code = self._extract_stock_code(message)
|
||
|
||
# 行情查询
|
||
if any(keyword in message_lower for keyword in ["行情", "价格", "涨跌", "实时", "quote"]):
|
||
return {
|
||
"type": "market_data",
|
||
"skill": "market_data",
|
||
"params": {
|
||
"stock_code": stock_code,
|
||
"data_type": "quote"
|
||
}
|
||
}
|
||
|
||
# K线查询
|
||
if any(keyword in message_lower for keyword in ["k线", "kline", "走势", "图表"]):
|
||
return {
|
||
"type": "visualization",
|
||
"skill": "visualization",
|
||
"params": {
|
||
"stock_code": stock_code,
|
||
"chart_type": "candlestick"
|
||
}
|
||
}
|
||
|
||
# 技术分析
|
||
if any(keyword in message_lower for keyword in ["技术", "指标", "macd", "rsi", "kdj", "均线", "ma"]):
|
||
return {
|
||
"type": "technical_analysis",
|
||
"skill": "technical_analysis",
|
||
"params": {
|
||
"stock_code": stock_code,
|
||
"indicators": ["ma", "macd", "rsi"]
|
||
}
|
||
}
|
||
|
||
# 基本面
|
||
if any(keyword in message_lower for keyword in ["基本面", "公司", "行业", "信息"]):
|
||
return {
|
||
"type": "fundamental",
|
||
"skill": "fundamental",
|
||
"params": {
|
||
"stock_code": stock_code
|
||
}
|
||
}
|
||
|
||
# 默认:行情查询
|
||
if stock_code:
|
||
return {
|
||
"type": "market_data",
|
||
"skill": "market_data",
|
||
"params": {
|
||
"stock_code": stock_code,
|
||
"data_type": "quote"
|
||
}
|
||
}
|
||
|
||
# 无法识别
|
||
return {
|
||
"type": "unknown",
|
||
"skill": None,
|
||
"params": {}
|
||
}
|
||
|
||
def _extract_stock_code(self, message: str) -> Optional[str]:
|
||
"""
|
||
从消息中提取股票代码
|
||
|
||
Args:
|
||
message: 用户消息
|
||
|
||
Returns:
|
||
股票代码或None
|
||
"""
|
||
from app.utils.stock_names import search_stock_by_name
|
||
|
||
# 匹配6位数字
|
||
pattern = r'\b\d{6}\b'
|
||
matches = re.findall(pattern, message)
|
||
|
||
if matches:
|
||
return matches[0]
|
||
|
||
# 使用股票名称数据库搜索
|
||
# 提取可能的股票名称(2-6个汉字)
|
||
chinese_pattern = r'[\u4e00-\u9fa5]{2,6}'
|
||
chinese_words = re.findall(chinese_pattern, message)
|
||
|
||
for word in chinese_words:
|
||
code = search_stock_by_name(word)
|
||
if code:
|
||
logger.info(f"识别股票名称: {word} -> {code}")
|
||
return code
|
||
|
||
return None
|
||
|
||
async def _execute_intent(self, intent: Dict[str, Any], message: str) -> Dict[str, Any]:
|
||
"""
|
||
执行意图对应的技能
|
||
|
||
Args:
|
||
intent: 意图字典
|
||
message: 原始消息
|
||
|
||
Returns:
|
||
执行结果
|
||
"""
|
||
if intent["type"] == "unknown":
|
||
return {
|
||
"success": False,
|
||
"error": "无法理解您的问题,请提供股票代码或明确的查询意图"
|
||
}
|
||
|
||
skill_name = intent["skill"]
|
||
params = intent["params"]
|
||
|
||
if not params.get("stock_code"):
|
||
return {
|
||
"success": False,
|
||
"error": "请提供股票代码(6位数字)"
|
||
}
|
||
|
||
# 执行技能
|
||
result = await skill_manager.execute_skill(skill_name, **params)
|
||
|
||
return result
|
||
|
||
def _generate_response(self, intent: Dict[str, Any], result: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
生成响应消息
|
||
|
||
Args:
|
||
intent: 意图
|
||
result: 执行结果
|
||
|
||
Returns:
|
||
响应字典
|
||
"""
|
||
if not result.get("success", True):
|
||
return {
|
||
"message": f"抱歉,{result.get('error', '处理失败')}",
|
||
"metadata": {
|
||
"type": "error"
|
||
}
|
||
}
|
||
|
||
data = result.get("data", result)
|
||
|
||
# 根据意图类型生成不同响应
|
||
if intent["type"] == "market_data":
|
||
return self._format_market_data_response(data)
|
||
elif intent["type"] == "technical_analysis":
|
||
return self._format_technical_response(data)
|
||
elif intent["type"] == "fundamental":
|
||
return self._format_fundamental_response(data)
|
||
elif intent["type"] == "visualization":
|
||
return self._format_visualization_response(data)
|
||
else:
|
||
return {
|
||
"message": "查询完成",
|
||
"metadata": {
|
||
"type": "data",
|
||
"data": data
|
||
}
|
||
}
|
||
|
||
def _format_market_data_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""格式化行情数据响应"""
|
||
if "error" in data:
|
||
return {
|
||
"message": f"查询失败:{data['error']}",
|
||
"metadata": {"type": "error"}
|
||
}
|
||
|
||
if "kline_data" in data:
|
||
kline_data = data["kline_data"]
|
||
message = f"已获取K线数据,共{len(kline_data)}条记录"
|
||
return {
|
||
"message": message,
|
||
"metadata": {
|
||
"type": "kline",
|
||
"data": kline_data
|
||
}
|
||
}
|
||
|
||
# 实时行情
|
||
message = f"""
|
||
【{data.get('name', '股票')}】({data.get('ts_code', '')})
|
||
交易日期:{data.get('trade_date', '')}
|
||
最新价:{data.get('close', 0):.2f}
|
||
涨跌额:{data.get('change', 0):.2f}
|
||
涨跌幅:{data.get('pct_chg', 0):.2f}%
|
||
开盘价:{data.get('open', 0):.2f}
|
||
最高价:{data.get('high', 0):.2f}
|
||
最低价:{data.get('low', 0):.2f}
|
||
成交量:{data.get('vol', 0):.0f}手
|
||
成交额:{data.get('amount', 0):.0f}千元
|
||
""".strip()
|
||
|
||
return {
|
||
"message": message,
|
||
"metadata": {
|
||
"type": "quote",
|
||
"data": data
|
||
}
|
||
}
|
||
|
||
def _format_technical_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""格式化技术分析响应"""
|
||
if "error" in data:
|
||
return {
|
||
"message": f"分析失败:{data['error']}",
|
||
"metadata": {"type": "error"}
|
||
}
|
||
|
||
indicators = data.get("indicators", {})
|
||
message_parts = [f"【{data.get('stock_code', '')}】技术指标:\n"]
|
||
|
||
if "ma" in indicators:
|
||
ma = indicators["ma"]
|
||
message_parts.append(f"均线:MA5={ma.get('ma5')}, MA10={ma.get('ma10')}, MA20={ma.get('ma20')}")
|
||
|
||
if "macd" in indicators:
|
||
macd = indicators["macd"]
|
||
message_parts.append(f"MACD:DIF={macd.get('dif')}, DEA={macd.get('dea')}, MACD={macd.get('macd')}")
|
||
|
||
if "rsi" in indicators:
|
||
rsi = indicators["rsi"]
|
||
message_parts.append(f"RSI:RSI6={rsi.get('rsi6')}, RSI12={rsi.get('rsi12')}, RSI24={rsi.get('rsi24')}")
|
||
|
||
return {
|
||
"message": "\n".join(message_parts),
|
||
"metadata": {
|
||
"type": "technical",
|
||
"data": data
|
||
}
|
||
}
|
||
|
||
def _format_fundamental_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""格式化基本面响应"""
|
||
if "error" in data:
|
||
return {
|
||
"message": f"查询失败:{data['error']}",
|
||
"metadata": {"type": "error"}
|
||
}
|
||
|
||
message = f"""
|
||
【{data.get('name', '股票')}】基本信息
|
||
股票代码:{data.get('ts_code', '')}
|
||
所属地域:{data.get('area', '')}
|
||
所属行业:{data.get('industry', '')}
|
||
上市市场:{data.get('market', '')}
|
||
上市日期:{data.get('list_date', '')}
|
||
""".strip()
|
||
|
||
return {
|
||
"message": message,
|
||
"metadata": {
|
||
"type": "fundamental",
|
||
"data": data
|
||
}
|
||
}
|
||
|
||
def _format_visualization_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""格式化可视化响应"""
|
||
if "error" in data:
|
||
return {
|
||
"message": f"生成图表失败:{data['error']}",
|
||
"metadata": {"type": "error"}
|
||
}
|
||
|
||
return {
|
||
"message": f"已生成{data.get('stock_code', '')}的K线图",
|
||
"metadata": {
|
||
"type": "chart",
|
||
"data": data
|
||
}
|
||
}
|
||
|
||
|
||
# 创建全局Agent实例
|
||
stock_agent = StockAnalysisAgent()
|