93 lines
2.6 KiB
Python
93 lines
2.6 KiB
Python
"""技术指标计算
|
||
|
||
基于 pandas 自研所有技术指标,不依赖第三方指标库。
|
||
支持:MA, EMA, MACD, RSI, BOLL, 量价关系。
|
||
"""
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
|
||
def calc_ma(series: pd.Series, period: int) -> pd.Series:
|
||
return series.rolling(window=period, min_periods=1).mean()
|
||
|
||
|
||
def calc_ema(series: pd.Series, period: int) -> pd.Series:
|
||
return series.ewm(span=period, adjust=False).mean()
|
||
|
||
|
||
def calc_macd(close: pd.Series, fast: int = 12, slow: int = 26, signal: int = 9) -> dict:
|
||
ema_fast = calc_ema(close, fast)
|
||
ema_slow = calc_ema(close, slow)
|
||
dif = ema_fast - ema_slow
|
||
dea = calc_ema(dif, signal)
|
||
macd_hist = (dif - dea) * 2
|
||
return {"dif": dif, "dea": dea, "macd": macd_hist}
|
||
|
||
|
||
def calc_rsi(close: pd.Series, period: int = 14) -> pd.Series:
|
||
delta = close.diff()
|
||
gain = delta.where(delta > 0, 0.0)
|
||
loss = (-delta).where(delta < 0, 0.0)
|
||
avg_gain = gain.rolling(window=period, min_periods=1).mean()
|
||
avg_loss = loss.rolling(window=period, min_periods=1).mean()
|
||
rs = avg_gain / avg_loss.replace(0, np.nan)
|
||
return 100 - (100 / (1 + rs))
|
||
|
||
|
||
def calc_boll(close: pd.Series, period: int = 20, num_std: float = 2.0) -> dict:
|
||
mid = calc_ma(close, period)
|
||
std = close.rolling(window=period, min_periods=1).std()
|
||
upper = mid + num_std * std
|
||
lower = mid - num_std * std
|
||
bandwidth = ((upper - lower) / mid) * 100
|
||
return {"upper": upper, "mid": mid, "lower": lower, "bandwidth": bandwidth}
|
||
|
||
|
||
def calc_volume_ma(volume: pd.Series, period: int = 5) -> pd.Series:
|
||
return volume.rolling(window=period, min_periods=1).mean()
|
||
|
||
|
||
def add_all_indicators(df: pd.DataFrame) -> pd.DataFrame:
|
||
"""为 DataFrame 添加所有技术指标列。
|
||
|
||
输入 df 需包含: close, high, low, open, vol (成交量)
|
||
"""
|
||
if df.empty or len(df) < 5:
|
||
return df
|
||
|
||
df = df.copy()
|
||
close = df["close"]
|
||
vol = df["vol"]
|
||
|
||
# 均线
|
||
for p in [5, 10, 20, 60]:
|
||
df[f"ma{p}"] = calc_ma(close, p)
|
||
|
||
# MACD
|
||
macd = calc_macd(close)
|
||
df["dif"] = macd["dif"]
|
||
df["dea"] = macd["dea"]
|
||
df["macd_hist"] = macd["macd"]
|
||
|
||
# RSI
|
||
df["rsi14"] = calc_rsi(close, 14)
|
||
|
||
# BOLL
|
||
boll = calc_boll(close, 20)
|
||
df["boll_upper"] = boll["upper"]
|
||
df["boll_mid"] = boll["mid"]
|
||
df["boll_lower"] = boll["lower"]
|
||
df["boll_bw"] = boll["bandwidth"]
|
||
|
||
# 量均线
|
||
df["vol_ma5"] = calc_volume_ma(vol, 5)
|
||
df["vol_ma10"] = calc_volume_ma(vol, 10)
|
||
df["vol_ma20"] = calc_volume_ma(vol, 20)
|
||
|
||
# 涨跌幅(如果没有 pct_chg 列)
|
||
if "pct_chg" not in df.columns:
|
||
df["pct_chg"] = close.pct_change() * 100
|
||
|
||
return df
|