297 lines
9.3 KiB
Python
297 lines
9.3 KiB
Python
"""
|
||
技术指标计算模块 - 纯 pandas/numpy 实现所有指标
|
||
所有函数返回 DataFrame/Series,可直接拼接到主 DataFrame
|
||
"""
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
|
||
# ============================================================
|
||
# 辅助函数
|
||
# ============================================================
|
||
|
||
def _ema(series: pd.Series, period: int) -> pd.Series:
|
||
return series.ewm(span=period, adjust=False).mean()
|
||
|
||
|
||
def _sma(series: pd.Series, period: int) -> pd.Series:
|
||
return series.rolling(window=period).mean()
|
||
|
||
|
||
def _wma(series: pd.Series, period: int) -> pd.Series:
|
||
weights = np.arange(1, period + 1, dtype=float)
|
||
return series.rolling(window=period).apply(
|
||
lambda x: np.dot(x, weights) / weights.sum(), raw=True
|
||
)
|
||
|
||
|
||
def _true_range(high: pd.Series, low: pd.Series, close: pd.Series) -> pd.Series:
|
||
prev_close = close.shift(1)
|
||
tr1 = high - low
|
||
tr2 = (high - prev_close).abs()
|
||
tr3 = (low - prev_close).abs()
|
||
return pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
|
||
|
||
|
||
def _atr(high, low, close, period: int) -> pd.Series:
|
||
tr = _true_range(high, low, close)
|
||
return _ema(tr, period)
|
||
|
||
|
||
# ============================================================
|
||
# 布林带 (Bollinger Bands)
|
||
# ============================================================
|
||
|
||
def bollinger_bands(close: pd.Series, period: int = 20, std_dev: float = 2.0):
|
||
mid = _sma(close, period)
|
||
std = close.rolling(window=period).std()
|
||
upper = mid + std_dev * std
|
||
lower = mid - std_dev * std
|
||
# 百分比位置 %B
|
||
pct_b = (close - lower) / (upper - lower)
|
||
return mid, upper, lower, pct_b
|
||
|
||
|
||
# ============================================================
|
||
# 肯特纳通道 (Keltner Channel)
|
||
# ============================================================
|
||
|
||
def keltner_channel(high, low, close, period: int = 20, atr_mult: float = 1.5):
|
||
mid = _ema(close, period)
|
||
atr = _atr(high, low, close, period)
|
||
upper = mid + atr_mult * atr
|
||
lower = mid - atr_mult * atr
|
||
pct = (close - lower) / (upper - lower)
|
||
return mid, upper, lower, pct
|
||
|
||
|
||
# ============================================================
|
||
# 唐奇安通道 (Donchian Channel)
|
||
# ============================================================
|
||
|
||
def donchian_channel(high, low, period: int = 20):
|
||
upper = high.rolling(window=period).max()
|
||
lower = low.rolling(window=period).min()
|
||
mid = (upper + lower) / 2
|
||
pct = (high.combine(low, lambda a, b: (a + b) / 2) - lower) / (upper - lower)
|
||
return mid, upper, lower, pct
|
||
|
||
|
||
# ============================================================
|
||
# EMA 交叉
|
||
# ============================================================
|
||
|
||
def ema_cross(close, fast_period: int = 9, slow_period: int = 21):
|
||
fast = _ema(close, fast_period)
|
||
slow = _ema(close, slow_period)
|
||
diff = fast - slow
|
||
return fast, slow, diff
|
||
|
||
|
||
# ============================================================
|
||
# MACD
|
||
# ============================================================
|
||
|
||
def macd(close, fast: int = 12, slow: int = 26, signal: int = 9):
|
||
ema_fast = _ema(close, fast)
|
||
ema_slow = _ema(close, slow)
|
||
macd_line = ema_fast - ema_slow
|
||
signal_line = _ema(macd_line, signal)
|
||
histogram = macd_line - signal_line
|
||
return macd_line, signal_line, histogram
|
||
|
||
|
||
# ============================================================
|
||
# ADX (Average Directional Index)
|
||
# ============================================================
|
||
|
||
def adx(high, low, close, period: int = 14):
|
||
prev_high = high.shift(1)
|
||
prev_low = low.shift(1)
|
||
|
||
plus_dm = np.where((high - prev_high) > (prev_low - low),
|
||
np.maximum(high - prev_high, 0), 0)
|
||
minus_dm = np.where((prev_low - low) > (high - prev_high),
|
||
np.maximum(prev_low - low, 0), 0)
|
||
|
||
plus_dm = pd.Series(plus_dm, index=high.index)
|
||
minus_dm = pd.Series(minus_dm, index=high.index)
|
||
|
||
atr = _atr(high, low, close, period)
|
||
plus_di = 100 * _ema(plus_dm, period) / atr
|
||
minus_di = 100 * _ema(minus_dm, period) / atr
|
||
|
||
dx = 100 * (plus_di - minus_di).abs() / (plus_di + minus_di).replace(0, 1e-10)
|
||
adx_val = _ema(dx, period)
|
||
return adx_val, plus_di, minus_di
|
||
|
||
|
||
# ============================================================
|
||
# Supertrend
|
||
# ============================================================
|
||
|
||
def supertrend(high, low, close, period: int = 10, multiplier: float = 3.0):
|
||
hl2 = (high + low) / 2
|
||
atr = _atr(high, low, close, period)
|
||
|
||
ub = (hl2 + multiplier * atr).values.copy()
|
||
lb = (hl2 - multiplier * atr).values.copy()
|
||
c = close.values
|
||
n = len(c)
|
||
|
||
direction = np.ones(n, dtype=np.int8)
|
||
st = np.empty(n)
|
||
st[0] = lb[0]
|
||
|
||
for i in range(1, n):
|
||
if c[i] > ub[i - 1]:
|
||
direction[i] = 1
|
||
elif c[i] < lb[i - 1]:
|
||
direction[i] = -1
|
||
else:
|
||
direction[i] = direction[i - 1]
|
||
|
||
if direction[i] == 1:
|
||
if direction[i - 1] == 1:
|
||
lb[i] = max(lb[i], lb[i - 1])
|
||
st[i] = lb[i]
|
||
else:
|
||
if direction[i - 1] == -1:
|
||
ub[i] = min(ub[i], ub[i - 1])
|
||
st[i] = ub[i]
|
||
|
||
return (pd.Series(st, index=close.index),
|
||
pd.Series(direction, index=close.index))
|
||
|
||
|
||
# ============================================================
|
||
# RSI
|
||
# ============================================================
|
||
|
||
def rsi(close, period: int = 14):
|
||
delta = close.diff()
|
||
gain = delta.clip(lower=0)
|
||
loss = (-delta).clip(lower=0)
|
||
avg_gain = _ema(gain, period)
|
||
avg_loss = _ema(loss, period)
|
||
rs = avg_gain / avg_loss.replace(0, 1e-10)
|
||
return 100 - (100 / (1 + rs))
|
||
|
||
|
||
# ============================================================
|
||
# Stochastic Oscillator
|
||
# ============================================================
|
||
|
||
def stochastic(high, low, close, k_period: int = 14, d_period: int = 3, smooth: int = 3):
|
||
lowest_low = low.rolling(window=k_period).min()
|
||
highest_high = high.rolling(window=k_period).max()
|
||
raw_k = 100 * (close - lowest_low) / (highest_high - lowest_low).replace(0, 1e-10)
|
||
k = _sma(raw_k, smooth)
|
||
d = _sma(k, d_period)
|
||
return k, d
|
||
|
||
|
||
# ============================================================
|
||
# CCI (Commodity Channel Index)
|
||
# ============================================================
|
||
|
||
def cci(high, low, close, period: int = 20):
|
||
tp = (high + low + close) / 3
|
||
sma_tp = _sma(tp, period)
|
||
mad = tp.rolling(window=period).apply(lambda x: np.abs(x - x.mean()).mean(), raw=True)
|
||
return (tp - sma_tp) / (0.015 * mad.replace(0, 1e-10))
|
||
|
||
|
||
# ============================================================
|
||
# Williams %R
|
||
# ============================================================
|
||
|
||
def williams_r(high, low, close, period: int = 14):
|
||
highest_high = high.rolling(window=period).max()
|
||
lowest_low = low.rolling(window=period).min()
|
||
return -100 * (highest_high - close) / (highest_high - lowest_low).replace(0, 1e-10)
|
||
|
||
|
||
# ============================================================
|
||
# WMA (替代 VWMA,因为没有 volume)
|
||
# ============================================================
|
||
|
||
def wma(close, period: int = 20):
|
||
return _wma(close, period)
|
||
|
||
|
||
# ============================================================
|
||
# 统一计算所有指标并拼接到 DataFrame
|
||
# ============================================================
|
||
|
||
def compute_all_indicators(df: pd.DataFrame, params: dict) -> pd.DataFrame:
|
||
"""
|
||
根据参数字典计算所有指标,返回带指标列的 DataFrame
|
||
params 示例:
|
||
{
|
||
'bb_period': 20, 'bb_std': 2.0,
|
||
'kc_period': 20, 'kc_mult': 1.5,
|
||
'dc_period': 20,
|
||
'ema_fast': 9, 'ema_slow': 21,
|
||
'macd_fast': 12, 'macd_slow': 26, 'macd_signal': 9,
|
||
'adx_period': 14,
|
||
'st_period': 10, 'st_mult': 3.0,
|
||
'rsi_period': 14,
|
||
'stoch_k': 14, 'stoch_d': 3, 'stoch_smooth': 3,
|
||
'cci_period': 20,
|
||
'wr_period': 14,
|
||
'wma_period': 20,
|
||
}
|
||
"""
|
||
out = df.copy()
|
||
h, l, c = out['high'], out['low'], out['close']
|
||
|
||
# 布林带
|
||
bb_mid, bb_up, bb_lo, bb_pct = bollinger_bands(c, params['bb_period'], params['bb_std'])
|
||
out['bb_pct'] = bb_pct
|
||
|
||
# 肯特纳通道
|
||
_, _, _, kc_pct = keltner_channel(h, l, c, params['kc_period'], params['kc_mult'])
|
||
out['kc_pct'] = kc_pct
|
||
|
||
# 唐奇安通道
|
||
_, _, _, dc_pct = donchian_channel(h, l, params['dc_period'])
|
||
out['dc_pct'] = dc_pct
|
||
|
||
# EMA 交叉
|
||
_, _, ema_diff = ema_cross(c, params['ema_fast'], params['ema_slow'])
|
||
out['ema_diff'] = ema_diff
|
||
|
||
# MACD
|
||
macd_l, macd_s, macd_h = macd(c, params['macd_fast'], params['macd_slow'], params['macd_signal'])
|
||
out['macd_hist'] = macd_h
|
||
|
||
# ADX
|
||
adx_val, plus_di, minus_di = adx(h, l, c, params['adx_period'])
|
||
out['adx'] = adx_val
|
||
out['di_diff'] = plus_di - minus_di
|
||
|
||
# Supertrend
|
||
_, st_dir = supertrend(h, l, c, params['st_period'], params['st_mult'])
|
||
out['st_dir'] = st_dir
|
||
|
||
# RSI
|
||
out['rsi'] = rsi(c, params['rsi_period'])
|
||
|
||
# Stochastic
|
||
stoch_k, stoch_d = stochastic(h, l, c, params['stoch_k'], params['stoch_d'], params['stoch_smooth'])
|
||
out['stoch_k'] = stoch_k
|
||
out['stoch_d'] = stoch_d
|
||
|
||
# CCI
|
||
out['cci'] = cci(h, l, c, params['cci_period'])
|
||
|
||
# Williams %R
|
||
out['wr'] = williams_r(h, l, c, params['wr_period'])
|
||
|
||
# WMA
|
||
out['wma'] = wma(c, params['wma_period'])
|
||
out['wma_diff'] = c - out['wma']
|
||
|
||
return out
|