340 lines
9.7 KiB
Python
340 lines
9.7 KiB
Python
"""
|
||
技能管理器
|
||
管理所有技能的注册、发现和调用
|
||
"""
|
||
import asyncio
|
||
from typing import Dict, Optional, List, Type, Any
|
||
from app.skills.base import BaseSkill
|
||
from app.utils.logger import logger
|
||
|
||
|
||
class SkillManager:
|
||
"""技能管理器"""
|
||
|
||
def __init__(self):
|
||
"""初始化技能管理器"""
|
||
self._skills: Dict[str, BaseSkill] = {}
|
||
logger.info("技能管理器初始化")
|
||
|
||
def register(self, skill: BaseSkill) -> bool:
|
||
"""
|
||
注册技能
|
||
|
||
Args:
|
||
skill: 技能实例
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
if not skill.name:
|
||
logger.error("技能名称不能为空")
|
||
return False
|
||
|
||
if skill.name in self._skills:
|
||
logger.warning(f"技能已存在,将被覆盖: {skill.name}")
|
||
|
||
self._skills[skill.name] = skill
|
||
logger.info(f"技能注册成功: {skill.name}")
|
||
return True
|
||
|
||
def unregister(self, skill_name: str) -> bool:
|
||
"""
|
||
注销技能
|
||
|
||
Args:
|
||
skill_name: 技能名称
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
if skill_name in self._skills:
|
||
del self._skills[skill_name]
|
||
logger.info(f"技能注销成功: {skill_name}")
|
||
return True
|
||
|
||
logger.warning(f"技能不存在: {skill_name}")
|
||
return False
|
||
|
||
def get_skill(self, skill_name: str) -> Optional[BaseSkill]:
|
||
"""
|
||
获取技能
|
||
|
||
Args:
|
||
skill_name: 技能名称
|
||
|
||
Returns:
|
||
技能实例或None
|
||
"""
|
||
return self._skills.get(skill_name)
|
||
|
||
def get_all_skills(self) -> List[BaseSkill]:
|
||
"""
|
||
获取所有技能
|
||
|
||
Returns:
|
||
技能列表
|
||
"""
|
||
return list(self._skills.values())
|
||
|
||
def get_enabled_skills(self) -> List[BaseSkill]:
|
||
"""
|
||
获取所有启用的技能
|
||
|
||
Returns:
|
||
启用的技能列表
|
||
"""
|
||
return [skill for skill in self._skills.values() if skill.enabled]
|
||
|
||
async def execute_skill(self, skill_name: str, **kwargs) -> Dict:
|
||
"""
|
||
执行技能
|
||
|
||
Args:
|
||
skill_name: 技能名称
|
||
**kwargs: 技能参数
|
||
|
||
Returns:
|
||
执行结果
|
||
"""
|
||
skill = self.get_skill(skill_name)
|
||
|
||
if not skill:
|
||
logger.error(f"❌ 技能不存在: {skill_name}")
|
||
return {
|
||
"success": False,
|
||
"error": f"技能不存在: {skill_name}"
|
||
}
|
||
|
||
if not skill.enabled:
|
||
logger.warning(f"⚠️ 技能已禁用: {skill_name}")
|
||
return {
|
||
"success": False,
|
||
"error": f"技能已禁用: {skill_name}"
|
||
}
|
||
|
||
# 验证参数
|
||
valid, error = skill.validate_params(**kwargs)
|
||
if not valid:
|
||
logger.error(f"❌ 技能参数验证失败 {skill_name}: {error}")
|
||
return {
|
||
"success": False,
|
||
"error": error
|
||
}
|
||
|
||
# 执行技能
|
||
try:
|
||
logger.info(f"🚀 开始执行技能: {skill_name}, 参数: {kwargs}")
|
||
result = await skill.execute(**kwargs)
|
||
logger.info(f"✅ 技能执行成功: {skill_name}")
|
||
return {
|
||
"success": True,
|
||
"data": result
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"❌ 技能执行失败 {skill_name}: {e}")
|
||
return {
|
||
"success": False,
|
||
"error": str(e)
|
||
}
|
||
|
||
def enable_skill(self, skill_name: str) -> bool:
|
||
"""
|
||
启用技能
|
||
|
||
Args:
|
||
skill_name: 技能名称
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
skill = self.get_skill(skill_name)
|
||
if skill:
|
||
skill.enable()
|
||
logger.info(f"技能已启用: {skill_name}")
|
||
return True
|
||
return False
|
||
|
||
def disable_skill(self, skill_name: str) -> bool:
|
||
"""
|
||
禁用技能
|
||
|
||
Args:
|
||
skill_name: 技能名称
|
||
|
||
Returns:
|
||
是否成功
|
||
"""
|
||
skill = self.get_skill(skill_name)
|
||
if skill:
|
||
skill.disable()
|
||
logger.info(f"技能已禁用: {skill_name}")
|
||
return True
|
||
return False
|
||
|
||
def get_skills_info(self) -> List[Dict]:
|
||
"""
|
||
获取所有技能信息
|
||
|
||
Returns:
|
||
技能信息列表
|
||
"""
|
||
return [skill.get_info() for skill in self._skills.values()]
|
||
|
||
async def execute_plan(
|
||
self,
|
||
plan: Dict[str, Any],
|
||
stock_code: str
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
执行技能规划
|
||
|
||
Args:
|
||
plan: 技能执行计划(来自SkillPlanner)
|
||
stock_code: 股票代码
|
||
|
||
Returns:
|
||
{
|
||
'results': {
|
||
'market_data': {...},
|
||
'technical_analysis': {...},
|
||
...
|
||
},
|
||
'execution_time': float,
|
||
'errors': List[str]
|
||
}
|
||
"""
|
||
import time
|
||
start_time = time.time()
|
||
|
||
skills = plan.get('skills', [])
|
||
strategy = plan.get('execution_strategy', 'parallel')
|
||
|
||
logger.info(f"开始执行技能规划: {len(skills)}个技能, 策略: {strategy}")
|
||
|
||
if strategy == 'parallel':
|
||
results = await self._execute_parallel(skills, stock_code)
|
||
else:
|
||
results = await self._execute_sequential(skills, stock_code)
|
||
|
||
execution_time = time.time() - start_time
|
||
logger.info(f"技能规划执行完成,耗时: {execution_time:.2f}秒")
|
||
|
||
return {
|
||
'results': results['results'],
|
||
'execution_time': execution_time,
|
||
'errors': results['errors']
|
||
}
|
||
|
||
async def _execute_parallel(
|
||
self,
|
||
skills: List[Dict[str, Any]],
|
||
stock_code: str
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
并行执行技能(按优先级分组)
|
||
|
||
Args:
|
||
skills: 技能列表
|
||
stock_code: 股票代码
|
||
|
||
Returns:
|
||
执行结果
|
||
"""
|
||
# 按优先级分组
|
||
priority_groups = {}
|
||
for skill_info in skills:
|
||
priority = skill_info['priority']
|
||
if priority not in priority_groups:
|
||
priority_groups[priority] = []
|
||
priority_groups[priority].append(skill_info)
|
||
|
||
all_results = {}
|
||
all_errors = []
|
||
|
||
# 按优先级顺序执行
|
||
for priority in sorted(priority_groups.keys()):
|
||
skill_group = priority_groups[priority]
|
||
logger.info(f"📋 执行优先级 {priority} 的技能: {[s['name'] for s in skill_group]}")
|
||
|
||
# 同一优先级的技能并行执行
|
||
tasks = []
|
||
for skill_info in skill_group:
|
||
params = skill_info['params'].copy()
|
||
params['stock_code'] = stock_code
|
||
logger.info(f" ➡️ 准备执行技能: {skill_info['name']}, 原因: {skill_info.get('reason', '未知')}")
|
||
task = self.execute_skill(skill_info['name'], **params)
|
||
tasks.append((skill_info['name'], task))
|
||
|
||
# 等待所有任务完成
|
||
results = await asyncio.gather(*[task for _, task in tasks], return_exceptions=True)
|
||
|
||
# 处理结果
|
||
for (skill_name, _), result in zip(tasks, results):
|
||
if isinstance(result, Exception):
|
||
logger.error(f"❌ 技能执行异常: {skill_name}, {result}")
|
||
all_errors.append(f"{skill_name}: {str(result)}")
|
||
all_results[skill_name] = {'error': str(result)}
|
||
elif result.get('success'):
|
||
# 不在这里记录成功日志,因为 execute_skill 已经记录了
|
||
all_results[skill_name] = result.get('data', {})
|
||
else:
|
||
error_msg = result.get('error', '未知错误')
|
||
logger.error(f"❌ 技能执行失败: {skill_name}, {error_msg}")
|
||
all_errors.append(f"{skill_name}: {error_msg}")
|
||
all_results[skill_name] = {'error': error_msg}
|
||
|
||
return {
|
||
'results': all_results,
|
||
'errors': all_errors
|
||
}
|
||
|
||
async def _execute_sequential(
|
||
self,
|
||
skills: List[Dict[str, Any]],
|
||
stock_code: str
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
串行执行技能
|
||
|
||
Args:
|
||
skills: 技能列表
|
||
stock_code: 股票代码
|
||
|
||
Returns:
|
||
执行结果
|
||
"""
|
||
all_results = {}
|
||
all_errors = []
|
||
|
||
for skill_info in skills:
|
||
skill_name = skill_info['name']
|
||
params = skill_info['params'].copy()
|
||
params['stock_code'] = stock_code
|
||
|
||
logger.info(f"执行技能: {skill_name}")
|
||
|
||
try:
|
||
result = await self.execute_skill(skill_name, **params)
|
||
|
||
if result.get('success'):
|
||
all_results[skill_name] = result.get('data', {})
|
||
else:
|
||
error_msg = result.get('error', '未知错误')
|
||
logger.error(f"技能执行失败: {skill_name}, {error_msg}")
|
||
all_errors.append(f"{skill_name}: {error_msg}")
|
||
all_results[skill_name] = {'error': error_msg}
|
||
|
||
except Exception as e:
|
||
logger.error(f"技能执行异常: {skill_name}, {e}")
|
||
all_errors.append(f"{skill_name}: {str(e)}")
|
||
all_results[skill_name] = {'error': str(e)}
|
||
|
||
return {
|
||
'results': all_results,
|
||
'errors': all_errors
|
||
}
|
||
|
||
|
||
# 创建全局技能管理器实例
|
||
skill_manager = SkillManager()
|