This commit is contained in:
ddrwode
2026-02-21 15:20:33 +08:00
parent 1c74dadb7e
commit 2d9914ec19
7 changed files with 91 additions and 10 deletions

View File

@@ -1 +1 @@
["macd", "macd_signal", "macd_hist", "adx", "di_plus", "di_minus", "trix", "aroon_up", "aroon_down", "cci", "dpo", "kst", "vortex_pos", "vortex_neg", "rsi", "stoch_k", "stoch_d", "williams_r", "roc", "tsi", "uo", "ao", "ppo", "stoch_rsi_k", "stoch_rsi_d", "bb_width", "bb_pband", "atr", "price_change_pct", "high_low_range", "body_ratio", "upper_shadow", "lower_shadow", "volatility_ratio", "close_sma20_ratio", "close_ema12_ratio", "macd_lag1", "macd_signal_lag1", "macd_hist_lag1", "adx_lag1", "di_plus_lag1", "di_minus_lag1", "trix_lag1", "aroon_up_lag1", "aroon_down_lag1", "cci_lag1", "dpo_lag1", "kst_lag1", "vortex_pos_lag1", "vortex_neg_lag1", "rsi_lag1", "stoch_k_lag1", "stoch_d_lag1", "williams_r_lag1", "roc_lag1", "tsi_lag1", "uo_lag1", "ao_lag1", "ppo_lag1", "stoch_rsi_k_lag1", "stoch_rsi_d_lag1", "bb_width_lag1", "bb_pband_lag1", "atr_lag1", "price_change_pct_lag1", "high_low_range_lag1", "body_ratio_lag1", "upper_shadow_lag1", "lower_shadow_lag1", "volatility_ratio_lag1", "close_sma20_ratio_lag1", "close_ema12_ratio_lag1", "macd_lag3", "macd_signal_lag3", "macd_hist_lag3", "adx_lag3", "di_plus_lag3", "di_minus_lag3", "trix_lag3", "aroon_up_lag3", "aroon_down_lag3", "cci_lag3", "dpo_lag3", "kst_lag3", "vortex_pos_lag3", "vortex_neg_lag3", "rsi_lag3", "stoch_k_lag3", "stoch_d_lag3", "williams_r_lag3", "roc_lag3", "tsi_lag3", "uo_lag3", "ao_lag3", "ppo_lag3", "stoch_rsi_k_lag3", "stoch_rsi_d_lag3", "bb_width_lag3", "bb_pband_lag3", "atr_lag3", "price_change_pct_lag3", "high_low_range_lag3", "body_ratio_lag3", "upper_shadow_lag3", "lower_shadow_lag3", "volatility_ratio_lag3", "close_sma20_ratio_lag3", "close_ema12_ratio_lag3", "macd_lag5", "macd_signal_lag5", "macd_hist_lag5", "adx_lag5", "di_plus_lag5", "di_minus_lag5", "trix_lag5", "aroon_up_lag5", "aroon_down_lag5", "cci_lag5", "dpo_lag5", "kst_lag5", "vortex_pos_lag5", "vortex_neg_lag5", "rsi_lag5", "stoch_k_lag5", "stoch_d_lag5", "williams_r_lag5", "roc_lag5", "tsi_lag5", "uo_lag5", "ao_lag5", "ppo_lag5", "stoch_rsi_k_lag5", "stoch_rsi_d_lag5", "bb_width_lag5", "bb_pband_lag5", "atr_lag5", "price_change_pct_lag5", "high_low_range_lag5", "body_ratio_lag5", "upper_shadow_lag5", "lower_shadow_lag5", "volatility_ratio_lag5", "close_sma20_ratio_lag5", "close_ema12_ratio_lag5", "rsi_5m", "macd_5m", "adx_5m", "bb_pband_5m", "atr_5m", "cci_5m", "rsi_60m", "macd_60m", "adx_60m", "bb_pband_60m", "atr_60m", "cci_60m"]
["macd", "macd_signal", "macd_hist", "adx", "di_plus", "di_minus", "trix", "aroon_up", "aroon_down", "cci", "dpo", "kst", "vortex_pos", "vortex_neg", "rsi", "stoch_k", "stoch_d", "williams_r", "roc", "tsi", "uo", "ao", "ppo", "stoch_rsi_k", "stoch_rsi_d", "bb_width", "bb_pband", "atr", "price_change_pct", "high_low_range", "body_ratio", "upper_shadow", "lower_shadow", "volatility_ratio", "close_sma20_ratio", "close_ema12_ratio", "macd_lag1", "macd_signal_lag1", "macd_hist_lag1", "adx_lag1", "di_plus_lag1", "di_minus_lag1", "trix_lag1", "aroon_up_lag1", "aroon_down_lag1", "cci_lag1", "dpo_lag1", "kst_lag1", "vortex_pos_lag1", "vortex_neg_lag1", "rsi_lag1", "stoch_k_lag1", "stoch_d_lag1", "williams_r_lag1", "roc_lag1", "tsi_lag1", "uo_lag1", "ao_lag1", "ppo_lag1", "stoch_rsi_k_lag1", "stoch_rsi_d_lag1", "bb_width_lag1", "bb_pband_lag1", "atr_lag1", "price_change_pct_lag1", "high_low_range_lag1", "body_ratio_lag1", "upper_shadow_lag1", "lower_shadow_lag1", "volatility_ratio_lag1", "close_sma20_ratio_lag1", "close_ema12_ratio_lag1", "macd_lag3", "macd_signal_lag3", "macd_hist_lag3", "adx_lag3", "di_plus_lag3", "di_minus_lag3", "trix_lag3", "aroon_up_lag3", "aroon_down_lag3", "cci_lag3", "dpo_lag3", "kst_lag3", "vortex_pos_lag3", "vortex_neg_lag3", "rsi_lag3", "stoch_k_lag3", "stoch_d_lag3", "williams_r_lag3", "roc_lag3", "tsi_lag3", "uo_lag3", "ao_lag3", "ppo_lag3", "stoch_rsi_k_lag3", "stoch_rsi_d_lag3", "bb_width_lag3", "bb_pband_lag3", "atr_lag3", "price_change_pct_lag3", "high_low_range_lag3", "body_ratio_lag3", "upper_shadow_lag3", "lower_shadow_lag3", "volatility_ratio_lag3", "close_sma20_ratio_lag3", "close_ema12_ratio_lag3", "macd_lag5", "macd_signal_lag5", "macd_hist_lag5", "adx_lag5", "di_plus_lag5", "di_minus_lag5", "trix_lag5", "aroon_up_lag5", "aroon_down_lag5", "cci_lag5", "dpo_lag5", "kst_lag5", "vortex_pos_lag5", "vortex_neg_lag5", "rsi_lag5", "stoch_k_lag5", "stoch_d_lag5", "williams_r_lag5", "roc_lag5", "tsi_lag5", "uo_lag5", "ao_lag5", "ppo_lag5", "stoch_rsi_k_lag5", "stoch_rsi_d_lag5", "bb_width_lag5", "bb_pband_lag5", "atr_lag5", "price_change_pct_lag5", "high_low_range_lag5", "body_ratio_lag5", "upper_shadow_lag5", "lower_shadow_lag5", "volatility_ratio_lag5", "close_sma20_ratio_lag5", "close_ema12_ratio_lag5", "rsi_1m", "macd_1m", "adx_1m", "bb_pband_1m", "atr_1m", "cci_1m", "rsi_5m", "macd_5m", "adx_5m", "bb_pband_5m", "atr_5m", "cci_5m", "rsi_60m", "macd_60m", "adx_60m", "bb_pband_60m", "atr_60m", "cci_60m", "pullback_ratio_1m", "recovery_ratio_1m"]

