stock-ai-agent/backend/app/utils/indicators.py
2026-02-03 10:08:15 +08:00

159 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
技术指标计算模块
提供常用技术指标的计算功能
"""
import pandas as pd
import numpy as np
from typing import Tuple
def calculate_ma(data: pd.Series, period: int = 5) -> pd.Series:
"""
计算移动平均线MA
Args:
data: 价格数据
period: 周期
Returns:
MA值
"""
return data.rolling(window=period).mean()
def calculate_ema(data: pd.Series, period: int = 12) -> pd.Series:
"""
计算指数移动平均线EMA
Args:
data: 价格数据
period: 周期
Returns:
EMA值
"""
return data.ewm(span=period, adjust=False).mean()
def calculate_macd(
data: pd.Series,
fast_period: int = 12,
slow_period: int = 26,
signal_period: int = 9
) -> Tuple[pd.Series, pd.Series, pd.Series]:
"""
计算MACD指标
Args:
data: 价格数据
fast_period: 快线周期
slow_period: 慢线周期
signal_period: 信号线周期
Returns:
(DIF, DEA, MACD柱)
"""
ema_fast = calculate_ema(data, fast_period)
ema_slow = calculate_ema(data, slow_period)
dif = ema_fast - ema_slow
dea = dif.ewm(span=signal_period, adjust=False).mean()
macd = (dif - dea) * 2
return dif, dea, macd
def calculate_rsi(data: pd.Series, period: int = 14) -> pd.Series:
"""
计算相对强弱指标RSI
Args:
data: 价格数据
period: 周期
Returns:
RSI值
"""
delta = data.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
rsi = 100 - (100 / (1 + rs))
return rsi
def calculate_kdj(
high: pd.Series,
low: pd.Series,
close: pd.Series,
period: int = 9,
m1: int = 3,
m2: int = 3
) -> Tuple[pd.Series, pd.Series, pd.Series]:
"""
计算KDJ指标
Args:
high: 最高价
low: 最低价
close: 收盘价
period: 周期
m1: K值平滑参数
m2: D值平滑参数
Returns:
(K, D, J)
"""
low_min = low.rolling(window=period).min()
high_max = high.rolling(window=period).max()
rsv = (close - low_min) / (high_max - low_min) * 100
k = rsv.ewm(com=m1 - 1, adjust=False).mean()
d = k.ewm(com=m2 - 1, adjust=False).mean()
j = 3 * k - 2 * d
return k, d, j
def calculate_boll(
data: pd.Series,
period: int = 20,
std_dev: float = 2.0
) -> Tuple[pd.Series, pd.Series, pd.Series]:
"""
计算布林带BOLL
Args:
data: 价格数据
period: 周期
std_dev: 标准差倍数
Returns:
(上轨, 中轨, 下轨)
"""
middle = data.rolling(window=period).mean()
std = data.rolling(window=period).std()
upper = middle + (std * std_dev)
lower = middle - (std * std_dev)
return upper, middle, lower
def calculate_volume_ma(volume: pd.Series, period: int = 5) -> pd.Series:
"""
计算成交量移动平均
Args:
volume: 成交量数据
period: 周期
Returns:
成交量MA
"""
return volume.rolling(window=period).mean()