""" 技能管理器 管理所有技能的注册、发现和调用 """ import asyncio from typing import Dict, Optional, List, Type, Any from app.skills.base import BaseSkill from app.utils.logger import logger class SkillManager: """技能管理器""" def __init__(self): """初始化技能管理器""" self._skills: Dict[str, BaseSkill] = {} logger.info("技能管理器初始化") def register(self, skill: BaseSkill) -> bool: """ 注册技能 Args: skill: 技能实例 Returns: 是否成功 """ if not skill.name: logger.error("技能名称不能为空") return False if skill.name in self._skills: logger.warning(f"技能已存在,将被覆盖: {skill.name}") self._skills[skill.name] = skill logger.info(f"技能注册成功: {skill.name}") return True def unregister(self, skill_name: str) -> bool: """ 注销技能 Args: skill_name: 技能名称 Returns: 是否成功 """ if skill_name in self._skills: del self._skills[skill_name] logger.info(f"技能注销成功: {skill_name}") return True logger.warning(f"技能不存在: {skill_name}") return False def get_skill(self, skill_name: str) -> Optional[BaseSkill]: """ 获取技能 Args: skill_name: 技能名称 Returns: 技能实例或None """ return self._skills.get(skill_name) def get_all_skills(self) -> List[BaseSkill]: """ 获取所有技能 Returns: 技能列表 """ return list(self._skills.values()) def get_enabled_skills(self) -> List[BaseSkill]: """ 获取所有启用的技能 Returns: 启用的技能列表 """ return [skill for skill in self._skills.values() if skill.enabled] async def execute_skill(self, skill_name: str, **kwargs) -> Dict: """ 执行技能 Args: skill_name: 技能名称 **kwargs: 技能参数 Returns: 执行结果 """ skill = self.get_skill(skill_name) if not skill: return { "success": False, "error": f"技能不存在: {skill_name}" } if not skill.enabled: return { "success": False, "error": f"技能已禁用: {skill_name}" } # 验证参数 valid, error = skill.validate_params(**kwargs) if not valid: return { "success": False, "error": error } # 执行技能 try: result = await skill.execute(**kwargs) return { "success": True, "data": result } except Exception as e: logger.error(f"技能执行失败 {skill_name}: {e}") return { "success": False, "error": str(e) } def enable_skill(self, skill_name: str) -> bool: """ 启用技能 Args: skill_name: 技能名称 Returns: 是否成功 """ skill = self.get_skill(skill_name) if skill: skill.enable() logger.info(f"技能已启用: {skill_name}") return True return False def disable_skill(self, skill_name: str) -> bool: """ 禁用技能 Args: skill_name: 技能名称 Returns: 是否成功 """ skill = self.get_skill(skill_name) if skill: skill.disable() logger.info(f"技能已禁用: {skill_name}") return True return False def get_skills_info(self) -> List[Dict]: """ 获取所有技能信息 Returns: 技能信息列表 """ return [skill.get_info() for skill in self._skills.values()] async def execute_plan( self, plan: Dict[str, Any], stock_code: str ) -> Dict[str, Any]: """ 执行技能规划 Args: plan: 技能执行计划(来自SkillPlanner) stock_code: 股票代码 Returns: { 'results': { 'market_data': {...}, 'technical_analysis': {...}, ... }, 'execution_time': float, 'errors': List[str] } """ import time start_time = time.time() skills = plan.get('skills', []) strategy = plan.get('execution_strategy', 'parallel') logger.info(f"开始执行技能规划: {len(skills)}个技能, 策略: {strategy}") if strategy == 'parallel': results = await self._execute_parallel(skills, stock_code) else: results = await self._execute_sequential(skills, stock_code) execution_time = time.time() - start_time logger.info(f"技能规划执行完成,耗时: {execution_time:.2f}秒") return { 'results': results['results'], 'execution_time': execution_time, 'errors': results['errors'] } async def _execute_parallel( self, skills: List[Dict[str, Any]], stock_code: str ) -> Dict[str, Any]: """ 并行执行技能(按优先级分组) Args: skills: 技能列表 stock_code: 股票代码 Returns: 执行结果 """ # 按优先级分组 priority_groups = {} for skill_info in skills: priority = skill_info['priority'] if priority not in priority_groups: priority_groups[priority] = [] priority_groups[priority].append(skill_info) all_results = {} all_errors = [] # 按优先级顺序执行 for priority in sorted(priority_groups.keys()): skill_group = priority_groups[priority] logger.info(f"执行优先级 {priority} 的技能: {[s['name'] for s in skill_group]}") # 同一优先级的技能并行执行 tasks = [] for skill_info in skill_group: params = skill_info['params'].copy() params['stock_code'] = stock_code task = self.execute_skill(skill_info['name'], **params) tasks.append((skill_info['name'], task)) # 等待所有任务完成 results = await asyncio.gather(*[task for _, task in tasks], return_exceptions=True) # 处理结果 for (skill_name, _), result in zip(tasks, results): if isinstance(result, Exception): logger.error(f"技能执行失败: {skill_name}, {result}") all_errors.append(f"{skill_name}: {str(result)}") all_results[skill_name] = {'error': str(result)} elif result.get('success'): all_results[skill_name] = result.get('data', {}) else: error_msg = result.get('error', '未知错误') logger.error(f"技能执行失败: {skill_name}, {error_msg}") all_errors.append(f"{skill_name}: {error_msg}") all_results[skill_name] = {'error': error_msg} return { 'results': all_results, 'errors': all_errors } async def _execute_sequential( self, skills: List[Dict[str, Any]], stock_code: str ) -> Dict[str, Any]: """ 串行执行技能 Args: skills: 技能列表 stock_code: 股票代码 Returns: 执行结果 """ all_results = {} all_errors = [] for skill_info in skills: skill_name = skill_info['name'] params = skill_info['params'].copy() params['stock_code'] = stock_code logger.info(f"执行技能: {skill_name}") try: result = await self.execute_skill(skill_name, **params) if result.get('success'): all_results[skill_name] = result.get('data', {}) else: error_msg = result.get('error', '未知错误') logger.error(f"技能执行失败: {skill_name}, {error_msg}") all_errors.append(f"{skill_name}: {error_msg}") all_results[skill_name] = {'error': error_msg} except Exception as e: logger.error(f"技能执行异常: {skill_name}, {e}") all_errors.append(f"{skill_name}: {str(e)}") all_results[skill_name] = {'error': str(e)} return { 'results': all_results, 'errors': all_errors } # 创建全局技能管理器实例 skill_manager = SkillManager()