stock-ai-agent/backend/app/agent/skill_planner.py
2026-02-05 12:35:20 +08:00

304 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
技能规划器 - 根据用户意图智能选择技能组合
"""
from typing import Dict, Any, List, Set
from app.utils.logger import logger
class SkillPlanner:
"""智能技能规划器 - 根据问题意图动态选择技能"""
# 维度到技能的映射
DIMENSION_SKILL_MAP = {
'price_trend': {
'required': ['market_data', 'brave_search'], # brave_search 必需
'optional': []
},
'technical': {
'required': ['market_data', 'technical_analysis', 'brave_search'], # brave_search 必需
'optional': ['visualization']
},
'fundamental': {
'required': ['fundamental', 'brave_search'], # brave_search 必需
'optional': []
},
'valuation': {
'required': ['advanced_data', 'brave_search'], # brave_search 必需
'optional': []
},
'money_flow': {
'required': ['advanced_data', 'brave_search'], # brave_search 必需
'optional': []
},
'risk': {
'required': ['technical_analysis', 'advanced_data', 'brave_search'], # brave_search 必需
'optional': []
},
'news': { # 新闻维度
'required': ['brave_search'],
'optional': []
}
}
# 技能依赖关系
SKILL_DEPENDENCIES = {
'technical_analysis': ['market_data'], # 技术分析依赖行情数据
'visualization': ['market_data'], # 可视化依赖行情数据
}
# 技能优先级(数字越小优先级越高)
SKILL_PRIORITY = {
'market_data': 1, # 最高优先级
'fundamental': 1,
'brave_search': 1, # 新闻搜索也是高优先级
'technical_analysis': 2,
'advanced_data': 2,
'visualization': 3, # 最低优先级
'us_stock_analysis': 1
}
# 分析深度策略
DEPTH_STRATEGY = {
'quick': {
'max_skills': 2,
'include_optional': False,
'use_cache': True
},
'standard': {
'max_skills': 4,
'include_optional': True,
'use_cache': True
},
'deep': {
'max_skills': None, # 无限制
'include_optional': True,
'use_cache': False
}
}
def __init__(self):
"""初始化技能规划器"""
logger.info("技能规划器初始化")
def plan_skills(self, intent: Dict[str, Any]) -> Dict[str, Any]:
"""
根据意图规划技能执行
Args:
intent: 问题意图来自QuestionAnalyzer
Returns:
SkillExecutionPlan: {
'skills': [
{
'name': 'market_data',
'params': {...},
'priority': 1,
'required': True,
'reason': '用户关注价格走势'
},
...
],
'execution_strategy': 'parallel' | 'sequential',
'cache_strategy': 'use' | 'bypass'
}
"""
# 1. 根据维度映射技能
skills = self._map_dimensions_to_skills(intent.get('dimensions', {}))
# 2. 根据分析深度调整
depth = intent.get('analysis_depth', 'standard')
skills = self._apply_depth_strategy(skills, depth)
# 3. 解析依赖关系
skills = self._resolve_dependencies(skills)
# 4. 去重
skills = list(set(skills))
# 5. 排序(按优先级)
sorted_skills = self._sort_by_priority(skills)
# 6. 构建执行计划
plan = {
'skills': [
{
'name': skill,
'params': self._get_skill_params(skill, intent),
'priority': self.SKILL_PRIORITY.get(skill, 5),
'required': True,
'reason': self._get_skill_reason(skill, intent)
}
for skill in sorted_skills
],
'execution_strategy': self._determine_strategy(sorted_skills),
'cache_strategy': 'use' if self.DEPTH_STRATEGY[depth]['use_cache'] else 'bypass'
}
logger.info(f"技能规划完成: {[s['name'] for s in plan['skills']]}, 策略: {plan['execution_strategy']}")
return plan
def _map_dimensions_to_skills(self, dimensions: Dict[str, bool]) -> List[str]:
"""将用户关注维度映射到技能"""
skills = []
for dimension, enabled in dimensions.items():
if enabled and dimension in self.DIMENSION_SKILL_MAP:
mapping = self.DIMENSION_SKILL_MAP[dimension]
skills.extend(mapping['required'])
# 默认也添加可选技能(特别是 brave_search
skills.extend(mapping['optional'])
return skills
def _apply_depth_strategy(self, skills: List[str], depth: str) -> List[str]:
"""根据分析深度调整技能列表"""
strategy = self.DEPTH_STRATEGY.get(depth, self.DEPTH_STRATEGY['standard'])
# 如果有最大技能数限制
if strategy['max_skills'] is not None and len(skills) > strategy['max_skills']:
# 按优先级保留前N个
sorted_skills = self._sort_by_priority(skills)
skills = sorted_skills[:strategy['max_skills']]
return skills
def _resolve_dependencies(self, skills: List[str]) -> List[str]:
"""解析技能依赖关系,自动添加依赖的技能"""
resolved_skills = set(skills)
for skill in skills:
if skill in self.SKILL_DEPENDENCIES:
dependencies = self.SKILL_DEPENDENCIES[skill]
resolved_skills.update(dependencies)
return list(resolved_skills)
def _sort_by_priority(self, skills: List[str]) -> List[str]:
"""按优先级排序技能"""
return sorted(skills, key=lambda s: self.SKILL_PRIORITY.get(s, 999))
def _determine_strategy(self, skills: List[str]) -> str:
"""确定执行策略(并行/串行)"""
# 如果技能数量少于等于3使用并行
if len(skills) <= 3:
return 'parallel'
else:
# 技能较多时,按优先级分组并行
return 'parallel'
def _get_skill_params(self, skill_name: str, intent: Dict[str, Any]) -> Dict[str, Any]:
"""获取技能执行参数"""
params = {}
if skill_name == 'market_data':
params['data_type'] = 'quote'
elif skill_name == 'technical_analysis':
# 根据用户关注点决定指标
indicators = ['ma', 'macd']
specific_concerns = intent.get('specific_concerns', [])
if any('rsi' in concern.lower() for concern in specific_concerns):
indicators.append('rsi')
if any('kdj' in concern.lower() for concern in specific_concerns):
indicators.append('kdj')
if any('布林' in concern or 'boll' in concern.lower() for concern in specific_concerns):
indicators.append('boll')
params['indicators'] = indicators
elif skill_name == 'advanced_data':
# 根据维度决定数据类型
data_types = []
dimensions = intent.get('dimensions', {})
if dimensions.get('valuation'):
data_types.append('valuation')
if dimensions.get('money_flow'):
data_types.append('money_flow')
if not data_types:
data_types = ['valuation', 'money_flow']
params['data_types'] = data_types
elif skill_name == 'brave_search':
# 构建搜索查询
target = intent.get('target', {})
stock_name = target.get('stock_name', '')
stock_code = target.get('stock_code', '')
dimensions = intent.get('dimensions', {})
# 根据维度构建搜索关键词
search_keywords = []
if stock_name:
search_keywords.append(stock_name)
elif stock_code:
search_keywords.append(stock_code)
# 添加维度相关关键词
if dimensions.get('fundamental'):
search_keywords.append('财报 业绩')
if dimensions.get('news'):
search_keywords.append('最新消息')
if dimensions.get('risk'):
search_keywords.append('风险 预警')
# 如果没有特定维度,搜索一般新闻
if not any(dimensions.values()):
search_keywords.append('最新动态')
params['query'] = ' '.join(search_keywords)
params['search_type'] = 'news' # 默认搜索新闻
params['count'] = 5
params['freshness'] = 'pw' # 过去一周
return params
def _get_skill_reason(self, skill_name: str, intent: Dict[str, Any]) -> str:
"""获取调用该技能的原因"""
dimensions = intent.get('dimensions', {})
reasons = []
if skill_name == 'market_data':
if dimensions.get('price_trend'):
reasons.append('用户关注价格走势')
else:
reasons.append('获取基础行情数据')
elif skill_name == 'technical_analysis':
if dimensions.get('technical'):
reasons.append('用户关注技术指标')
else:
reasons.append('提供技术面分析')
elif skill_name == 'fundamental':
if dimensions.get('fundamental'):
reasons.append('用户关注基本面')
else:
reasons.append('提供公司基本信息')
elif skill_name == 'advanced_data':
if dimensions.get('valuation'):
reasons.append('用户关注估值')
if dimensions.get('money_flow'):
reasons.append('用户关注资金流向')
if not reasons:
reasons.append('提供高级财务数据')
elif skill_name == 'visualization':
reasons.append('生成K线图表')
elif skill_name == 'brave_search':
if dimensions.get('news'):
reasons.append('用户关注最新新闻')
elif dimensions.get('fundamental'):
reasons.append('搜索公司最新动态和财报信息')
elif dimensions.get('risk'):
reasons.append('搜索风险预警信息')
else:
reasons.append('获取最新市场资讯和舆情')
return ', '.join(reasons) if reasons else '提供分析数据'