Files
jyx_code4/strategy/feature_engine.py

205 lines
7.2 KiB
Python
Raw Normal View History

2026-02-20 20:57:25 +08:00
"""
特征工程 标准化多周期融合滞后特征标签生成
"""
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from loguru import logger
from .config import FEATURE_CONFIG as FC, PRIMARY_PERIOD, AUX_PERIODS
from .data_loader import load_kline
from .indicators import compute_all_indicators, get_indicator_names
def build_features(primary_df: pd.DataFrame,
aux_dfs: dict = None,
has_volume: bool = False) -> pd.DataFrame:
"""
构建完整特征矩阵
:param primary_df: 主周期K线 DataFrame
:param aux_dfs: {period: DataFrame} 辅助周期数据可选
:param has_volume: 是否有成交量
:return: 特征矩阵 DataFrame
"""
# 1. 主周期指标
logger.info("计算主周期指标...")
df = compute_all_indicators(primary_df, has_volume=has_volume)
# 2. 滞后特征
logger.info("生成滞后特征...")
indicator_cols = get_indicator_names(has_volume)
existing_cols = [col for col in indicator_cols if col in df.columns]
lag_frames = []
for lag in FC['lookback_lags']:
lagged = df[existing_cols].shift(lag).add_suffix(f'_lag{lag}')
lag_frames.append(lagged)
if lag_frames:
df = pd.concat([df] + lag_frames, axis=1)
# 3. 多周期融合
if aux_dfs:
aux_frames = []
for period, aux_df in aux_dfs.items():
logger.info(f"融合 {period}分钟 辅助周期特征...")
aux_with_ind = compute_all_indicators(aux_df, has_volume=has_volume)
key_indicators = ['rsi', 'macd', 'adx', 'bb_pband', 'atr', 'cci']
for ind in key_indicators:
if ind in aux_with_ind.columns:
aligned = aux_with_ind[ind].reindex(df.index, method='ffill')
aux_frames.append(aligned.rename(f'{ind}_{period}m'))
if aux_frames:
df = pd.concat([df] + aux_frames, axis=1)
# 4. 去除全NaN列
before_cols = len(df.columns)
df.dropna(axis=1, how='all', inplace=True)
after_cols = len(df.columns)
if before_cols != after_cols:
logger.info(f"移除 {before_cols - after_cols} 个全NaN列")
logger.info(f"特征矩阵: {df.shape[0]} 行 x {df.shape[1]}")
return df
def generate_labels(df: pd.DataFrame, forward_periods: int = None,
threshold: float = None) -> pd.Series:
"""
生成交易标签
:param df: 包含 close 列的 DataFrame
:param forward_periods: 未来N根K线
:param threshold: 涨跌阈值
:return: Series值为 0=观望, 1=做多, 2=做空
"""
if forward_periods is None:
forward_periods = FC['label_forward_periods']
if threshold is None:
threshold = FC['label_threshold']
future_return = df['close'].shift(-forward_periods) / df['close'] - 1
labels = pd.Series(0, index=df.index, name='label') # 默认观望
labels[future_return > threshold] = 1 # 做多
labels[future_return < -threshold] = 2 # 做空
# 最后 forward_periods 行无法计算标签设为NaN
labels.iloc[-forward_periods:] = np.nan
dist = labels.value_counts().to_dict()
logger.info(f"标签分布: 观望={dist.get(0, 0)}, 做多={dist.get(1, 0)}, 做空={dist.get(2, 0)}")
return labels
def prepare_dataset(period: int = None, start_date: str = None, end_date: str = None,
has_volume: bool = False) -> tuple:
"""
一键准备训练数据集
:return: (X, y, feature_names) 已去NaN已标准化
"""
if period is None:
period = PRIMARY_PERIOD
# 加载主周期
primary_df = load_kline(period, start_date, end_date)
if primary_df.empty:
raise ValueError(f"{period}分钟 K线数据为空")
# 加载辅助周期
aux_dfs = {}
for aux_p in AUX_PERIODS:
try:
aux_df = load_kline(aux_p, start_date, end_date)
if not aux_df.empty:
aux_dfs[aux_p] = aux_df
except Exception as e:
logger.warning(f"加载 {aux_p}分钟 辅助数据失败: {e}")
# 构建特征
df = build_features(primary_df, aux_dfs, has_volume=has_volume)
# 生成标签
labels = generate_labels(df)
df = df.copy()
df['label'] = labels
# 去除NaN行
df.dropna(inplace=True)
logger.info(f"去NaN后剩余 {len(df)}")
# 分离 X, y
exclude_cols = ['open', 'high', 'low', 'close', 'timestamp', 'label']
if 'volume' in df.columns:
exclude_cols.append('volume')
# 排除价格级别特征(会泄露绝对价格信息导致过拟合)
price_level_patterns = [
'sma_', 'ema_', 'bb_upper', 'bb_mid', 'bb_lower',
'kc_upper', 'kc_mid', 'kc_lower', 'dc_upper', 'dc_mid', 'dc_lower',
'ichimoku_conv', 'ichimoku_base', 'ichimoku_a', 'ichimoku_b',
'kama', 'vwap', 'obv', 'adi', 'vpt', 'nvi',
'momentum_3', 'momentum_5',
]
feature_cols = []
for c in df.columns:
if c in exclude_cols:
continue
base_name = c.split('_lag')[0]
# 去掉周期后缀
for suffix in ['_5m', '_60m', '_3m', '_15m', '_30m', '_1m']:
if base_name.endswith(suffix):
base_name = base_name[:-len(suffix)]
break
if any(base_name.startswith(p) or base_name == p.rstrip('_') for p in price_level_patterns):
continue
feature_cols.append(c)
logger.info(f"排除价格级别特征后剩余 {len(feature_cols)} 个特征")
X = df[feature_cols].copy()
y = df['label'].astype(int)
# 标准化
scaler = None
if FC['normalize']:
scaler = StandardScaler()
X = pd.DataFrame(scaler.fit_transform(X), columns=feature_cols, index=X.index)
logger.info("特征已标准化")
logger.info(f"最终数据集: X={X.shape}, y={y.shape}, 特征数={len(feature_cols)}")
return X, y, feature_cols, scaler
def get_latest_feature_row(period: int = None, feature_cols: list = None,
start_date: str = None, end_date: str = None) -> pd.DataFrame:
"""
构建特征并返回最后一行的特征矩阵用于实盘/模拟盘预测
需传入已保存的 feature_cols保证与训练时一致
:return: 1 DataFrame列为 feature_cols若缺列则返回空 DataFrame
"""
if period is None:
period = PRIMARY_PERIOD
if not feature_cols:
return pd.DataFrame()
primary_df = load_kline(period, start_date, end_date)
if primary_df.empty:
return pd.DataFrame()
aux_dfs = {}
for aux_p in AUX_PERIODS:
try:
aux_df = load_kline(aux_p, start_date, end_date)
if not aux_df.empty:
aux_dfs[aux_p] = aux_df
except Exception:
pass
df = build_features(primary_df, aux_dfs, has_volume=False)
labels = generate_labels(df)
df = df.copy()
df['label'] = labels
df.dropna(inplace=True)
missing = [c for c in feature_cols if c not in df.columns]
if missing:
logger.warning(f"get_latest_feature_row: 缺少特征列 {missing[:5]}{'...' if len(missing) > 5 else ''}")
return pd.DataFrame()
return df[feature_cols].iloc[-1:].copy()