astock-agent/backend/app/analysis/technical.py
2026-04-10 23:38:37 +08:00

93 lines
2.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""技术指标计算
基于 pandas 自研所有技术指标,不依赖第三方指标库。
支持MA, EMA, MACD, RSI, BOLL, 量价关系。
"""
import numpy as np
import pandas as pd
def calc_ma(series: pd.Series, period: int) -> pd.Series:
return series.rolling(window=period, min_periods=1).mean()
def calc_ema(series: pd.Series, period: int) -> pd.Series:
return series.ewm(span=period, adjust=False).mean()
def calc_macd(close: pd.Series, fast: int = 12, slow: int = 26, signal: int = 9) -> dict:
ema_fast = calc_ema(close, fast)
ema_slow = calc_ema(close, slow)
dif = ema_fast - ema_slow
dea = calc_ema(dif, signal)
macd_hist = (dif - dea) * 2
return {"dif": dif, "dea": dea, "macd": macd_hist}
def calc_rsi(close: pd.Series, period: int = 14) -> pd.Series:
delta = close.diff()
gain = delta.where(delta > 0, 0.0)
loss = (-delta).where(delta < 0, 0.0)
avg_gain = gain.rolling(window=period, min_periods=1).mean()
avg_loss = loss.rolling(window=period, min_periods=1).mean()
rs = avg_gain / avg_loss.replace(0, np.nan)
return 100 - (100 / (1 + rs))
def calc_boll(close: pd.Series, period: int = 20, num_std: float = 2.0) -> dict:
mid = calc_ma(close, period)
std = close.rolling(window=period, min_periods=1).std()
upper = mid + num_std * std
lower = mid - num_std * std
bandwidth = ((upper - lower) / mid) * 100
return {"upper": upper, "mid": mid, "lower": lower, "bandwidth": bandwidth}
def calc_volume_ma(volume: pd.Series, period: int = 5) -> pd.Series:
return volume.rolling(window=period, min_periods=1).mean()
def add_all_indicators(df: pd.DataFrame) -> pd.DataFrame:
"""为 DataFrame 添加所有技术指标列。
输入 df 需包含: close, high, low, open, vol (成交量)
"""
if df.empty or len(df) < 5:
return df
df = df.copy()
close = df["close"]
vol = df["vol"]
# 均线
for p in [5, 10, 20, 60]:
df[f"ma{p}"] = calc_ma(close, p)
# MACD
macd = calc_macd(close)
df["dif"] = macd["dif"]
df["dea"] = macd["dea"]
df["macd_hist"] = macd["macd"]
# RSI
df["rsi14"] = calc_rsi(close, 14)
# BOLL
boll = calc_boll(close, 20)
df["boll_upper"] = boll["upper"]
df["boll_mid"] = boll["mid"]
df["boll_lower"] = boll["lower"]
df["boll_bw"] = boll["bandwidth"]
# 量均线
df["vol_ma5"] = calc_volume_ma(vol, 5)
df["vol_ma10"] = calc_volume_ma(vol, 10)
df["vol_ma20"] = calc_volume_ma(vol, 20)
# 涨跌幅(如果没有 pct_chg 列)
if "pct_chg" not in df.columns:
df["pct_chg"] = close.pct_change() * 100
return df