stock-ai-agent/backend/app/agent/skill_planner.py
2026-02-05 18:23:22 +08:00

397 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
技能规划器 - 根据用户意图智能选择技能组合
"""
from typing import Dict, Any, List, Set
from app.utils.logger import logger
class SkillPlanner:
"""智能技能规划器 - 根据问题意图动态选择技能"""
# A股维度到技能的映射
A_STOCK_DIMENSION_SKILL_MAP = {
'price_trend': {
'required': ['market_data', 'brave_search'],
'optional': []
},
'technical': {
'required': ['market_data', 'technical_analysis', 'brave_search'],
'optional': ['visualization']
},
'fundamental': {
'required': ['fundamental', 'brave_search'],
'optional': []
},
'valuation': {
'required': ['advanced_data', 'brave_search'],
'optional': []
},
'money_flow': {
'required': ['advanced_data', 'brave_search'],
'optional': []
},
'risk': {
'required': ['technical_analysis', 'advanced_data', 'brave_search'],
'optional': []
},
'news': {
'required': ['brave_search'],
'optional': []
}
}
# 美股/港股维度到技能的映射(使用 yfinance
INTL_STOCK_DIMENSION_SKILL_MAP = {
'price_trend': {
'required': ['us_stock_analysis', 'brave_search'],
'optional': []
},
'technical': {
'required': ['us_stock_analysis', 'brave_search'],
'optional': []
},
'fundamental': {
'required': ['us_stock_analysis', 'brave_search'],
'optional': []
},
'valuation': {
'required': ['us_stock_analysis', 'brave_search'],
'optional': []
},
'money_flow': {
'required': ['us_stock_analysis', 'brave_search'],
'optional': []
},
'risk': {
'required': ['us_stock_analysis', 'brave_search'],
'optional': []
},
'news': {
'required': ['brave_search'],
'optional': []
}
}
# 技能依赖关系(仅 A 股)
SKILL_DEPENDENCIES = {
'technical_analysis': ['market_data'],
'visualization': ['market_data'],
}
# 技能优先级(数字越小优先级越高)
SKILL_PRIORITY = {
'market_data': 1,
'fundamental': 1,
'brave_search': 1,
'us_stock_analysis': 1,
'technical_analysis': 2,
'advanced_data': 2,
'visualization': 3,
}
# 分析深度策略
DEPTH_STRATEGY = {
'quick': {
'max_skills': 2,
'include_optional': False,
'use_cache': True
},
'standard': {
'max_skills': 4,
'include_optional': True,
'use_cache': True
},
'deep': {
'max_skills': None,
'include_optional': True,
'use_cache': False
}
}
def __init__(self):
"""初始化技能规划器"""
logger.info("技能规划器初始化")
def plan_skills(self, intent: Dict[str, Any]) -> Dict[str, Any]:
"""
根据意图规划技能执行
Args:
intent: 问题意图来自QuestionAnalyzer
Returns:
SkillExecutionPlan
"""
# 获取市场类型
target = intent.get('target', {})
market = target.get('market', 'A股')
stock_code = target.get('stock_code', '')
stock_name = target.get('stock_name', '')
# 根据市场类型选择不同的技能映射
if market in ('美股', '港股'):
return self._plan_intl_stock_skills(intent, market, stock_code, stock_name)
else:
return self._plan_a_stock_skills(intent)
def _plan_a_stock_skills(self, intent: Dict[str, Any]) -> Dict[str, Any]:
"""规划 A 股技能"""
# 1. 根据维度映射技能
skills = self._map_dimensions_to_skills(
intent.get('dimensions', {}),
self.A_STOCK_DIMENSION_SKILL_MAP
)
# 2. 根据分析深度调整
depth = intent.get('analysis_depth', 'standard')
skills = self._apply_depth_strategy(skills, depth)
# 3. 解析依赖关系
skills = self._resolve_dependencies(skills)
# 4. 去重并排序
skills = list(set(skills))
sorted_skills = self._sort_by_priority(skills)
# 5. 构建执行计划
plan = {
'skills': [
{
'name': skill,
'params': self._get_skill_params(skill, intent),
'priority': self.SKILL_PRIORITY.get(skill, 5),
'required': True,
'reason': self._get_skill_reason(skill, intent)
}
for skill in sorted_skills
],
'execution_strategy': self._determine_strategy(sorted_skills),
'cache_strategy': 'use' if self.DEPTH_STRATEGY[depth]['use_cache'] else 'bypass'
}
logger.info(f"[A股] 技能规划完成: {[s['name'] for s in plan['skills']]}, 策略: {plan['execution_strategy']}")
return plan
def _plan_intl_stock_skills(self, intent: Dict[str, Any], market: str, stock_code: str, stock_name: str) -> Dict[str, Any]:
"""规划美股/港股技能"""
# 1. 根据维度映射技能
skills = self._map_dimensions_to_skills(
intent.get('dimensions', {}),
self.INTL_STOCK_DIMENSION_SKILL_MAP
)
# 2. 确保至少有 us_stock_analysis
if 'us_stock_analysis' not in skills:
skills.append('us_stock_analysis')
# 3. 去重并排序
skills = list(set(skills))
sorted_skills = self._sort_by_priority(skills)
# 4. 构建执行计划
depth = intent.get('analysis_depth', 'standard')
plan = {
'skills': [
{
'name': skill,
'params': self._get_intl_skill_params(skill, stock_code, stock_name),
'priority': self.SKILL_PRIORITY.get(skill, 5),
'required': skill == 'us_stock_analysis',
'reason': self._get_intl_skill_reason(skill, market)
}
for skill in sorted_skills
],
'execution_strategy': 'parallel',
'cache_strategy': 'use' if self.DEPTH_STRATEGY[depth]['use_cache'] else 'bypass'
}
logger.info(f"[{market}] 技能规划完成: {[s['name'] for s in plan['skills']]}, 策略: {plan['execution_strategy']}")
return plan
def _get_intl_skill_params(self, skill_name: str, stock_code: str, stock_name: str) -> Dict[str, Any]:
"""获取美股/港股技能参数"""
if skill_name == 'us_stock_analysis':
return {
'symbol': stock_code,
'analysis_type': 'comprehensive'
}
elif skill_name == 'brave_search':
return {
'query': f'{stock_name} 最新动态 财报',
'search_type': 'news',
'count': 5,
'freshness': 'pw'
}
return {}
def _get_intl_skill_reason(self, skill_name: str, market: str) -> str:
"""获取美股/港股技能调用原因"""
if skill_name == 'us_stock_analysis':
return f'获取{market}基础数据和技术指标'
elif skill_name == 'brave_search':
return '获取最新市场资讯和舆情'
return '提供分析数据'
def _map_dimensions_to_skills(self, dimensions: Dict[str, bool], skill_map: Dict) -> List[str]:
"""将用户关注维度映射到技能"""
skills = []
for dimension, enabled in dimensions.items():
if enabled and dimension in skill_map:
mapping = skill_map[dimension]
skills.extend(mapping['required'])
skills.extend(mapping['optional'])
return skills
def _apply_depth_strategy(self, skills: List[str], depth: str) -> List[str]:
"""根据分析深度调整技能列表"""
strategy = self.DEPTH_STRATEGY.get(depth, self.DEPTH_STRATEGY['standard'])
# 如果有最大技能数限制
if strategy['max_skills'] is not None and len(skills) > strategy['max_skills']:
# 按优先级保留前N个
sorted_skills = self._sort_by_priority(skills)
skills = sorted_skills[:strategy['max_skills']]
return skills
def _resolve_dependencies(self, skills: List[str]) -> List[str]:
"""解析技能依赖关系,自动添加依赖的技能"""
resolved_skills = set(skills)
for skill in skills:
if skill in self.SKILL_DEPENDENCIES:
dependencies = self.SKILL_DEPENDENCIES[skill]
resolved_skills.update(dependencies)
return list(resolved_skills)
def _sort_by_priority(self, skills: List[str]) -> List[str]:
"""按优先级排序技能"""
return sorted(skills, key=lambda s: self.SKILL_PRIORITY.get(s, 999))
def _determine_strategy(self, skills: List[str]) -> str:
"""确定执行策略(并行/串行)"""
# 如果技能数量少于等于3使用并行
if len(skills) <= 3:
return 'parallel'
else:
# 技能较多时,按优先级分组并行
return 'parallel'
def _get_skill_params(self, skill_name: str, intent: Dict[str, Any]) -> Dict[str, Any]:
"""获取技能执行参数"""
params = {}
if skill_name == 'market_data':
params['data_type'] = 'quote'
elif skill_name == 'technical_analysis':
# 根据用户关注点决定指标
indicators = ['ma', 'macd']
specific_concerns = intent.get('specific_concerns', [])
if any('rsi' in concern.lower() for concern in specific_concerns):
indicators.append('rsi')
if any('kdj' in concern.lower() for concern in specific_concerns):
indicators.append('kdj')
if any('布林' in concern or 'boll' in concern.lower() for concern in specific_concerns):
indicators.append('boll')
params['indicators'] = indicators
elif skill_name == 'advanced_data':
# 根据维度决定数据类型
data_types = []
dimensions = intent.get('dimensions', {})
if dimensions.get('valuation'):
data_types.append('valuation')
if dimensions.get('money_flow'):
data_types.append('money_flow')
if not data_types:
data_types = ['valuation', 'money_flow']
params['data_types'] = data_types
elif skill_name == 'brave_search':
# 构建搜索查询
target = intent.get('target', {})
stock_name = target.get('stock_name', '')
stock_code = target.get('stock_code', '')
dimensions = intent.get('dimensions', {})
# 根据维度构建搜索关键词
search_keywords = []
if stock_name:
search_keywords.append(stock_name)
elif stock_code:
search_keywords.append(stock_code)
# 添加维度相关关键词
if dimensions.get('fundamental'):
search_keywords.append('财报 业绩')
if dimensions.get('news'):
search_keywords.append('最新消息')
if dimensions.get('risk'):
search_keywords.append('风险 预警')
# 如果没有特定维度,搜索一般新闻
if not any(dimensions.values()):
search_keywords.append('最新动态')
params['query'] = ' '.join(search_keywords)
params['search_type'] = 'news' # 默认搜索新闻
params['count'] = 5
params['freshness'] = 'pw' # 过去一周
return params
def _get_skill_reason(self, skill_name: str, intent: Dict[str, Any]) -> str:
"""获取调用该技能的原因"""
dimensions = intent.get('dimensions', {})
reasons = []
if skill_name == 'market_data':
if dimensions.get('price_trend'):
reasons.append('用户关注价格走势')
else:
reasons.append('获取基础行情数据')
elif skill_name == 'technical_analysis':
if dimensions.get('technical'):
reasons.append('用户关注技术指标')
else:
reasons.append('提供技术面分析')
elif skill_name == 'fundamental':
if dimensions.get('fundamental'):
reasons.append('用户关注基本面')
else:
reasons.append('提供公司基本信息')
elif skill_name == 'advanced_data':
if dimensions.get('valuation'):
reasons.append('用户关注估值')
if dimensions.get('money_flow'):
reasons.append('用户关注资金流向')
if not reasons:
reasons.append('提供高级财务数据')
elif skill_name == 'visualization':
reasons.append('生成K线图表')
elif skill_name == 'brave_search':
if dimensions.get('news'):
reasons.append('用户关注最新新闻')
elif dimensions.get('fundamental'):
reasons.append('搜索公司最新动态和财报信息')
elif dimensions.get('risk'):
reasons.append('搜索风险预警信息')
else:
reasons.append('获取最新市场资讯和舆情')
return ', '.join(reasons) if reasons else '提供分析数据'