stock-ai-agent/backend/app/agent/core.py
2026-02-03 10:08:15 +08:00

379 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI Agent核心
基于LangChain的股票分析Agent
"""
import re
import json
from typing import Dict, Any, Optional
from app.config import get_settings
from app.agent.context import ContextManager
from app.agent.skill_manager import skill_manager
from app.skills.market_data import MarketDataSkill
from app.skills.technical_analysis import TechnicalAnalysisSkill
from app.skills.fundamental import FundamentalSkill
from app.skills.visualization import VisualizationSkill
from app.utils.logger import logger
class StockAnalysisAgent:
"""股票分析Agent"""
def __init__(self):
"""初始化Agent"""
self.context_manager = ContextManager()
self.settings = get_settings()
# 注册技能
self._register_skills()
# 初始化LLM简化版使用规则匹配
# 在实际部署时这里应该集成智谱AI GLM-4
self.use_llm = bool(self.settings.zhipuai_api_key)
logger.info("Stock Analysis Agent初始化完成")
def _register_skills(self):
"""注册所有技能"""
skill_manager.register(MarketDataSkill())
skill_manager.register(TechnicalAnalysisSkill())
skill_manager.register(FundamentalSkill())
skill_manager.register(VisualizationSkill())
logger.info("技能注册完成")
async def process_message(
self,
message: str,
session_id: str,
user_id: Optional[str] = None
) -> Dict[str, Any]:
"""
处理用户消息
Args:
message: 用户消息
session_id: 会话ID
user_id: 用户ID
Returns:
响应结果
"""
logger.info(f"处理消息: {message[:50]}...")
# 保存用户消息
self.context_manager.add_message(session_id, "user", message)
# 意图识别和技能调用
intent = self._recognize_intent(message)
logger.info(f"识别意图: {intent}")
# 执行技能
result = await self._execute_intent(intent, message)
# 生成响应
response = self._generate_response(intent, result)
# 保存助手响应
self.context_manager.add_message(
session_id,
"assistant",
response["message"],
metadata=response.get("metadata")
)
return response
def _recognize_intent(self, message: str) -> Dict[str, Any]:
"""
识别用户意图(简化版规则匹配)
Args:
message: 用户消息
Returns:
意图字典
"""
message_lower = message.lower()
# 提取股票代码
stock_code = self._extract_stock_code(message)
# 行情查询
if any(keyword in message_lower for keyword in ["行情", "价格", "涨跌", "实时", "quote"]):
return {
"type": "market_data",
"skill": "market_data",
"params": {
"stock_code": stock_code,
"data_type": "quote"
}
}
# K线查询
if any(keyword in message_lower for keyword in ["k线", "kline", "走势", "图表"]):
return {
"type": "visualization",
"skill": "visualization",
"params": {
"stock_code": stock_code,
"chart_type": "candlestick"
}
}
# 技术分析
if any(keyword in message_lower for keyword in ["技术", "指标", "macd", "rsi", "kdj", "均线", "ma"]):
return {
"type": "technical_analysis",
"skill": "technical_analysis",
"params": {
"stock_code": stock_code,
"indicators": ["ma", "macd", "rsi"]
}
}
# 基本面
if any(keyword in message_lower for keyword in ["基本面", "公司", "行业", "信息"]):
return {
"type": "fundamental",
"skill": "fundamental",
"params": {
"stock_code": stock_code
}
}
# 默认:行情查询
if stock_code:
return {
"type": "market_data",
"skill": "market_data",
"params": {
"stock_code": stock_code,
"data_type": "quote"
}
}
# 无法识别
return {
"type": "unknown",
"skill": None,
"params": {}
}
def _extract_stock_code(self, message: str) -> Optional[str]:
"""
从消息中提取股票代码
Args:
message: 用户消息
Returns:
股票代码或None
"""
from app.utils.stock_names import search_stock_by_name
# 匹配6位数字
pattern = r'\b\d{6}\b'
matches = re.findall(pattern, message)
if matches:
return matches[0]
# 使用股票名称数据库搜索
# 提取可能的股票名称2-6个汉字
chinese_pattern = r'[\u4e00-\u9fa5]{2,6}'
chinese_words = re.findall(chinese_pattern, message)
for word in chinese_words:
code = search_stock_by_name(word)
if code:
logger.info(f"识别股票名称: {word} -> {code}")
return code
return None
async def _execute_intent(self, intent: Dict[str, Any], message: str) -> Dict[str, Any]:
"""
执行意图对应的技能
Args:
intent: 意图字典
message: 原始消息
Returns:
执行结果
"""
if intent["type"] == "unknown":
return {
"success": False,
"error": "无法理解您的问题,请提供股票代码或明确的查询意图"
}
skill_name = intent["skill"]
params = intent["params"]
if not params.get("stock_code"):
return {
"success": False,
"error": "请提供股票代码6位数字"
}
# 执行技能
result = await skill_manager.execute_skill(skill_name, **params)
return result
def _generate_response(self, intent: Dict[str, Any], result: Dict[str, Any]) -> Dict[str, Any]:
"""
生成响应消息
Args:
intent: 意图
result: 执行结果
Returns:
响应字典
"""
if not result.get("success", True):
return {
"message": f"抱歉,{result.get('error', '处理失败')}",
"metadata": {
"type": "error"
}
}
data = result.get("data", result)
# 根据意图类型生成不同响应
if intent["type"] == "market_data":
return self._format_market_data_response(data)
elif intent["type"] == "technical_analysis":
return self._format_technical_response(data)
elif intent["type"] == "fundamental":
return self._format_fundamental_response(data)
elif intent["type"] == "visualization":
return self._format_visualization_response(data)
else:
return {
"message": "查询完成",
"metadata": {
"type": "data",
"data": data
}
}
def _format_market_data_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""格式化行情数据响应"""
if "error" in data:
return {
"message": f"查询失败:{data['error']}",
"metadata": {"type": "error"}
}
if "kline_data" in data:
kline_data = data["kline_data"]
message = f"已获取K线数据{len(kline_data)}条记录"
return {
"message": message,
"metadata": {
"type": "kline",
"data": kline_data
}
}
# 实时行情
message = f"""
{data.get('name', '股票')}】({data.get('ts_code', '')})
交易日期:{data.get('trade_date', '')}
最新价:{data.get('close', 0):.2f}
涨跌额:{data.get('change', 0):.2f}
涨跌幅:{data.get('pct_chg', 0):.2f}%
开盘价:{data.get('open', 0):.2f}
最高价:{data.get('high', 0):.2f}
最低价:{data.get('low', 0):.2f}
成交量:{data.get('vol', 0):.0f}
成交额:{data.get('amount', 0):.0f}千元
""".strip()
return {
"message": message,
"metadata": {
"type": "quote",
"data": data
}
}
def _format_technical_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""格式化技术分析响应"""
if "error" in data:
return {
"message": f"分析失败:{data['error']}",
"metadata": {"type": "error"}
}
indicators = data.get("indicators", {})
message_parts = [f"{data.get('stock_code', '')}】技术指标:\n"]
if "ma" in indicators:
ma = indicators["ma"]
message_parts.append(f"均线MA5={ma.get('ma5')}, MA10={ma.get('ma10')}, MA20={ma.get('ma20')}")
if "macd" in indicators:
macd = indicators["macd"]
message_parts.append(f"MACDDIF={macd.get('dif')}, DEA={macd.get('dea')}, MACD={macd.get('macd')}")
if "rsi" in indicators:
rsi = indicators["rsi"]
message_parts.append(f"RSIRSI6={rsi.get('rsi6')}, RSI12={rsi.get('rsi12')}, RSI24={rsi.get('rsi24')}")
return {
"message": "\n".join(message_parts),
"metadata": {
"type": "technical",
"data": data
}
}
def _format_fundamental_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""格式化基本面响应"""
if "error" in data:
return {
"message": f"查询失败:{data['error']}",
"metadata": {"type": "error"}
}
message = f"""
{data.get('name', '股票')}】基本信息
股票代码:{data.get('ts_code', '')}
所属地域:{data.get('area', '')}
所属行业:{data.get('industry', '')}
上市市场:{data.get('market', '')}
上市日期:{data.get('list_date', '')}
""".strip()
return {
"message": message,
"metadata": {
"type": "fundamental",
"data": data
}
}
def _format_visualization_response(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""格式化可视化响应"""
if "error" in data:
return {
"message": f"生成图表失败:{data['error']}",
"metadata": {"type": "error"}
}
return {
"message": f"已生成{data.get('stock_code', '')}的K线图",
"metadata": {
"type": "chart",
"data": data
}
}
# 创建全局Agent实例
stock_agent = StockAnalysisAgent()