257 lines
8.6 KiB
Python
257 lines
8.6 KiB
Python
"""
|
||
技能规划器 - 根据用户意图智能选择技能组合
|
||
"""
|
||
from typing import Dict, Any, List, Set
|
||
from app.utils.logger import logger
|
||
|
||
|
||
class SkillPlanner:
|
||
"""智能技能规划器 - 根据问题意图动态选择技能"""
|
||
|
||
# 维度到技能的映射
|
||
DIMENSION_SKILL_MAP = {
|
||
'price_trend': {
|
||
'required': ['market_data'],
|
||
'optional': []
|
||
},
|
||
'technical': {
|
||
'required': ['market_data', 'technical_analysis'],
|
||
'optional': ['visualization']
|
||
},
|
||
'fundamental': {
|
||
'required': ['fundamental'],
|
||
'optional': []
|
||
},
|
||
'valuation': {
|
||
'required': ['advanced_data'],
|
||
'optional': []
|
||
},
|
||
'money_flow': {
|
||
'required': ['advanced_data'],
|
||
'optional': []
|
||
},
|
||
'risk': {
|
||
'required': ['technical_analysis', 'advanced_data'],
|
||
'optional': []
|
||
}
|
||
}
|
||
|
||
# 技能依赖关系
|
||
SKILL_DEPENDENCIES = {
|
||
'technical_analysis': ['market_data'], # 技术分析依赖行情数据
|
||
'visualization': ['market_data'], # 可视化依赖行情数据
|
||
}
|
||
|
||
# 技能优先级(数字越小优先级越高)
|
||
SKILL_PRIORITY = {
|
||
'market_data': 1, # 最高优先级
|
||
'fundamental': 1,
|
||
'technical_analysis': 2,
|
||
'advanced_data': 2,
|
||
'visualization': 3, # 最低优先级
|
||
'us_stock_analysis': 1
|
||
}
|
||
|
||
# 分析深度策略
|
||
DEPTH_STRATEGY = {
|
||
'quick': {
|
||
'max_skills': 2,
|
||
'include_optional': False,
|
||
'use_cache': True
|
||
},
|
||
'standard': {
|
||
'max_skills': 4,
|
||
'include_optional': True,
|
||
'use_cache': True
|
||
},
|
||
'deep': {
|
||
'max_skills': None, # 无限制
|
||
'include_optional': True,
|
||
'use_cache': False
|
||
}
|
||
}
|
||
|
||
def __init__(self):
|
||
"""初始化技能规划器"""
|
||
logger.info("技能规划器初始化")
|
||
|
||
def plan_skills(self, intent: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
根据意图规划技能执行
|
||
|
||
Args:
|
||
intent: 问题意图(来自QuestionAnalyzer)
|
||
|
||
Returns:
|
||
SkillExecutionPlan: {
|
||
'skills': [
|
||
{
|
||
'name': 'market_data',
|
||
'params': {...},
|
||
'priority': 1,
|
||
'required': True,
|
||
'reason': '用户关注价格走势'
|
||
},
|
||
...
|
||
],
|
||
'execution_strategy': 'parallel' | 'sequential',
|
||
'cache_strategy': 'use' | 'bypass'
|
||
}
|
||
"""
|
||
# 1. 根据维度映射技能
|
||
skills = self._map_dimensions_to_skills(intent.get('dimensions', {}))
|
||
|
||
# 2. 根据分析深度调整
|
||
depth = intent.get('analysis_depth', 'standard')
|
||
skills = self._apply_depth_strategy(skills, depth)
|
||
|
||
# 3. 解析依赖关系
|
||
skills = self._resolve_dependencies(skills)
|
||
|
||
# 4. 去重
|
||
skills = list(set(skills))
|
||
|
||
# 5. 排序(按优先级)
|
||
sorted_skills = self._sort_by_priority(skills)
|
||
|
||
# 6. 构建执行计划
|
||
plan = {
|
||
'skills': [
|
||
{
|
||
'name': skill,
|
||
'params': self._get_skill_params(skill, intent),
|
||
'priority': self.SKILL_PRIORITY.get(skill, 5),
|
||
'required': True,
|
||
'reason': self._get_skill_reason(skill, intent)
|
||
}
|
||
for skill in sorted_skills
|
||
],
|
||
'execution_strategy': self._determine_strategy(sorted_skills),
|
||
'cache_strategy': 'use' if self.DEPTH_STRATEGY[depth]['use_cache'] else 'bypass'
|
||
}
|
||
|
||
logger.info(f"技能规划完成: {[s['name'] for s in plan['skills']]}, 策略: {plan['execution_strategy']}")
|
||
return plan
|
||
|
||
def _map_dimensions_to_skills(self, dimensions: Dict[str, bool]) -> List[str]:
|
||
"""将用户关注维度映射到技能"""
|
||
skills = []
|
||
|
||
for dimension, enabled in dimensions.items():
|
||
if enabled and dimension in self.DIMENSION_SKILL_MAP:
|
||
mapping = self.DIMENSION_SKILL_MAP[dimension]
|
||
skills.extend(mapping['required'])
|
||
# 可选技能稍后根据深度策略添加
|
||
|
||
return skills
|
||
|
||
def _apply_depth_strategy(self, skills: List[str], depth: str) -> List[str]:
|
||
"""根据分析深度调整技能列表"""
|
||
strategy = self.DEPTH_STRATEGY.get(depth, self.DEPTH_STRATEGY['standard'])
|
||
|
||
# 如果有最大技能数限制
|
||
if strategy['max_skills'] is not None and len(skills) > strategy['max_skills']:
|
||
# 按优先级保留前N个
|
||
sorted_skills = self._sort_by_priority(skills)
|
||
skills = sorted_skills[:strategy['max_skills']]
|
||
|
||
return skills
|
||
|
||
def _resolve_dependencies(self, skills: List[str]) -> List[str]:
|
||
"""解析技能依赖关系,自动添加依赖的技能"""
|
||
resolved_skills = set(skills)
|
||
|
||
for skill in skills:
|
||
if skill in self.SKILL_DEPENDENCIES:
|
||
dependencies = self.SKILL_DEPENDENCIES[skill]
|
||
resolved_skills.update(dependencies)
|
||
|
||
return list(resolved_skills)
|
||
|
||
def _sort_by_priority(self, skills: List[str]) -> List[str]:
|
||
"""按优先级排序技能"""
|
||
return sorted(skills, key=lambda s: self.SKILL_PRIORITY.get(s, 999))
|
||
|
||
def _determine_strategy(self, skills: List[str]) -> str:
|
||
"""确定执行策略(并行/串行)"""
|
||
# 如果技能数量少于等于3,使用并行
|
||
if len(skills) <= 3:
|
||
return 'parallel'
|
||
else:
|
||
# 技能较多时,按优先级分组并行
|
||
return 'parallel'
|
||
|
||
def _get_skill_params(self, skill_name: str, intent: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""获取技能执行参数"""
|
||
params = {}
|
||
|
||
if skill_name == 'market_data':
|
||
params['data_type'] = 'quote'
|
||
|
||
elif skill_name == 'technical_analysis':
|
||
# 根据用户关注点决定指标
|
||
indicators = ['ma', 'macd']
|
||
specific_concerns = intent.get('specific_concerns', [])
|
||
|
||
if any('rsi' in concern.lower() for concern in specific_concerns):
|
||
indicators.append('rsi')
|
||
if any('kdj' in concern.lower() for concern in specific_concerns):
|
||
indicators.append('kdj')
|
||
if any('布林' in concern or 'boll' in concern.lower() for concern in specific_concerns):
|
||
indicators.append('boll')
|
||
|
||
params['indicators'] = indicators
|
||
|
||
elif skill_name == 'advanced_data':
|
||
# 根据维度决定数据类型
|
||
data_types = []
|
||
dimensions = intent.get('dimensions', {})
|
||
|
||
if dimensions.get('valuation'):
|
||
data_types.append('valuation')
|
||
if dimensions.get('money_flow'):
|
||
data_types.append('money_flow')
|
||
|
||
if not data_types:
|
||
data_types = ['valuation', 'money_flow']
|
||
|
||
params['data_types'] = data_types
|
||
|
||
return params
|
||
|
||
def _get_skill_reason(self, skill_name: str, intent: Dict[str, Any]) -> str:
|
||
"""获取调用该技能的原因"""
|
||
dimensions = intent.get('dimensions', {})
|
||
reasons = []
|
||
|
||
if skill_name == 'market_data':
|
||
if dimensions.get('price_trend'):
|
||
reasons.append('用户关注价格走势')
|
||
else:
|
||
reasons.append('获取基础行情数据')
|
||
|
||
elif skill_name == 'technical_analysis':
|
||
if dimensions.get('technical'):
|
||
reasons.append('用户关注技术指标')
|
||
else:
|
||
reasons.append('提供技术面分析')
|
||
|
||
elif skill_name == 'fundamental':
|
||
if dimensions.get('fundamental'):
|
||
reasons.append('用户关注基本面')
|
||
else:
|
||
reasons.append('提供公司基本信息')
|
||
|
||
elif skill_name == 'advanced_data':
|
||
if dimensions.get('valuation'):
|
||
reasons.append('用户关注估值')
|
||
if dimensions.get('money_flow'):
|
||
reasons.append('用户关注资金流向')
|
||
if not reasons:
|
||
reasons.append('提供高级财务数据')
|
||
|
||
elif skill_name == 'visualization':
|
||
reasons.append('生成K线图表')
|
||
|
||
return ', '.join(reasons) if reasons else '提供分析数据'
|