stock-ai-agent/backend/app/skills/technical_analysis.py
2026-02-03 10:08:15 +08:00

203 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
技术分析技能
提供技术指标计算和分析
"""
import pandas as pd
from typing import Dict, Any
from app.skills.base import BaseSkill, SkillParameter
from app.services.tushare_service import tushare_service
from app.utils.indicators import (
calculate_ma, calculate_macd, calculate_rsi,
calculate_kdj, calculate_boll
)
from app.utils.logger import logger
class TechnicalAnalysisSkill(BaseSkill):
"""技术分析技能"""
def __init__(self):
super().__init__()
self.name = "technical_analysis"
self.description = "计算股票技术指标MA、MACD、RSI、KDJ、BOLL等"
self.parameters = [
SkillParameter(
name="stock_code",
type="string",
description="股票代码",
required=True
),
SkillParameter(
name="indicators",
type="array",
description="要计算的指标列表ma、macd、rsi、kdj、boll",
required=False,
default=["ma", "macd"]
),
SkillParameter(
name="period",
type="integer",
description="数据周期(天数)",
required=False,
default=60
)
]
async def execute(self, **kwargs) -> Dict[str, Any]:
"""
执行技术分析
Args:
stock_code: 股票代码
indicators: 指标列表
period: 数据周期
Returns:
技术指标结果
"""
stock_code = kwargs.get("stock_code")
indicators = kwargs.get("indicators", ["ma", "macd"])
period = kwargs.get("period", 60)
logger.info(f"技术分析: {stock_code}, 指标: {indicators}")
# 获取K线数据
kline_data = tushare_service.get_kline_data(stock_code)
if not kline_data:
return {
"error": f"未找到K线数据: {stock_code}"
}
# 转换为DataFrame
df = pd.DataFrame(kline_data)
# 计算指标
result = {
"stock_code": stock_code,
"indicators": {}
}
try:
if "ma" in indicators:
result["indicators"]["ma"] = self._calculate_ma(df)
if "macd" in indicators:
result["indicators"]["macd"] = self._calculate_macd(df)
if "rsi" in indicators:
result["indicators"]["rsi"] = self._calculate_rsi(df)
if "kdj" in indicators:
result["indicators"]["kdj"] = self._calculate_kdj(df)
if "boll" in indicators:
result["indicators"]["boll"] = self._calculate_boll(df)
return result
except Exception as e:
logger.error(f"技术指标计算失败: {e}")
return {
"error": f"技术指标计算失败: {str(e)}"
}
def _calculate_ma(self, df: pd.DataFrame) -> Dict[str, Any]:
"""计算均线"""
close = df['close']
ma5 = calculate_ma(close, 5)
ma10 = calculate_ma(close, 10)
ma20 = calculate_ma(close, 20)
ma60 = calculate_ma(close, 60)
# 获取最新值
latest_ma5 = ma5.iloc[-1] if not ma5.empty else None
latest_ma10 = ma10.iloc[-1] if not ma10.empty else None
latest_ma20 = ma20.iloc[-1] if not ma20.empty else None
latest_ma60 = ma60.iloc[-1] if not ma60.empty else None
return {
"ma5": round(latest_ma5, 2) if latest_ma5 else None,
"ma10": round(latest_ma10, 2) if latest_ma10 else None,
"ma20": round(latest_ma20, 2) if latest_ma20 else None,
"ma60": round(latest_ma60, 2) if latest_ma60 else None,
"description": "移动平均线"
}
def _calculate_macd(self, df: pd.DataFrame) -> Dict[str, Any]:
"""计算MACD"""
close = df['close']
dif, dea, macd = calculate_macd(close)
# 获取最新值
latest_dif = dif.iloc[-1] if not dif.empty else None
latest_dea = dea.iloc[-1] if not dea.empty else None
latest_macd = macd.iloc[-1] if not macd.empty else None
return {
"dif": round(latest_dif, 2) if latest_dif else None,
"dea": round(latest_dea, 2) if latest_dea else None,
"macd": round(latest_macd, 2) if latest_macd else None,
"description": "MACD指标"
}
def _calculate_rsi(self, df: pd.DataFrame) -> Dict[str, Any]:
"""计算RSI"""
close = df['close']
rsi6 = calculate_rsi(close, 6)
rsi12 = calculate_rsi(close, 12)
rsi24 = calculate_rsi(close, 24)
# 获取最新值
latest_rsi6 = rsi6.iloc[-1] if not rsi6.empty else None
latest_rsi12 = rsi12.iloc[-1] if not rsi12.empty else None
latest_rsi24 = rsi24.iloc[-1] if not rsi24.empty else None
return {
"rsi6": round(latest_rsi6, 2) if latest_rsi6 else None,
"rsi12": round(latest_rsi12, 2) if latest_rsi12 else None,
"rsi24": round(latest_rsi24, 2) if latest_rsi24 else None,
"description": "相对强弱指标"
}
def _calculate_kdj(self, df: pd.DataFrame) -> Dict[str, Any]:
"""计算KDJ"""
high = df['high']
low = df['low']
close = df['close']
k, d, j = calculate_kdj(high, low, close)
# 获取最新值
latest_k = k.iloc[-1] if not k.empty else None
latest_d = d.iloc[-1] if not d.empty else None
latest_j = j.iloc[-1] if not j.empty else None
return {
"k": round(latest_k, 2) if latest_k else None,
"d": round(latest_d, 2) if latest_d else None,
"j": round(latest_j, 2) if latest_j else None,
"description": "KDJ指标"
}
def _calculate_boll(self, df: pd.DataFrame) -> Dict[str, Any]:
"""计算布林带"""
close = df['close']
upper, middle, lower = calculate_boll(close)
# 获取最新值
latest_upper = upper.iloc[-1] if not upper.empty else None
latest_middle = middle.iloc[-1] if not middle.empty else None
latest_lower = lower.iloc[-1] if not lower.empty else None
return {
"upper": round(latest_upper, 2) if latest_upper else None,
"middle": round(latest_middle, 2) if latest_middle else None,
"lower": round(latest_lower, 2) if latest_lower else None,
"description": "布林带"
}