Files
jyx_code4/strategy/indicators.py
ddrwode 21f2adc4a4 哈哈
2026-02-20 20:57:25 +08:00

279 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
52个技术指标计算引擎 — 基于 ta 库
覆盖趋势、动量、波动率、成交量、自定义衍生特征五大类
"""
import pandas as pd
import numpy as np
import ta
from .config import INDICATOR_PARAMS as P
def compute_all_indicators(df: pd.DataFrame, has_volume: bool = False) -> pd.DataFrame:
"""
计算全部52个技术指标返回拼接后的DataFrame
:param df: 必须包含 open, high, low, close 列;可选 volume 列
:param has_volume: 是否有成交量数据
:return: 原始列 + 52个指标列
"""
out = df.copy()
o, h, l, c = out['open'], out['high'], out['low'], out['close']
v = out['volume'] if has_volume and 'volume' in out.columns else None
# ========== 趋势类 (14) ==========
out = _add_trend(out, o, h, l, c)
# ========== 动量类 (12) ==========
out = _add_momentum(out, h, l, c, v)
# ========== 波动率类 (8) ==========
out = _add_volatility(out, h, l, c)
# ========== 成交量类 (8) ==========
if has_volume and v is not None:
out = _add_volume(out, h, l, c, v)
# ========== 自定义衍生特征 (10) ==========
out = _add_custom(out, o, h, l, c)
return out
def _add_trend(out, o, h, l, c):
"""趋势类指标 (14个特征)"""
# SMA (5个)
for w in P['sma_windows']:
out[f'sma_{w}'] = ta.trend.sma_indicator(c, window=w)
# EMA (2个)
for w in P['ema_windows']:
out[f'ema_{w}'] = ta.trend.ema_indicator(c, window=w)
# MACD (3个)
macd = ta.trend.MACD(c, window_slow=P['macd_slow'], window_fast=P['macd_fast'],
window_sign=P['macd_signal'])
out['macd'] = macd.macd()
out['macd_signal'] = macd.macd_signal()
out['macd_hist'] = macd.macd_diff()
# ADX + DI (3个)
adx = ta.trend.ADXIndicator(h, l, c, window=P['adx_window'])
out['adx'] = adx.adx()
out['di_plus'] = adx.adx_pos()
out['di_minus'] = adx.adx_neg()
# Ichimoku (4个)
ichi = ta.trend.IchimokuIndicator(h, l,
window1=P['ichimoku_conversion'],
window2=P['ichimoku_base'],
window3=P['ichimoku_span_b'])
out['ichimoku_conv'] = ichi.ichimoku_conversion_line()
out['ichimoku_base'] = ichi.ichimoku_base_line()
out['ichimoku_a'] = ichi.ichimoku_a()
out['ichimoku_b'] = ichi.ichimoku_b()
# TRIX
out['trix'] = ta.trend.trix(c, window=P['trix_window'])
# Aroon (2个)
aroon = ta.trend.AroonIndicator(h, l, window=P['aroon_window'])
out['aroon_up'] = aroon.aroon_up()
out['aroon_down'] = aroon.aroon_down()
# CCI
out['cci'] = ta.trend.cci(h, l, c, window=P['cci_window'])
# DPO
out['dpo'] = ta.trend.dpo(c, window=P['dpo_window'])
# KST
kst = ta.trend.KSTIndicator(c, roc1=P['kst_roc1'], roc2=P['kst_roc2'],
roc3=P['kst_roc3'], roc4=P['kst_roc4'])
out['kst'] = kst.kst()
# Vortex (2个)
vortex = ta.trend.VortexIndicator(h, l, c, window=P['vortex_window'])
out['vortex_pos'] = vortex.vortex_indicator_pos()
out['vortex_neg'] = vortex.vortex_indicator_neg()
return out
def _add_momentum(out, h, l, c, v):
"""动量类指标 (12个特征)"""
# RSI
out['rsi'] = ta.momentum.rsi(c, window=P['rsi_window'])
# Stochastic %K / %D
stoch = ta.momentum.StochasticOscillator(h, l, c,
window=P['stoch_window'],
smooth_window=P['stoch_smooth'])
out['stoch_k'] = stoch.stoch()
out['stoch_d'] = stoch.stoch_signal()
# Williams %R
out['williams_r'] = ta.momentum.williams_r(h, l, c, lbp=P['williams_window'])
# ROC
out['roc'] = ta.momentum.roc(c, window=P['roc_window'])
# MFI需要volume
if v is not None:
out['mfi'] = ta.volume.money_flow_index(h, l, c, v, window=P['mfi_window'])
# TSI
out['tsi'] = ta.momentum.tsi(c, window_slow=P['tsi_slow'], window_fast=P['tsi_fast'])
# Ultimate Oscillator
out['uo'] = ta.momentum.ultimate_oscillator(h, l, c,
window1=P['uo_short'],
window2=P['uo_medium'],
window3=P['uo_long'])
# Awesome Oscillator
out['ao'] = ta.momentum.awesome_oscillator(h, l,
window1=P['ao_short'],
window2=P['ao_long'])
# KAMA
out['kama'] = ta.momentum.kama(c, window=P['kama_window'])
# PPO
out['ppo'] = ta.momentum.ppo(c, window_slow=P['ppo_slow'], window_fast=P['ppo_fast'])
# Stochastic RSI %K / %D
stoch_rsi = ta.momentum.StochRSIIndicator(c,
window=P['stoch_rsi_window'],
smooth1=P['stoch_rsi_smooth'],
smooth2=P['stoch_rsi_smooth'])
out['stoch_rsi_k'] = stoch_rsi.stochrsi_k()
out['stoch_rsi_d'] = stoch_rsi.stochrsi_d()
return out
def _add_volatility(out, h, l, c):
"""波动率类指标 (8个特征 — 含子指标共12列)"""
# Bollinger Bands (5个)
bb = ta.volatility.BollingerBands(c, window=P['bb_window'], window_dev=P['bb_std'])
out['bb_upper'] = bb.bollinger_hband()
out['bb_mid'] = bb.bollinger_mavg()
out['bb_lower'] = bb.bollinger_lband()
out['bb_width'] = bb.bollinger_wband()
out['bb_pband'] = bb.bollinger_pband()
# ATR
out['atr'] = ta.volatility.average_true_range(h, l, c, window=P['atr_window'])
# Keltner Channel (3个)
kc = ta.volatility.KeltnerChannel(h, l, c, window=P['kc_window'])
out['kc_upper'] = kc.keltner_channel_hband()
out['kc_mid'] = kc.keltner_channel_mband()
out['kc_lower'] = kc.keltner_channel_lband()
# Donchian Channel (3个)
dc = ta.volatility.DonchianChannel(h, l, c, window=P['dc_window'])
out['dc_upper'] = dc.donchian_channel_hband()
out['dc_mid'] = dc.donchian_channel_mband()
out['dc_lower'] = dc.donchian_channel_lband()
return out
def _add_volume(out, h, l, c, v):
"""成交量类指标 (8个特征)"""
# OBV
out['obv'] = ta.volume.on_balance_volume(c, v)
# VWAP
out['vwap'] = ta.volume.volume_weighted_average_price(h, l, c, v)
# CMF
out['cmf'] = ta.volume.chaikin_money_flow(h, l, c, v, window=P['cmf_window'])
# ADI (Accumulation/Distribution Index)
out['adi'] = ta.volume.acc_dist_index(h, l, c, v)
# EMV (Ease of Movement)
out['emv'] = ta.volume.ease_of_movement(h, l, v, window=P['emv_window'])
# Force Index
out['fi'] = ta.volume.force_index(c, v, window=P['fi_window'])
# VPT (Volume Price Trend)
out['vpt'] = ta.volume.volume_price_trend(c, v)
# NVI (Negative Volume Index)
out['nvi'] = ta.volume.negative_volume_index(c, v)
return out
def _add_custom(out, o, h, l, c):
"""自定义衍生特征 (10个)"""
# 价格变化率
out['price_change_pct'] = c.pct_change()
# 振幅High-Low范围 / Close
out['high_low_range'] = (h - l) / c
# 实体比率(|Close-Open| / (High-Low)
body = (c - o).abs()
hl_range = (h - l).replace(0, np.nan)
out['body_ratio'] = body / hl_range
# 上影线比率
upper_shadow = h - pd.concat([o, c], axis=1).max(axis=1)
out['upper_shadow'] = upper_shadow / hl_range
# 下影线比率
lower_shadow = pd.concat([o, c], axis=1).min(axis=1) - l
out['lower_shadow'] = lower_shadow / hl_range
# 波动率比率ATR / Close 的滚动比值)
atr = ta.volatility.average_true_range(h, l, c, window=14)
out['volatility_ratio'] = atr / c
# Close / SMA20 比率
sma20 = ta.trend.sma_indicator(c, window=20)
out['close_sma20_ratio'] = c / sma20.replace(0, np.nan)
# Close / EMA12 比率
ema12 = ta.trend.ema_indicator(c, window=12)
out['close_ema12_ratio'] = c / ema12.replace(0, np.nan)
# 动量 3周期
out['momentum_3'] = c - c.shift(3)
# 动量 5周期
out['momentum_5'] = c - c.shift(5)
return out
def get_indicator_names(has_volume: bool = False) -> list:
"""返回所有指标列名"""
names = []
# 趋势
for w in P['sma_windows']:
names.append(f'sma_{w}')
for w in P['ema_windows']:
names.append(f'ema_{w}')
names += ['macd', 'macd_signal', 'macd_hist', 'adx', 'di_plus', 'di_minus']
names += ['ichimoku_conv', 'ichimoku_base', 'ichimoku_a', 'ichimoku_b']
names += ['trix', 'aroon_up', 'aroon_down', 'cci', 'dpo', 'kst', 'vortex_pos', 'vortex_neg']
# 动量
names += ['rsi', 'stoch_k', 'stoch_d', 'williams_r', 'roc', 'tsi', 'uo', 'ao', 'kama', 'ppo',
'stoch_rsi_k', 'stoch_rsi_d']
if has_volume:
names.append('mfi')
# 波动率
names += ['bb_upper', 'bb_mid', 'bb_lower', 'bb_width', 'bb_pband', 'atr',
'kc_upper', 'kc_mid', 'kc_lower', 'dc_upper', 'dc_mid', 'dc_lower']
# 成交量
if has_volume:
names += ['obv', 'vwap', 'cmf', 'adi', 'emv', 'fi', 'vpt', 'nvi']
# 自定义
names += ['price_change_pct', 'high_low_range', 'body_ratio', 'upper_shadow', 'lower_shadow',
'volatility_ratio', 'close_sma20_ratio', 'close_ema12_ratio', 'momentum_3', 'momentum_5']
return names