159 lines
3.2 KiB
Python
159 lines
3.2 KiB
Python
"""
|
||
技术指标计算模块
|
||
提供常用技术指标的计算功能
|
||
"""
|
||
import pandas as pd
|
||
import numpy as np
|
||
from typing import Tuple
|
||
|
||
|
||
def calculate_ma(data: pd.Series, period: int = 5) -> pd.Series:
|
||
"""
|
||
计算移动平均线(MA)
|
||
|
||
Args:
|
||
data: 价格数据
|
||
period: 周期
|
||
|
||
Returns:
|
||
MA值
|
||
"""
|
||
return data.rolling(window=period).mean()
|
||
|
||
|
||
def calculate_ema(data: pd.Series, period: int = 12) -> pd.Series:
|
||
"""
|
||
计算指数移动平均线(EMA)
|
||
|
||
Args:
|
||
data: 价格数据
|
||
period: 周期
|
||
|
||
Returns:
|
||
EMA值
|
||
"""
|
||
return data.ewm(span=period, adjust=False).mean()
|
||
|
||
|
||
def calculate_macd(
|
||
data: pd.Series,
|
||
fast_period: int = 12,
|
||
slow_period: int = 26,
|
||
signal_period: int = 9
|
||
) -> Tuple[pd.Series, pd.Series, pd.Series]:
|
||
"""
|
||
计算MACD指标
|
||
|
||
Args:
|
||
data: 价格数据
|
||
fast_period: 快线周期
|
||
slow_period: 慢线周期
|
||
signal_period: 信号线周期
|
||
|
||
Returns:
|
||
(DIF, DEA, MACD柱)
|
||
"""
|
||
ema_fast = calculate_ema(data, fast_period)
|
||
ema_slow = calculate_ema(data, slow_period)
|
||
|
||
dif = ema_fast - ema_slow
|
||
dea = dif.ewm(span=signal_period, adjust=False).mean()
|
||
macd = (dif - dea) * 2
|
||
|
||
return dif, dea, macd
|
||
|
||
|
||
def calculate_rsi(data: pd.Series, period: int = 14) -> pd.Series:
|
||
"""
|
||
计算相对强弱指标(RSI)
|
||
|
||
Args:
|
||
data: 价格数据
|
||
period: 周期
|
||
|
||
Returns:
|
||
RSI值
|
||
"""
|
||
delta = data.diff()
|
||
|
||
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
||
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
||
|
||
rs = gain / loss
|
||
rsi = 100 - (100 / (1 + rs))
|
||
|
||
return rsi
|
||
|
||
|
||
def calculate_kdj(
|
||
high: pd.Series,
|
||
low: pd.Series,
|
||
close: pd.Series,
|
||
period: int = 9,
|
||
m1: int = 3,
|
||
m2: int = 3
|
||
) -> Tuple[pd.Series, pd.Series, pd.Series]:
|
||
"""
|
||
计算KDJ指标
|
||
|
||
Args:
|
||
high: 最高价
|
||
low: 最低价
|
||
close: 收盘价
|
||
period: 周期
|
||
m1: K值平滑参数
|
||
m2: D值平滑参数
|
||
|
||
Returns:
|
||
(K, D, J)
|
||
"""
|
||
low_min = low.rolling(window=period).min()
|
||
high_max = high.rolling(window=period).max()
|
||
|
||
rsv = (close - low_min) / (high_max - low_min) * 100
|
||
|
||
k = rsv.ewm(com=m1 - 1, adjust=False).mean()
|
||
d = k.ewm(com=m2 - 1, adjust=False).mean()
|
||
j = 3 * k - 2 * d
|
||
|
||
return k, d, j
|
||
|
||
|
||
def calculate_boll(
|
||
data: pd.Series,
|
||
period: int = 20,
|
||
std_dev: float = 2.0
|
||
) -> Tuple[pd.Series, pd.Series, pd.Series]:
|
||
"""
|
||
计算布林带(BOLL)
|
||
|
||
Args:
|
||
data: 价格数据
|
||
period: 周期
|
||
std_dev: 标准差倍数
|
||
|
||
Returns:
|
||
(上轨, 中轨, 下轨)
|
||
"""
|
||
middle = data.rolling(window=period).mean()
|
||
std = data.rolling(window=period).std()
|
||
|
||
upper = middle + (std * std_dev)
|
||
lower = middle - (std * std_dev)
|
||
|
||
return upper, middle, lower
|
||
|
||
|
||
def calculate_volume_ma(volume: pd.Series, period: int = 5) -> pd.Series:
|
||
"""
|
||
计算成交量移动平均
|
||
|
||
Args:
|
||
volume: 成交量数据
|
||
period: 周期
|
||
|
||
Returns:
|
||
成交量MA
|
||
"""
|
||
return volume.rolling(window=period).mean()
|