This commit is contained in:
ddrwode
2026-02-21 15:38:23 +08:00
parent 2d9914ec19
commit f0fe26acbf
5 changed files with 126 additions and 40 deletions

View File

@@ -9,8 +9,9 @@ import numpy as np
from pathlib import Path
from loguru import logger
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.preprocessing import StandardScaler
from .config import MODEL_CONFIG as MC, PRIMARY_PERIOD, PROJECT_ROOT
from .config import MODEL_CONFIG as MC, FEATURE_CONFIG as FC, PRIMARY_PERIOD, PROJECT_ROOT
from .feature_engine import prepare_dataset, get_latest_feature_row
from .backtest import BacktestEngine, print_metrics
@@ -29,6 +30,8 @@ class AIStrategy:
"""
self.model_type = model_type
self.models = [] # 存储每个窗口训练的模型
self.scalers = [] # 与每个窗口模型对应的 scaler若启用标准化
self.last_scaler = None
self.feature_importance = None
def _create_model(self):
@@ -45,7 +48,8 @@ class AIStrategy:
raise ValueError(f"不支持的模型类型: {self.model_type}")
def walk_forward_train(self, X: pd.DataFrame, y: pd.Series,
confidence_threshold: float = 0.45) -> pd.Series:
confidence_threshold: float = 0.45,
normalize: bool = True) -> pd.Series:
"""
Walk-Forward 滚动训练与预测
:param confidence_threshold: 概率阈值低于此值的预测设为0观望
@@ -56,27 +60,47 @@ class AIStrategy:
step = MC['walk_forward_step']
n = len(X)
all_preds = pd.Series(dtype=float)
all_preds = pd.Series(dtype=int)
window_count = 0
logger.info(f"Walk-Forward: 数据量={n}, 训练窗口={train_size}, "
f"测试窗口={test_size}, 步长={step}, 置信阈值={confidence_threshold}")
f"测试窗口={test_size}, 步长={step}, 置信阈值={confidence_threshold}, "
f"窗口内标准化={'' if normalize else ''}")
start = 0
while start + train_size + test_size <= n:
train_end = start + train_size
test_end = min(train_end + test_size, n)
X_train = X.iloc[start:train_end]
X_train_raw = X.iloc[start:train_end]
y_train = y.iloc[start:train_end]
X_test = X.iloc[train_end:test_end]
X_test_raw = X.iloc[train_end:test_end]
y_test = y.iloc[train_end:test_end]
# 每个窗口单独拟合 scaler避免未来数据泄露
if normalize:
scaler = StandardScaler()
X_train = pd.DataFrame(
scaler.fit_transform(X_train_raw),
columns=X.columns,
index=X_train_raw.index,
)
X_test = pd.DataFrame(
scaler.transform(X_test_raw),
columns=X.columns,
index=X_test_raw.index,
)
else:
scaler = None
X_train = X_train_raw
X_test = X_test_raw
# 训练:使用类别平衡权重,避免模型偏向做多、很少出做空
model = self._create_model()
sw = compute_sample_weight('balanced', y_train)
model.fit(X_train, y_train, sample_weight=sw)
self.models.append(model)
self.scalers.append(scaler)
# 预测概率 + 置信度过滤
proba = model.predict_proba(X_test)
@@ -107,6 +131,7 @@ class AIStrategy:
self.feature_importance = pd.Series(
last_model.feature_importances_, index=X.columns
).sort_values(ascending=False)
self.last_scaler = self.scalers[-1] if self.scalers else None
logger.info(f"Walk-Forward 完成: {window_count} 个窗口, "
f"{len(all_preds)} 条预测")
@@ -141,10 +166,16 @@ class AIStrategy:
load_start = load_start_ts.strftime('%Y-%m-%d')
logger.info(f"回测区间 [{start_date} ~ {end_date}],向前加载 {warm_months} 月至 {load_start} 用于训练")
X, y, feature_names, scaler = prepare_dataset(period, load_start, load_end)
# AI 策略使用窗口内标准化prepare_dataset 此处返回未标准化特征,避免全样本泄露
X, y, feature_names, _ = prepare_dataset(period, load_start, load_end, normalize=False)
# 2. Walk-Forward 训练
predictions = self.walk_forward_train(X, y)
predictions = self.walk_forward_train(
X,
y,
confidence_threshold=MC.get('confidence_threshold', 0.45),
normalize=bool(FC.get('normalize', True)),
)
# 3. 回测仅用用户指定区间将预测对齐到该区间的每根K线
df = load_kline(period, start_date, end_date)
@@ -172,13 +203,18 @@ class AIStrategy:
logger.info("-" * 50)
# 5. 保存最后一窗模型、scaler、特征列供实盘 get_live_signal 使用)
if self.models and scaler is not None:
if self.models:
SCHEME_B_MODEL_DIR.mkdir(parents=True, exist_ok=True)
joblib.dump(self.models[-1], SCHEME_B_MODEL_FILE)
joblib.dump(scaler, SCHEME_B_SCALER_FILE)
# 若未启用标准化,保存 None实盘推理时将自动跳过 transform
joblib.dump(self.last_scaler, SCHEME_B_SCALER_FILE)
with open(SCHEME_B_FEATURES_FILE, 'w', encoding='utf-8') as f:
json.dump(feature_names, f, ensure_ascii=False)
logger.info(f"已保存方案B模型: {SCHEME_B_MODEL_FILE}, scaler, {len(feature_names)} 个特征")
logger.info(
f"已保存方案B模型: {SCHEME_B_MODEL_FILE}, "
f"scaler={'启用' if self.last_scaler is not None else ''}, "
f"{len(feature_names)} 个特征"
)
# 6. 输出特征重要性
top_feat = self.get_top_features(15)
@@ -210,18 +246,23 @@ def get_live_signal(period: int = None, model_type: str = 'lightgbm',
"""
if period is None:
period = PRIMARY_PERIOD
if not SCHEME_B_MODEL_FILE.exists() or not SCHEME_B_SCALER_FILE.exists() or not SCHEME_B_FEATURES_FILE.exists():
if not SCHEME_B_MODEL_FILE.exists() or not SCHEME_B_FEATURES_FILE.exists():
logger.warning("方案B模型未找到请先运行 AI 策略训练保存模型")
return 0
model = joblib.load(SCHEME_B_MODEL_FILE)
scaler = joblib.load(SCHEME_B_SCALER_FILE)
scaler = None
if SCHEME_B_SCALER_FILE.exists():
scaler = joblib.load(SCHEME_B_SCALER_FILE)
with open(SCHEME_B_FEATURES_FILE, 'r', encoding='utf-8') as f:
feature_cols = json.load(f)
X_last = get_latest_feature_row(period, feature_cols, start_date, end_date)
if X_last.empty:
return 0
X_scaled = scaler.transform(X_last)
X_scaled_df = pd.DataFrame(X_scaled, columns=feature_cols, index=X_last.index)
if scaler is not None:
X_scaled = scaler.transform(X_last)
X_scaled_df = pd.DataFrame(X_scaled, columns=feature_cols, index=X_last.index)
else:
X_scaled_df = X_last.copy()
proba = model.predict_proba(X_scaled_df)[0] # (p0, p1, p2)
pred = model.predict(X_scaled_df)[0]
confidence_threshold = MC.get('confidence_threshold', 0.45)

View File

@@ -19,13 +19,14 @@ def _add_intra15m_1m_pullback_features(df: pd.DataFrame, primary_df: pd.DataFram
"""
if one_min_df.empty or 'high' not in one_min_df.columns or 'low' not in one_min_df.columns:
return df
# 每根 1m K 线归属到其所在 15m 的结束时间
# 每根 1m K 线归属到其所在 15m 的起始时间(数据库 15m id 与 bar 起始时刻对齐)
one_min = one_min_df[['high', 'low']].copy()
one_min['15m_end'] = one_min.index.ceil('15min')
agg = one_min.groupby('15m_end').agg(intra_high_1m=('high', 'max'), intra_low_1m=('low', 'min'))
one_min['15m_start'] = one_min.index.floor('15min')
agg = one_min.groupby('15m_start').agg(intra_high_1m=('high', 'max'), intra_low_1m=('low', 'min'))
# 对齐到主周期索引
agg = agg.reindex(primary_df.index)
agg = agg.ffill().bfill() # 边界缺失时前后填充
# 仅前向填充,避免使用未来数据回填到过去
agg = agg.ffill()
h15 = primary_df['high'].values
l15 = primary_df['low'].values
c15 = primary_df['close'].values
@@ -126,10 +127,10 @@ def generate_labels(df: pd.DataFrame, forward_periods: int = None,
def prepare_dataset(period: int = None, start_date: str = None, end_date: str = None,
has_volume: bool = False) -> tuple:
has_volume: bool = False, normalize: bool = None) -> tuple:
"""
一键准备训练数据集
:return: (X, y, feature_names) — 已去NaN、已标准化
:return: (X, y, feature_names, scaler) — 已去NaN;是否标准化由 normalize 决定
"""
if period is None:
period = PRIMARY_PERIOD
@@ -193,9 +194,11 @@ def prepare_dataset(period: int = None, start_date: str = None, end_date: str =
X = df[feature_cols].copy()
y = df['label'].astype(int)
# 标准化
# 标准化(可选)
scaler = None
if FC['normalize']:
if normalize is None:
normalize = FC['normalize']
if normalize:
scaler = StandardScaler()
X = pd.DataFrame(scaler.fit_transform(X), columns=feature_cols, index=X.index)
logger.info("特征已标准化")
@@ -228,12 +231,14 @@ def get_latest_feature_row(period: int = None, feature_cols: list = None,
except Exception:
pass
df = build_features(primary_df, aux_dfs, has_volume=False)
labels = generate_labels(df)
df = df.copy()
df['label'] = labels
df.dropna(inplace=True)
missing = [c for c in feature_cols if c not in df.columns]
if missing:
logger.warning(f"get_latest_feature_row: 缺少特征列 {missing[:5]}{'...' if len(missing) > 5 else ''}")
return pd.DataFrame()
return df[feature_cols].iloc[-1:].copy()
X = df[feature_cols].copy()
X.replace([np.inf, -np.inf], np.nan, inplace=True)
X.dropna(inplace=True)
if X.empty:
logger.warning("get_latest_feature_row: 可用于预测的特征为空数据不足或存在NaN")
return pd.DataFrame()
return X.iloc[-1:].copy()