stock-ai-agent/backend/app/agent/skill_manager.py
2026-02-03 23:50:48 +08:00

333 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
技能管理器
管理所有技能的注册、发现和调用
"""
import asyncio
from typing import Dict, Optional, List, Type, Any
from app.skills.base import BaseSkill
from app.utils.logger import logger
class SkillManager:
"""技能管理器"""
def __init__(self):
"""初始化技能管理器"""
self._skills: Dict[str, BaseSkill] = {}
logger.info("技能管理器初始化")
def register(self, skill: BaseSkill) -> bool:
"""
注册技能
Args:
skill: 技能实例
Returns:
是否成功
"""
if not skill.name:
logger.error("技能名称不能为空")
return False
if skill.name in self._skills:
logger.warning(f"技能已存在,将被覆盖: {skill.name}")
self._skills[skill.name] = skill
logger.info(f"技能注册成功: {skill.name}")
return True
def unregister(self, skill_name: str) -> bool:
"""
注销技能
Args:
skill_name: 技能名称
Returns:
是否成功
"""
if skill_name in self._skills:
del self._skills[skill_name]
logger.info(f"技能注销成功: {skill_name}")
return True
logger.warning(f"技能不存在: {skill_name}")
return False
def get_skill(self, skill_name: str) -> Optional[BaseSkill]:
"""
获取技能
Args:
skill_name: 技能名称
Returns:
技能实例或None
"""
return self._skills.get(skill_name)
def get_all_skills(self) -> List[BaseSkill]:
"""
获取所有技能
Returns:
技能列表
"""
return list(self._skills.values())
def get_enabled_skills(self) -> List[BaseSkill]:
"""
获取所有启用的技能
Returns:
启用的技能列表
"""
return [skill for skill in self._skills.values() if skill.enabled]
async def execute_skill(self, skill_name: str, **kwargs) -> Dict:
"""
执行技能
Args:
skill_name: 技能名称
**kwargs: 技能参数
Returns:
执行结果
"""
skill = self.get_skill(skill_name)
if not skill:
return {
"success": False,
"error": f"技能不存在: {skill_name}"
}
if not skill.enabled:
return {
"success": False,
"error": f"技能已禁用: {skill_name}"
}
# 验证参数
valid, error = skill.validate_params(**kwargs)
if not valid:
return {
"success": False,
"error": error
}
# 执行技能
try:
result = await skill.execute(**kwargs)
return {
"success": True,
"data": result
}
except Exception as e:
logger.error(f"技能执行失败 {skill_name}: {e}")
return {
"success": False,
"error": str(e)
}
def enable_skill(self, skill_name: str) -> bool:
"""
启用技能
Args:
skill_name: 技能名称
Returns:
是否成功
"""
skill = self.get_skill(skill_name)
if skill:
skill.enable()
logger.info(f"技能已启用: {skill_name}")
return True
return False
def disable_skill(self, skill_name: str) -> bool:
"""
禁用技能
Args:
skill_name: 技能名称
Returns:
是否成功
"""
skill = self.get_skill(skill_name)
if skill:
skill.disable()
logger.info(f"技能已禁用: {skill_name}")
return True
return False
def get_skills_info(self) -> List[Dict]:
"""
获取所有技能信息
Returns:
技能信息列表
"""
return [skill.get_info() for skill in self._skills.values()]
async def execute_plan(
self,
plan: Dict[str, Any],
stock_code: str
) -> Dict[str, Any]:
"""
执行技能规划
Args:
plan: 技能执行计划来自SkillPlanner
stock_code: 股票代码
Returns:
{
'results': {
'market_data': {...},
'technical_analysis': {...},
...
},
'execution_time': float,
'errors': List[str]
}
"""
import time
start_time = time.time()
skills = plan.get('skills', [])
strategy = plan.get('execution_strategy', 'parallel')
logger.info(f"开始执行技能规划: {len(skills)}个技能, 策略: {strategy}")
if strategy == 'parallel':
results = await self._execute_parallel(skills, stock_code)
else:
results = await self._execute_sequential(skills, stock_code)
execution_time = time.time() - start_time
logger.info(f"技能规划执行完成,耗时: {execution_time:.2f}")
return {
'results': results['results'],
'execution_time': execution_time,
'errors': results['errors']
}
async def _execute_parallel(
self,
skills: List[Dict[str, Any]],
stock_code: str
) -> Dict[str, Any]:
"""
并行执行技能(按优先级分组)
Args:
skills: 技能列表
stock_code: 股票代码
Returns:
执行结果
"""
# 按优先级分组
priority_groups = {}
for skill_info in skills:
priority = skill_info['priority']
if priority not in priority_groups:
priority_groups[priority] = []
priority_groups[priority].append(skill_info)
all_results = {}
all_errors = []
# 按优先级顺序执行
for priority in sorted(priority_groups.keys()):
skill_group = priority_groups[priority]
logger.info(f"执行优先级 {priority} 的技能: {[s['name'] for s in skill_group]}")
# 同一优先级的技能并行执行
tasks = []
for skill_info in skill_group:
params = skill_info['params'].copy()
params['stock_code'] = stock_code
task = self.execute_skill(skill_info['name'], **params)
tasks.append((skill_info['name'], task))
# 等待所有任务完成
results = await asyncio.gather(*[task for _, task in tasks], return_exceptions=True)
# 处理结果
for (skill_name, _), result in zip(tasks, results):
if isinstance(result, Exception):
logger.error(f"技能执行失败: {skill_name}, {result}")
all_errors.append(f"{skill_name}: {str(result)}")
all_results[skill_name] = {'error': str(result)}
elif result.get('success'):
all_results[skill_name] = result.get('data', {})
else:
error_msg = result.get('error', '未知错误')
logger.error(f"技能执行失败: {skill_name}, {error_msg}")
all_errors.append(f"{skill_name}: {error_msg}")
all_results[skill_name] = {'error': error_msg}
return {
'results': all_results,
'errors': all_errors
}
async def _execute_sequential(
self,
skills: List[Dict[str, Any]],
stock_code: str
) -> Dict[str, Any]:
"""
串行执行技能
Args:
skills: 技能列表
stock_code: 股票代码
Returns:
执行结果
"""
all_results = {}
all_errors = []
for skill_info in skills:
skill_name = skill_info['name']
params = skill_info['params'].copy()
params['stock_code'] = stock_code
logger.info(f"执行技能: {skill_name}")
try:
result = await self.execute_skill(skill_name, **params)
if result.get('success'):
all_results[skill_name] = result.get('data', {})
else:
error_msg = result.get('error', '未知错误')
logger.error(f"技能执行失败: {skill_name}, {error_msg}")
all_errors.append(f"{skill_name}: {error_msg}")
all_results[skill_name] = {'error': error_msg}
except Exception as e:
logger.error(f"技能执行异常: {skill_name}, {e}")
all_errors.append(f"{skill_name}: {str(e)}")
all_results[skill_name] = {'error': str(e)}
return {
'results': all_results,
'errors': all_errors
}
# 创建全局技能管理器实例
skill_manager = SkillManager()