haha
This commit is contained in:
@@ -19,13 +19,14 @@ def _add_intra15m_1m_pullback_features(df: pd.DataFrame, primary_df: pd.DataFram
|
||||
"""
|
||||
if one_min_df.empty or 'high' not in one_min_df.columns or 'low' not in one_min_df.columns:
|
||||
return df
|
||||
# 每根 1m K 线归属到其所在 15m 的结束时间
|
||||
# 每根 1m K 线归属到其所在 15m 的起始时间(数据库 15m id 与 bar 起始时刻对齐)
|
||||
one_min = one_min_df[['high', 'low']].copy()
|
||||
one_min['15m_end'] = one_min.index.ceil('15min')
|
||||
agg = one_min.groupby('15m_end').agg(intra_high_1m=('high', 'max'), intra_low_1m=('low', 'min'))
|
||||
one_min['15m_start'] = one_min.index.floor('15min')
|
||||
agg = one_min.groupby('15m_start').agg(intra_high_1m=('high', 'max'), intra_low_1m=('low', 'min'))
|
||||
# 对齐到主周期索引
|
||||
agg = agg.reindex(primary_df.index)
|
||||
agg = agg.ffill().bfill() # 边界缺失时前后填充
|
||||
# 仅前向填充,避免使用未来数据回填到过去
|
||||
agg = agg.ffill()
|
||||
h15 = primary_df['high'].values
|
||||
l15 = primary_df['low'].values
|
||||
c15 = primary_df['close'].values
|
||||
@@ -126,10 +127,10 @@ def generate_labels(df: pd.DataFrame, forward_periods: int = None,
|
||||
|
||||
|
||||
def prepare_dataset(period: int = None, start_date: str = None, end_date: str = None,
|
||||
has_volume: bool = False) -> tuple:
|
||||
has_volume: bool = False, normalize: bool = None) -> tuple:
|
||||
"""
|
||||
一键准备训练数据集
|
||||
:return: (X, y, feature_names) — 已去NaN、已标准化
|
||||
:return: (X, y, feature_names, scaler) — 已去NaN;是否标准化由 normalize 决定
|
||||
"""
|
||||
if period is None:
|
||||
period = PRIMARY_PERIOD
|
||||
@@ -193,9 +194,11 @@ def prepare_dataset(period: int = None, start_date: str = None, end_date: str =
|
||||
X = df[feature_cols].copy()
|
||||
y = df['label'].astype(int)
|
||||
|
||||
# 标准化
|
||||
# 标准化(可选)
|
||||
scaler = None
|
||||
if FC['normalize']:
|
||||
if normalize is None:
|
||||
normalize = FC['normalize']
|
||||
if normalize:
|
||||
scaler = StandardScaler()
|
||||
X = pd.DataFrame(scaler.fit_transform(X), columns=feature_cols, index=X.index)
|
||||
logger.info("特征已标准化")
|
||||
@@ -228,12 +231,14 @@ def get_latest_feature_row(period: int = None, feature_cols: list = None,
|
||||
except Exception:
|
||||
pass
|
||||
df = build_features(primary_df, aux_dfs, has_volume=False)
|
||||
labels = generate_labels(df)
|
||||
df = df.copy()
|
||||
df['label'] = labels
|
||||
df.dropna(inplace=True)
|
||||
missing = [c for c in feature_cols if c not in df.columns]
|
||||
if missing:
|
||||
logger.warning(f"get_latest_feature_row: 缺少特征列 {missing[:5]}{'...' if len(missing) > 5 else ''}")
|
||||
return pd.DataFrame()
|
||||
return df[feature_cols].iloc[-1:].copy()
|
||||
X = df[feature_cols].copy()
|
||||
X.replace([np.inf, -np.inf], np.nan, inplace=True)
|
||||
X.dropna(inplace=True)
|
||||
if X.empty:
|
||||
logger.warning("get_latest_feature_row: 可用于预测的特征为空(数据不足或存在NaN)")
|
||||
return pd.DataFrame()
|
||||
return X.iloc[-1:].copy()
|
||||
|
||||
Reference in New Issue
Block a user