203 lines
6.5 KiB
Python
203 lines
6.5 KiB
Python
"""
|
||
技术分析技能
|
||
提供技术指标计算和分析
|
||
"""
|
||
import pandas as pd
|
||
from typing import Dict, Any
|
||
from app.skills.base import BaseSkill, SkillParameter
|
||
from app.services.tushare_service import tushare_service
|
||
from app.utils.indicators import (
|
||
calculate_ma, calculate_macd, calculate_rsi,
|
||
calculate_kdj, calculate_boll
|
||
)
|
||
from app.utils.logger import logger
|
||
|
||
|
||
class TechnicalAnalysisSkill(BaseSkill):
|
||
"""技术分析技能"""
|
||
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.name = "technical_analysis"
|
||
self.description = "计算股票技术指标(MA、MACD、RSI、KDJ、BOLL等)"
|
||
self.parameters = [
|
||
SkillParameter(
|
||
name="stock_code",
|
||
type="string",
|
||
description="股票代码",
|
||
required=True
|
||
),
|
||
SkillParameter(
|
||
name="indicators",
|
||
type="array",
|
||
description="要计算的指标列表(ma、macd、rsi、kdj、boll)",
|
||
required=False,
|
||
default=["ma", "macd"]
|
||
),
|
||
SkillParameter(
|
||
name="period",
|
||
type="integer",
|
||
description="数据周期(天数)",
|
||
required=False,
|
||
default=60
|
||
)
|
||
]
|
||
|
||
async def execute(self, **kwargs) -> Dict[str, Any]:
|
||
"""
|
||
执行技术分析
|
||
|
||
Args:
|
||
stock_code: 股票代码
|
||
indicators: 指标列表
|
||
period: 数据周期
|
||
|
||
Returns:
|
||
技术指标结果
|
||
"""
|
||
stock_code = kwargs.get("stock_code")
|
||
indicators = kwargs.get("indicators", ["ma", "macd"])
|
||
period = kwargs.get("period", 60)
|
||
|
||
logger.info(f"技术分析: {stock_code}, 指标: {indicators}")
|
||
|
||
# 获取K线数据
|
||
kline_data = tushare_service.get_kline_data(stock_code)
|
||
|
||
if not kline_data:
|
||
return {
|
||
"error": f"未找到K线数据: {stock_code}"
|
||
}
|
||
|
||
# 转换为DataFrame
|
||
df = pd.DataFrame(kline_data)
|
||
|
||
# 计算指标
|
||
result = {
|
||
"stock_code": stock_code,
|
||
"indicators": {}
|
||
}
|
||
|
||
try:
|
||
if "ma" in indicators:
|
||
result["indicators"]["ma"] = self._calculate_ma(df)
|
||
|
||
if "macd" in indicators:
|
||
result["indicators"]["macd"] = self._calculate_macd(df)
|
||
|
||
if "rsi" in indicators:
|
||
result["indicators"]["rsi"] = self._calculate_rsi(df)
|
||
|
||
if "kdj" in indicators:
|
||
result["indicators"]["kdj"] = self._calculate_kdj(df)
|
||
|
||
if "boll" in indicators:
|
||
result["indicators"]["boll"] = self._calculate_boll(df)
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"技术指标计算失败: {e}")
|
||
return {
|
||
"error": f"技术指标计算失败: {str(e)}"
|
||
}
|
||
|
||
def _calculate_ma(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||
"""计算均线"""
|
||
close = df['close']
|
||
|
||
ma5 = calculate_ma(close, 5)
|
||
ma10 = calculate_ma(close, 10)
|
||
ma20 = calculate_ma(close, 20)
|
||
ma60 = calculate_ma(close, 60)
|
||
|
||
# 获取最新值
|
||
latest_ma5 = ma5.iloc[-1] if not ma5.empty else None
|
||
latest_ma10 = ma10.iloc[-1] if not ma10.empty else None
|
||
latest_ma20 = ma20.iloc[-1] if not ma20.empty else None
|
||
latest_ma60 = ma60.iloc[-1] if not ma60.empty else None
|
||
|
||
return {
|
||
"ma5": round(latest_ma5, 2) if latest_ma5 else None,
|
||
"ma10": round(latest_ma10, 2) if latest_ma10 else None,
|
||
"ma20": round(latest_ma20, 2) if latest_ma20 else None,
|
||
"ma60": round(latest_ma60, 2) if latest_ma60 else None,
|
||
"description": "移动平均线"
|
||
}
|
||
|
||
def _calculate_macd(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||
"""计算MACD"""
|
||
close = df['close']
|
||
|
||
dif, dea, macd = calculate_macd(close)
|
||
|
||
# 获取最新值
|
||
latest_dif = dif.iloc[-1] if not dif.empty else None
|
||
latest_dea = dea.iloc[-1] if not dea.empty else None
|
||
latest_macd = macd.iloc[-1] if not macd.empty else None
|
||
|
||
return {
|
||
"dif": round(latest_dif, 2) if latest_dif else None,
|
||
"dea": round(latest_dea, 2) if latest_dea else None,
|
||
"macd": round(latest_macd, 2) if latest_macd else None,
|
||
"description": "MACD指标"
|
||
}
|
||
|
||
def _calculate_rsi(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||
"""计算RSI"""
|
||
close = df['close']
|
||
|
||
rsi6 = calculate_rsi(close, 6)
|
||
rsi12 = calculate_rsi(close, 12)
|
||
rsi24 = calculate_rsi(close, 24)
|
||
|
||
# 获取最新值
|
||
latest_rsi6 = rsi6.iloc[-1] if not rsi6.empty else None
|
||
latest_rsi12 = rsi12.iloc[-1] if not rsi12.empty else None
|
||
latest_rsi24 = rsi24.iloc[-1] if not rsi24.empty else None
|
||
|
||
return {
|
||
"rsi6": round(latest_rsi6, 2) if latest_rsi6 else None,
|
||
"rsi12": round(latest_rsi12, 2) if latest_rsi12 else None,
|
||
"rsi24": round(latest_rsi24, 2) if latest_rsi24 else None,
|
||
"description": "相对强弱指标"
|
||
}
|
||
|
||
def _calculate_kdj(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||
"""计算KDJ"""
|
||
high = df['high']
|
||
low = df['low']
|
||
close = df['close']
|
||
|
||
k, d, j = calculate_kdj(high, low, close)
|
||
|
||
# 获取最新值
|
||
latest_k = k.iloc[-1] if not k.empty else None
|
||
latest_d = d.iloc[-1] if not d.empty else None
|
||
latest_j = j.iloc[-1] if not j.empty else None
|
||
|
||
return {
|
||
"k": round(latest_k, 2) if latest_k else None,
|
||
"d": round(latest_d, 2) if latest_d else None,
|
||
"j": round(latest_j, 2) if latest_j else None,
|
||
"description": "KDJ指标"
|
||
}
|
||
|
||
def _calculate_boll(self, df: pd.DataFrame) -> Dict[str, Any]:
|
||
"""计算布林带"""
|
||
close = df['close']
|
||
|
||
upper, middle, lower = calculate_boll(close)
|
||
|
||
# 获取最新值
|
||
latest_upper = upper.iloc[-1] if not upper.empty else None
|
||
latest_middle = middle.iloc[-1] if not middle.empty else None
|
||
latest_lower = lower.iloc[-1] if not lower.empty else None
|
||
|
||
return {
|
||
"upper": round(latest_upper, 2) if latest_upper else None,
|
||
"middle": round(latest_middle, 2) if latest_middle else None,
|
||
"lower": round(latest_lower, 2) if latest_lower else None,
|
||
"description": "布林带"
|
||
}
|