279 lines
9.3 KiB
Python
279 lines
9.3 KiB
Python
"""
|
||
52个技术指标计算引擎 — 基于 ta 库
|
||
覆盖趋势、动量、波动率、成交量、自定义衍生特征五大类
|
||
"""
|
||
import pandas as pd
|
||
import numpy as np
|
||
import ta
|
||
from .config import INDICATOR_PARAMS as P
|
||
|
||
|
||
def compute_all_indicators(df: pd.DataFrame, has_volume: bool = False) -> pd.DataFrame:
|
||
"""
|
||
计算全部52个技术指标,返回拼接后的DataFrame
|
||
:param df: 必须包含 open, high, low, close 列;可选 volume 列
|
||
:param has_volume: 是否有成交量数据
|
||
:return: 原始列 + 52个指标列
|
||
"""
|
||
out = df.copy()
|
||
o, h, l, c = out['open'], out['high'], out['low'], out['close']
|
||
v = out['volume'] if has_volume and 'volume' in out.columns else None
|
||
|
||
# ========== 趋势类 (14) ==========
|
||
out = _add_trend(out, o, h, l, c)
|
||
|
||
# ========== 动量类 (12) ==========
|
||
out = _add_momentum(out, h, l, c, v)
|
||
|
||
# ========== 波动率类 (8) ==========
|
||
out = _add_volatility(out, h, l, c)
|
||
|
||
# ========== 成交量类 (8) ==========
|
||
if has_volume and v is not None:
|
||
out = _add_volume(out, h, l, c, v)
|
||
|
||
# ========== 自定义衍生特征 (10) ==========
|
||
out = _add_custom(out, o, h, l, c)
|
||
|
||
return out
|
||
|
||
|
||
def _add_trend(out, o, h, l, c):
|
||
"""趋势类指标 (14个特征)"""
|
||
# SMA (5个)
|
||
for w in P['sma_windows']:
|
||
out[f'sma_{w}'] = ta.trend.sma_indicator(c, window=w)
|
||
|
||
# EMA (2个)
|
||
for w in P['ema_windows']:
|
||
out[f'ema_{w}'] = ta.trend.ema_indicator(c, window=w)
|
||
|
||
# MACD (3个)
|
||
macd = ta.trend.MACD(c, window_slow=P['macd_slow'], window_fast=P['macd_fast'],
|
||
window_sign=P['macd_signal'])
|
||
out['macd'] = macd.macd()
|
||
out['macd_signal'] = macd.macd_signal()
|
||
out['macd_hist'] = macd.macd_diff()
|
||
|
||
# ADX + DI (3个)
|
||
adx = ta.trend.ADXIndicator(h, l, c, window=P['adx_window'])
|
||
out['adx'] = adx.adx()
|
||
out['di_plus'] = adx.adx_pos()
|
||
out['di_minus'] = adx.adx_neg()
|
||
|
||
# Ichimoku (4个)
|
||
ichi = ta.trend.IchimokuIndicator(h, l,
|
||
window1=P['ichimoku_conversion'],
|
||
window2=P['ichimoku_base'],
|
||
window3=P['ichimoku_span_b'])
|
||
out['ichimoku_conv'] = ichi.ichimoku_conversion_line()
|
||
out['ichimoku_base'] = ichi.ichimoku_base_line()
|
||
out['ichimoku_a'] = ichi.ichimoku_a()
|
||
out['ichimoku_b'] = ichi.ichimoku_b()
|
||
|
||
# TRIX
|
||
out['trix'] = ta.trend.trix(c, window=P['trix_window'])
|
||
|
||
# Aroon (2个)
|
||
aroon = ta.trend.AroonIndicator(h, l, window=P['aroon_window'])
|
||
out['aroon_up'] = aroon.aroon_up()
|
||
out['aroon_down'] = aroon.aroon_down()
|
||
|
||
# CCI
|
||
out['cci'] = ta.trend.cci(h, l, c, window=P['cci_window'])
|
||
|
||
# DPO
|
||
out['dpo'] = ta.trend.dpo(c, window=P['dpo_window'])
|
||
|
||
# KST
|
||
kst = ta.trend.KSTIndicator(c, roc1=P['kst_roc1'], roc2=P['kst_roc2'],
|
||
roc3=P['kst_roc3'], roc4=P['kst_roc4'])
|
||
out['kst'] = kst.kst()
|
||
|
||
# Vortex (2个)
|
||
vortex = ta.trend.VortexIndicator(h, l, c, window=P['vortex_window'])
|
||
out['vortex_pos'] = vortex.vortex_indicator_pos()
|
||
out['vortex_neg'] = vortex.vortex_indicator_neg()
|
||
|
||
return out
|
||
|
||
|
||
def _add_momentum(out, h, l, c, v):
|
||
"""动量类指标 (12个特征)"""
|
||
# RSI
|
||
out['rsi'] = ta.momentum.rsi(c, window=P['rsi_window'])
|
||
|
||
# Stochastic %K / %D
|
||
stoch = ta.momentum.StochasticOscillator(h, l, c,
|
||
window=P['stoch_window'],
|
||
smooth_window=P['stoch_smooth'])
|
||
out['stoch_k'] = stoch.stoch()
|
||
out['stoch_d'] = stoch.stoch_signal()
|
||
|
||
# Williams %R
|
||
out['williams_r'] = ta.momentum.williams_r(h, l, c, lbp=P['williams_window'])
|
||
|
||
# ROC
|
||
out['roc'] = ta.momentum.roc(c, window=P['roc_window'])
|
||
|
||
# MFI(需要volume)
|
||
if v is not None:
|
||
out['mfi'] = ta.volume.money_flow_index(h, l, c, v, window=P['mfi_window'])
|
||
|
||
# TSI
|
||
out['tsi'] = ta.momentum.tsi(c, window_slow=P['tsi_slow'], window_fast=P['tsi_fast'])
|
||
|
||
# Ultimate Oscillator
|
||
out['uo'] = ta.momentum.ultimate_oscillator(h, l, c,
|
||
window1=P['uo_short'],
|
||
window2=P['uo_medium'],
|
||
window3=P['uo_long'])
|
||
|
||
# Awesome Oscillator
|
||
out['ao'] = ta.momentum.awesome_oscillator(h, l,
|
||
window1=P['ao_short'],
|
||
window2=P['ao_long'])
|
||
|
||
# KAMA
|
||
out['kama'] = ta.momentum.kama(c, window=P['kama_window'])
|
||
|
||
# PPO
|
||
out['ppo'] = ta.momentum.ppo(c, window_slow=P['ppo_slow'], window_fast=P['ppo_fast'])
|
||
|
||
# Stochastic RSI %K / %D
|
||
stoch_rsi = ta.momentum.StochRSIIndicator(c,
|
||
window=P['stoch_rsi_window'],
|
||
smooth1=P['stoch_rsi_smooth'],
|
||
smooth2=P['stoch_rsi_smooth'])
|
||
out['stoch_rsi_k'] = stoch_rsi.stochrsi_k()
|
||
out['stoch_rsi_d'] = stoch_rsi.stochrsi_d()
|
||
|
||
return out
|
||
|
||
|
||
def _add_volatility(out, h, l, c):
|
||
"""波动率类指标 (8个特征 — 含子指标共12列)"""
|
||
# Bollinger Bands (5个)
|
||
bb = ta.volatility.BollingerBands(c, window=P['bb_window'], window_dev=P['bb_std'])
|
||
out['bb_upper'] = bb.bollinger_hband()
|
||
out['bb_mid'] = bb.bollinger_mavg()
|
||
out['bb_lower'] = bb.bollinger_lband()
|
||
out['bb_width'] = bb.bollinger_wband()
|
||
out['bb_pband'] = bb.bollinger_pband()
|
||
|
||
# ATR
|
||
out['atr'] = ta.volatility.average_true_range(h, l, c, window=P['atr_window'])
|
||
|
||
# Keltner Channel (3个)
|
||
kc = ta.volatility.KeltnerChannel(h, l, c, window=P['kc_window'])
|
||
out['kc_upper'] = kc.keltner_channel_hband()
|
||
out['kc_mid'] = kc.keltner_channel_mband()
|
||
out['kc_lower'] = kc.keltner_channel_lband()
|
||
|
||
# Donchian Channel (3个)
|
||
dc = ta.volatility.DonchianChannel(h, l, c, window=P['dc_window'])
|
||
out['dc_upper'] = dc.donchian_channel_hband()
|
||
out['dc_mid'] = dc.donchian_channel_mband()
|
||
out['dc_lower'] = dc.donchian_channel_lband()
|
||
|
||
return out
|
||
|
||
|
||
def _add_volume(out, h, l, c, v):
|
||
"""成交量类指标 (8个特征)"""
|
||
# OBV
|
||
out['obv'] = ta.volume.on_balance_volume(c, v)
|
||
|
||
# VWAP
|
||
out['vwap'] = ta.volume.volume_weighted_average_price(h, l, c, v)
|
||
|
||
# CMF
|
||
out['cmf'] = ta.volume.chaikin_money_flow(h, l, c, v, window=P['cmf_window'])
|
||
|
||
# ADI (Accumulation/Distribution Index)
|
||
out['adi'] = ta.volume.acc_dist_index(h, l, c, v)
|
||
|
||
# EMV (Ease of Movement)
|
||
out['emv'] = ta.volume.ease_of_movement(h, l, v, window=P['emv_window'])
|
||
|
||
# Force Index
|
||
out['fi'] = ta.volume.force_index(c, v, window=P['fi_window'])
|
||
|
||
# VPT (Volume Price Trend)
|
||
out['vpt'] = ta.volume.volume_price_trend(c, v)
|
||
|
||
# NVI (Negative Volume Index)
|
||
out['nvi'] = ta.volume.negative_volume_index(c, v)
|
||
|
||
return out
|
||
|
||
|
||
def _add_custom(out, o, h, l, c):
|
||
"""自定义衍生特征 (10个)"""
|
||
# 价格变化率
|
||
out['price_change_pct'] = c.pct_change()
|
||
|
||
# 振幅(High-Low范围 / Close)
|
||
out['high_low_range'] = (h - l) / c
|
||
|
||
# 实体比率(|Close-Open| / (High-Low))
|
||
body = (c - o).abs()
|
||
hl_range = (h - l).replace(0, np.nan)
|
||
out['body_ratio'] = body / hl_range
|
||
|
||
# 上影线比率
|
||
upper_shadow = h - pd.concat([o, c], axis=1).max(axis=1)
|
||
out['upper_shadow'] = upper_shadow / hl_range
|
||
|
||
# 下影线比率
|
||
lower_shadow = pd.concat([o, c], axis=1).min(axis=1) - l
|
||
out['lower_shadow'] = lower_shadow / hl_range
|
||
|
||
# 波动率比率(ATR / Close 的滚动比值)
|
||
atr = ta.volatility.average_true_range(h, l, c, window=14)
|
||
out['volatility_ratio'] = atr / c
|
||
|
||
# Close / SMA20 比率
|
||
sma20 = ta.trend.sma_indicator(c, window=20)
|
||
out['close_sma20_ratio'] = c / sma20.replace(0, np.nan)
|
||
|
||
# Close / EMA12 比率
|
||
ema12 = ta.trend.ema_indicator(c, window=12)
|
||
out['close_ema12_ratio'] = c / ema12.replace(0, np.nan)
|
||
|
||
# 动量 3周期
|
||
out['momentum_3'] = c - c.shift(3)
|
||
|
||
# 动量 5周期
|
||
out['momentum_5'] = c - c.shift(5)
|
||
|
||
return out
|
||
|
||
|
||
def get_indicator_names(has_volume: bool = False) -> list:
|
||
"""返回所有指标列名"""
|
||
names = []
|
||
# 趋势
|
||
for w in P['sma_windows']:
|
||
names.append(f'sma_{w}')
|
||
for w in P['ema_windows']:
|
||
names.append(f'ema_{w}')
|
||
names += ['macd', 'macd_signal', 'macd_hist', 'adx', 'di_plus', 'di_minus']
|
||
names += ['ichimoku_conv', 'ichimoku_base', 'ichimoku_a', 'ichimoku_b']
|
||
names += ['trix', 'aroon_up', 'aroon_down', 'cci', 'dpo', 'kst', 'vortex_pos', 'vortex_neg']
|
||
# 动量
|
||
names += ['rsi', 'stoch_k', 'stoch_d', 'williams_r', 'roc', 'tsi', 'uo', 'ao', 'kama', 'ppo',
|
||
'stoch_rsi_k', 'stoch_rsi_d']
|
||
if has_volume:
|
||
names.append('mfi')
|
||
# 波动率
|
||
names += ['bb_upper', 'bb_mid', 'bb_lower', 'bb_width', 'bb_pband', 'atr',
|
||
'kc_upper', 'kc_mid', 'kc_lower', 'dc_upper', 'dc_mid', 'dc_lower']
|
||
# 成交量
|
||
if has_volume:
|
||
names += ['obv', 'vwap', 'cmf', 'adi', 'emv', 'fi', 'vpt', 'nvi']
|
||
# 自定义
|
||
names += ['price_change_pct', 'high_low_range', 'body_ratio', 'upper_shadow', 'lower_shadow',
|
||
'volatility_ratio', 'close_sma20_ratio', 'close_ema12_ratio', 'momentum_3', 'momentum_5']
|
||
return names
|