Binary file not shown.

Binary file not shown.

View File

@@ -1,14 +1,26 @@
#!/usr/bin/env python
"""
仅运行方案BLightGBM训练并保存模型供实盘脚本 get_live_signal 使用。
仅运行方案BLightGBM/XGBoost)训练并保存模型,供实盘脚本 get_live_signal 使用。
需先保证 models/database.db 中有 15m/5m/1h K 线(例如运行 抓取多周期K线.py
macOS Apple Silicon 若 LightGBM 报 libomp 错误,可用 --model xgboost 或自动会 fallback 到 XGBoost。
"""
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent))
from strategy.ai_strategy import run_ai_strategy
def _lightgbm_available():
try:
import lightgbm as lgb # noqa: F401
return True
except OSError as e:
if 'libomp' in str(e).lower() or 'libomp.dylib' in str(e):
return False
raise
except Exception:
return False
if __name__ == '__main__':
@@ -16,6 +28,17 @@ if __name__ == '__main__':
p.add_argument('--period', type=int, default=15, help='主周期分钟,默认 15')
p.add_argument('--start', type=str, default=None, help='回测开始日期,如 2024-01-01')
p.add_argument('--end', type=str, default=None, help='回测结束日期,如 2024-12-31')
p.add_argument('--model', type=str, default=None, choices=['lightgbm', 'xgboost'],
help='模型: lightgbm 或 xgboost不指定时优先 lightgbm失败则用 xgboost')
args = p.parse_args()
run_ai_strategy(model_type='lightgbm', period=args.period,
model_type = args.model
if model_type is None:
model_type = 'lightgbm'
if not _lightgbm_available():
print('LightGBM 不可用(常见于 macOS 缺 libomp改用 XGBoost。若需 LightGBM 请安装: brew install libomp')
model_type = 'xgboost'
from strategy.ai_strategy import run_ai_strategy
run_ai_strategy(model_type=model_type, period=args.period,
start_date=args.start, end_date=args.end)

View File

@@ -8,6 +8,7 @@ import pandas as pd
import numpy as np
from pathlib import Path
from loguru import logger
from sklearn.utils.class_weight import compute_sample_weight
from .config import MODEL_CONFIG as MC, PRIMARY_PERIOD, PROJECT_ROOT
from .feature_engine import prepare_dataset, get_latest_feature_row
@@ -71,9 +72,10 @@ class AIStrategy:
X_test = X.iloc[train_end:test_end]
y_test = y.iloc[train_end:test_end]
# 训练
# 训练:使用类别平衡权重,避免模型偏向做多、很少出做空
model = self._create_model()
model.fit(X_train, y_train)
sw = compute_sample_weight('balanced', y_train)
model.fit(X_train, y_train, sample_weight=sw)
self.models.append(model)
# 预测概率 + 置信度过滤
@@ -160,6 +162,15 @@ class AIStrategy:
print_metrics(result['metrics'], f"方案B: {self.model_type} AI策略")
# 4.1 打印每月收益USDT
if result.get('monthly_pnl') is not None and not result['monthly_pnl'].empty:
mp = result['monthly_pnl'].astype(float).round(2)
logger.info("\n" + "-" * 50)
logger.info(" 方案B 每月收益 (USDT)")
logger.info("-" * 50)
logger.info(f"\n{mp.to_string()}")
logger.info("-" * 50)
# 5. 保存最后一窗模型、scaler、特征列供实盘 get_live_signal 使用)
if self.models and scaler is not None:
SCHEME_B_MODEL_DIR.mkdir(parents=True, exist_ok=True)
@@ -214,8 +225,16 @@ def get_live_signal(period: int = None, model_type: str = 'lightgbm',
proba = model.predict_proba(X_scaled_df)[0] # (p0, p1, p2)
pred = model.predict(X_scaled_df)[0]
confidence_threshold = MC.get('confidence_threshold', 0.45)
logger.info(f"方案B 预测概率: 观望={proba[0]:.2f} 做多={proba[1]:.2f} 做空={proba[2]:.2f} -> {int(pred)}")
if proba.max() < confidence_threshold:
logger.info(f"置信度 {proba.max():.2f} < {confidence_threshold},视为观望")
# 做空可单独使用更低阈值,避免模型偏向做多导致从不开空
threshold_short = MC.get('confidence_threshold_short')
pred_int = int(pred)
use_threshold = (threshold_short if threshold_short is not None and pred_int == 2 else confidence_threshold)
proba_pred = proba[pred_int]
logger.info(f"方案B 预测概率: 观望={proba[0]:.2f} 做多={proba[1]:.2f} 做空={proba[2]:.2f} -> {pred_int}")
if proba_pred < use_threshold:
if pred_int == 2:
logger.info(f"做空概率 {proba[2]:.2f} < {use_threshold},视为观望(可在 config 中设置 confidence_threshold_short=0.40 或重训)")
else:
logger.info(f"预测类别置信度 {proba_pred:.2f} < {use_threshold},视为观望")
return 0
return int(pred)
return pred_int

View File

@@ -134,6 +134,10 @@ MODEL_CONFIG = {
'num_class': 3,
'verbosity': 0,
},
# 实盘信号置信度阈值get_live_signal
'confidence_threshold': 0.45, # 观望/做多/做空 通用下限,低于则输出观望
'confidence_threshold_short': None, # 若设置(如 0.40),做空时用此阈值,便于多出做空信号
}
# ============ 统计筛选参数 ============

View File

@@ -11,6 +11,37 @@ from .data_loader import load_kline
from .indicators import compute_all_indicators, get_indicator_names
def _add_intra15m_1m_pullback_features(df: pd.DataFrame, primary_df: pd.DataFrame,
one_min_df: pd.DataFrame) -> pd.DataFrame:
"""
用 1 分钟数据在每根 15m K 线内的最高/最低,计算「冲高回落」「探底回升」程度,
便于模型识别 15 分钟涨到一半又回调的情况。
"""
if one_min_df.empty or 'high' not in one_min_df.columns or 'low' not in one_min_df.columns:
return df
# 每根 1m K 线归属到其所在 15m 的结束时间
one_min = one_min_df[['high', 'low']].copy()
one_min['15m_end'] = one_min.index.ceil('15min')
agg = one_min.groupby('15m_end').agg(intra_high_1m=('high', 'max'), intra_low_1m=('low', 'min'))
# 对齐到主周期索引
agg = agg.reindex(primary_df.index)
agg = agg.ffill().bfill() # 边界缺失时前后填充
h15 = primary_df['high'].values
l15 = primary_df['low'].values
c15 = primary_df['close'].values
range_15 = h15 - l15
range_15 = np.where(range_15 <= 0, np.nan, range_15)
pullback = (agg['intra_high_1m'].values - c15) / range_15 # 收盘相对 15m 内 1m 最高回落比例
recovery = (c15 - agg['intra_low_1m'].values) / range_15 # 收盘相对 15m 内 1m 最低回升比例
pullback = np.clip(np.nan_to_num(pullback, nan=0), 0, 1)
recovery = np.clip(np.nan_to_num(recovery, nan=0), 0, 1)
df = df.copy()
df['pullback_ratio_1m'] = pullback
df['recovery_ratio_1m'] = recovery
logger.info("已加入 1m 周期内冲高回落特征: pullback_ratio_1m, recovery_ratio_1m")
return df
def build_features(primary_df: pd.DataFrame,
aux_dfs: dict = None,
has_volume: bool = False) -> pd.DataFrame:
@@ -51,6 +82,10 @@ def build_features(primary_df: pd.DataFrame,
if aux_frames:
df = pd.concat([df] + aux_frames, axis=1)
# 3.5 1 分钟周期内「冲高回落」特征15m 内 1m 最高/最低 vs 15m 收盘,用于判断涨到一半又回调)
if aux_dfs and 1 in aux_dfs:
df = _add_intra15m_1m_pullback_features(df, primary_df, aux_dfs[1])
# 4. 去除全NaN列
before_cols = len(df.columns)
df.dropna(axis=1, how='all', inplace=True)