245 lines
9.3 KiB
Python
245 lines
9.3 KiB
Python
"""
|
||
特征工程 — 标准化、多周期融合、滞后特征、标签生成
|
||
"""
|
||
import pandas as pd
|
||
import numpy as np
|
||
from sklearn.preprocessing import StandardScaler
|
||
from loguru import logger
|
||
|
||
from .config import FEATURE_CONFIG as FC, PRIMARY_PERIOD, AUX_PERIODS
|
||
from .data_loader import load_kline
|
||
from .indicators import compute_all_indicators, get_indicator_names
|
||
|
||
|
||
def _add_intra15m_1m_pullback_features(df: pd.DataFrame, primary_df: pd.DataFrame,
|
||
one_min_df: pd.DataFrame) -> pd.DataFrame:
|
||
"""
|
||
用 1 分钟数据在每根 15m K 线内的最高/最低,计算「冲高回落」「探底回升」程度,
|
||
便于模型识别 15 分钟涨到一半又回调的情况。
|
||
"""
|
||
if one_min_df.empty or 'high' not in one_min_df.columns or 'low' not in one_min_df.columns:
|
||
return df
|
||
# 每根 1m K 线归属到其所在 15m 的起始时间(数据库 15m id 与 bar 起始时刻对齐)
|
||
one_min = one_min_df[['high', 'low']].copy()
|
||
one_min['15m_start'] = one_min.index.floor('15min')
|
||
agg = one_min.groupby('15m_start').agg(intra_high_1m=('high', 'max'), intra_low_1m=('low', 'min'))
|
||
# 对齐到主周期索引
|
||
agg = agg.reindex(primary_df.index)
|
||
# 仅前向填充,避免使用未来数据回填到过去
|
||
agg = agg.ffill()
|
||
h15 = primary_df['high'].values
|
||
l15 = primary_df['low'].values
|
||
c15 = primary_df['close'].values
|
||
range_15 = h15 - l15
|
||
range_15 = np.where(range_15 <= 0, np.nan, range_15)
|
||
pullback = (agg['intra_high_1m'].values - c15) / range_15 # 收盘相对 15m 内 1m 最高回落比例
|
||
recovery = (c15 - agg['intra_low_1m'].values) / range_15 # 收盘相对 15m 内 1m 最低回升比例
|
||
pullback = np.clip(np.nan_to_num(pullback, nan=0), 0, 1)
|
||
recovery = np.clip(np.nan_to_num(recovery, nan=0), 0, 1)
|
||
df = df.copy()
|
||
df['pullback_ratio_1m'] = pullback
|
||
df['recovery_ratio_1m'] = recovery
|
||
logger.info("已加入 1m 周期内冲高回落特征: pullback_ratio_1m, recovery_ratio_1m")
|
||
return df
|
||
|
||
|
||
def build_features(primary_df: pd.DataFrame,
|
||
aux_dfs: dict = None,
|
||
has_volume: bool = False) -> pd.DataFrame:
|
||
"""
|
||
构建完整特征矩阵
|
||
:param primary_df: 主周期K线 DataFrame
|
||
:param aux_dfs: {period: DataFrame} 辅助周期数据(可选)
|
||
:param has_volume: 是否有成交量
|
||
:return: 特征矩阵 DataFrame
|
||
"""
|
||
# 1. 主周期指标
|
||
logger.info("计算主周期指标...")
|
||
df = compute_all_indicators(primary_df, has_volume=has_volume)
|
||
|
||
# 2. 滞后特征
|
||
logger.info("生成滞后特征...")
|
||
indicator_cols = get_indicator_names(has_volume)
|
||
existing_cols = [col for col in indicator_cols if col in df.columns]
|
||
|
||
lag_frames = []
|
||
for lag in FC['lookback_lags']:
|
||
lagged = df[existing_cols].shift(lag).add_suffix(f'_lag{lag}')
|
||
lag_frames.append(lagged)
|
||
if lag_frames:
|
||
df = pd.concat([df] + lag_frames, axis=1)
|
||
|
||
# 3. 多周期融合
|
||
if aux_dfs:
|
||
aux_frames = []
|
||
for period, aux_df in aux_dfs.items():
|
||
logger.info(f"融合 {period}分钟 辅助周期特征...")
|
||
aux_with_ind = compute_all_indicators(aux_df, has_volume=has_volume)
|
||
key_indicators = ['rsi', 'macd', 'adx', 'bb_pband', 'atr', 'cci']
|
||
for ind in key_indicators:
|
||
if ind in aux_with_ind.columns:
|
||
aligned = aux_with_ind[ind].reindex(df.index, method='ffill')
|
||
aux_frames.append(aligned.rename(f'{ind}_{period}m'))
|
||
if aux_frames:
|
||
df = pd.concat([df] + aux_frames, axis=1)
|
||
|
||
# 3.5 1 分钟周期内「冲高回落」特征(15m 内 1m 最高/最低 vs 15m 收盘,用于判断涨到一半又回调)
|
||
if aux_dfs and 1 in aux_dfs:
|
||
df = _add_intra15m_1m_pullback_features(df, primary_df, aux_dfs[1])
|
||
|
||
# 4. 去除全NaN列
|
||
before_cols = len(df.columns)
|
||
df.dropna(axis=1, how='all', inplace=True)
|
||
after_cols = len(df.columns)
|
||
if before_cols != after_cols:
|
||
logger.info(f"移除 {before_cols - after_cols} 个全NaN列")
|
||
|
||
logger.info(f"特征矩阵: {df.shape[0]} 行 x {df.shape[1]} 列")
|
||
return df
|
||
|
||
|
||
def generate_labels(df: pd.DataFrame, forward_periods: int = None,
|
||
threshold: float = None) -> pd.Series:
|
||
"""
|
||
生成交易标签
|
||
:param df: 包含 close 列的 DataFrame
|
||
:param forward_periods: 未来N根K线
|
||
:param threshold: 涨跌阈值
|
||
:return: Series,值为 0=观望, 1=做多, 2=做空
|
||
"""
|
||
if forward_periods is None:
|
||
forward_periods = FC['label_forward_periods']
|
||
if threshold is None:
|
||
threshold = FC['label_threshold']
|
||
|
||
future_return = df['close'].shift(-forward_periods) / df['close'] - 1
|
||
|
||
labels = pd.Series(0, index=df.index, name='label') # 默认观望
|
||
labels[future_return > threshold] = 1 # 做多
|
||
labels[future_return < -threshold] = 2 # 做空
|
||
|
||
# 最后 forward_periods 行无法计算标签,设为NaN
|
||
labels.iloc[-forward_periods:] = np.nan
|
||
|
||
dist = labels.value_counts().to_dict()
|
||
logger.info(f"标签分布: 观望={dist.get(0, 0)}, 做多={dist.get(1, 0)}, 做空={dist.get(2, 0)}")
|
||
return labels
|
||
|
||
|
||
def prepare_dataset(period: int = None, start_date: str = None, end_date: str = None,
|
||
has_volume: bool = False, normalize: bool = None) -> tuple:
|
||
"""
|
||
一键准备训练数据集
|
||
:return: (X, y, feature_names, scaler) — 已去NaN;是否标准化由 normalize 决定
|
||
"""
|
||
if period is None:
|
||
period = PRIMARY_PERIOD
|
||
|
||
# 加载主周期
|
||
primary_df = load_kline(period, start_date, end_date)
|
||
if primary_df.empty:
|
||
raise ValueError(f"{period}分钟 K线数据为空")
|
||
|
||
# 加载辅助周期
|
||
aux_dfs = {}
|
||
for aux_p in AUX_PERIODS:
|
||
try:
|
||
aux_df = load_kline(aux_p, start_date, end_date)
|
||
if not aux_df.empty:
|
||
aux_dfs[aux_p] = aux_df
|
||
except Exception as e:
|
||
logger.warning(f"加载 {aux_p}分钟 辅助数据失败: {e}")
|
||
|
||
# 构建特征
|
||
df = build_features(primary_df, aux_dfs, has_volume=has_volume)
|
||
|
||
# 生成标签
|
||
labels = generate_labels(df)
|
||
df = df.copy()
|
||
df['label'] = labels
|
||
|
||
# 去除NaN行
|
||
df.dropna(inplace=True)
|
||
logger.info(f"去NaN后剩余 {len(df)} 行")
|
||
|
||
# 分离 X, y
|
||
exclude_cols = ['open', 'high', 'low', 'close', 'timestamp', 'label']
|
||
if 'volume' in df.columns:
|
||
exclude_cols.append('volume')
|
||
|
||
# 排除价格级别特征(会泄露绝对价格信息导致过拟合)
|
||
price_level_patterns = [
|
||
'sma_', 'ema_', 'bb_upper', 'bb_mid', 'bb_lower',
|
||
'kc_upper', 'kc_mid', 'kc_lower', 'dc_upper', 'dc_mid', 'dc_lower',
|
||
'ichimoku_conv', 'ichimoku_base', 'ichimoku_a', 'ichimoku_b',
|
||
'kama', 'vwap', 'obv', 'adi', 'vpt', 'nvi',
|
||
'momentum_3', 'momentum_5',
|
||
]
|
||
feature_cols = []
|
||
for c in df.columns:
|
||
if c in exclude_cols:
|
||
continue
|
||
base_name = c.split('_lag')[0]
|
||
# 去掉周期后缀
|
||
for suffix in ['_5m', '_60m', '_3m', '_15m', '_30m', '_1m']:
|
||
if base_name.endswith(suffix):
|
||
base_name = base_name[:-len(suffix)]
|
||
break
|
||
if any(base_name.startswith(p) or base_name == p.rstrip('_') for p in price_level_patterns):
|
||
continue
|
||
feature_cols.append(c)
|
||
|
||
logger.info(f"排除价格级别特征后剩余 {len(feature_cols)} 个特征")
|
||
|
||
X = df[feature_cols].copy()
|
||
y = df['label'].astype(int)
|
||
|
||
# 标准化(可选)
|
||
scaler = None
|
||
if normalize is None:
|
||
normalize = FC['normalize']
|
||
if normalize:
|
||
scaler = StandardScaler()
|
||
X = pd.DataFrame(scaler.fit_transform(X), columns=feature_cols, index=X.index)
|
||
logger.info("特征已标准化")
|
||
|
||
logger.info(f"最终数据集: X={X.shape}, y={y.shape}, 特征数={len(feature_cols)}")
|
||
return X, y, feature_cols, scaler
|
||
|
||
|
||
def get_latest_feature_row(period: int = None, feature_cols: list = None,
|
||
start_date: str = None, end_date: str = None) -> pd.DataFrame:
|
||
"""
|
||
构建特征并返回最后一行的特征矩阵(用于实盘/模拟盘预测)。
|
||
需传入已保存的 feature_cols,保证与训练时一致。
|
||
:return: 1 行 DataFrame,列为 feature_cols;若缺列则返回空 DataFrame
|
||
"""
|
||
if period is None:
|
||
period = PRIMARY_PERIOD
|
||
if not feature_cols:
|
||
return pd.DataFrame()
|
||
|
||
primary_df = load_kline(period, start_date, end_date)
|
||
if primary_df.empty:
|
||
return pd.DataFrame()
|
||
aux_dfs = {}
|
||
for aux_p in AUX_PERIODS:
|
||
try:
|
||
aux_df = load_kline(aux_p, start_date, end_date)
|
||
if not aux_df.empty:
|
||
aux_dfs[aux_p] = aux_df
|
||
except Exception:
|
||
pass
|
||
df = build_features(primary_df, aux_dfs, has_volume=False)
|
||
missing = [c for c in feature_cols if c not in df.columns]
|
||
if missing:
|
||
logger.warning(f"get_latest_feature_row: 缺少特征列 {missing[:5]}{'...' if len(missing) > 5 else ''}")
|
||
return pd.DataFrame()
|
||
X = df[feature_cols].copy()
|
||
X.replace([np.inf, -np.inf], np.nan, inplace=True)
|
||
X.dropna(inplace=True)
|
||
if X.empty:
|
||
logger.warning("get_latest_feature_row: 可用于预测的特征为空(数据不足或存在NaN)")
|
||
return pd.DataFrame()
|
||
return X.iloc[-1:].copy()
|