Files
codex_jxs_code/strategy/indicators.py
2026-02-23 04:09:34 +08:00

297 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
技术指标计算模块 - 纯 pandas/numpy 实现所有指标
所有函数返回 DataFrame/Series可直接拼接到主 DataFrame
"""
import numpy as np
import pandas as pd
# ============================================================
# 辅助函数
# ============================================================
def _ema(series: pd.Series, period: int) -> pd.Series:
return series.ewm(span=period, adjust=False).mean()
def _sma(series: pd.Series, period: int) -> pd.Series:
return series.rolling(window=period).mean()
def _wma(series: pd.Series, period: int) -> pd.Series:
weights = np.arange(1, period + 1, dtype=float)
return series.rolling(window=period).apply(
lambda x: np.dot(x, weights) / weights.sum(), raw=True
)
def _true_range(high: pd.Series, low: pd.Series, close: pd.Series) -> pd.Series:
prev_close = close.shift(1)
tr1 = high - low
tr2 = (high - prev_close).abs()
tr3 = (low - prev_close).abs()
return pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)
def _atr(high, low, close, period: int) -> pd.Series:
tr = _true_range(high, low, close)
return _ema(tr, period)
# ============================================================
# 布林带 (Bollinger Bands)
# ============================================================
def bollinger_bands(close: pd.Series, period: int = 20, std_dev: float = 2.0):
mid = _sma(close, period)
std = close.rolling(window=period).std()
upper = mid + std_dev * std
lower = mid - std_dev * std
# 百分比位置 %B
pct_b = (close - lower) / (upper - lower)
return mid, upper, lower, pct_b
# ============================================================
# 肯特纳通道 (Keltner Channel)
# ============================================================
def keltner_channel(high, low, close, period: int = 20, atr_mult: float = 1.5):
mid = _ema(close, period)
atr = _atr(high, low, close, period)
upper = mid + atr_mult * atr
lower = mid - atr_mult * atr
pct = (close - lower) / (upper - lower)
return mid, upper, lower, pct
# ============================================================
# 唐奇安通道 (Donchian Channel)
# ============================================================
def donchian_channel(high, low, period: int = 20):
upper = high.rolling(window=period).max()
lower = low.rolling(window=period).min()
mid = (upper + lower) / 2
pct = (high.combine(low, lambda a, b: (a + b) / 2) - lower) / (upper - lower)
return mid, upper, lower, pct
# ============================================================
# EMA 交叉
# ============================================================
def ema_cross(close, fast_period: int = 9, slow_period: int = 21):
fast = _ema(close, fast_period)
slow = _ema(close, slow_period)
diff = fast - slow
return fast, slow, diff
# ============================================================
# MACD
# ============================================================
def macd(close, fast: int = 12, slow: int = 26, signal: int = 9):
ema_fast = _ema(close, fast)
ema_slow = _ema(close, slow)
macd_line = ema_fast - ema_slow
signal_line = _ema(macd_line, signal)
histogram = macd_line - signal_line
return macd_line, signal_line, histogram
# ============================================================
# ADX (Average Directional Index)
# ============================================================
def adx(high, low, close, period: int = 14):
prev_high = high.shift(1)
prev_low = low.shift(1)
plus_dm = np.where((high - prev_high) > (prev_low - low),
np.maximum(high - prev_high, 0), 0)
minus_dm = np.where((prev_low - low) > (high - prev_high),
np.maximum(prev_low - low, 0), 0)
plus_dm = pd.Series(plus_dm, index=high.index)
minus_dm = pd.Series(minus_dm, index=high.index)
atr = _atr(high, low, close, period)
plus_di = 100 * _ema(plus_dm, period) / atr
minus_di = 100 * _ema(minus_dm, period) / atr
dx = 100 * (plus_di - minus_di).abs() / (plus_di + minus_di).replace(0, 1e-10)
adx_val = _ema(dx, period)
return adx_val, plus_di, minus_di
# ============================================================
# Supertrend
# ============================================================
def supertrend(high, low, close, period: int = 10, multiplier: float = 3.0):
hl2 = (high + low) / 2
atr = _atr(high, low, close, period)
ub = (hl2 + multiplier * atr).values.copy()
lb = (hl2 - multiplier * atr).values.copy()
c = close.values
n = len(c)
direction = np.ones(n, dtype=np.int8)
st = np.empty(n)
st[0] = lb[0]
for i in range(1, n):
if c[i] > ub[i - 1]:
direction[i] = 1
elif c[i] < lb[i - 1]:
direction[i] = -1
else:
direction[i] = direction[i - 1]
if direction[i] == 1:
if direction[i - 1] == 1:
lb[i] = max(lb[i], lb[i - 1])
st[i] = lb[i]
else:
if direction[i - 1] == -1:
ub[i] = min(ub[i], ub[i - 1])
st[i] = ub[i]
return (pd.Series(st, index=close.index),
pd.Series(direction, index=close.index))
# ============================================================
# RSI
# ============================================================
def rsi(close, period: int = 14):
delta = close.diff()
gain = delta.clip(lower=0)
loss = (-delta).clip(lower=0)
avg_gain = _ema(gain, period)
avg_loss = _ema(loss, period)
rs = avg_gain / avg_loss.replace(0, 1e-10)
return 100 - (100 / (1 + rs))
# ============================================================
# Stochastic Oscillator
# ============================================================
def stochastic(high, low, close, k_period: int = 14, d_period: int = 3, smooth: int = 3):
lowest_low = low.rolling(window=k_period).min()
highest_high = high.rolling(window=k_period).max()
raw_k = 100 * (close - lowest_low) / (highest_high - lowest_low).replace(0, 1e-10)
k = _sma(raw_k, smooth)
d = _sma(k, d_period)
return k, d
# ============================================================
# CCI (Commodity Channel Index)
# ============================================================
def cci(high, low, close, period: int = 20):
tp = (high + low + close) / 3
sma_tp = _sma(tp, period)
mad = tp.rolling(window=period).apply(lambda x: np.abs(x - x.mean()).mean(), raw=True)
return (tp - sma_tp) / (0.015 * mad.replace(0, 1e-10))
# ============================================================
# Williams %R
# ============================================================
def williams_r(high, low, close, period: int = 14):
highest_high = high.rolling(window=period).max()
lowest_low = low.rolling(window=period).min()
return -100 * (highest_high - close) / (highest_high - lowest_low).replace(0, 1e-10)
# ============================================================
# WMA (替代 VWMA因为没有 volume)
# ============================================================
def wma(close, period: int = 20):
return _wma(close, period)
# ============================================================
# 统一计算所有指标并拼接到 DataFrame
# ============================================================
def compute_all_indicators(df: pd.DataFrame, params: dict) -> pd.DataFrame:
"""
根据参数字典计算所有指标,返回带指标列的 DataFrame
params 示例:
{
'bb_period': 20, 'bb_std': 2.0,
'kc_period': 20, 'kc_mult': 1.5,
'dc_period': 20,
'ema_fast': 9, 'ema_slow': 21,
'macd_fast': 12, 'macd_slow': 26, 'macd_signal': 9,
'adx_period': 14,
'st_period': 10, 'st_mult': 3.0,
'rsi_period': 14,
'stoch_k': 14, 'stoch_d': 3, 'stoch_smooth': 3,
'cci_period': 20,
'wr_period': 14,
'wma_period': 20,
}
"""
out = df.copy()
h, l, c = out['high'], out['low'], out['close']
# 布林带
bb_mid, bb_up, bb_lo, bb_pct = bollinger_bands(c, params['bb_period'], params['bb_std'])
out['bb_pct'] = bb_pct
# 肯特纳通道
_, _, _, kc_pct = keltner_channel(h, l, c, params['kc_period'], params['kc_mult'])
out['kc_pct'] = kc_pct
# 唐奇安通道
_, _, _, dc_pct = donchian_channel(h, l, params['dc_period'])
out['dc_pct'] = dc_pct
# EMA 交叉
_, _, ema_diff = ema_cross(c, params['ema_fast'], params['ema_slow'])
out['ema_diff'] = ema_diff
# MACD
macd_l, macd_s, macd_h = macd(c, params['macd_fast'], params['macd_slow'], params['macd_signal'])
out['macd_hist'] = macd_h
# ADX
adx_val, plus_di, minus_di = adx(h, l, c, params['adx_period'])
out['adx'] = adx_val
out['di_diff'] = plus_di - minus_di
# Supertrend
_, st_dir = supertrend(h, l, c, params['st_period'], params['st_mult'])
out['st_dir'] = st_dir
# RSI
out['rsi'] = rsi(c, params['rsi_period'])
# Stochastic
stoch_k, stoch_d = stochastic(h, l, c, params['stoch_k'], params['stoch_d'], params['stoch_smooth'])
out['stoch_k'] = stoch_k
out['stoch_d'] = stoch_d
# CCI
out['cci'] = cci(h, l, c, params['cci_period'])
# Williams %R
out['wr'] = williams_r(h, l, c, params['wr_period'])
# WMA
out['wma'] = wma(c, params['wma_period'])
out['wma_diff'] = c - out['wma']
return out