397 lines
14 KiB
Python
397 lines
14 KiB
Python
"""
|
||
技能规划器 - 根据用户意图智能选择技能组合
|
||
"""
|
||
from typing import Dict, Any, List, Set
|
||
from app.utils.logger import logger
|
||
|
||
|
||
class SkillPlanner:
|
||
"""智能技能规划器 - 根据问题意图动态选择技能"""
|
||
|
||
# A股维度到技能的映射
|
||
A_STOCK_DIMENSION_SKILL_MAP = {
|
||
'price_trend': {
|
||
'required': ['market_data', 'brave_search'],
|
||
'optional': []
|
||
},
|
||
'technical': {
|
||
'required': ['market_data', 'technical_analysis', 'brave_search'],
|
||
'optional': ['visualization']
|
||
},
|
||
'fundamental': {
|
||
'required': ['fundamental', 'brave_search'],
|
||
'optional': []
|
||
},
|
||
'valuation': {
|
||
'required': ['advanced_data', 'brave_search'],
|
||
'optional': []
|
||
},
|
||
'money_flow': {
|
||
'required': ['advanced_data', 'brave_search'],
|
||
'optional': []
|
||
},
|
||
'risk': {
|
||
'required': ['technical_analysis', 'advanced_data', 'brave_search'],
|
||
'optional': []
|
||
},
|
||
'news': {
|
||
'required': ['brave_search'],
|
||
'optional': []
|
||
}
|
||
}
|
||
|
||
# 美股/港股维度到技能的映射(使用 yfinance)
|
||
INTL_STOCK_DIMENSION_SKILL_MAP = {
|
||
'price_trend': {
|
||
'required': ['us_stock_analysis', 'brave_search'],
|
||
'optional': []
|
||
},
|
||
'technical': {
|
||
'required': ['us_stock_analysis', 'brave_search'],
|
||
'optional': []
|
||
},
|
||
'fundamental': {
|
||
'required': ['us_stock_analysis', 'brave_search'],
|
||
'optional': []
|
||
},
|
||
'valuation': {
|
||
'required': ['us_stock_analysis', 'brave_search'],
|
||
'optional': []
|
||
},
|
||
'money_flow': {
|
||
'required': ['us_stock_analysis', 'brave_search'],
|
||
'optional': []
|
||
},
|
||
'risk': {
|
||
'required': ['us_stock_analysis', 'brave_search'],
|
||
'optional': []
|
||
},
|
||
'news': {
|
||
'required': ['brave_search'],
|
||
'optional': []
|
||
}
|
||
}
|
||
|
||
# 技能依赖关系(仅 A 股)
|
||
SKILL_DEPENDENCIES = {
|
||
'technical_analysis': ['market_data'],
|
||
'visualization': ['market_data'],
|
||
}
|
||
|
||
# 技能优先级(数字越小优先级越高)
|
||
SKILL_PRIORITY = {
|
||
'market_data': 1,
|
||
'fundamental': 1,
|
||
'brave_search': 1,
|
||
'us_stock_analysis': 1,
|
||
'technical_analysis': 2,
|
||
'advanced_data': 2,
|
||
'visualization': 3,
|
||
}
|
||
|
||
# 分析深度策略
|
||
DEPTH_STRATEGY = {
|
||
'quick': {
|
||
'max_skills': 2,
|
||
'include_optional': False,
|
||
'use_cache': True
|
||
},
|
||
'standard': {
|
||
'max_skills': 4,
|
||
'include_optional': True,
|
||
'use_cache': True
|
||
},
|
||
'deep': {
|
||
'max_skills': None,
|
||
'include_optional': True,
|
||
'use_cache': False
|
||
}
|
||
}
|
||
|
||
def __init__(self):
|
||
"""初始化技能规划器"""
|
||
logger.info("技能规划器初始化")
|
||
|
||
def plan_skills(self, intent: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
根据意图规划技能执行
|
||
|
||
Args:
|
||
intent: 问题意图(来自QuestionAnalyzer)
|
||
|
||
Returns:
|
||
SkillExecutionPlan
|
||
"""
|
||
# 获取市场类型
|
||
target = intent.get('target', {})
|
||
market = target.get('market', 'A股')
|
||
stock_code = target.get('stock_code', '')
|
||
stock_name = target.get('stock_name', '')
|
||
|
||
# 根据市场类型选择不同的技能映射
|
||
if market in ('美股', '港股'):
|
||
return self._plan_intl_stock_skills(intent, market, stock_code, stock_name)
|
||
else:
|
||
return self._plan_a_stock_skills(intent)
|
||
|
||
def _plan_a_stock_skills(self, intent: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""规划 A 股技能"""
|
||
# 1. 根据维度映射技能
|
||
skills = self._map_dimensions_to_skills(
|
||
intent.get('dimensions', {}),
|
||
self.A_STOCK_DIMENSION_SKILL_MAP
|
||
)
|
||
|
||
# 2. 根据分析深度调整
|
||
depth = intent.get('analysis_depth', 'standard')
|
||
skills = self._apply_depth_strategy(skills, depth)
|
||
|
||
# 3. 解析依赖关系
|
||
skills = self._resolve_dependencies(skills)
|
||
|
||
# 4. 去重并排序
|
||
skills = list(set(skills))
|
||
sorted_skills = self._sort_by_priority(skills)
|
||
|
||
# 5. 构建执行计划
|
||
plan = {
|
||
'skills': [
|
||
{
|
||
'name': skill,
|
||
'params': self._get_skill_params(skill, intent),
|
||
'priority': self.SKILL_PRIORITY.get(skill, 5),
|
||
'required': True,
|
||
'reason': self._get_skill_reason(skill, intent)
|
||
}
|
||
for skill in sorted_skills
|
||
],
|
||
'execution_strategy': self._determine_strategy(sorted_skills),
|
||
'cache_strategy': 'use' if self.DEPTH_STRATEGY[depth]['use_cache'] else 'bypass'
|
||
}
|
||
|
||
logger.info(f"[A股] 技能规划完成: {[s['name'] for s in plan['skills']]}, 策略: {plan['execution_strategy']}")
|
||
return plan
|
||
|
||
def _plan_intl_stock_skills(self, intent: Dict[str, Any], market: str, stock_code: str, stock_name: str) -> Dict[str, Any]:
|
||
"""规划美股/港股技能"""
|
||
# 1. 根据维度映射技能
|
||
skills = self._map_dimensions_to_skills(
|
||
intent.get('dimensions', {}),
|
||
self.INTL_STOCK_DIMENSION_SKILL_MAP
|
||
)
|
||
|
||
# 2. 确保至少有 us_stock_analysis
|
||
if 'us_stock_analysis' not in skills:
|
||
skills.append('us_stock_analysis')
|
||
|
||
# 3. 去重并排序
|
||
skills = list(set(skills))
|
||
sorted_skills = self._sort_by_priority(skills)
|
||
|
||
# 4. 构建执行计划
|
||
depth = intent.get('analysis_depth', 'standard')
|
||
plan = {
|
||
'skills': [
|
||
{
|
||
'name': skill,
|
||
'params': self._get_intl_skill_params(skill, stock_code, stock_name),
|
||
'priority': self.SKILL_PRIORITY.get(skill, 5),
|
||
'required': skill == 'us_stock_analysis',
|
||
'reason': self._get_intl_skill_reason(skill, market)
|
||
}
|
||
for skill in sorted_skills
|
||
],
|
||
'execution_strategy': 'parallel',
|
||
'cache_strategy': 'use' if self.DEPTH_STRATEGY[depth]['use_cache'] else 'bypass'
|
||
}
|
||
|
||
logger.info(f"[{market}] 技能规划完成: {[s['name'] for s in plan['skills']]}, 策略: {plan['execution_strategy']}")
|
||
return plan
|
||
|
||
def _get_intl_skill_params(self, skill_name: str, stock_code: str, stock_name: str) -> Dict[str, Any]:
|
||
"""获取美股/港股技能参数"""
|
||
if skill_name == 'us_stock_analysis':
|
||
return {
|
||
'symbol': stock_code,
|
||
'analysis_type': 'comprehensive'
|
||
}
|
||
elif skill_name == 'brave_search':
|
||
return {
|
||
'query': f'{stock_name} 最新动态 财报',
|
||
'search_type': 'news',
|
||
'count': 5,
|
||
'freshness': 'pw'
|
||
}
|
||
return {}
|
||
|
||
def _get_intl_skill_reason(self, skill_name: str, market: str) -> str:
|
||
"""获取美股/港股技能调用原因"""
|
||
if skill_name == 'us_stock_analysis':
|
||
return f'获取{market}基础数据和技术指标'
|
||
elif skill_name == 'brave_search':
|
||
return '获取最新市场资讯和舆情'
|
||
return '提供分析数据'
|
||
|
||
def _map_dimensions_to_skills(self, dimensions: Dict[str, bool], skill_map: Dict) -> List[str]:
|
||
"""将用户关注维度映射到技能"""
|
||
skills = []
|
||
|
||
for dimension, enabled in dimensions.items():
|
||
if enabled and dimension in skill_map:
|
||
mapping = skill_map[dimension]
|
||
skills.extend(mapping['required'])
|
||
skills.extend(mapping['optional'])
|
||
|
||
return skills
|
||
|
||
def _apply_depth_strategy(self, skills: List[str], depth: str) -> List[str]:
|
||
"""根据分析深度调整技能列表"""
|
||
strategy = self.DEPTH_STRATEGY.get(depth, self.DEPTH_STRATEGY['standard'])
|
||
|
||
# 如果有最大技能数限制
|
||
if strategy['max_skills'] is not None and len(skills) > strategy['max_skills']:
|
||
# 按优先级保留前N个
|
||
sorted_skills = self._sort_by_priority(skills)
|
||
skills = sorted_skills[:strategy['max_skills']]
|
||
|
||
return skills
|
||
|
||
def _resolve_dependencies(self, skills: List[str]) -> List[str]:
|
||
"""解析技能依赖关系,自动添加依赖的技能"""
|
||
resolved_skills = set(skills)
|
||
|
||
for skill in skills:
|
||
if skill in self.SKILL_DEPENDENCIES:
|
||
dependencies = self.SKILL_DEPENDENCIES[skill]
|
||
resolved_skills.update(dependencies)
|
||
|
||
return list(resolved_skills)
|
||
|
||
def _sort_by_priority(self, skills: List[str]) -> List[str]:
|
||
"""按优先级排序技能"""
|
||
return sorted(skills, key=lambda s: self.SKILL_PRIORITY.get(s, 999))
|
||
|
||
def _determine_strategy(self, skills: List[str]) -> str:
|
||
"""确定执行策略(并行/串行)"""
|
||
# 如果技能数量少于等于3,使用并行
|
||
if len(skills) <= 3:
|
||
return 'parallel'
|
||
else:
|
||
# 技能较多时,按优先级分组并行
|
||
return 'parallel'
|
||
|
||
def _get_skill_params(self, skill_name: str, intent: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""获取技能执行参数"""
|
||
params = {}
|
||
|
||
if skill_name == 'market_data':
|
||
params['data_type'] = 'quote'
|
||
|
||
elif skill_name == 'technical_analysis':
|
||
# 根据用户关注点决定指标
|
||
indicators = ['ma', 'macd']
|
||
specific_concerns = intent.get('specific_concerns', [])
|
||
|
||
if any('rsi' in concern.lower() for concern in specific_concerns):
|
||
indicators.append('rsi')
|
||
if any('kdj' in concern.lower() for concern in specific_concerns):
|
||
indicators.append('kdj')
|
||
if any('布林' in concern or 'boll' in concern.lower() for concern in specific_concerns):
|
||
indicators.append('boll')
|
||
|
||
params['indicators'] = indicators
|
||
|
||
elif skill_name == 'advanced_data':
|
||
# 根据维度决定数据类型
|
||
data_types = []
|
||
dimensions = intent.get('dimensions', {})
|
||
|
||
if dimensions.get('valuation'):
|
||
data_types.append('valuation')
|
||
if dimensions.get('money_flow'):
|
||
data_types.append('money_flow')
|
||
|
||
if not data_types:
|
||
data_types = ['valuation', 'money_flow']
|
||
|
||
params['data_types'] = data_types
|
||
|
||
elif skill_name == 'brave_search':
|
||
# 构建搜索查询
|
||
target = intent.get('target', {})
|
||
stock_name = target.get('stock_name', '')
|
||
stock_code = target.get('stock_code', '')
|
||
dimensions = intent.get('dimensions', {})
|
||
|
||
# 根据维度构建搜索关键词
|
||
search_keywords = []
|
||
if stock_name:
|
||
search_keywords.append(stock_name)
|
||
elif stock_code:
|
||
search_keywords.append(stock_code)
|
||
|
||
# 添加维度相关关键词
|
||
if dimensions.get('fundamental'):
|
||
search_keywords.append('财报 业绩')
|
||
if dimensions.get('news'):
|
||
search_keywords.append('最新消息')
|
||
if dimensions.get('risk'):
|
||
search_keywords.append('风险 预警')
|
||
|
||
# 如果没有特定维度,搜索一般新闻
|
||
if not any(dimensions.values()):
|
||
search_keywords.append('最新动态')
|
||
|
||
params['query'] = ' '.join(search_keywords)
|
||
params['search_type'] = 'news' # 默认搜索新闻
|
||
params['count'] = 5
|
||
params['freshness'] = 'pw' # 过去一周
|
||
|
||
return params
|
||
|
||
def _get_skill_reason(self, skill_name: str, intent: Dict[str, Any]) -> str:
|
||
"""获取调用该技能的原因"""
|
||
dimensions = intent.get('dimensions', {})
|
||
reasons = []
|
||
|
||
if skill_name == 'market_data':
|
||
if dimensions.get('price_trend'):
|
||
reasons.append('用户关注价格走势')
|
||
else:
|
||
reasons.append('获取基础行情数据')
|
||
|
||
elif skill_name == 'technical_analysis':
|
||
if dimensions.get('technical'):
|
||
reasons.append('用户关注技术指标')
|
||
else:
|
||
reasons.append('提供技术面分析')
|
||
|
||
elif skill_name == 'fundamental':
|
||
if dimensions.get('fundamental'):
|
||
reasons.append('用户关注基本面')
|
||
else:
|
||
reasons.append('提供公司基本信息')
|
||
|
||
elif skill_name == 'advanced_data':
|
||
if dimensions.get('valuation'):
|
||
reasons.append('用户关注估值')
|
||
if dimensions.get('money_flow'):
|
||
reasons.append('用户关注资金流向')
|
||
if not reasons:
|
||
reasons.append('提供高级财务数据')
|
||
|
||
elif skill_name == 'visualization':
|
||
reasons.append('生成K线图表')
|
||
|
||
elif skill_name == 'brave_search':
|
||
if dimensions.get('news'):
|
||
reasons.append('用户关注最新新闻')
|
||
elif dimensions.get('fundamental'):
|
||
reasons.append('搜索公司最新动态和财报信息')
|
||
elif dimensions.get('risk'):
|
||
reasons.append('搜索风险预警信息')
|
||
else:
|
||
reasons.append('获取最新市场资讯和舆情')
|
||
|
||
return ', '.join(reasons) if reasons else '提供分析数据'
|