haha
This commit is contained in:
@@ -11,6 +11,37 @@ from .data_loader import load_kline
|
||||
from .indicators import compute_all_indicators, get_indicator_names
|
||||
|
||||
|
||||
def _add_intra15m_1m_pullback_features(df: pd.DataFrame, primary_df: pd.DataFrame,
|
||||
one_min_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
用 1 分钟数据在每根 15m K 线内的最高/最低,计算「冲高回落」「探底回升」程度,
|
||||
便于模型识别 15 分钟涨到一半又回调的情况。
|
||||
"""
|
||||
if one_min_df.empty or 'high' not in one_min_df.columns or 'low' not in one_min_df.columns:
|
||||
return df
|
||||
# 每根 1m K 线归属到其所在 15m 的结束时间
|
||||
one_min = one_min_df[['high', 'low']].copy()
|
||||
one_min['15m_end'] = one_min.index.ceil('15min')
|
||||
agg = one_min.groupby('15m_end').agg(intra_high_1m=('high', 'max'), intra_low_1m=('low', 'min'))
|
||||
# 对齐到主周期索引
|
||||
agg = agg.reindex(primary_df.index)
|
||||
agg = agg.ffill().bfill() # 边界缺失时前后填充
|
||||
h15 = primary_df['high'].values
|
||||
l15 = primary_df['low'].values
|
||||
c15 = primary_df['close'].values
|
||||
range_15 = h15 - l15
|
||||
range_15 = np.where(range_15 <= 0, np.nan, range_15)
|
||||
pullback = (agg['intra_high_1m'].values - c15) / range_15 # 收盘相对 15m 内 1m 最高回落比例
|
||||
recovery = (c15 - agg['intra_low_1m'].values) / range_15 # 收盘相对 15m 内 1m 最低回升比例
|
||||
pullback = np.clip(np.nan_to_num(pullback, nan=0), 0, 1)
|
||||
recovery = np.clip(np.nan_to_num(recovery, nan=0), 0, 1)
|
||||
df = df.copy()
|
||||
df['pullback_ratio_1m'] = pullback
|
||||
df['recovery_ratio_1m'] = recovery
|
||||
logger.info("已加入 1m 周期内冲高回落特征: pullback_ratio_1m, recovery_ratio_1m")
|
||||
return df
|
||||
|
||||
|
||||
def build_features(primary_df: pd.DataFrame,
|
||||
aux_dfs: dict = None,
|
||||
has_volume: bool = False) -> pd.DataFrame:
|
||||
@@ -51,6 +82,10 @@ def build_features(primary_df: pd.DataFrame,
|
||||
if aux_frames:
|
||||
df = pd.concat([df] + aux_frames, axis=1)
|
||||
|
||||
# 3.5 1 分钟周期内「冲高回落」特征(15m 内 1m 最高/最低 vs 15m 收盘,用于判断涨到一半又回调)
|
||||
if aux_dfs and 1 in aux_dfs:
|
||||
df = _add_intra15m_1m_pullback_features(df, primary_df, aux_dfs[1])
|
||||
|
||||
# 4. 去除全NaN列
|
||||
before_cols = len(df.columns)
|
||||
df.dropna(axis=1, how='all', inplace=True)
|
||||
|
||||
Reference in New Issue
Block a user