205 lines
7.2 KiB
Python
205 lines
7.2 KiB
Python
|
|
"""
|
|||
|
|
特征工程 — 标准化、多周期融合、滞后特征、标签生成
|
|||
|
|
"""
|
|||
|
|
import pandas as pd
|
|||
|
|
import numpy as np
|
|||
|
|
from sklearn.preprocessing import StandardScaler
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
from .config import FEATURE_CONFIG as FC, PRIMARY_PERIOD, AUX_PERIODS
|
|||
|
|
from .data_loader import load_kline
|
|||
|
|
from .indicators import compute_all_indicators, get_indicator_names
|
|||
|
|
|
|||
|
|
|
|||
|
|
def build_features(primary_df: pd.DataFrame,
|
|||
|
|
aux_dfs: dict = None,
|
|||
|
|
has_volume: bool = False) -> pd.DataFrame:
|
|||
|
|
"""
|
|||
|
|
构建完整特征矩阵
|
|||
|
|
:param primary_df: 主周期K线 DataFrame
|
|||
|
|
:param aux_dfs: {period: DataFrame} 辅助周期数据(可选)
|
|||
|
|
:param has_volume: 是否有成交量
|
|||
|
|
:return: 特征矩阵 DataFrame
|
|||
|
|
"""
|
|||
|
|
# 1. 主周期指标
|
|||
|
|
logger.info("计算主周期指标...")
|
|||
|
|
df = compute_all_indicators(primary_df, has_volume=has_volume)
|
|||
|
|
|
|||
|
|
# 2. 滞后特征
|
|||
|
|
logger.info("生成滞后特征...")
|
|||
|
|
indicator_cols = get_indicator_names(has_volume)
|
|||
|
|
existing_cols = [col for col in indicator_cols if col in df.columns]
|
|||
|
|
|
|||
|
|
lag_frames = []
|
|||
|
|
for lag in FC['lookback_lags']:
|
|||
|
|
lagged = df[existing_cols].shift(lag).add_suffix(f'_lag{lag}')
|
|||
|
|
lag_frames.append(lagged)
|
|||
|
|
if lag_frames:
|
|||
|
|
df = pd.concat([df] + lag_frames, axis=1)
|
|||
|
|
|
|||
|
|
# 3. 多周期融合
|
|||
|
|
if aux_dfs:
|
|||
|
|
aux_frames = []
|
|||
|
|
for period, aux_df in aux_dfs.items():
|
|||
|
|
logger.info(f"融合 {period}分钟 辅助周期特征...")
|
|||
|
|
aux_with_ind = compute_all_indicators(aux_df, has_volume=has_volume)
|
|||
|
|
key_indicators = ['rsi', 'macd', 'adx', 'bb_pband', 'atr', 'cci']
|
|||
|
|
for ind in key_indicators:
|
|||
|
|
if ind in aux_with_ind.columns:
|
|||
|
|
aligned = aux_with_ind[ind].reindex(df.index, method='ffill')
|
|||
|
|
aux_frames.append(aligned.rename(f'{ind}_{period}m'))
|
|||
|
|
if aux_frames:
|
|||
|
|
df = pd.concat([df] + aux_frames, axis=1)
|
|||
|
|
|
|||
|
|
# 4. 去除全NaN列
|
|||
|
|
before_cols = len(df.columns)
|
|||
|
|
df.dropna(axis=1, how='all', inplace=True)
|
|||
|
|
after_cols = len(df.columns)
|
|||
|
|
if before_cols != after_cols:
|
|||
|
|
logger.info(f"移除 {before_cols - after_cols} 个全NaN列")
|
|||
|
|
|
|||
|
|
logger.info(f"特征矩阵: {df.shape[0]} 行 x {df.shape[1]} 列")
|
|||
|
|
return df
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_labels(df: pd.DataFrame, forward_periods: int = None,
|
|||
|
|
threshold: float = None) -> pd.Series:
|
|||
|
|
"""
|
|||
|
|
生成交易标签
|
|||
|
|
:param df: 包含 close 列的 DataFrame
|
|||
|
|
:param forward_periods: 未来N根K线
|
|||
|
|
:param threshold: 涨跌阈值
|
|||
|
|
:return: Series,值为 0=观望, 1=做多, 2=做空
|
|||
|
|
"""
|
|||
|
|
if forward_periods is None:
|
|||
|
|
forward_periods = FC['label_forward_periods']
|
|||
|
|
if threshold is None:
|
|||
|
|
threshold = FC['label_threshold']
|
|||
|
|
|
|||
|
|
future_return = df['close'].shift(-forward_periods) / df['close'] - 1
|
|||
|
|
|
|||
|
|
labels = pd.Series(0, index=df.index, name='label') # 默认观望
|
|||
|
|
labels[future_return > threshold] = 1 # 做多
|
|||
|
|
labels[future_return < -threshold] = 2 # 做空
|
|||
|
|
|
|||
|
|
# 最后 forward_periods 行无法计算标签,设为NaN
|
|||
|
|
labels.iloc[-forward_periods:] = np.nan
|
|||
|
|
|
|||
|
|
dist = labels.value_counts().to_dict()
|
|||
|
|
logger.info(f"标签分布: 观望={dist.get(0, 0)}, 做多={dist.get(1, 0)}, 做空={dist.get(2, 0)}")
|
|||
|
|
return labels
|
|||
|
|
|
|||
|
|
|
|||
|
|
def prepare_dataset(period: int = None, start_date: str = None, end_date: str = None,
|
|||
|
|
has_volume: bool = False) -> tuple:
|
|||
|
|
"""
|
|||
|
|
一键准备训练数据集
|
|||
|
|
:return: (X, y, feature_names) — 已去NaN、已标准化
|
|||
|
|
"""
|
|||
|
|
if period is None:
|
|||
|
|
period = PRIMARY_PERIOD
|
|||
|
|
|
|||
|
|
# 加载主周期
|
|||
|
|
primary_df = load_kline(period, start_date, end_date)
|
|||
|
|
if primary_df.empty:
|
|||
|
|
raise ValueError(f"{period}分钟 K线数据为空")
|
|||
|
|
|
|||
|
|
# 加载辅助周期
|
|||
|
|
aux_dfs = {}
|
|||
|
|
for aux_p in AUX_PERIODS:
|
|||
|
|
try:
|
|||
|
|
aux_df = load_kline(aux_p, start_date, end_date)
|
|||
|
|
if not aux_df.empty:
|
|||
|
|
aux_dfs[aux_p] = aux_df
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"加载 {aux_p}分钟 辅助数据失败: {e}")
|
|||
|
|
|
|||
|
|
# 构建特征
|
|||
|
|
df = build_features(primary_df, aux_dfs, has_volume=has_volume)
|
|||
|
|
|
|||
|
|
# 生成标签
|
|||
|
|
labels = generate_labels(df)
|
|||
|
|
df = df.copy()
|
|||
|
|
df['label'] = labels
|
|||
|
|
|
|||
|
|
# 去除NaN行
|
|||
|
|
df.dropna(inplace=True)
|
|||
|
|
logger.info(f"去NaN后剩余 {len(df)} 行")
|
|||
|
|
|
|||
|
|
# 分离 X, y
|
|||
|
|
exclude_cols = ['open', 'high', 'low', 'close', 'timestamp', 'label']
|
|||
|
|
if 'volume' in df.columns:
|
|||
|
|
exclude_cols.append('volume')
|
|||
|
|
|
|||
|
|
# 排除价格级别特征(会泄露绝对价格信息导致过拟合)
|
|||
|
|
price_level_patterns = [
|
|||
|
|
'sma_', 'ema_', 'bb_upper', 'bb_mid', 'bb_lower',
|
|||
|
|
'kc_upper', 'kc_mid', 'kc_lower', 'dc_upper', 'dc_mid', 'dc_lower',
|
|||
|
|
'ichimoku_conv', 'ichimoku_base', 'ichimoku_a', 'ichimoku_b',
|
|||
|
|
'kama', 'vwap', 'obv', 'adi', 'vpt', 'nvi',
|
|||
|
|
'momentum_3', 'momentum_5',
|
|||
|
|
]
|
|||
|
|
feature_cols = []
|
|||
|
|
for c in df.columns:
|
|||
|
|
if c in exclude_cols:
|
|||
|
|
continue
|
|||
|
|
base_name = c.split('_lag')[0]
|
|||
|
|
# 去掉周期后缀
|
|||
|
|
for suffix in ['_5m', '_60m', '_3m', '_15m', '_30m', '_1m']:
|
|||
|
|
if base_name.endswith(suffix):
|
|||
|
|
base_name = base_name[:-len(suffix)]
|
|||
|
|
break
|
|||
|
|
if any(base_name.startswith(p) or base_name == p.rstrip('_') for p in price_level_patterns):
|
|||
|
|
continue
|
|||
|
|
feature_cols.append(c)
|
|||
|
|
|
|||
|
|
logger.info(f"排除价格级别特征后剩余 {len(feature_cols)} 个特征")
|
|||
|
|
|
|||
|
|
X = df[feature_cols].copy()
|
|||
|
|
y = df['label'].astype(int)
|
|||
|
|
|
|||
|
|
# 标准化
|
|||
|
|
scaler = None
|
|||
|
|
if FC['normalize']:
|
|||
|
|
scaler = StandardScaler()
|
|||
|
|
X = pd.DataFrame(scaler.fit_transform(X), columns=feature_cols, index=X.index)
|
|||
|
|
logger.info("特征已标准化")
|
|||
|
|
|
|||
|
|
logger.info(f"最终数据集: X={X.shape}, y={y.shape}, 特征数={len(feature_cols)}")
|
|||
|
|
return X, y, feature_cols, scaler
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_latest_feature_row(period: int = None, feature_cols: list = None,
|
|||
|
|
start_date: str = None, end_date: str = None) -> pd.DataFrame:
|
|||
|
|
"""
|
|||
|
|
构建特征并返回最后一行的特征矩阵(用于实盘/模拟盘预测)。
|
|||
|
|
需传入已保存的 feature_cols,保证与训练时一致。
|
|||
|
|
:return: 1 行 DataFrame,列为 feature_cols;若缺列则返回空 DataFrame
|
|||
|
|
"""
|
|||
|
|
if period is None:
|
|||
|
|
period = PRIMARY_PERIOD
|
|||
|
|
if not feature_cols:
|
|||
|
|
return pd.DataFrame()
|
|||
|
|
|
|||
|
|
primary_df = load_kline(period, start_date, end_date)
|
|||
|
|
if primary_df.empty:
|
|||
|
|
return pd.DataFrame()
|
|||
|
|
aux_dfs = {}
|
|||
|
|
for aux_p in AUX_PERIODS:
|
|||
|
|
try:
|
|||
|
|
aux_df = load_kline(aux_p, start_date, end_date)
|
|||
|
|
if not aux_df.empty:
|
|||
|
|
aux_dfs[aux_p] = aux_df
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
df = build_features(primary_df, aux_dfs, has_volume=False)
|
|||
|
|
labels = generate_labels(df)
|
|||
|
|
df = df.copy()
|
|||
|
|
df['label'] = labels
|
|||
|
|
df.dropna(inplace=True)
|
|||
|
|
missing = [c for c in feature_cols if c not in df.columns]
|
|||
|
|
if missing:
|
|||
|
|
logger.warning(f"get_latest_feature_row: 缺少特征列 {missing[:5]}{'...' if len(missing) > 5 else ''}")
|
|||
|
|
return pd.DataFrame()
|
|||
|
|
return df[feature_cols].iloc[-1:].copy()
|