stock-ai-agent/backend/app/agent/skill_planner.py
2026-02-03 23:50:48 +08:00

257 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
技能规划器 - 根据用户意图智能选择技能组合
"""
from typing import Dict, Any, List, Set
from app.utils.logger import logger
class SkillPlanner:
"""智能技能规划器 - 根据问题意图动态选择技能"""
# 维度到技能的映射
DIMENSION_SKILL_MAP = {
'price_trend': {
'required': ['market_data'],
'optional': []
},
'technical': {
'required': ['market_data', 'technical_analysis'],
'optional': ['visualization']
},
'fundamental': {
'required': ['fundamental'],
'optional': []
},
'valuation': {
'required': ['advanced_data'],
'optional': []
},
'money_flow': {
'required': ['advanced_data'],
'optional': []
},
'risk': {
'required': ['technical_analysis', 'advanced_data'],
'optional': []
}
}
# 技能依赖关系
SKILL_DEPENDENCIES = {
'technical_analysis': ['market_data'], # 技术分析依赖行情数据
'visualization': ['market_data'], # 可视化依赖行情数据
}
# 技能优先级(数字越小优先级越高)
SKILL_PRIORITY = {
'market_data': 1, # 最高优先级
'fundamental': 1,
'technical_analysis': 2,
'advanced_data': 2,
'visualization': 3, # 最低优先级
'us_stock_analysis': 1
}
# 分析深度策略
DEPTH_STRATEGY = {
'quick': {
'max_skills': 2,
'include_optional': False,
'use_cache': True
},
'standard': {
'max_skills': 4,
'include_optional': True,
'use_cache': True
},
'deep': {
'max_skills': None, # 无限制
'include_optional': True,
'use_cache': False
}
}
def __init__(self):
"""初始化技能规划器"""
logger.info("技能规划器初始化")
def plan_skills(self, intent: Dict[str, Any]) -> Dict[str, Any]:
"""
根据意图规划技能执行
Args:
intent: 问题意图来自QuestionAnalyzer
Returns:
SkillExecutionPlan: {
'skills': [
{
'name': 'market_data',
'params': {...},
'priority': 1,
'required': True,
'reason': '用户关注价格走势'
},
...
],
'execution_strategy': 'parallel' | 'sequential',
'cache_strategy': 'use' | 'bypass'
}
"""
# 1. 根据维度映射技能
skills = self._map_dimensions_to_skills(intent.get('dimensions', {}))
# 2. 根据分析深度调整
depth = intent.get('analysis_depth', 'standard')
skills = self._apply_depth_strategy(skills, depth)
# 3. 解析依赖关系
skills = self._resolve_dependencies(skills)
# 4. 去重
skills = list(set(skills))
# 5. 排序(按优先级)
sorted_skills = self._sort_by_priority(skills)
# 6. 构建执行计划
plan = {
'skills': [
{
'name': skill,
'params': self._get_skill_params(skill, intent),
'priority': self.SKILL_PRIORITY.get(skill, 5),
'required': True,
'reason': self._get_skill_reason(skill, intent)
}
for skill in sorted_skills
],
'execution_strategy': self._determine_strategy(sorted_skills),
'cache_strategy': 'use' if self.DEPTH_STRATEGY[depth]['use_cache'] else 'bypass'
}
logger.info(f"技能规划完成: {[s['name'] for s in plan['skills']]}, 策略: {plan['execution_strategy']}")
return plan
def _map_dimensions_to_skills(self, dimensions: Dict[str, bool]) -> List[str]:
"""将用户关注维度映射到技能"""
skills = []
for dimension, enabled in dimensions.items():
if enabled and dimension in self.DIMENSION_SKILL_MAP:
mapping = self.DIMENSION_SKILL_MAP[dimension]
skills.extend(mapping['required'])
# 可选技能稍后根据深度策略添加
return skills
def _apply_depth_strategy(self, skills: List[str], depth: str) -> List[str]:
"""根据分析深度调整技能列表"""
strategy = self.DEPTH_STRATEGY.get(depth, self.DEPTH_STRATEGY['standard'])
# 如果有最大技能数限制
if strategy['max_skills'] is not None and len(skills) > strategy['max_skills']:
# 按优先级保留前N个
sorted_skills = self._sort_by_priority(skills)
skills = sorted_skills[:strategy['max_skills']]
return skills
def _resolve_dependencies(self, skills: List[str]) -> List[str]:
"""解析技能依赖关系,自动添加依赖的技能"""
resolved_skills = set(skills)
for skill in skills:
if skill in self.SKILL_DEPENDENCIES:
dependencies = self.SKILL_DEPENDENCIES[skill]
resolved_skills.update(dependencies)
return list(resolved_skills)
def _sort_by_priority(self, skills: List[str]) -> List[str]:
"""按优先级排序技能"""
return sorted(skills, key=lambda s: self.SKILL_PRIORITY.get(s, 999))
def _determine_strategy(self, skills: List[str]) -> str:
"""确定执行策略(并行/串行)"""
# 如果技能数量少于等于3使用并行
if len(skills) <= 3:
return 'parallel'
else:
# 技能较多时,按优先级分组并行
return 'parallel'
def _get_skill_params(self, skill_name: str, intent: Dict[str, Any]) -> Dict[str, Any]:
"""获取技能执行参数"""
params = {}
if skill_name == 'market_data':
params['data_type'] = 'quote'
elif skill_name == 'technical_analysis':
# 根据用户关注点决定指标
indicators = ['ma', 'macd']
specific_concerns = intent.get('specific_concerns', [])
if any('rsi' in concern.lower() for concern in specific_concerns):
indicators.append('rsi')
if any('kdj' in concern.lower() for concern in specific_concerns):
indicators.append('kdj')
if any('布林' in concern or 'boll' in concern.lower() for concern in specific_concerns):
indicators.append('boll')
params['indicators'] = indicators
elif skill_name == 'advanced_data':
# 根据维度决定数据类型
data_types = []
dimensions = intent.get('dimensions', {})
if dimensions.get('valuation'):
data_types.append('valuation')
if dimensions.get('money_flow'):
data_types.append('money_flow')
if not data_types:
data_types = ['valuation', 'money_flow']
params['data_types'] = data_types
return params
def _get_skill_reason(self, skill_name: str, intent: Dict[str, Any]) -> str:
"""获取调用该技能的原因"""
dimensions = intent.get('dimensions', {})
reasons = []
if skill_name == 'market_data':
if dimensions.get('price_trend'):
reasons.append('用户关注价格走势')
else:
reasons.append('获取基础行情数据')
elif skill_name == 'technical_analysis':
if dimensions.get('technical'):
reasons.append('用户关注技术指标')
else:
reasons.append('提供技术面分析')
elif skill_name == 'fundamental':
if dimensions.get('fundamental'):
reasons.append('用户关注基本面')
else:
reasons.append('提供公司基本信息')
elif skill_name == 'advanced_data':
if dimensions.get('valuation'):
reasons.append('用户关注估值')
if dimensions.get('money_flow'):
reasons.append('用户关注资金流向')
if not reasons:
reasons.append('提供高级财务数据')
elif skill_name == 'visualization':
reasons.append('生成K线图表')
return ', '.join(reasons) if reasons else '提供分析数据'