Files
jyx_code4/strategy/feature_engine.py
ddrwode f0fe26acbf haha
2026-02-21 15:38:23 +08:00

245 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
特征工程 — 标准化、多周期融合、滞后特征、标签生成
"""
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from loguru import logger
from .config import FEATURE_CONFIG as FC, PRIMARY_PERIOD, AUX_PERIODS
from .data_loader import load_kline
from .indicators import compute_all_indicators, get_indicator_names
def _add_intra15m_1m_pullback_features(df: pd.DataFrame, primary_df: pd.DataFrame,
one_min_df: pd.DataFrame) -> pd.DataFrame:
"""
用 1 分钟数据在每根 15m K 线内的最高/最低,计算「冲高回落」「探底回升」程度,
便于模型识别 15 分钟涨到一半又回调的情况。
"""
if one_min_df.empty or 'high' not in one_min_df.columns or 'low' not in one_min_df.columns:
return df
# 每根 1m K 线归属到其所在 15m 的起始时间(数据库 15m id 与 bar 起始时刻对齐)
one_min = one_min_df[['high', 'low']].copy()
one_min['15m_start'] = one_min.index.floor('15min')
agg = one_min.groupby('15m_start').agg(intra_high_1m=('high', 'max'), intra_low_1m=('low', 'min'))
# 对齐到主周期索引
agg = agg.reindex(primary_df.index)
# 仅前向填充,避免使用未来数据回填到过去
agg = agg.ffill()
h15 = primary_df['high'].values
l15 = primary_df['low'].values
c15 = primary_df['close'].values
range_15 = h15 - l15
range_15 = np.where(range_15 <= 0, np.nan, range_15)
pullback = (agg['intra_high_1m'].values - c15) / range_15 # 收盘相对 15m 内 1m 最高回落比例
recovery = (c15 - agg['intra_low_1m'].values) / range_15 # 收盘相对 15m 内 1m 最低回升比例
pullback = np.clip(np.nan_to_num(pullback, nan=0), 0, 1)
recovery = np.clip(np.nan_to_num(recovery, nan=0), 0, 1)
df = df.copy()
df['pullback_ratio_1m'] = pullback
df['recovery_ratio_1m'] = recovery
logger.info("已加入 1m 周期内冲高回落特征: pullback_ratio_1m, recovery_ratio_1m")
return df
def build_features(primary_df: pd.DataFrame,
aux_dfs: dict = None,
has_volume: bool = False) -> pd.DataFrame:
"""
构建完整特征矩阵
:param primary_df: 主周期K线 DataFrame
:param aux_dfs: {period: DataFrame} 辅助周期数据(可选)
:param has_volume: 是否有成交量
:return: 特征矩阵 DataFrame
"""
# 1. 主周期指标
logger.info("计算主周期指标...")
df = compute_all_indicators(primary_df, has_volume=has_volume)
# 2. 滞后特征
logger.info("生成滞后特征...")
indicator_cols = get_indicator_names(has_volume)
existing_cols = [col for col in indicator_cols if col in df.columns]
lag_frames = []
for lag in FC['lookback_lags']:
lagged = df[existing_cols].shift(lag).add_suffix(f'_lag{lag}')
lag_frames.append(lagged)
if lag_frames:
df = pd.concat([df] + lag_frames, axis=1)
# 3. 多周期融合
if aux_dfs:
aux_frames = []
for period, aux_df in aux_dfs.items():
logger.info(f"融合 {period}分钟 辅助周期特征...")
aux_with_ind = compute_all_indicators(aux_df, has_volume=has_volume)
key_indicators = ['rsi', 'macd', 'adx', 'bb_pband', 'atr', 'cci']
for ind in key_indicators:
if ind in aux_with_ind.columns:
aligned = aux_with_ind[ind].reindex(df.index, method='ffill')
aux_frames.append(aligned.rename(f'{ind}_{period}m'))
if aux_frames:
df = pd.concat([df] + aux_frames, axis=1)
# 3.5 1 分钟周期内「冲高回落」特征15m 内 1m 最高/最低 vs 15m 收盘,用于判断涨到一半又回调)
if aux_dfs and 1 in aux_dfs:
df = _add_intra15m_1m_pullback_features(df, primary_df, aux_dfs[1])
# 4. 去除全NaN列
before_cols = len(df.columns)
df.dropna(axis=1, how='all', inplace=True)
after_cols = len(df.columns)
if before_cols != after_cols:
logger.info(f"移除 {before_cols - after_cols} 个全NaN列")
logger.info(f"特征矩阵: {df.shape[0]} 行 x {df.shape[1]}")
return df
def generate_labels(df: pd.DataFrame, forward_periods: int = None,
threshold: float = None) -> pd.Series:
"""
生成交易标签
:param df: 包含 close 列的 DataFrame
:param forward_periods: 未来N根K线
:param threshold: 涨跌阈值
:return: Series值为 0=观望, 1=做多, 2=做空
"""
if forward_periods is None:
forward_periods = FC['label_forward_periods']
if threshold is None:
threshold = FC['label_threshold']
future_return = df['close'].shift(-forward_periods) / df['close'] - 1
labels = pd.Series(0, index=df.index, name='label') # 默认观望
labels[future_return > threshold] = 1 # 做多
labels[future_return < -threshold] = 2 # 做空
# 最后 forward_periods 行无法计算标签设为NaN
labels.iloc[-forward_periods:] = np.nan
dist = labels.value_counts().to_dict()
logger.info(f"标签分布: 观望={dist.get(0, 0)}, 做多={dist.get(1, 0)}, 做空={dist.get(2, 0)}")
return labels
def prepare_dataset(period: int = None, start_date: str = None, end_date: str = None,
has_volume: bool = False, normalize: bool = None) -> tuple:
"""
一键准备训练数据集
:return: (X, y, feature_names, scaler) — 已去NaN是否标准化由 normalize 决定
"""
if period is None:
period = PRIMARY_PERIOD
# 加载主周期
primary_df = load_kline(period, start_date, end_date)
if primary_df.empty:
raise ValueError(f"{period}分钟 K线数据为空")
# 加载辅助周期
aux_dfs = {}
for aux_p in AUX_PERIODS:
try:
aux_df = load_kline(aux_p, start_date, end_date)
if not aux_df.empty:
aux_dfs[aux_p] = aux_df
except Exception as e:
logger.warning(f"加载 {aux_p}分钟 辅助数据失败: {e}")
# 构建特征
df = build_features(primary_df, aux_dfs, has_volume=has_volume)
# 生成标签
labels = generate_labels(df)
df = df.copy()
df['label'] = labels
# 去除NaN行
df.dropna(inplace=True)
logger.info(f"去NaN后剩余 {len(df)}")
# 分离 X, y
exclude_cols = ['open', 'high', 'low', 'close', 'timestamp', 'label']
if 'volume' in df.columns:
exclude_cols.append('volume')
# 排除价格级别特征(会泄露绝对价格信息导致过拟合)
price_level_patterns = [
'sma_', 'ema_', 'bb_upper', 'bb_mid', 'bb_lower',
'kc_upper', 'kc_mid', 'kc_lower', 'dc_upper', 'dc_mid', 'dc_lower',
'ichimoku_conv', 'ichimoku_base', 'ichimoku_a', 'ichimoku_b',
'kama', 'vwap', 'obv', 'adi', 'vpt', 'nvi',
'momentum_3', 'momentum_5',
]
feature_cols = []
for c in df.columns:
if c in exclude_cols:
continue
base_name = c.split('_lag')[0]
# 去掉周期后缀
for suffix in ['_5m', '_60m', '_3m', '_15m', '_30m', '_1m']:
if base_name.endswith(suffix):
base_name = base_name[:-len(suffix)]
break
if any(base_name.startswith(p) or base_name == p.rstrip('_') for p in price_level_patterns):
continue
feature_cols.append(c)
logger.info(f"排除价格级别特征后剩余 {len(feature_cols)} 个特征")
X = df[feature_cols].copy()
y = df['label'].astype(int)
# 标准化(可选)
scaler = None
if normalize is None:
normalize = FC['normalize']
if normalize:
scaler = StandardScaler()
X = pd.DataFrame(scaler.fit_transform(X), columns=feature_cols, index=X.index)
logger.info("特征已标准化")
logger.info(f"最终数据集: X={X.shape}, y={y.shape}, 特征数={len(feature_cols)}")
return X, y, feature_cols, scaler
def get_latest_feature_row(period: int = None, feature_cols: list = None,
start_date: str = None, end_date: str = None) -> pd.DataFrame:
"""
构建特征并返回最后一行的特征矩阵(用于实盘/模拟盘预测)。
需传入已保存的 feature_cols保证与训练时一致。
:return: 1 行 DataFrame列为 feature_cols若缺列则返回空 DataFrame
"""
if period is None:
period = PRIMARY_PERIOD
if not feature_cols:
return pd.DataFrame()
primary_df = load_kline(period, start_date, end_date)
if primary_df.empty:
return pd.DataFrame()
aux_dfs = {}
for aux_p in AUX_PERIODS:
try:
aux_df = load_kline(aux_p, start_date, end_date)
if not aux_df.empty:
aux_dfs[aux_p] = aux_df
except Exception:
pass
df = build_features(primary_df, aux_dfs, has_volume=False)
missing = [c for c in feature_cols if c not in df.columns]
if missing:
logger.warning(f"get_latest_feature_row: 缺少特征列 {missing[:5]}{'...' if len(missing) > 5 else ''}")
return pd.DataFrame()
X = df[feature_cols].copy()
X.replace([np.inf, -np.inf], np.nan, inplace=True)
X.dropna(inplace=True)
if X.empty:
logger.warning("get_latest_feature_row: 可用于预测的特征为空数据不足或存在NaN")
return pd.DataFrame()
return X.iloc[-1:].copy()