update
This commit is contained in:
parent
c5c88bd73e
commit
f1656e8189
@ -35,7 +35,8 @@ class ContextManager:
|
||||
for msg in messages:
|
||||
context.append({
|
||||
"role": msg.role,
|
||||
"content": msg.content
|
||||
"content": msg.content,
|
||||
"metadata": msg.metadata if hasattr(msg, 'metadata') else {}
|
||||
})
|
||||
|
||||
return context
|
||||
@ -91,3 +92,104 @@ class ContextManager:
|
||||
formatted.append(f"{role}: {msg['content']}")
|
||||
|
||||
return "\n".join(formatted)
|
||||
|
||||
def extract_context_info(self, session_id: str) -> Dict:
|
||||
"""
|
||||
提取上下文信息
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
ContextInfo: {
|
||||
'last_stock': str | None, # 上次讨论的股票
|
||||
'last_topic': str | None, # 上次的话题
|
||||
'user_preferences': dict # 用户偏好
|
||||
}
|
||||
"""
|
||||
history = self.get_context(session_id)
|
||||
|
||||
return {
|
||||
'last_stock': self._extract_last_stock(history),
|
||||
'last_topic': self._extract_last_topic(history),
|
||||
'user_preferences': self._analyze_user_preferences(history)
|
||||
}
|
||||
|
||||
def _extract_last_stock(self, history: List[Dict]) -> Optional[str]:
|
||||
"""
|
||||
从历史对话中提取最后讨论的股票
|
||||
|
||||
Args:
|
||||
history: 对话历史
|
||||
|
||||
Returns:
|
||||
股票代码或None
|
||||
"""
|
||||
# 从后往前查找
|
||||
for msg in reversed(history):
|
||||
if msg['role'] == 'assistant':
|
||||
metadata = msg.get('metadata', {})
|
||||
if isinstance(metadata, dict):
|
||||
# 尝试从不同位置提取股票代码
|
||||
if 'data' in metadata:
|
||||
data = metadata['data']
|
||||
if isinstance(data, dict):
|
||||
if 'stock_code' in data:
|
||||
return data['stock_code']
|
||||
if 'ts_code' in data:
|
||||
return data['ts_code']
|
||||
|
||||
# 尝试从intent中提取
|
||||
if 'intent' in metadata:
|
||||
intent = metadata['intent']
|
||||
if isinstance(intent, dict) and 'target' in intent:
|
||||
target = intent['target']
|
||||
if isinstance(target, dict) and 'stock_code' in target:
|
||||
return target['stock_code']
|
||||
|
||||
return None
|
||||
|
||||
def _extract_last_topic(self, history: List[Dict]) -> Optional[str]:
|
||||
"""
|
||||
从历史对话中提取最后的话题
|
||||
|
||||
Args:
|
||||
history: 对话历史
|
||||
|
||||
Returns:
|
||||
话题或None
|
||||
"""
|
||||
if not history:
|
||||
return None
|
||||
|
||||
# 获取最后一条用户消息
|
||||
for msg in reversed(history):
|
||||
if msg['role'] == 'user':
|
||||
content = msg['content']
|
||||
# 简单提取话题(前50个字符)
|
||||
return content[:50] if len(content) > 50 else content
|
||||
|
||||
return None
|
||||
|
||||
def _analyze_user_preferences(self, history: List[Dict]) -> Dict:
|
||||
"""
|
||||
分析用户偏好
|
||||
|
||||
Args:
|
||||
history: 对话历史
|
||||
|
||||
Returns:
|
||||
用户偏好字典
|
||||
"""
|
||||
preferences = {
|
||||
'preferred_style': 'casual',
|
||||
'typical_time_scope': 'short_term',
|
||||
'frequent_dimensions': []
|
||||
}
|
||||
|
||||
# 简单的偏好分析(可以后续扩展)
|
||||
if len(history) > 5:
|
||||
# 如果对话较多,可能是专业用户
|
||||
preferences['preferred_style'] = 'professional'
|
||||
|
||||
return preferences
|
||||
|
||||
305
backend/app/agent/question_analyzer.py
Normal file
305
backend/app/agent/question_analyzer.py
Normal file
@ -0,0 +1,305 @@
|
||||
"""
|
||||
问题分析器 - 使用LLM深度理解用户意图
|
||||
"""
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, List
|
||||
from app.services.llm_service import llm_service
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class QuestionAnalyzer:
|
||||
"""智能问题分析器 - 使用LLM深度理解用户意图"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化问题分析器"""
|
||||
self.use_llm = llm_service.client is not None
|
||||
if not self.use_llm:
|
||||
logger.warning("LLM未配置,QuestionAnalyzer将使用降级模式")
|
||||
|
||||
async def analyze_question(
|
||||
self,
|
||||
question: str,
|
||||
context: List[Dict],
|
||||
session_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
深度分析用户问题
|
||||
|
||||
Args:
|
||||
question: 用户问题
|
||||
context: 对话历史上下文
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
QuestionIntent: {
|
||||
'type': 'stock_analysis' | 'market_overview' | 'knowledge' | 'chat',
|
||||
'target': {
|
||||
'stock_code': str,
|
||||
'stock_name': str,
|
||||
'market': 'A股' | '美股'
|
||||
},
|
||||
'dimensions': {
|
||||
'price_trend': bool, # 价格走势
|
||||
'technical': bool, # 技术指标
|
||||
'fundamental': bool, # 基本面
|
||||
'valuation': bool, # 估值
|
||||
'money_flow': bool, # 资金流向
|
||||
'risk': bool # 风险分析
|
||||
},
|
||||
'time_scope': {
|
||||
'short_term': bool, # 短期(1-2周)
|
||||
'medium_term': bool, # 中期(1-3月)
|
||||
'long_term': bool # 长期(半年+)
|
||||
},
|
||||
'analysis_depth': 'quick' | 'standard' | 'deep',
|
||||
'specific_concerns': List[str], # 特定关注点
|
||||
'context_references': {
|
||||
'refers_to_previous': bool,
|
||||
'comparison_target': str | None
|
||||
},
|
||||
'user_style': {
|
||||
'tone': 'professional' | 'casual',
|
||||
'detail_level': 'brief' | 'detailed'
|
||||
}
|
||||
}
|
||||
"""
|
||||
if not self.use_llm:
|
||||
# 降级模式:返回基本的意图分析
|
||||
return self._fallback_analysis(question)
|
||||
|
||||
# 构建上下文字符串
|
||||
context_str = self._format_context(context)
|
||||
|
||||
# 构建LLM prompt
|
||||
prompt = self._build_analysis_prompt(question, context_str)
|
||||
|
||||
try:
|
||||
# 异步调用LLM
|
||||
result = await self._call_llm_async(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.3,
|
||||
max_tokens=800
|
||||
)
|
||||
|
||||
if not result:
|
||||
logger.warning("LLM返回空结果,使用降级模式")
|
||||
return self._fallback_analysis(question)
|
||||
|
||||
# 清理和解析JSON
|
||||
intent = self._parse_llm_response(result)
|
||||
|
||||
if intent:
|
||||
logger.info(f"问题分析成功: type={intent.get('type')}, dimensions={intent.get('dimensions')}")
|
||||
return intent
|
||||
else:
|
||||
logger.warning("JSON解析失败,使用降级模式")
|
||||
return self._fallback_analysis(question)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"问题分析失败: {e}")
|
||||
return self._fallback_analysis(question)
|
||||
|
||||
async def _call_llm_async(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
temperature: float = 0.3,
|
||||
max_tokens: int = 800
|
||||
) -> Optional[str]:
|
||||
"""异步调用LLM"""
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(
|
||||
None,
|
||||
lambda: llm_service.chat(messages, temperature, max_tokens)
|
||||
)
|
||||
|
||||
def _format_context(self, context: List[Dict]) -> str:
|
||||
"""格式化对话历史上下文"""
|
||||
if not context:
|
||||
return ""
|
||||
|
||||
context_str = "\n\n【对话历史】\n"
|
||||
# 只取最近4条消息
|
||||
for msg in context[-4:]:
|
||||
role = "用户" if msg["role"] == "user" else "助手"
|
||||
content = msg['content'][:100] # 限制长度
|
||||
context_str += f"{role}: {content}\n"
|
||||
|
||||
return context_str
|
||||
|
||||
def _build_analysis_prompt(self, question: str, context_str: str) -> str:
|
||||
"""构建问题分析的LLM prompt"""
|
||||
prompt = f"""你是一个专业的金融问题分析专家。请深度分析用户的问题,提取结构化信息。
|
||||
|
||||
{context_str}
|
||||
|
||||
【当前问题】
|
||||
用户: {question}
|
||||
|
||||
请分析以下维度:
|
||||
|
||||
1. **问题类型**
|
||||
- stock_analysis: 针对特定股票的分析(如"贵州茅台怎么样"、"分析比亚迪"、"AAPL走势")
|
||||
- market_overview: 市场整体分析(如"最近有什么投资机会"、"现在适合买股票吗")
|
||||
- knowledge: 金融知识问答(如"什么是MACD"、"如何看K线图")
|
||||
- chat: 一般对话(如"你好"、"在吗")
|
||||
|
||||
2. **用户关注维度**(如果是stock_analysis)
|
||||
分析用户想了解哪些方面:
|
||||
- price_trend: 价格走势、涨跌情况、最新价格
|
||||
- technical: 技术指标(MACD、RSI、均线、KDJ等)
|
||||
- fundamental: 基本面(公司业务、行业地位、财务状况)
|
||||
- valuation: 估值水平(PE、PB、市值、估值是否合理)
|
||||
- money_flow: 资金流向、主力动向、大单流入流出
|
||||
- risk: 风险分析、风险提示、投资风险
|
||||
|
||||
3. **时间范围**
|
||||
- short_term: 短期(1-2周)- 如"短期走势"、"近期表现"
|
||||
- medium_term: 中期(1-3月)- 如"中期趋势"、"未来一个月"
|
||||
- long_term: 长期(半年以上)- 如"长期投资"、"适合长期持有吗"
|
||||
|
||||
4. **分析深度**
|
||||
- quick: 快速查看(只需要基本信息,如"价格多少")
|
||||
- standard: 标准分析(常规分析,如"怎么样"、"分析一下")
|
||||
- deep: 深度分析(全面详细,如"全面分析"、"深度研究")
|
||||
|
||||
5. **特定关注点**
|
||||
提取用户明确提到的关注点,如:
|
||||
- "支撑位在哪"
|
||||
- "盈利能力如何"
|
||||
- "适合长期持有吗"
|
||||
- "有没有金叉"
|
||||
|
||||
6. **上下文引用**
|
||||
- 是否引用了之前的对话("这只股票"、"它"、"那技术面呢")
|
||||
- 是否要求对比分析("和上次相比"、"对比一下")
|
||||
|
||||
7. **用户风格**
|
||||
- tone: professional(专业,使用专业术语)/ casual(随意,通俗易懂)
|
||||
- detail_level: brief(简洁,简短回答)/ detailed(详细,详细分析)
|
||||
|
||||
请以JSON格式返回分析结果:
|
||||
{{
|
||||
"type": "问题类型",
|
||||
"target": {{
|
||||
"stock_code": "股票代码(如有,只返回纯数字代码,如600519或002594,不要包含市场标识)",
|
||||
"stock_name": "股票名称(如有,只返回公司名称,如贵州茅台或比亚迪)",
|
||||
"market": "A股/美股"
|
||||
}},
|
||||
"dimensions": {{
|
||||
"price_trend": true/false,
|
||||
"technical": true/false,
|
||||
"fundamental": true/false,
|
||||
"valuation": true/false,
|
||||
"money_flow": true/false,
|
||||
"risk": true/false
|
||||
}},
|
||||
"time_scope": {{
|
||||
"short_term": true/false,
|
||||
"medium_term": true/false,
|
||||
"long_term": true/false
|
||||
}},
|
||||
"analysis_depth": "quick/standard/deep",
|
||||
"specific_concerns": ["关注点1", "关注点2"],
|
||||
"context_references": {{
|
||||
"refers_to_previous": true/false,
|
||||
"comparison_target": "对比目标(如有)"
|
||||
}},
|
||||
"user_style": {{
|
||||
"tone": "professional/casual",
|
||||
"detail_level": "brief/detailed"
|
||||
}}
|
||||
}}
|
||||
|
||||
只返回JSON,不要有任何其他内容。"""
|
||||
|
||||
return prompt
|
||||
|
||||
def _parse_llm_response(self, response: str) -> Optional[Dict[str, Any]]:
|
||||
"""解析LLM返回的JSON响应"""
|
||||
try:
|
||||
# 清理结果,移除可能的markdown代码块标记
|
||||
result = response.strip()
|
||||
if result.startswith("```json"):
|
||||
result = result[7:]
|
||||
if result.startswith("```"):
|
||||
result = result[3:]
|
||||
if result.endswith("```"):
|
||||
result = result[:-3]
|
||||
result = result.strip()
|
||||
|
||||
# 检查是否为空
|
||||
if not result:
|
||||
return None
|
||||
|
||||
# 解析JSON
|
||||
intent = json.loads(result)
|
||||
return intent
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON解析失败: {e}, 原始响应: {response[:200]}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"解析LLM响应失败: {e}")
|
||||
return None
|
||||
|
||||
def _fallback_analysis(self, question: str) -> Dict[str, Any]:
|
||||
"""降级模式:基于规则的简单分析"""
|
||||
question_lower = question.lower()
|
||||
|
||||
# 简单的关键词匹配
|
||||
is_stock_query = any(kw in question for kw in [
|
||||
"股票", "分析", "怎么样", "如何", "走势", "价格", "涨", "跌"
|
||||
])
|
||||
|
||||
if is_stock_query:
|
||||
# 尝试提取股票名称(简单规则)
|
||||
return {
|
||||
'type': 'stock_analysis',
|
||||
'target': {
|
||||
'stock_code': '',
|
||||
'stock_name': '',
|
||||
'market': 'A股'
|
||||
},
|
||||
'dimensions': {
|
||||
'price_trend': True,
|
||||
'technical': True,
|
||||
'fundamental': True,
|
||||
'valuation': False,
|
||||
'money_flow': False,
|
||||
'risk': False
|
||||
},
|
||||
'time_scope': {
|
||||
'short_term': True,
|
||||
'medium_term': True,
|
||||
'long_term': False
|
||||
},
|
||||
'analysis_depth': 'standard',
|
||||
'specific_concerns': [],
|
||||
'context_references': {
|
||||
'refers_to_previous': False,
|
||||
'comparison_target': None
|
||||
},
|
||||
'user_style': {
|
||||
'tone': 'casual',
|
||||
'detail_level': 'detailed'
|
||||
}
|
||||
}
|
||||
else:
|
||||
# 默认为一般对话
|
||||
return {
|
||||
'type': 'chat',
|
||||
'target': {},
|
||||
'dimensions': {},
|
||||
'time_scope': {},
|
||||
'analysis_depth': 'quick',
|
||||
'specific_concerns': [],
|
||||
'context_references': {
|
||||
'refers_to_previous': False,
|
||||
'comparison_target': None
|
||||
},
|
||||
'user_style': {
|
||||
'tone': 'casual',
|
||||
'detail_level': 'brief'
|
||||
}
|
||||
}
|
||||
@ -2,7 +2,8 @@
|
||||
技能管理器
|
||||
管理所有技能的注册、发现和调用
|
||||
"""
|
||||
from typing import Dict, Optional, List, Type
|
||||
import asyncio
|
||||
from typing import Dict, Optional, List, Type, Any
|
||||
from app.skills.base import BaseSkill
|
||||
from app.utils.logger import logger
|
||||
|
||||
@ -174,6 +175,158 @@ class SkillManager:
|
||||
"""
|
||||
return [skill.get_info() for skill in self._skills.values()]
|
||||
|
||||
async def execute_plan(
|
||||
self,
|
||||
plan: Dict[str, Any],
|
||||
stock_code: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
执行技能规划
|
||||
|
||||
Args:
|
||||
plan: 技能执行计划(来自SkillPlanner)
|
||||
stock_code: 股票代码
|
||||
|
||||
Returns:
|
||||
{
|
||||
'results': {
|
||||
'market_data': {...},
|
||||
'technical_analysis': {...},
|
||||
...
|
||||
},
|
||||
'execution_time': float,
|
||||
'errors': List[str]
|
||||
}
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
skills = plan.get('skills', [])
|
||||
strategy = plan.get('execution_strategy', 'parallel')
|
||||
|
||||
logger.info(f"开始执行技能规划: {len(skills)}个技能, 策略: {strategy}")
|
||||
|
||||
if strategy == 'parallel':
|
||||
results = await self._execute_parallel(skills, stock_code)
|
||||
else:
|
||||
results = await self._execute_sequential(skills, stock_code)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(f"技能规划执行完成,耗时: {execution_time:.2f}秒")
|
||||
|
||||
return {
|
||||
'results': results['results'],
|
||||
'execution_time': execution_time,
|
||||
'errors': results['errors']
|
||||
}
|
||||
|
||||
async def _execute_parallel(
|
||||
self,
|
||||
skills: List[Dict[str, Any]],
|
||||
stock_code: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
并行执行技能(按优先级分组)
|
||||
|
||||
Args:
|
||||
skills: 技能列表
|
||||
stock_code: 股票代码
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
# 按优先级分组
|
||||
priority_groups = {}
|
||||
for skill_info in skills:
|
||||
priority = skill_info['priority']
|
||||
if priority not in priority_groups:
|
||||
priority_groups[priority] = []
|
||||
priority_groups[priority].append(skill_info)
|
||||
|
||||
all_results = {}
|
||||
all_errors = []
|
||||
|
||||
# 按优先级顺序执行
|
||||
for priority in sorted(priority_groups.keys()):
|
||||
skill_group = priority_groups[priority]
|
||||
logger.info(f"执行优先级 {priority} 的技能: {[s['name'] for s in skill_group]}")
|
||||
|
||||
# 同一优先级的技能并行执行
|
||||
tasks = []
|
||||
for skill_info in skill_group:
|
||||
params = skill_info['params'].copy()
|
||||
params['stock_code'] = stock_code
|
||||
task = self.execute_skill(skill_info['name'], **params)
|
||||
tasks.append((skill_info['name'], task))
|
||||
|
||||
# 等待所有任务完成
|
||||
results = await asyncio.gather(*[task for _, task in tasks], return_exceptions=True)
|
||||
|
||||
# 处理结果
|
||||
for (skill_name, _), result in zip(tasks, results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"技能执行失败: {skill_name}, {result}")
|
||||
all_errors.append(f"{skill_name}: {str(result)}")
|
||||
all_results[skill_name] = {'error': str(result)}
|
||||
elif result.get('success'):
|
||||
all_results[skill_name] = result.get('data', {})
|
||||
else:
|
||||
error_msg = result.get('error', '未知错误')
|
||||
logger.error(f"技能执行失败: {skill_name}, {error_msg}")
|
||||
all_errors.append(f"{skill_name}: {error_msg}")
|
||||
all_results[skill_name] = {'error': error_msg}
|
||||
|
||||
return {
|
||||
'results': all_results,
|
||||
'errors': all_errors
|
||||
}
|
||||
|
||||
async def _execute_sequential(
|
||||
self,
|
||||
skills: List[Dict[str, Any]],
|
||||
stock_code: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
串行执行技能
|
||||
|
||||
Args:
|
||||
skills: 技能列表
|
||||
stock_code: 股票代码
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
all_results = {}
|
||||
all_errors = []
|
||||
|
||||
for skill_info in skills:
|
||||
skill_name = skill_info['name']
|
||||
params = skill_info['params'].copy()
|
||||
params['stock_code'] = stock_code
|
||||
|
||||
logger.info(f"执行技能: {skill_name}")
|
||||
|
||||
try:
|
||||
result = await self.execute_skill(skill_name, **params)
|
||||
|
||||
if result.get('success'):
|
||||
all_results[skill_name] = result.get('data', {})
|
||||
else:
|
||||
error_msg = result.get('error', '未知错误')
|
||||
logger.error(f"技能执行失败: {skill_name}, {error_msg}")
|
||||
all_errors.append(f"{skill_name}: {error_msg}")
|
||||
all_results[skill_name] = {'error': error_msg}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"技能执行异常: {skill_name}, {e}")
|
||||
all_errors.append(f"{skill_name}: {str(e)}")
|
||||
all_results[skill_name] = {'error': str(e)}
|
||||
|
||||
return {
|
||||
'results': all_results,
|
||||
'errors': all_errors
|
||||
}
|
||||
|
||||
|
||||
# 创建全局技能管理器实例
|
||||
skill_manager = SkillManager()
|
||||
|
||||
256
backend/app/agent/skill_planner.py
Normal file
256
backend/app/agent/skill_planner.py
Normal file
@ -0,0 +1,256 @@
|
||||
"""
|
||||
技能规划器 - 根据用户意图智能选择技能组合
|
||||
"""
|
||||
from typing import Dict, Any, List, Set
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class SkillPlanner:
|
||||
"""智能技能规划器 - 根据问题意图动态选择技能"""
|
||||
|
||||
# 维度到技能的映射
|
||||
DIMENSION_SKILL_MAP = {
|
||||
'price_trend': {
|
||||
'required': ['market_data'],
|
||||
'optional': []
|
||||
},
|
||||
'technical': {
|
||||
'required': ['market_data', 'technical_analysis'],
|
||||
'optional': ['visualization']
|
||||
},
|
||||
'fundamental': {
|
||||
'required': ['fundamental'],
|
||||
'optional': []
|
||||
},
|
||||
'valuation': {
|
||||
'required': ['advanced_data'],
|
||||
'optional': []
|
||||
},
|
||||
'money_flow': {
|
||||
'required': ['advanced_data'],
|
||||
'optional': []
|
||||
},
|
||||
'risk': {
|
||||
'required': ['technical_analysis', 'advanced_data'],
|
||||
'optional': []
|
||||
}
|
||||
}
|
||||
|
||||
# 技能依赖关系
|
||||
SKILL_DEPENDENCIES = {
|
||||
'technical_analysis': ['market_data'], # 技术分析依赖行情数据
|
||||
'visualization': ['market_data'], # 可视化依赖行情数据
|
||||
}
|
||||
|
||||
# 技能优先级(数字越小优先级越高)
|
||||
SKILL_PRIORITY = {
|
||||
'market_data': 1, # 最高优先级
|
||||
'fundamental': 1,
|
||||
'technical_analysis': 2,
|
||||
'advanced_data': 2,
|
||||
'visualization': 3, # 最低优先级
|
||||
'us_stock_analysis': 1
|
||||
}
|
||||
|
||||
# 分析深度策略
|
||||
DEPTH_STRATEGY = {
|
||||
'quick': {
|
||||
'max_skills': 2,
|
||||
'include_optional': False,
|
||||
'use_cache': True
|
||||
},
|
||||
'standard': {
|
||||
'max_skills': 4,
|
||||
'include_optional': True,
|
||||
'use_cache': True
|
||||
},
|
||||
'deep': {
|
||||
'max_skills': None, # 无限制
|
||||
'include_optional': True,
|
||||
'use_cache': False
|
||||
}
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""初始化技能规划器"""
|
||||
logger.info("技能规划器初始化")
|
||||
|
||||
def plan_skills(self, intent: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
根据意图规划技能执行
|
||||
|
||||
Args:
|
||||
intent: 问题意图(来自QuestionAnalyzer)
|
||||
|
||||
Returns:
|
||||
SkillExecutionPlan: {
|
||||
'skills': [
|
||||
{
|
||||
'name': 'market_data',
|
||||
'params': {...},
|
||||
'priority': 1,
|
||||
'required': True,
|
||||
'reason': '用户关注价格走势'
|
||||
},
|
||||
...
|
||||
],
|
||||
'execution_strategy': 'parallel' | 'sequential',
|
||||
'cache_strategy': 'use' | 'bypass'
|
||||
}
|
||||
"""
|
||||
# 1. 根据维度映射技能
|
||||
skills = self._map_dimensions_to_skills(intent.get('dimensions', {}))
|
||||
|
||||
# 2. 根据分析深度调整
|
||||
depth = intent.get('analysis_depth', 'standard')
|
||||
skills = self._apply_depth_strategy(skills, depth)
|
||||
|
||||
# 3. 解析依赖关系
|
||||
skills = self._resolve_dependencies(skills)
|
||||
|
||||
# 4. 去重
|
||||
skills = list(set(skills))
|
||||
|
||||
# 5. 排序(按优先级)
|
||||
sorted_skills = self._sort_by_priority(skills)
|
||||
|
||||
# 6. 构建执行计划
|
||||
plan = {
|
||||
'skills': [
|
||||
{
|
||||
'name': skill,
|
||||
'params': self._get_skill_params(skill, intent),
|
||||
'priority': self.SKILL_PRIORITY.get(skill, 5),
|
||||
'required': True,
|
||||
'reason': self._get_skill_reason(skill, intent)
|
||||
}
|
||||
for skill in sorted_skills
|
||||
],
|
||||
'execution_strategy': self._determine_strategy(sorted_skills),
|
||||
'cache_strategy': 'use' if self.DEPTH_STRATEGY[depth]['use_cache'] else 'bypass'
|
||||
}
|
||||
|
||||
logger.info(f"技能规划完成: {[s['name'] for s in plan['skills']]}, 策略: {plan['execution_strategy']}")
|
||||
return plan
|
||||
|
||||
def _map_dimensions_to_skills(self, dimensions: Dict[str, bool]) -> List[str]:
|
||||
"""将用户关注维度映射到技能"""
|
||||
skills = []
|
||||
|
||||
for dimension, enabled in dimensions.items():
|
||||
if enabled and dimension in self.DIMENSION_SKILL_MAP:
|
||||
mapping = self.DIMENSION_SKILL_MAP[dimension]
|
||||
skills.extend(mapping['required'])
|
||||
# 可选技能稍后根据深度策略添加
|
||||
|
||||
return skills
|
||||
|
||||
def _apply_depth_strategy(self, skills: List[str], depth: str) -> List[str]:
|
||||
"""根据分析深度调整技能列表"""
|
||||
strategy = self.DEPTH_STRATEGY.get(depth, self.DEPTH_STRATEGY['standard'])
|
||||
|
||||
# 如果有最大技能数限制
|
||||
if strategy['max_skills'] is not None and len(skills) > strategy['max_skills']:
|
||||
# 按优先级保留前N个
|
||||
sorted_skills = self._sort_by_priority(skills)
|
||||
skills = sorted_skills[:strategy['max_skills']]
|
||||
|
||||
return skills
|
||||
|
||||
def _resolve_dependencies(self, skills: List[str]) -> List[str]:
|
||||
"""解析技能依赖关系,自动添加依赖的技能"""
|
||||
resolved_skills = set(skills)
|
||||
|
||||
for skill in skills:
|
||||
if skill in self.SKILL_DEPENDENCIES:
|
||||
dependencies = self.SKILL_DEPENDENCIES[skill]
|
||||
resolved_skills.update(dependencies)
|
||||
|
||||
return list(resolved_skills)
|
||||
|
||||
def _sort_by_priority(self, skills: List[str]) -> List[str]:
|
||||
"""按优先级排序技能"""
|
||||
return sorted(skills, key=lambda s: self.SKILL_PRIORITY.get(s, 999))
|
||||
|
||||
def _determine_strategy(self, skills: List[str]) -> str:
|
||||
"""确定执行策略(并行/串行)"""
|
||||
# 如果技能数量少于等于3,使用并行
|
||||
if len(skills) <= 3:
|
||||
return 'parallel'
|
||||
else:
|
||||
# 技能较多时,按优先级分组并行
|
||||
return 'parallel'
|
||||
|
||||
def _get_skill_params(self, skill_name: str, intent: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""获取技能执行参数"""
|
||||
params = {}
|
||||
|
||||
if skill_name == 'market_data':
|
||||
params['data_type'] = 'quote'
|
||||
|
||||
elif skill_name == 'technical_analysis':
|
||||
# 根据用户关注点决定指标
|
||||
indicators = ['ma', 'macd']
|
||||
specific_concerns = intent.get('specific_concerns', [])
|
||||
|
||||
if any('rsi' in concern.lower() for concern in specific_concerns):
|
||||
indicators.append('rsi')
|
||||
if any('kdj' in concern.lower() for concern in specific_concerns):
|
||||
indicators.append('kdj')
|
||||
if any('布林' in concern or 'boll' in concern.lower() for concern in specific_concerns):
|
||||
indicators.append('boll')
|
||||
|
||||
params['indicators'] = indicators
|
||||
|
||||
elif skill_name == 'advanced_data':
|
||||
# 根据维度决定数据类型
|
||||
data_types = []
|
||||
dimensions = intent.get('dimensions', {})
|
||||
|
||||
if dimensions.get('valuation'):
|
||||
data_types.append('valuation')
|
||||
if dimensions.get('money_flow'):
|
||||
data_types.append('money_flow')
|
||||
|
||||
if not data_types:
|
||||
data_types = ['valuation', 'money_flow']
|
||||
|
||||
params['data_types'] = data_types
|
||||
|
||||
return params
|
||||
|
||||
def _get_skill_reason(self, skill_name: str, intent: Dict[str, Any]) -> str:
|
||||
"""获取调用该技能的原因"""
|
||||
dimensions = intent.get('dimensions', {})
|
||||
reasons = []
|
||||
|
||||
if skill_name == 'market_data':
|
||||
if dimensions.get('price_trend'):
|
||||
reasons.append('用户关注价格走势')
|
||||
else:
|
||||
reasons.append('获取基础行情数据')
|
||||
|
||||
elif skill_name == 'technical_analysis':
|
||||
if dimensions.get('technical'):
|
||||
reasons.append('用户关注技术指标')
|
||||
else:
|
||||
reasons.append('提供技术面分析')
|
||||
|
||||
elif skill_name == 'fundamental':
|
||||
if dimensions.get('fundamental'):
|
||||
reasons.append('用户关注基本面')
|
||||
else:
|
||||
reasons.append('提供公司基本信息')
|
||||
|
||||
elif skill_name == 'advanced_data':
|
||||
if dimensions.get('valuation'):
|
||||
reasons.append('用户关注估值')
|
||||
if dimensions.get('money_flow'):
|
||||
reasons.append('用户关注资金流向')
|
||||
if not reasons:
|
||||
reasons.append('提供高级财务数据')
|
||||
|
||||
elif skill_name == 'visualization':
|
||||
reasons.append('生成K线图表')
|
||||
|
||||
return ', '.join(reasons) if reasons else '提供分析数据'
|
||||
@ -8,6 +8,8 @@ from typing import Dict, Any, Optional, List
|
||||
from app.config import get_settings
|
||||
from app.agent.context import ContextManager
|
||||
from app.agent.skill_manager import skill_manager
|
||||
from app.agent.question_analyzer import QuestionAnalyzer
|
||||
from app.agent.skill_planner import SkillPlanner
|
||||
from app.skills.market_data import MarketDataSkill
|
||||
from app.skills.technical_analysis import TechnicalAnalysisSkill
|
||||
from app.skills.fundamental import FundamentalSkill
|
||||
@ -27,6 +29,10 @@ class SmartStockAgent:
|
||||
self.context_manager = ContextManager()
|
||||
self.settings = get_settings()
|
||||
|
||||
# 初始化智能组件
|
||||
self.question_analyzer = QuestionAnalyzer()
|
||||
self.skill_planner = SkillPlanner()
|
||||
|
||||
# 注册技能
|
||||
self._register_skills()
|
||||
|
||||
@ -34,7 +40,7 @@ class SmartStockAgent:
|
||||
self.use_llm = bool(self.settings.zhipuai_api_key) and llm_service.client is not None
|
||||
|
||||
if self.use_llm:
|
||||
logger.info("Smart Agent初始化完成(LLM深度集成模式 + Tushare Pro高级数据)")
|
||||
logger.info("Smart Agent初始化完成(智能模式 + LLM深度集成 + Tushare Pro高级数据)")
|
||||
else:
|
||||
logger.warning("Smart Agent初始化完成(规则模式,建议配置LLM)")
|
||||
|
||||
@ -63,82 +69,19 @@ class SmartStockAgent:
|
||||
user_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理用户消息(智能版)
|
||||
处理用户消息(非流式,已废弃,保留用于兼容)
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
session_id: 会话ID
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
响应结果
|
||||
实际使用 process_message_stream 进行流式输出
|
||||
"""
|
||||
logger.info(f"处理消息: {message[:50]}...")
|
||||
# 收集流式输出
|
||||
full_response = ""
|
||||
async for chunk in self.process_message_stream(message, session_id, user_id):
|
||||
full_response += chunk
|
||||
|
||||
# 保存用户消息
|
||||
self.context_manager.add_message(session_id, "user", message)
|
||||
|
||||
# 第一步:使用LLM理解问题意图
|
||||
intent_analysis = await self._analyze_question_intent(message, session_id)
|
||||
|
||||
if not intent_analysis:
|
||||
# 不直接说"无法理解",而是引导用户
|
||||
response = {
|
||||
"message": """我是您的金融智能助手,可以帮您:
|
||||
|
||||
📊 **股票分析** - 分析个股走势、技术指标、基本面
|
||||
📈 **市场观察** - 解读大盘走势、行业热点
|
||||
📚 **知识问答** - 解答金融投资相关问题
|
||||
|
||||
您可以这样问我:
|
||||
• "分析一下贵州茅台"
|
||||
• "现在A股市场怎么样"
|
||||
• "什么是MACD指标"
|
||||
|
||||
请告诉我您想了解什么?""",
|
||||
"metadata": {"type": "guide"}
|
||||
return {
|
||||
"message": full_response,
|
||||
"metadata": {"type": "text"}
|
||||
}
|
||||
self.context_manager.add_message(session_id, "assistant", response["message"])
|
||||
return response
|
||||
|
||||
# 第二步:根据意图类型处理
|
||||
question_type = intent_analysis['type']
|
||||
|
||||
if question_type == 'stock_specific':
|
||||
# 针对特定股票的问题
|
||||
response = await self._handle_stock_question(intent_analysis, message)
|
||||
elif question_type == 'macro_finance':
|
||||
# 宏观金融问题
|
||||
response = await self._handle_macro_question(intent_analysis, message)
|
||||
elif question_type == 'knowledge':
|
||||
# 金融知识问答
|
||||
response = await self._handle_knowledge_question(intent_analysis, message)
|
||||
elif question_type == 'general_chat':
|
||||
# 一般对话,引导用户
|
||||
response = await self._handle_general_chat(intent_analysis, message)
|
||||
else:
|
||||
# 未知类型,智能引导
|
||||
response = {
|
||||
"message": f"""我理解您想了解:{intent_analysis.get('description', '相关信息')}
|
||||
|
||||
作为金融智能助手,我擅长:
|
||||
• 📊 分析具体股票(如"分析比亚迪")
|
||||
• 📈 解读市场走势(如"现在大盘怎么样")
|
||||
• 📚 解答金融知识(如"什么是市盈率")
|
||||
|
||||
能否更具体地告诉我您想了解什么?""",
|
||||
"metadata": {"type": "guide"}
|
||||
}
|
||||
|
||||
# 保存助手响应
|
||||
self.context_manager.add_message(
|
||||
session_id,
|
||||
"assistant",
|
||||
response["message"],
|
||||
metadata=response.get("metadata")
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _is_comprehensive_analysis(self, message: str) -> bool:
|
||||
"""
|
||||
@ -1732,7 +1675,7 @@ RSI:{technical.get('rsi', 0):.2f if technical.get('rsi') else '计算中'}
|
||||
user_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
流式处理用户消息
|
||||
流式处理用户消息(智能模式)
|
||||
|
||||
Args:
|
||||
message: 用户消息
|
||||
@ -1742,62 +1685,52 @@ RSI:{technical.get('rsi', 0):.2f if technical.get('rsi') else '计算中'}
|
||||
Yields:
|
||||
响应文本片段
|
||||
"""
|
||||
logger.info(f"流式处理消息: {message[:50]}...")
|
||||
logger.info(f"[智能模式-流式] 处理消息: {message[:50]}...")
|
||||
|
||||
# 保存用户消息
|
||||
# 1. 保存用户消息
|
||||
self.context_manager.add_message(session_id, "user", message)
|
||||
|
||||
# 第一步:使用LLM理解问题意图
|
||||
intent_analysis = await self._analyze_question_intent(message, session_id)
|
||||
# 2. 提取上下文信息
|
||||
context_info = self.context_manager.extract_context_info(session_id)
|
||||
logger.info(f"[智能模式-流式] 上下文信息: last_stock={context_info.get('last_stock')}")
|
||||
|
||||
if not intent_analysis:
|
||||
# 引导用户
|
||||
guide_message = """我是您的金融智能助手,可以帮您:
|
||||
# 3. 深度问题分析
|
||||
intent = await self.question_analyzer.analyze_question(
|
||||
question=message,
|
||||
context=self.context_manager.get_context(session_id),
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
📊 **股票分析** - 分析个股走势、技术指标、基本面
|
||||
📈 **市场观察** - 解读大盘走势、行业热点
|
||||
📚 **知识问答** - 解答金融投资相关问题
|
||||
logger.info(f"[智能模式-流式] 问题分析: type={intent.get('type')}, dimensions={intent.get('dimensions')}")
|
||||
|
||||
您可以这样问我:
|
||||
• "分析一下贵州茅台"
|
||||
• "现在A股市场怎么样"
|
||||
• "什么是MACD指标"
|
||||
|
||||
请告诉我您想了解什么?"""
|
||||
self.context_manager.add_message(session_id, "assistant", guide_message)
|
||||
yield guide_message
|
||||
return
|
||||
|
||||
# 第二步:根据意图类型处理(流式)
|
||||
question_type = intent_analysis['type']
|
||||
# 4. 处理上下文引用(代词解析)
|
||||
if intent.get('context_references', {}).get('refers_to_previous'):
|
||||
intent = self._resolve_context_references(intent, context_info)
|
||||
logger.info(f"[智能模式-流式] 上下文解析后: target={intent.get('target')}")
|
||||
|
||||
# 5. 根据问题类型分发(流式)
|
||||
full_response = ""
|
||||
if question_type == 'stock_specific':
|
||||
# 针对特定股票的问题 - 流式输出
|
||||
async for chunk in self._handle_stock_question_stream(intent_analysis, message):
|
||||
if intent['type'] == 'stock_analysis':
|
||||
async for chunk in self._handle_stock_analysis_stream(intent, message):
|
||||
full_response += chunk
|
||||
yield chunk
|
||||
elif question_type in ['macro_finance', 'knowledge', 'general_chat']:
|
||||
# 其他类型 - 直接流式输出(不需要特殊处理)
|
||||
response = await self._handle_other_question(question_type, intent_analysis, message)
|
||||
elif intent['type'] == 'market_overview':
|
||||
response = await self._handle_macro_question(intent, message)
|
||||
full_response = response["message"]
|
||||
for char in full_response:
|
||||
yield char
|
||||
elif intent['type'] == 'knowledge':
|
||||
response = await self._handle_knowledge_question(intent, message)
|
||||
full_response = response["message"]
|
||||
# 逐字yield
|
||||
for char in full_response:
|
||||
yield char
|
||||
else:
|
||||
# 未知类型
|
||||
guide_message = f"""我理解您想了解:{intent_analysis.get('description', '相关信息')}
|
||||
response = await self._handle_general_chat(intent, message)
|
||||
full_response = response["message"]
|
||||
for char in full_response:
|
||||
yield char
|
||||
|
||||
作为金融智能助手,我擅长:
|
||||
• 📊 分析具体股票(如"分析比亚迪")
|
||||
• 📈 解读市场走势(如"现在大盘怎么样")
|
||||
• 📚 解答金融知识(如"什么是市盈率")
|
||||
|
||||
能否更具体地告诉我您想了解什么?"""
|
||||
full_response = guide_message
|
||||
yield guide_message
|
||||
|
||||
# 保存助手响应
|
||||
# 6. 保存助手响应
|
||||
self.context_manager.add_message(session_id, "assistant", full_response)
|
||||
|
||||
async def _handle_other_question(
|
||||
@ -1882,9 +1815,9 @@ RSI:{technical.get('rsi', 0):.2f if technical.get('rsi') else '计算中'}
|
||||
yield f"分析{stock_name}时出错:{str(e)}"
|
||||
|
||||
async def _handle_us_stock_stream(self, keyword: str, message: str):
|
||||
"""流式处理美股分析"""
|
||||
"""流式处理美股分析(智能模式)"""
|
||||
symbol = self._get_us_stock_symbol(keyword)
|
||||
logger.info(f"流式处理美股查询: {keyword} -> {symbol}")
|
||||
logger.info(f"[智能模式-流式] 美股查询: {keyword} -> {symbol}")
|
||||
|
||||
try:
|
||||
result = await skill_manager.execute_skill("us_stock_analysis", symbol=symbol, analysis_type="comprehensive")
|
||||
@ -1893,9 +1826,20 @@ RSI:{technical.get('rsi', 0):.2f if technical.get('rsi') else '计算中'}
|
||||
yield f"抱歉,未找到美股 {symbol}。请确认股票代码是否正确。"
|
||||
return
|
||||
|
||||
# 使用LLM流式分析
|
||||
# 使用智能模式的动态prompt生成
|
||||
if self.use_llm:
|
||||
async for chunk in self._llm_us_stock_analysis_stream(result["data"], message):
|
||||
# 构建美股数据的动态prompt
|
||||
us_data = result["data"]
|
||||
prompt = self._build_us_stock_dynamic_prompt(us_data, symbol, message)
|
||||
|
||||
# 流式生成
|
||||
stream = llm_service.chat_stream(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.7,
|
||||
max_tokens=2000
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
else:
|
||||
yield self._format_us_stock_data(result["data"])
|
||||
@ -1904,6 +1848,71 @@ RSI:{technical.get('rsi', 0):.2f if technical.get('rsi') else '计算中'}
|
||||
logger.error(f"美股查询失败: {e}")
|
||||
yield f"查询美股 {symbol} 时出错:{str(e)}"
|
||||
|
||||
def _build_us_stock_dynamic_prompt(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
symbol: str,
|
||||
user_message: str
|
||||
) -> str:
|
||||
"""
|
||||
为美股构建动态prompt
|
||||
|
||||
Args:
|
||||
data: 美股数据
|
||||
symbol: 股票代码
|
||||
user_message: 用户消息
|
||||
|
||||
Returns:
|
||||
prompt字符串
|
||||
"""
|
||||
# 提取数据
|
||||
name = data.get('name', symbol)
|
||||
current_price = data.get('current_price', 0)
|
||||
change = data.get('change', 0)
|
||||
change_percent = data.get('change_percent', 0)
|
||||
volume = data.get('volume', 0)
|
||||
market_cap = data.get('market_cap', 0)
|
||||
pe_ratio = data.get('pe_ratio', 0)
|
||||
|
||||
# 技术指标
|
||||
technical = data.get('technical_indicators', {})
|
||||
ma5 = technical.get('ma5', 0)
|
||||
ma10 = technical.get('ma10', 0)
|
||||
ma20 = technical.get('ma20', 0)
|
||||
rsi = technical.get('rsi', 0)
|
||||
macd = technical.get('macd', 0)
|
||||
|
||||
prompt = f"""你是一个专业的美股分析师。请根据以下数据分析【{name}({symbol})】。
|
||||
|
||||
**用户问题**: {user_message}
|
||||
|
||||
## 数据信息
|
||||
|
||||
**行情数据**:
|
||||
- 最新价: ${current_price:.2f}
|
||||
- 涨跌: ${change:+.2f} ({change_percent:+.2f}%)
|
||||
- 成交量: {volume:,.0f}
|
||||
- 市值: ${market_cap:,.0f}
|
||||
- 市盈率(PE): {pe_ratio:.2f}
|
||||
|
||||
**技术指标**:
|
||||
- 均线: MA5=${ma5:.2f}, MA10=${ma10:.2f}, MA20=${ma20:.2f}
|
||||
- RSI: {rsi:.2f}
|
||||
- MACD: {macd:.4f}
|
||||
|
||||
## 分析要求
|
||||
|
||||
请根据用户的问题,提供自然、有针对性的分析。不要使用固定格式,而是像专业分析师一样,用自然的语言回答用户的问题。
|
||||
|
||||
- 如果用户关注价格走势,重点分析价格和趋势
|
||||
- 如果用户关注技术指标,重点分析技术面
|
||||
- 如果用户关注基本面,重点分析公司情况和估值
|
||||
|
||||
请直接开始分析,不要添加日期标题。最后声明:"以上分析仅供参考,不构成投资建议。美股投资有风险,请谨慎决策。"
|
||||
"""
|
||||
|
||||
return prompt
|
||||
|
||||
async def _llm_comprehensive_analysis_stream(self, data: Dict[str, Any], user_message: str, is_index: bool = False):
|
||||
"""使用LLM流式进行综合分析"""
|
||||
from datetime import datetime
|
||||
@ -2122,5 +2131,449 @@ MACD:{f"{technical.get('macd'):.4f}" if technical.get('macd') else '计算中'
|
||||
yield chunk
|
||||
|
||||
|
||||
# ==================== 新增:智能模式方法 ====================
|
||||
|
||||
async def _handle_stock_analysis_v2(
|
||||
self,
|
||||
intent: Dict[str, Any],
|
||||
message: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
处理股票分析请求(智能模式 V2)
|
||||
|
||||
Args:
|
||||
intent: 问题意图
|
||||
message: 用户消息
|
||||
|
||||
Returns:
|
||||
响应结果
|
||||
"""
|
||||
target = intent.get('target', {})
|
||||
stock_code = target.get('stock_code')
|
||||
stock_name = target.get('stock_name')
|
||||
market = target.get('market', 'A股')
|
||||
|
||||
# 如果没有股票代码,尝试匹配
|
||||
if not stock_code and stock_name:
|
||||
# 检测是否为美股
|
||||
is_us_stock = self._is_us_stock(stock_name, market)
|
||||
|
||||
if is_us_stock:
|
||||
return await self._handle_us_stock(stock_name, message)
|
||||
|
||||
# A股匹配
|
||||
stock_info = await self._match_stock_with_llm(stock_name)
|
||||
if not stock_info:
|
||||
return {
|
||||
"message": f"抱歉,未找到股票\"{stock_name}\"。请确认名称或代码是否正确。",
|
||||
"metadata": {"type": "error"}
|
||||
}
|
||||
|
||||
stock_code = stock_info['code']
|
||||
stock_name = stock_info['name']
|
||||
|
||||
if not stock_code:
|
||||
return {
|
||||
"message": "抱歉,我没有识别到您提到的股票。请提供更明确的股票代码或名称。",
|
||||
"metadata": {"type": "error"}
|
||||
}
|
||||
|
||||
logger.info(f"[智能模式] 分析股票: {stock_name}({stock_code})")
|
||||
|
||||
# 1. 技能规划
|
||||
plan = self.skill_planner.plan_skills(intent)
|
||||
logger.info(f"[智能模式] 技能规划: {[s['name'] for s in plan['skills']]}")
|
||||
|
||||
# 2. 执行技能
|
||||
execution_results = await skill_manager.execute_plan(
|
||||
plan=plan,
|
||||
stock_code=stock_code
|
||||
)
|
||||
|
||||
if execution_results['errors']:
|
||||
logger.warning(f"[智能模式] 技能执行有错误: {execution_results['errors']}")
|
||||
|
||||
# 3. 智能生成回答
|
||||
analysis = await self._generate_intelligent_response(
|
||||
intent=intent,
|
||||
execution_results=execution_results['results'],
|
||||
stock_code=stock_code,
|
||||
stock_name=stock_name,
|
||||
user_message=message
|
||||
)
|
||||
|
||||
return {
|
||||
"message": analysis,
|
||||
"metadata": {
|
||||
"type": "stock_analysis",
|
||||
"intent": intent,
|
||||
"plan": plan,
|
||||
"data": {
|
||||
"stock_code": stock_code,
|
||||
"stock_name": stock_name,
|
||||
**execution_results['results']
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async def _generate_intelligent_response(
|
||||
self,
|
||||
intent: Dict[str, Any],
|
||||
execution_results: Dict[str, Any],
|
||||
stock_code: str,
|
||||
stock_name: str,
|
||||
user_message: str
|
||||
) -> str:
|
||||
"""
|
||||
智能生成回答 - 根据用户意图定制
|
||||
|
||||
Args:
|
||||
intent: 问题意图
|
||||
execution_results: 技能执行结果
|
||||
stock_code: 股票代码
|
||||
stock_name: 股票名称
|
||||
user_message: 用户消息
|
||||
|
||||
Returns:
|
||||
分析报告
|
||||
"""
|
||||
# 1. 构建动态prompt
|
||||
prompt = self._build_dynamic_prompt(
|
||||
intent=intent,
|
||||
data=execution_results,
|
||||
stock_code=stock_code,
|
||||
stock_name=stock_name,
|
||||
user_message=user_message
|
||||
)
|
||||
|
||||
# 2. 调用LLM生成
|
||||
max_tokens = self._calculate_max_tokens(intent)
|
||||
response = await self._call_llm_async(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.7,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
if not response:
|
||||
# 降级到规则化格式
|
||||
return self._format_fallback_response(execution_results, stock_name)
|
||||
|
||||
return response
|
||||
|
||||
def _build_dynamic_prompt(
|
||||
self,
|
||||
intent: Dict[str, Any],
|
||||
data: Dict[str, Any],
|
||||
stock_code: str,
|
||||
stock_name: str,
|
||||
user_message: str
|
||||
) -> str:
|
||||
"""
|
||||
根据意图动态构建prompt
|
||||
|
||||
Args:
|
||||
intent: 问题意图
|
||||
data: 执行结果数据
|
||||
stock_code: 股票代码
|
||||
stock_name: 股票名称
|
||||
user_message: 用户消息
|
||||
|
||||
Returns:
|
||||
prompt字符串
|
||||
"""
|
||||
dimensions = intent.get('dimensions', {})
|
||||
time_scope = intent.get('time_scope', {})
|
||||
specific_concerns = intent.get('specific_concerns', [])
|
||||
user_style = intent.get('user_style', {})
|
||||
|
||||
# 基础部分
|
||||
prompt_parts = [
|
||||
f"你是一个专业的股票分析师。请根据以下数据分析【{stock_name}({stock_code})】。",
|
||||
"",
|
||||
f"**用户问题**: {user_message}",
|
||||
""
|
||||
]
|
||||
|
||||
# 添加用户关注点
|
||||
if specific_concerns:
|
||||
prompt_parts.append(f"**用户特别关注**: {', '.join(specific_concerns)}")
|
||||
prompt_parts.append("")
|
||||
|
||||
# 添加数据部分
|
||||
prompt_parts.append("## 数据信息")
|
||||
prompt_parts.append("")
|
||||
|
||||
# 根据维度添加相应数据
|
||||
if dimensions.get('price_trend') and 'market_data' in data:
|
||||
prompt_parts.append(self._format_market_data_section(data['market_data']))
|
||||
|
||||
if dimensions.get('technical') and 'technical_analysis' in data:
|
||||
prompt_parts.append(self._format_technical_section(data['technical_analysis']))
|
||||
|
||||
if dimensions.get('fundamental') and 'fundamental' in data:
|
||||
prompt_parts.append(self._format_fundamental_section(data['fundamental']))
|
||||
|
||||
if dimensions.get('valuation') or dimensions.get('money_flow'):
|
||||
if 'advanced_data' in data:
|
||||
prompt_parts.append(self._format_advanced_section(data['advanced_data']))
|
||||
|
||||
# 分析要求
|
||||
prompt_parts.append("")
|
||||
prompt_parts.append("## 分析要求")
|
||||
prompt_parts.append("")
|
||||
|
||||
# 根据时间范围调整
|
||||
if time_scope.get('short_term'):
|
||||
prompt_parts.append("- 重点分析短期走势(1-2周)")
|
||||
if time_scope.get('medium_term'):
|
||||
prompt_parts.append("- 分析中期趋势(1-3个月)")
|
||||
if time_scope.get('long_term'):
|
||||
prompt_parts.append("- 评估长期投资价值(半年以上)")
|
||||
|
||||
# 根据用户风格调整
|
||||
if user_style.get('tone') == 'casual':
|
||||
prompt_parts.append("- 使用通俗易懂的语言,避免过多专业术语")
|
||||
else:
|
||||
prompt_parts.append("- 使用专业的金融术语和分析方法")
|
||||
|
||||
if user_style.get('detail_level') == 'brief':
|
||||
prompt_parts.append("- 简洁回答,控制在200-300字")
|
||||
else:
|
||||
prompt_parts.append("- 详细分析,控制在500-600字")
|
||||
|
||||
# 输出格式
|
||||
prompt_parts.append("")
|
||||
prompt_parts.append("请直接开始分析,不要添加日期标题。最后声明:\"以上分析仅供参考,不构成投资建议。\"")
|
||||
|
||||
return "\n".join(prompt_parts)
|
||||
|
||||
def _format_market_data_section(self, data: Dict) -> str:
|
||||
"""格式化行情数据部分"""
|
||||
if 'error' in data:
|
||||
return "**行情数据**: 暂时无法获取"
|
||||
|
||||
return f"""**行情数据**:
|
||||
- 最新价: {data.get('close', 0):.2f}元
|
||||
- 涨跌幅: {data.get('pct_chg', 0):+.2f}%
|
||||
- 成交量: {data.get('vol', 0):.0f}手
|
||||
- 成交额: {data.get('amount', 0):.0f}千元
|
||||
"""
|
||||
|
||||
def _format_technical_section(self, data: Dict) -> str:
|
||||
"""格式化技术指标部分"""
|
||||
if 'error' in data:
|
||||
return "**技术指标**: 暂时无法获取"
|
||||
|
||||
indicators = data.get('indicators', {})
|
||||
parts = ["**技术指标**:"]
|
||||
|
||||
if 'ma' in indicators:
|
||||
ma = indicators['ma']
|
||||
parts.append(f"- 均线: MA5={ma.get('ma5', 0):.2f}, MA10={ma.get('ma10', 0):.2f}, MA20={ma.get('ma20', 0):.2f}")
|
||||
|
||||
if 'macd' in indicators:
|
||||
macd = indicators['macd']
|
||||
parts.append(f"- MACD: DIF={macd.get('dif', 0):.4f}, DEA={macd.get('dea', 0):.4f}, MACD={macd.get('macd', 0):.4f}")
|
||||
|
||||
if 'rsi' in indicators:
|
||||
rsi = indicators['rsi']
|
||||
parts.append(f"- RSI: RSI6={rsi.get('rsi6', 0):.2f}, RSI12={rsi.get('rsi12', 0):.2f}")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _format_fundamental_section(self, data: Dict) -> str:
|
||||
"""格式化基本面部分"""
|
||||
if 'error' in data:
|
||||
return "**基本面**: 暂时无法获取"
|
||||
|
||||
return f"""**基本面**:
|
||||
- 公司名称: {data.get('name', '')}
|
||||
- 所属行业: {data.get('industry', '')}
|
||||
- 所属地域: {data.get('area', '')}
|
||||
- 上市市场: {data.get('market', '')}
|
||||
"""
|
||||
|
||||
def _format_advanced_section(self, data: Dict) -> str:
|
||||
"""格式化高级数据部分"""
|
||||
if 'error' in data:
|
||||
return "**高级数据**: 暂时无法获取"
|
||||
|
||||
parts = ["**高级数据**:"]
|
||||
|
||||
if 'valuation' in data:
|
||||
val = data['valuation']
|
||||
parts.append(f"- 估值: PE={val.get('pe', 0):.2f}, PB={val.get('pb', 0):.2f}")
|
||||
|
||||
if 'money_flow' in data and data['money_flow']:
|
||||
mf = data['money_flow'][0] if isinstance(data['money_flow'], list) else data['money_flow']
|
||||
parts.append(f"- 资金流向: 净流入={mf.get('net_mf_amount', 0):.2f}万元")
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
def _calculate_max_tokens(self, intent: Dict[str, Any]) -> int:
|
||||
"""根据意图计算max_tokens"""
|
||||
depth = intent.get('analysis_depth', 'standard')
|
||||
detail_level = intent.get('user_style', {}).get('detail_level', 'detailed')
|
||||
|
||||
if depth == 'quick' or detail_level == 'brief':
|
||||
return 800
|
||||
elif depth == 'deep' or detail_level == 'detailed':
|
||||
return 2000
|
||||
else:
|
||||
return 1500
|
||||
|
||||
def _format_fallback_response(self, data: Dict, stock_name: str) -> str:
|
||||
"""降级响应格式"""
|
||||
parts = [f"【{stock_name}】分析报告\n"]
|
||||
|
||||
if 'market_data' in data and 'error' not in data['market_data']:
|
||||
md = data['market_data']
|
||||
parts.append(f"最新价: {md.get('close', 0):.2f}元")
|
||||
parts.append(f"涨跌幅: {md.get('pct_chg', 0):+.2f}%\n")
|
||||
|
||||
parts.append("以上分析仅供参考,不构成投资建议。")
|
||||
return "\n".join(parts)
|
||||
|
||||
async def _handle_stock_analysis_stream(
|
||||
self,
|
||||
intent: Dict[str, Any],
|
||||
message: str
|
||||
):
|
||||
"""
|
||||
流式处理股票分析请求(智能模式)
|
||||
|
||||
Args:
|
||||
intent: 问题意图
|
||||
message: 用户消息
|
||||
|
||||
Yields:
|
||||
响应文本片段
|
||||
"""
|
||||
target = intent.get('target', {})
|
||||
stock_code = target.get('stock_code')
|
||||
stock_name = target.get('stock_name')
|
||||
market = target.get('market', 'A股')
|
||||
|
||||
# 检测是否为美股
|
||||
is_us_stock = market == '美股' or self._is_us_stock(stock_name or stock_code or '', market)
|
||||
|
||||
# 如果是美股,直接使用美股处理流程
|
||||
if is_us_stock:
|
||||
async for chunk in self._handle_us_stock_stream(stock_name or stock_code, message):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# A股处理流程
|
||||
# 如果没有股票代码,尝试匹配
|
||||
if not stock_code and stock_name:
|
||||
stock_info = await self._match_stock_with_llm(stock_name)
|
||||
if not stock_info:
|
||||
yield f"抱歉,未找到股票\"{stock_name}\"。请确认名称或代码是否正确。"
|
||||
return
|
||||
|
||||
stock_code = stock_info['code']
|
||||
stock_name = stock_info['name']
|
||||
|
||||
if not stock_code:
|
||||
yield "抱歉,我没有识别到您提到的股票。请提供更明确的股票代码或名称。"
|
||||
return
|
||||
|
||||
logger.info(f"[智能模式-流式] 分析股票: {stock_name}({stock_code})")
|
||||
|
||||
# 1. 技能规划
|
||||
plan = self.skill_planner.plan_skills(intent)
|
||||
logger.info(f"[智能模式-流式] 技能规划: {[s['name'] for s in plan['skills']]}")
|
||||
|
||||
# 2. 执行技能
|
||||
execution_results = await skill_manager.execute_plan(
|
||||
plan=plan,
|
||||
stock_code=stock_code
|
||||
)
|
||||
|
||||
if execution_results['errors']:
|
||||
logger.warning(f"[智能模式-流式] 技能执行有错误: {execution_results['errors']}")
|
||||
|
||||
# 3. 智能生成回答(流式)
|
||||
async for chunk in self._generate_intelligent_response_stream(
|
||||
intent=intent,
|
||||
execution_results=execution_results['results'],
|
||||
stock_code=stock_code,
|
||||
stock_name=stock_name,
|
||||
user_message=message
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def _generate_intelligent_response_stream(
|
||||
self,
|
||||
intent: Dict[str, Any],
|
||||
execution_results: Dict[str, Any],
|
||||
stock_code: str,
|
||||
stock_name: str,
|
||||
user_message: str
|
||||
):
|
||||
"""
|
||||
智能生成回答(流式) - 根据用户意图定制
|
||||
|
||||
Args:
|
||||
intent: 问题意图
|
||||
execution_results: 技能执行结果
|
||||
stock_code: 股票代码
|
||||
stock_name: 股票名称
|
||||
user_message: 用户消息
|
||||
|
||||
Yields:
|
||||
响应文本片段
|
||||
"""
|
||||
# 1. 构建动态prompt
|
||||
prompt = self._build_dynamic_prompt(
|
||||
intent=intent,
|
||||
data=execution_results,
|
||||
stock_code=stock_code,
|
||||
stock_name=stock_name,
|
||||
user_message=user_message
|
||||
)
|
||||
|
||||
# 2. 调用LLM流式生成
|
||||
if self.use_llm:
|
||||
stream = llm_service.chat_stream(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
temperature=0.7,
|
||||
max_tokens=self._calculate_max_tokens(intent)
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
yield chunk
|
||||
else:
|
||||
# 降级到规则化格式
|
||||
fallback = self._format_fallback_response(execution_results, stock_name)
|
||||
for char in fallback:
|
||||
yield char
|
||||
|
||||
def _resolve_context_references(
|
||||
self,
|
||||
intent: Dict[str, Any],
|
||||
context_info: Dict
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
解析上下文引用(代词解析)
|
||||
|
||||
Args:
|
||||
intent: 问题意图
|
||||
context_info: 上下文信息
|
||||
|
||||
Returns:
|
||||
更新后的意图
|
||||
"""
|
||||
target = intent.get('target', {})
|
||||
|
||||
# 如果用户说"这只股票"、"它"等,从上下文中提取
|
||||
if not target.get('stock_code') and context_info.get('last_stock'):
|
||||
target['stock_code'] = context_info['last_stock']
|
||||
intent['target'] = target
|
||||
logger.info(f"[智能模式] 从上下文解析股票代码: {target['stock_code']}")
|
||||
|
||||
return intent
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
smart_agent = SmartStockAgent()
|
||||
|
||||
87
backend/tests/test_intelligent_agent.py
Normal file
87
backend/tests/test_intelligent_agent.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""
|
||||
测试智能AI Agent
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from app.agent.smart_agent import SmartStockAgent
|
||||
|
||||
|
||||
async def test_intelligent_mode():
|
||||
"""测试智能模式"""
|
||||
print("=" * 60)
|
||||
print("测试智能AI Agent")
|
||||
print("=" * 60)
|
||||
|
||||
# 创建Agent实例
|
||||
agent = SmartStockAgent()
|
||||
|
||||
# 测试用例
|
||||
test_cases = [
|
||||
{
|
||||
"name": "简单股票查询",
|
||||
"message": "贵州茅台怎么样",
|
||||
"session_id": "test_session_1"
|
||||
},
|
||||
{
|
||||
"name": "技术分析关注",
|
||||
"message": "比亚迪的MACD指标怎么样,有没有金叉",
|
||||
"session_id": "test_session_2"
|
||||
},
|
||||
{
|
||||
"name": "基本面关注",
|
||||
"message": "宁德时代的基本面如何,盈利能力强吗",
|
||||
"session_id": "test_session_3"
|
||||
}
|
||||
]
|
||||
|
||||
for i, test_case in enumerate(test_cases, 1):
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"测试用例 {i}: {test_case['name']}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"用户问题: {test_case['message']}")
|
||||
print()
|
||||
|
||||
try:
|
||||
# 处理消息
|
||||
response = await agent.process_message(
|
||||
message=test_case['message'],
|
||||
session_id=test_case['session_id']
|
||||
)
|
||||
|
||||
# 打印结果
|
||||
print("响应:")
|
||||
print(response.get('message', ''))
|
||||
print()
|
||||
|
||||
# 打印元数据
|
||||
metadata = response.get('metadata', {})
|
||||
if metadata:
|
||||
print("元数据:")
|
||||
print(f" 类型: {metadata.get('type')}")
|
||||
if 'intent' in metadata:
|
||||
intent = metadata['intent']
|
||||
print(f" 意图类型: {intent.get('type')}")
|
||||
print(f" 关注维度: {intent.get('dimensions')}")
|
||||
print(f" 分析深度: {intent.get('analysis_depth')}")
|
||||
if 'plan' in metadata:
|
||||
plan = metadata['plan']
|
||||
skills = [s['name'] for s in plan.get('skills', [])]
|
||||
print(f" 调用技能: {skills}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print("测试完成")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_intelligent_mode())
|
||||
Loading…
Reference in New Issue
Block a user