diff --git a/models/scheme_b_features.json b/models/scheme_b_features.json index 3806700..288d29e 100644 --- a/models/scheme_b_features.json +++ b/models/scheme_b_features.json @@ -1 +1 @@ -["macd", "macd_signal", "macd_hist", "adx", "di_plus", "di_minus", "trix", "aroon_up", "aroon_down", "cci", "dpo", "kst", "vortex_pos", "vortex_neg", "rsi", "stoch_k", "stoch_d", "williams_r", "roc", "tsi", "uo", "ao", "ppo", "stoch_rsi_k", "stoch_rsi_d", "bb_width", "bb_pband", "atr", "price_change_pct", "high_low_range", "body_ratio", "upper_shadow", "lower_shadow", "volatility_ratio", "close_sma20_ratio", "close_ema12_ratio", "macd_lag1", "macd_signal_lag1", "macd_hist_lag1", "adx_lag1", "di_plus_lag1", "di_minus_lag1", "trix_lag1", "aroon_up_lag1", "aroon_down_lag1", "cci_lag1", "dpo_lag1", "kst_lag1", "vortex_pos_lag1", "vortex_neg_lag1", "rsi_lag1", "stoch_k_lag1", "stoch_d_lag1", "williams_r_lag1", "roc_lag1", "tsi_lag1", "uo_lag1", "ao_lag1", "ppo_lag1", "stoch_rsi_k_lag1", "stoch_rsi_d_lag1", "bb_width_lag1", "bb_pband_lag1", "atr_lag1", "price_change_pct_lag1", "high_low_range_lag1", "body_ratio_lag1", "upper_shadow_lag1", "lower_shadow_lag1", "volatility_ratio_lag1", "close_sma20_ratio_lag1", "close_ema12_ratio_lag1", "macd_lag3", "macd_signal_lag3", "macd_hist_lag3", "adx_lag3", "di_plus_lag3", "di_minus_lag3", "trix_lag3", "aroon_up_lag3", "aroon_down_lag3", "cci_lag3", "dpo_lag3", "kst_lag3", "vortex_pos_lag3", "vortex_neg_lag3", "rsi_lag3", "stoch_k_lag3", "stoch_d_lag3", "williams_r_lag3", "roc_lag3", "tsi_lag3", "uo_lag3", "ao_lag3", "ppo_lag3", "stoch_rsi_k_lag3", "stoch_rsi_d_lag3", "bb_width_lag3", "bb_pband_lag3", "atr_lag3", "price_change_pct_lag3", "high_low_range_lag3", "body_ratio_lag3", "upper_shadow_lag3", "lower_shadow_lag3", "volatility_ratio_lag3", "close_sma20_ratio_lag3", "close_ema12_ratio_lag3", "macd_lag5", "macd_signal_lag5", "macd_hist_lag5", "adx_lag5", "di_plus_lag5", "di_minus_lag5", "trix_lag5", "aroon_up_lag5", "aroon_down_lag5", "cci_lag5", "dpo_lag5", "kst_lag5", "vortex_pos_lag5", "vortex_neg_lag5", "rsi_lag5", "stoch_k_lag5", "stoch_d_lag5", "williams_r_lag5", "roc_lag5", "tsi_lag5", "uo_lag5", "ao_lag5", "ppo_lag5", "stoch_rsi_k_lag5", "stoch_rsi_d_lag5", "bb_width_lag5", "bb_pband_lag5", "atr_lag5", "price_change_pct_lag5", "high_low_range_lag5", "body_ratio_lag5", "upper_shadow_lag5", "lower_shadow_lag5", "volatility_ratio_lag5", "close_sma20_ratio_lag5", "close_ema12_ratio_lag5", "rsi_5m", "macd_5m", "adx_5m", "bb_pband_5m", "atr_5m", "cci_5m", "rsi_60m", "macd_60m", "adx_60m", "bb_pband_60m", "atr_60m", "cci_60m"] \ No newline at end of file +["macd", "macd_signal", "macd_hist", "adx", "di_plus", "di_minus", "trix", "aroon_up", "aroon_down", "cci", "dpo", "kst", "vortex_pos", "vortex_neg", "rsi", "stoch_k", "stoch_d", "williams_r", "roc", "tsi", "uo", "ao", "ppo", "stoch_rsi_k", "stoch_rsi_d", "bb_width", "bb_pband", "atr", "price_change_pct", "high_low_range", "body_ratio", "upper_shadow", "lower_shadow", "volatility_ratio", "close_sma20_ratio", "close_ema12_ratio", "macd_lag1", "macd_signal_lag1", "macd_hist_lag1", "adx_lag1", "di_plus_lag1", "di_minus_lag1", "trix_lag1", "aroon_up_lag1", "aroon_down_lag1", "cci_lag1", "dpo_lag1", "kst_lag1", "vortex_pos_lag1", "vortex_neg_lag1", "rsi_lag1", "stoch_k_lag1", "stoch_d_lag1", "williams_r_lag1", "roc_lag1", "tsi_lag1", "uo_lag1", "ao_lag1", "ppo_lag1", "stoch_rsi_k_lag1", "stoch_rsi_d_lag1", "bb_width_lag1", "bb_pband_lag1", "atr_lag1", "price_change_pct_lag1", "high_low_range_lag1", "body_ratio_lag1", "upper_shadow_lag1", "lower_shadow_lag1", "volatility_ratio_lag1", "close_sma20_ratio_lag1", "close_ema12_ratio_lag1", "macd_lag3", "macd_signal_lag3", "macd_hist_lag3", "adx_lag3", "di_plus_lag3", "di_minus_lag3", "trix_lag3", "aroon_up_lag3", "aroon_down_lag3", "cci_lag3", "dpo_lag3", "kst_lag3", "vortex_pos_lag3", "vortex_neg_lag3", "rsi_lag3", "stoch_k_lag3", "stoch_d_lag3", "williams_r_lag3", "roc_lag3", "tsi_lag3", "uo_lag3", "ao_lag3", "ppo_lag3", "stoch_rsi_k_lag3", "stoch_rsi_d_lag3", "bb_width_lag3", "bb_pband_lag3", "atr_lag3", "price_change_pct_lag3", "high_low_range_lag3", "body_ratio_lag3", "upper_shadow_lag3", "lower_shadow_lag3", "volatility_ratio_lag3", "close_sma20_ratio_lag3", "close_ema12_ratio_lag3", "macd_lag5", "macd_signal_lag5", "macd_hist_lag5", "adx_lag5", "di_plus_lag5", "di_minus_lag5", "trix_lag5", "aroon_up_lag5", "aroon_down_lag5", "cci_lag5", "dpo_lag5", "kst_lag5", "vortex_pos_lag5", "vortex_neg_lag5", "rsi_lag5", "stoch_k_lag5", "stoch_d_lag5", "williams_r_lag5", "roc_lag5", "tsi_lag5", "uo_lag5", "ao_lag5", "ppo_lag5", "stoch_rsi_k_lag5", "stoch_rsi_d_lag5", "bb_width_lag5", "bb_pband_lag5", "atr_lag5", "price_change_pct_lag5", "high_low_range_lag5", "body_ratio_lag5", "upper_shadow_lag5", "lower_shadow_lag5", "volatility_ratio_lag5", "close_sma20_ratio_lag5", "close_ema12_ratio_lag5", "rsi_1m", "macd_1m", "adx_1m", "bb_pband_1m", "atr_1m", "cci_1m", "rsi_5m", "macd_5m", "adx_5m", "bb_pband_5m", "atr_5m", "cci_5m", "rsi_60m", "macd_60m", "adx_60m", "bb_pband_60m", "atr_60m", "cci_60m", "pullback_ratio_1m", "recovery_ratio_1m"] \ No newline at end of file diff --git a/models/scheme_b_last_model.joblib b/models/scheme_b_last_model.joblib index a92beb2..f1ff2e6 100644 Binary files a/models/scheme_b_last_model.joblib and b/models/scheme_b_last_model.joblib differ diff --git a/models/scheme_b_scaler.joblib b/models/scheme_b_scaler.joblib index 465dcce..3efa0f2 100644 Binary files a/models/scheme_b_scaler.joblib and b/models/scheme_b_scaler.joblib differ diff --git a/run_scheme_b_train.py b/run_scheme_b_train.py index e42e327..1a3e52b 100644 --- a/run_scheme_b_train.py +++ b/run_scheme_b_train.py @@ -1,14 +1,26 @@ #!/usr/bin/env python """ -仅运行方案B(LightGBM)训练并保存模型,供实盘脚本 get_live_signal 使用。 +仅运行方案B(LightGBM/XGBoost)训练并保存模型,供实盘脚本 get_live_signal 使用。 需先保证 models/database.db 中有 15m/5m/1h K 线(例如运行 抓取多周期K线.py)。 +macOS Apple Silicon 若 LightGBM 报 libomp 错误,可用 --model xgboost 或自动会 fallback 到 XGBoost。 """ import argparse import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).resolve().parent)) -from strategy.ai_strategy import run_ai_strategy + + +def _lightgbm_available(): + try: + import lightgbm as lgb # noqa: F401 + return True + except OSError as e: + if 'libomp' in str(e).lower() or 'libomp.dylib' in str(e): + return False + raise + except Exception: + return False if __name__ == '__main__': @@ -16,6 +28,17 @@ if __name__ == '__main__': p.add_argument('--period', type=int, default=15, help='主周期分钟,默认 15') p.add_argument('--start', type=str, default=None, help='回测开始日期,如 2024-01-01') p.add_argument('--end', type=str, default=None, help='回测结束日期,如 2024-12-31') + p.add_argument('--model', type=str, default=None, choices=['lightgbm', 'xgboost'], + help='模型: lightgbm 或 xgboost;不指定时优先 lightgbm,失败则用 xgboost') args = p.parse_args() - run_ai_strategy(model_type='lightgbm', period=args.period, + + model_type = args.model + if model_type is None: + model_type = 'lightgbm' + if not _lightgbm_available(): + print('LightGBM 不可用(常见于 macOS 缺 libomp),改用 XGBoost。若需 LightGBM 请安装: brew install libomp') + model_type = 'xgboost' + + from strategy.ai_strategy import run_ai_strategy + run_ai_strategy(model_type=model_type, period=args.period, start_date=args.start, end_date=args.end) diff --git a/strategy/ai_strategy.py b/strategy/ai_strategy.py index be4f0df..6e0d1b2 100644 --- a/strategy/ai_strategy.py +++ b/strategy/ai_strategy.py @@ -8,6 +8,7 @@ import pandas as pd import numpy as np from pathlib import Path from loguru import logger +from sklearn.utils.class_weight import compute_sample_weight from .config import MODEL_CONFIG as MC, PRIMARY_PERIOD, PROJECT_ROOT from .feature_engine import prepare_dataset, get_latest_feature_row @@ -71,9 +72,10 @@ class AIStrategy: X_test = X.iloc[train_end:test_end] y_test = y.iloc[train_end:test_end] - # 训练 + # 训练:使用类别平衡权重,避免模型偏向做多、很少出做空 model = self._create_model() - model.fit(X_train, y_train) + sw = compute_sample_weight('balanced', y_train) + model.fit(X_train, y_train, sample_weight=sw) self.models.append(model) # 预测概率 + 置信度过滤 @@ -160,6 +162,15 @@ class AIStrategy: print_metrics(result['metrics'], f"方案B: {self.model_type} AI策略") + # 4.1 打印每月收益(USDT) + if result.get('monthly_pnl') is not None and not result['monthly_pnl'].empty: + mp = result['monthly_pnl'].astype(float).round(2) + logger.info("\n" + "-" * 50) + logger.info(" 方案B 每月收益 (USDT)") + logger.info("-" * 50) + logger.info(f"\n{mp.to_string()}") + logger.info("-" * 50) + # 5. 保存最后一窗模型、scaler、特征列(供实盘 get_live_signal 使用) if self.models and scaler is not None: SCHEME_B_MODEL_DIR.mkdir(parents=True, exist_ok=True) @@ -214,8 +225,16 @@ def get_live_signal(period: int = None, model_type: str = 'lightgbm', proba = model.predict_proba(X_scaled_df)[0] # (p0, p1, p2) pred = model.predict(X_scaled_df)[0] confidence_threshold = MC.get('confidence_threshold', 0.45) - logger.info(f"方案B 预测概率: 观望={proba[0]:.2f} 做多={proba[1]:.2f} 做空={proba[2]:.2f} -> {int(pred)}") - if proba.max() < confidence_threshold: - logger.info(f"置信度 {proba.max():.2f} < {confidence_threshold},视为观望") + # 做空可单独使用更低阈值,避免模型偏向做多导致从不开空 + threshold_short = MC.get('confidence_threshold_short') + pred_int = int(pred) + use_threshold = (threshold_short if threshold_short is not None and pred_int == 2 else confidence_threshold) + proba_pred = proba[pred_int] + logger.info(f"方案B 预测概率: 观望={proba[0]:.2f} 做多={proba[1]:.2f} 做空={proba[2]:.2f} -> {pred_int}") + if proba_pred < use_threshold: + if pred_int == 2: + logger.info(f"做空概率 {proba[2]:.2f} < {use_threshold},视为观望(可在 config 中设置 confidence_threshold_short=0.40 或重训)") + else: + logger.info(f"预测类别置信度 {proba_pred:.2f} < {use_threshold},视为观望") return 0 - return int(pred) + return pred_int diff --git a/strategy/config.py b/strategy/config.py index bcb2fb1..f9a3833 100644 --- a/strategy/config.py +++ b/strategy/config.py @@ -134,6 +134,10 @@ MODEL_CONFIG = { 'num_class': 3, 'verbosity': 0, }, + + # 实盘信号置信度阈值(get_live_signal) + 'confidence_threshold': 0.45, # 观望/做多/做空 通用下限,低于则输出观望 + 'confidence_threshold_short': None, # 若设置(如 0.40),做空时用此阈值,便于多出做空信号 } # ============ 统计筛选参数 ============ diff --git a/strategy/feature_engine.py b/strategy/feature_engine.py index 433c475..3ca18fa 100644 --- a/strategy/feature_engine.py +++ b/strategy/feature_engine.py @@ -11,6 +11,37 @@ from .data_loader import load_kline from .indicators import compute_all_indicators, get_indicator_names +def _add_intra15m_1m_pullback_features(df: pd.DataFrame, primary_df: pd.DataFrame, + one_min_df: pd.DataFrame) -> pd.DataFrame: + """ + 用 1 分钟数据在每根 15m K 线内的最高/最低,计算「冲高回落」「探底回升」程度, + 便于模型识别 15 分钟涨到一半又回调的情况。 + """ + if one_min_df.empty or 'high' not in one_min_df.columns or 'low' not in one_min_df.columns: + return df + # 每根 1m K 线归属到其所在 15m 的结束时间 + one_min = one_min_df[['high', 'low']].copy() + one_min['15m_end'] = one_min.index.ceil('15min') + agg = one_min.groupby('15m_end').agg(intra_high_1m=('high', 'max'), intra_low_1m=('low', 'min')) + # 对齐到主周期索引 + agg = agg.reindex(primary_df.index) + agg = agg.ffill().bfill() # 边界缺失时前后填充 + h15 = primary_df['high'].values + l15 = primary_df['low'].values + c15 = primary_df['close'].values + range_15 = h15 - l15 + range_15 = np.where(range_15 <= 0, np.nan, range_15) + pullback = (agg['intra_high_1m'].values - c15) / range_15 # 收盘相对 15m 内 1m 最高回落比例 + recovery = (c15 - agg['intra_low_1m'].values) / range_15 # 收盘相对 15m 内 1m 最低回升比例 + pullback = np.clip(np.nan_to_num(pullback, nan=0), 0, 1) + recovery = np.clip(np.nan_to_num(recovery, nan=0), 0, 1) + df = df.copy() + df['pullback_ratio_1m'] = pullback + df['recovery_ratio_1m'] = recovery + logger.info("已加入 1m 周期内冲高回落特征: pullback_ratio_1m, recovery_ratio_1m") + return df + + def build_features(primary_df: pd.DataFrame, aux_dfs: dict = None, has_volume: bool = False) -> pd.DataFrame: @@ -51,6 +82,10 @@ def build_features(primary_df: pd.DataFrame, if aux_frames: df = pd.concat([df] + aux_frames, axis=1) + # 3.5 1 分钟周期内「冲高回落」特征(15m 内 1m 最高/最低 vs 15m 收盘,用于判断涨到一半又回调) + if aux_dfs and 1 in aux_dfs: + df = _add_intra15m_1m_pullback_features(df, primary_df, aux_dfs[1]) + # 4. 去除全NaN列 before_cols = len(df.columns) df.dropna(axis=1, how='all', inplace=True)