Files
jyx_code4/strategy/indicators.py

279 lines
9.3 KiB
Python
Raw Normal View History

2026-02-20 20:57:25 +08:00
"""
52个技术指标计算引擎 基于 ta
覆盖趋势动量波动率成交量自定义衍生特征五大类
"""
import pandas as pd
import numpy as np
import ta
from .config import INDICATOR_PARAMS as P
def compute_all_indicators(df: pd.DataFrame, has_volume: bool = False) -> pd.DataFrame:
"""
计算全部52个技术指标返回拼接后的DataFrame
:param df: 必须包含 open, high, low, close 可选 volume
:param has_volume: 是否有成交量数据
:return: 原始列 + 52个指标列
"""
out = df.copy()
o, h, l, c = out['open'], out['high'], out['low'], out['close']
v = out['volume'] if has_volume and 'volume' in out.columns else None
# ========== 趋势类 (14) ==========
out = _add_trend(out, o, h, l, c)
# ========== 动量类 (12) ==========
out = _add_momentum(out, h, l, c, v)
# ========== 波动率类 (8) ==========
out = _add_volatility(out, h, l, c)
# ========== 成交量类 (8) ==========
if has_volume and v is not None:
out = _add_volume(out, h, l, c, v)
# ========== 自定义衍生特征 (10) ==========
out = _add_custom(out, o, h, l, c)
return out
def _add_trend(out, o, h, l, c):
"""趋势类指标 (14个特征)"""
# SMA (5个)
for w in P['sma_windows']:
out[f'sma_{w}'] = ta.trend.sma_indicator(c, window=w)
# EMA (2个)
for w in P['ema_windows']:
out[f'ema_{w}'] = ta.trend.ema_indicator(c, window=w)
# MACD (3个)
macd = ta.trend.MACD(c, window_slow=P['macd_slow'], window_fast=P['macd_fast'],
window_sign=P['macd_signal'])
out['macd'] = macd.macd()
out['macd_signal'] = macd.macd_signal()
out['macd_hist'] = macd.macd_diff()
# ADX + DI (3个)
adx = ta.trend.ADXIndicator(h, l, c, window=P['adx_window'])
out['adx'] = adx.adx()
out['di_plus'] = adx.adx_pos()
out['di_minus'] = adx.adx_neg()
# Ichimoku (4个)
ichi = ta.trend.IchimokuIndicator(h, l,
window1=P['ichimoku_conversion'],
window2=P['ichimoku_base'],
window3=P['ichimoku_span_b'])
out['ichimoku_conv'] = ichi.ichimoku_conversion_line()
out['ichimoku_base'] = ichi.ichimoku_base_line()
out['ichimoku_a'] = ichi.ichimoku_a()
out['ichimoku_b'] = ichi.ichimoku_b()
# TRIX
out['trix'] = ta.trend.trix(c, window=P['trix_window'])
# Aroon (2个)
aroon = ta.trend.AroonIndicator(h, l, window=P['aroon_window'])
out['aroon_up'] = aroon.aroon_up()
out['aroon_down'] = aroon.aroon_down()
# CCI
out['cci'] = ta.trend.cci(h, l, c, window=P['cci_window'])
# DPO
out['dpo'] = ta.trend.dpo(c, window=P['dpo_window'])
# KST
kst = ta.trend.KSTIndicator(c, roc1=P['kst_roc1'], roc2=P['kst_roc2'],
roc3=P['kst_roc3'], roc4=P['kst_roc4'])
out['kst'] = kst.kst()
# Vortex (2个)
vortex = ta.trend.VortexIndicator(h, l, c, window=P['vortex_window'])
out['vortex_pos'] = vortex.vortex_indicator_pos()
out['vortex_neg'] = vortex.vortex_indicator_neg()
return out
def _add_momentum(out, h, l, c, v):
"""动量类指标 (12个特征)"""
# RSI
out['rsi'] = ta.momentum.rsi(c, window=P['rsi_window'])
# Stochastic %K / %D
stoch = ta.momentum.StochasticOscillator(h, l, c,
window=P['stoch_window'],
smooth_window=P['stoch_smooth'])
out['stoch_k'] = stoch.stoch()
out['stoch_d'] = stoch.stoch_signal()
# Williams %R
out['williams_r'] = ta.momentum.williams_r(h, l, c, lbp=P['williams_window'])
# ROC
out['roc'] = ta.momentum.roc(c, window=P['roc_window'])
# MFI需要volume
if v is not None:
out['mfi'] = ta.volume.money_flow_index(h, l, c, v, window=P['mfi_window'])
# TSI
out['tsi'] = ta.momentum.tsi(c, window_slow=P['tsi_slow'], window_fast=P['tsi_fast'])
# Ultimate Oscillator
out['uo'] = ta.momentum.ultimate_oscillator(h, l, c,
window1=P['uo_short'],
window2=P['uo_medium'],
window3=P['uo_long'])
# Awesome Oscillator
out['ao'] = ta.momentum.awesome_oscillator(h, l,
window1=P['ao_short'],
window2=P['ao_long'])
# KAMA
out['kama'] = ta.momentum.kama(c, window=P['kama_window'])
# PPO
out['ppo'] = ta.momentum.ppo(c, window_slow=P['ppo_slow'], window_fast=P['ppo_fast'])
# Stochastic RSI %K / %D
stoch_rsi = ta.momentum.StochRSIIndicator(c,
window=P['stoch_rsi_window'],
smooth1=P['stoch_rsi_smooth'],
smooth2=P['stoch_rsi_smooth'])
out['stoch_rsi_k'] = stoch_rsi.stochrsi_k()
out['stoch_rsi_d'] = stoch_rsi.stochrsi_d()
return out
def _add_volatility(out, h, l, c):
"""波动率类指标 (8个特征 — 含子指标共12列)"""
# Bollinger Bands (5个)
bb = ta.volatility.BollingerBands(c, window=P['bb_window'], window_dev=P['bb_std'])
out['bb_upper'] = bb.bollinger_hband()
out['bb_mid'] = bb.bollinger_mavg()
out['bb_lower'] = bb.bollinger_lband()
out['bb_width'] = bb.bollinger_wband()
out['bb_pband'] = bb.bollinger_pband()
# ATR
out['atr'] = ta.volatility.average_true_range(h, l, c, window=P['atr_window'])
# Keltner Channel (3个)
kc = ta.volatility.KeltnerChannel(h, l, c, window=P['kc_window'])
out['kc_upper'] = kc.keltner_channel_hband()
out['kc_mid'] = kc.keltner_channel_mband()
out['kc_lower'] = kc.keltner_channel_lband()
# Donchian Channel (3个)
dc = ta.volatility.DonchianChannel(h, l, c, window=P['dc_window'])
out['dc_upper'] = dc.donchian_channel_hband()
out['dc_mid'] = dc.donchian_channel_mband()
out['dc_lower'] = dc.donchian_channel_lband()
return out
def _add_volume(out, h, l, c, v):
"""成交量类指标 (8个特征)"""
# OBV
out['obv'] = ta.volume.on_balance_volume(c, v)
# VWAP
out['vwap'] = ta.volume.volume_weighted_average_price(h, l, c, v)
# CMF
out['cmf'] = ta.volume.chaikin_money_flow(h, l, c, v, window=P['cmf_window'])
# ADI (Accumulation/Distribution Index)
out['adi'] = ta.volume.acc_dist_index(h, l, c, v)
# EMV (Ease of Movement)
out['emv'] = ta.volume.ease_of_movement(h, l, v, window=P['emv_window'])
# Force Index
out['fi'] = ta.volume.force_index(c, v, window=P['fi_window'])
# VPT (Volume Price Trend)
out['vpt'] = ta.volume.volume_price_trend(c, v)
# NVI (Negative Volume Index)
out['nvi'] = ta.volume.negative_volume_index(c, v)
return out
def _add_custom(out, o, h, l, c):
"""自定义衍生特征 (10个)"""
# 价格变化率
out['price_change_pct'] = c.pct_change()
# 振幅High-Low范围 / Close
out['high_low_range'] = (h - l) / c
# 实体比率(|Close-Open| / (High-Low)
body = (c - o).abs()
hl_range = (h - l).replace(0, np.nan)
out['body_ratio'] = body / hl_range
# 上影线比率
upper_shadow = h - pd.concat([o, c], axis=1).max(axis=1)
out['upper_shadow'] = upper_shadow / hl_range
# 下影线比率
lower_shadow = pd.concat([o, c], axis=1).min(axis=1) - l
out['lower_shadow'] = lower_shadow / hl_range
# 波动率比率ATR / Close 的滚动比值)
atr = ta.volatility.average_true_range(h, l, c, window=14)
out['volatility_ratio'] = atr / c
# Close / SMA20 比率
sma20 = ta.trend.sma_indicator(c, window=20)
out['close_sma20_ratio'] = c / sma20.replace(0, np.nan)
# Close / EMA12 比率
ema12 = ta.trend.ema_indicator(c, window=12)
out['close_ema12_ratio'] = c / ema12.replace(0, np.nan)
# 动量 3周期
out['momentum_3'] = c - c.shift(3)
# 动量 5周期
out['momentum_5'] = c - c.shift(5)
return out
def get_indicator_names(has_volume: bool = False) -> list:
"""返回所有指标列名"""
names = []
# 趋势
for w in P['sma_windows']:
names.append(f'sma_{w}')
for w in P['ema_windows']:
names.append(f'ema_{w}')
names += ['macd', 'macd_signal', 'macd_hist', 'adx', 'di_plus', 'di_minus']
names += ['ichimoku_conv', 'ichimoku_base', 'ichimoku_a', 'ichimoku_b']
names += ['trix', 'aroon_up', 'aroon_down', 'cci', 'dpo', 'kst', 'vortex_pos', 'vortex_neg']
# 动量
names += ['rsi', 'stoch_k', 'stoch_d', 'williams_r', 'roc', 'tsi', 'uo', 'ao', 'kama', 'ppo',
'stoch_rsi_k', 'stoch_rsi_d']
if has_volume:
names.append('mfi')
# 波动率
names += ['bb_upper', 'bb_mid', 'bb_lower', 'bb_width', 'bb_pband', 'atr',
'kc_upper', 'kc_mid', 'kc_lower', 'dc_upper', 'dc_mid', 'dc_lower']
# 成交量
if has_volume:
names += ['obv', 'vwap', 'cmf', 'adi', 'emv', 'fi', 'vpt', 'nvi']
# 自定义
names += ['price_change_pct', 'high_low_range', 'body_ratio', 'upper_shadow', 'lower_shadow',
'volatility_ratio', 'close_sma20_ratio', 'close_ema12_ratio', 'momentum_3', 'momentum_5']
return names