990 lines
41 KiB
Python
990 lines
41 KiB
Python
"""
|
||
ETH 合约 AI 策略训练 + 回测
|
||
|
||
规则:
|
||
- 标的: ETH 合约 (BitMart 1分钟K线)
|
||
- 同一时间仅 1 个仓位 (多或空)
|
||
- 每笔固定: 100U 保证金 × 100 倍 = 10000U 名义价值
|
||
- 手续费 90% 返佣,目标: 月均净利 >= 1000 USDT
|
||
|
||
流程:
|
||
1. 训练: 使用 2020 年全年数据训练 LightGBM(特征含多种可调参数指标)
|
||
2. 回测: 使用 2021 年全年数据做严格样本外真实回测(模型从未见过 2021 数据)
|
||
3. 参数搜索在「2021 年月均净利」上选最佳并保存
|
||
"""
|
||
|
||
import datetime
|
||
import json
|
||
import sqlite3
|
||
import time as _time
|
||
from pathlib import Path
|
||
from collections import defaultdict
|
||
import numpy as np
|
||
import pandas as pd
|
||
import lightgbm as lgb
|
||
|
||
import warnings
|
||
warnings.filterwarnings('ignore')
|
||
|
||
# 常量:ETH 合约,100U 保证金 100 倍
|
||
MARGIN_PER_TRADE = 100.0 # USDT
|
||
LEVERAGE = 100
|
||
NOTIONAL_PER_TRADE = MARGIN_PER_TRADE * LEVERAGE # 10000 USDT
|
||
TAKER_FEE_RATE = 0.0006
|
||
REBATE_RATE = 0.90
|
||
# 手续费:先扣全额,返佣次日 8 点才到账,回撤按「可用资金」算
|
||
FULL_FEE_PER_TRADE = NOTIONAL_PER_TRADE * TAKER_FEE_RATE * 2
|
||
REBATE_PER_TRADE = FULL_FEE_PER_TRADE * REBATE_RATE
|
||
NET_FEE_PER_TRADE = FULL_FEE_PER_TRADE - REBATE_PER_TRADE # 净手续费
|
||
|
||
# 本金与风控目标
|
||
INITIAL_BALANCE = 10000.0 # 本金 10000 USDT
|
||
MAX_DD_TARGET = 500.0 # 目标:可用资金最大回撤 ≤ 500 USDT
|
||
|
||
# 训练 / 真实回测 时间范围(左闭右开)
|
||
TRAIN_START = '2020-01-01'
|
||
TRAIN_END_EXCLUSIVE = '2021-01-01' # 训练数据: [2020-01-01, 2021-01-01) = 2020 全年
|
||
TEST_START = '2021-01-01'
|
||
TEST_END_EXCLUSIVE = '2022-01-01' # 回测数据: [2021-01-01, 2022-01-01) = 2021 全年
|
||
|
||
|
||
def get_db_path():
|
||
return Path(__file__).parent / 'models' / 'database.db'
|
||
|
||
|
||
# 多周期:与 抓取多周期K线.py 一致,表名 bitmart_eth_{suffix}
|
||
PERIOD_TABLES = {
|
||
'1m': 'bitmart_eth_1m',
|
||
'3m': 'bitmart_eth_3m',
|
||
'5m': 'bitmart_eth_5m',
|
||
'15m': 'bitmart_eth_15m',
|
||
'30m': 'bitmart_eth_30m',
|
||
'1h': 'bitmart_eth_1h',
|
||
}
|
||
|
||
|
||
def _normalize_period(period) -> str:
|
||
"""5 -> '5m', 15 -> '15m', '5' -> '5m', '15m' -> '15m'"""
|
||
if isinstance(period, int):
|
||
return f'{period}m' if period != 60 else '1h'
|
||
s = str(period).strip().lower()
|
||
if s in PERIOD_TABLES:
|
||
return s
|
||
if s.endswith('m'):
|
||
return s
|
||
if s == '60' or s == '1h':
|
||
return '1h'
|
||
return f'{s}m' if s.isdigit() else s
|
||
|
||
|
||
# ==================== 数据加载 ====================
|
||
def load_klines(start_date: str = '2025-01-01', end_date: str = '2026-02-01', period: str = '1m'):
|
||
"""从 SQLite 读取指定周期 K 线。period: '1m'|'3m'|'5m'|'15m'|'30m'|'1h' 或 1|3|5|15|30|60"""
|
||
period = _normalize_period(period)
|
||
table = PERIOD_TABLES.get(period)
|
||
if not table:
|
||
raise ValueError(f"不支持的周期: {period},可选: {list(PERIOD_TABLES.keys())}")
|
||
db = get_db_path()
|
||
if not db.exists():
|
||
raise FileNotFoundError(f"数据库不存在: {db},请先运行 抓取多周期K线.py 拉取数据")
|
||
start_ts = int(datetime.datetime.strptime(start_date, '%Y-%m-%d').timestamp()) * 1000
|
||
end_ts = int(datetime.datetime.strptime(end_date, '%Y-%m-%d').timestamp()) * 1000
|
||
conn = sqlite3.connect(str(db))
|
||
df = pd.read_sql_query(
|
||
f"SELECT id as ts, open, high, low, close FROM {table} "
|
||
"WHERE id >= ? AND id < ? ORDER BY id",
|
||
conn, params=(start_ts, end_ts))
|
||
conn.close()
|
||
if len(df) == 0:
|
||
raise FileNotFoundError(f"表 {table} 中无 {start_date}~{end_date} 数据,请先抓取该周期 K 线")
|
||
df['datetime'] = pd.to_datetime(df['ts'], unit='ms')
|
||
df.set_index('datetime', inplace=True)
|
||
return df
|
||
|
||
|
||
# ==================== 指标参数(类布林带/均值回归,均可训练优化) ====================
|
||
def default_indicator_params():
|
||
"""默认指标参数,均可被搜索优化。含:布林、肯特纳、唐奇安、随机、CCI、威廉、RSI、ATR带、Z-Score、线性回归通道、中位数带、百分位排名、Chandelier、标准差带、Elder 射线等"""
|
||
return {
|
||
'bb_period': 20,
|
||
'bb_std': 2.0,
|
||
'keltner_period': 20,
|
||
'keltner_atr_mult': 2.0,
|
||
'donchian_period': 20,
|
||
'stoch_k_period': 14,
|
||
'stoch_d_period': 3,
|
||
'cci_period': 20,
|
||
'willr_period': 14,
|
||
'rsi_period': 14,
|
||
'atr_band_period': 20,
|
||
'atr_band_mult': 2.0,
|
||
# 均值回归类新增
|
||
'zscore_period': 20,
|
||
'lr_period': 20,
|
||
'lr_std_mult': 2.0,
|
||
'median_band_period': 20,
|
||
'pct_rank_period': 20,
|
||
'chandelier_period': 22,
|
||
'chandelier_mult': 3.0,
|
||
'std_band_period': 14,
|
||
'std_band_mult': 2.0,
|
||
'elder_period': 13,
|
||
}
|
||
|
||
|
||
# ==================== 特征工程(所有带带/通道类指标参数可调) ====================
|
||
def add_features(df: pd.DataFrame, ind: dict = None) -> pd.DataFrame:
|
||
"""生成特征。ind 为指标参数字典,缺省用 default_indicator_params()"""
|
||
if ind is None:
|
||
ind = default_indicator_params()
|
||
c = df['close']
|
||
h = df['high']
|
||
l = df['low']
|
||
o = df['open']
|
||
cp = c.replace(0, np.nan)
|
||
|
||
# --- 布林带 ---
|
||
bp = ind['bb_period']
|
||
bstd = ind['bb_std']
|
||
mid = c.rolling(bp).mean()
|
||
std = c.rolling(bp).std()
|
||
df['bb_upper'] = mid + bstd * std
|
||
df['bb_lower'] = mid - bstd * std
|
||
df['bb_mid'] = mid
|
||
df['bb_pct'] = (c - df['bb_lower']) / (df['bb_upper'] - df['bb_lower']).replace(0, np.nan)
|
||
df['bb_width'] = (df['bb_upper'] - df['bb_lower']) / mid.replace(0, np.nan)
|
||
|
||
# --- 肯特纳通道 (Keltner) ---
|
||
kp = ind['keltner_period']
|
||
katr = ind['keltner_atr_mult']
|
||
tr = pd.concat([h - l, (h - c.shift(1)).abs(), (l - c.shift(1)).abs()], axis=1).max(axis=1)
|
||
atr_k = tr.rolling(kp).mean()
|
||
k_mid = c.ewm(span=kp, adjust=False).mean()
|
||
df['keltner_upper'] = k_mid + katr * atr_k
|
||
df['keltner_lower'] = k_mid - katr * atr_k
|
||
df['keltner_mid'] = k_mid
|
||
df['keltner_pct'] = (c - df['keltner_lower']) / (df['keltner_upper'] - df['keltner_lower']).replace(0, np.nan)
|
||
df['keltner_width'] = (df['keltner_upper'] - df['keltner_lower']) / k_mid.replace(0, np.nan)
|
||
|
||
# --- 唐奇安通道 (Donchian) ---
|
||
dp = ind['donchian_period']
|
||
du = h.rolling(dp).max()
|
||
dd = l.rolling(dp).min()
|
||
dm = (du + dd) / 2
|
||
df['donchian_upper'] = du
|
||
df['donchian_lower'] = dd
|
||
df['donchian_mid'] = dm
|
||
df['donchian_pct'] = (c - dd) / (du - dd).replace(0, np.nan)
|
||
df['donchian_width'] = (du - dd) / dm.replace(0, np.nan)
|
||
|
||
# --- 随机指标 (Stochastic) ---
|
||
sk, sd = ind['stoch_k_period'], ind['stoch_d_period']
|
||
low_k = l.rolling(sk).min()
|
||
high_k = h.rolling(sk).max()
|
||
df['stoch_k'] = (c - low_k) / (high_k - low_k).replace(0, np.nan) * 100
|
||
df['stoch_d'] = df['stoch_k'].rolling(sd).mean()
|
||
|
||
# --- CCI ---
|
||
cci_p = ind['cci_period']
|
||
typical = (h + l + c) / 3
|
||
cci_m = typical.rolling(cci_p).mean()
|
||
cci_s = typical.rolling(cci_p).std()
|
||
df['cci'] = (typical - cci_m) / (0.015 * cci_s.replace(0, np.nan))
|
||
|
||
# --- 威廉 %R ---
|
||
wp = ind['willr_period']
|
||
high_w = h.rolling(wp).max()
|
||
low_w = l.rolling(wp).min()
|
||
df['willr'] = -100 * (high_w - c) / (high_w - low_w).replace(0, np.nan)
|
||
|
||
# --- RSI ---
|
||
rp = ind['rsi_period']
|
||
delta = c.diff()
|
||
gain = delta.clip(lower=0)
|
||
loss = (-delta).clip(lower=0)
|
||
avg_gain = gain.rolling(rp).mean()
|
||
avg_loss = loss.rolling(rp).mean()
|
||
rs = avg_gain / avg_loss.replace(0, np.nan)
|
||
df['rsi'] = 100 - 100 / (1 + rs)
|
||
|
||
# --- ATR 带 (中轨 SMA,带宽 ATR 倍数) ---
|
||
abp, abm = ind['atr_band_period'], ind['atr_band_mult']
|
||
atr_ab = tr.rolling(abp).mean()
|
||
ab_mid = c.rolling(abp).mean()
|
||
df['atr_band_upper'] = ab_mid + abm * atr_ab
|
||
df['atr_band_lower'] = ab_mid - abm * atr_ab
|
||
df['atr_band_mid'] = ab_mid
|
||
df['atr_band_pct'] = (c - df['atr_band_lower']) / (df['atr_band_upper'] - df['atr_band_lower']).replace(0, np.nan)
|
||
df['atr_band_width'] = (df['atr_band_upper'] - df['atr_band_lower']) / ab_mid.replace(0, np.nan)
|
||
|
||
# --- Z-Score 均值回归(价格偏离均值的标准差倍数)---
|
||
zp = ind['zscore_period']
|
||
z_mid = c.rolling(zp).mean()
|
||
z_std = c.rolling(zp).std()
|
||
df['zscore'] = (c - z_mid) / z_std.replace(0, np.nan)
|
||
df['zscore_abs'] = df['zscore'].abs()
|
||
|
||
# --- 线性回归通道 (Linear Regression Channel),向量化实现 ---
|
||
lrp = ind['lr_period']
|
||
lr_mult = ind['lr_std_mult']
|
||
sum_x = lrp * (lrp - 1) / 2.0
|
||
sum_x2 = lrp * (lrp - 1) * (2 * lrp - 1) / 6.0
|
||
s_xy = sum(c.shift(j) * j for j in range(lrp))
|
||
s_y = c.rolling(lrp).sum()
|
||
denom = lrp * sum_x2 - sum_x * sum_x
|
||
slope = (lrp * s_xy - sum_x * s_y) / (denom or 1e-10)
|
||
intercept = s_y / lrp - slope * (lrp - 1) / 2.0
|
||
lr_mid = intercept + slope * (lrp - 1)
|
||
lr_resid = c.rolling(lrp).std()
|
||
df['lr_upper'] = lr_mid + lr_mult * lr_resid
|
||
df['lr_lower'] = lr_mid - lr_mult * lr_resid
|
||
df['lr_mid'] = lr_mid
|
||
df['lr_pct'] = (c - df['lr_lower']) / (df['lr_upper'] - df['lr_lower']).replace(0, np.nan)
|
||
df['lr_width'] = (df['lr_upper'] - df['lr_lower']) / lr_mid.replace(0, np.nan)
|
||
|
||
# --- 中位数带 (Price vs Median,稳健均值回归) ---
|
||
mdp = ind['median_band_period']
|
||
med = c.rolling(mdp).median()
|
||
mstd = c.rolling(mdp).std()
|
||
df['median_band_mid'] = med
|
||
df['price_vs_median'] = (c - med) / mstd.replace(0, np.nan)
|
||
df['median_band_upper'] = med + 2 * mstd
|
||
df['median_band_lower'] = med - 2 * mstd
|
||
df['median_band_pct'] = (c - df['median_band_lower']) / (df['median_band_upper'] - df['median_band_lower']).replace(0, np.nan)
|
||
|
||
# --- 百分位排名 (Percent Rank,0~1),向量化:价格在区间内位置 ---
|
||
prp = ind['pct_rank_period']
|
||
rmin = c.rolling(prp).min()
|
||
rmax = c.rolling(prp).max()
|
||
df['pct_rank'] = (c - rmin) / (rmax - rmin).replace(0, np.nan)
|
||
|
||
# --- Chandelier Exit 通道 ---
|
||
cep, cem = ind['chandelier_period'], ind['chandelier_mult']
|
||
atr_ce = tr.rolling(cep).mean()
|
||
high_max = h.rolling(cep).max()
|
||
low_min = l.rolling(cep).min()
|
||
df['chandelier_upper'] = high_max - cem * atr_ce
|
||
df['chandelier_lower'] = low_min + cem * atr_ce
|
||
df['chandelier_mid'] = (df['chandelier_upper'] + df['chandelier_lower']) / 2
|
||
df['chandelier_pct'] = (c - df['chandelier_lower']) / (df['chandelier_upper'] - df['chandelier_lower']).replace(0, np.nan)
|
||
df['chandelier_width'] = (df['chandelier_upper'] - df['chandelier_lower']) / df['chandelier_mid'].replace(0, np.nan)
|
||
|
||
# --- 标准差带 (Std Band,与 BB 不同周期/倍数可组合) ---
|
||
sbp, sbm = ind['std_band_period'], ind['std_band_mult']
|
||
sb_mid = c.rolling(sbp).mean()
|
||
sb_std = c.rolling(sbp).std()
|
||
df['std_band_upper'] = sb_mid + sbm * sb_std
|
||
df['std_band_lower'] = sb_mid - sbm * sb_std
|
||
df['std_band_mid'] = sb_mid
|
||
df['std_band_pct'] = (c - df['std_band_lower']) / (df['std_band_upper'] - df['std_band_lower']).replace(0, np.nan)
|
||
df['std_band_width'] = (df['std_band_upper'] - df['std_band_lower']) / sb_mid.replace(0, np.nan)
|
||
|
||
# --- Elder 射线 (Elder Ray,价格与均线的偏离) ---
|
||
ep = ind['elder_period']
|
||
ema_elder = c.ewm(span=ep, adjust=False).mean()
|
||
df['elder_bull'] = (h - ema_elder) / cp
|
||
df['elder_bear'] = (l - ema_elder) / cp
|
||
df['elder_dist'] = (c - ema_elder) / cp
|
||
|
||
# --- EMA(固定若干周期作辅助)---
|
||
for p in [5, 8, 13, 21, 50, 120]:
|
||
df[f'ema_{p}'] = c.ewm(span=p, adjust=False).mean()
|
||
df['ema_fast_slow'] = (df['ema_8'] - df['ema_21']) / cp
|
||
df['price_vs_ema120'] = (c - df['ema_120']) / cp
|
||
df['ema8_slope'] = df['ema_8'].pct_change(5)
|
||
|
||
# --- MACD ---
|
||
ema12 = c.ewm(span=12, adjust=False).mean()
|
||
ema26 = c.ewm(span=26, adjust=False).mean()
|
||
df['macd'] = (ema12 - ema26) / cp
|
||
df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
|
||
df['macd_hist'] = df['macd'] - df['macd_signal']
|
||
|
||
# --- ATR 比例 ---
|
||
df['atr_14'] = tr.rolling(14).mean()
|
||
df['atr_pct'] = df['atr_14'] / cp
|
||
|
||
# --- 动量与波动 ---
|
||
for p in [1, 3, 5, 10, 20]:
|
||
df[f'ret_{p}'] = c.pct_change(p)
|
||
df['vol_5'] = c.pct_change().rolling(5).std()
|
||
df['vol_20'] = c.pct_change().rolling(20).std()
|
||
|
||
# --- K线形态与时间 ---
|
||
body = (c - o).abs()
|
||
df['body_pct'] = body / cp
|
||
df['price_position_20'] = (c - l.rolling(20).min()) / (h.rolling(20).max() - l.rolling(20).min()).replace(0, np.nan)
|
||
df['hour_sin'] = np.sin(2 * np.pi * df.index.hour / 24)
|
||
df['hour_cos'] = np.cos(2 * np.pi * df.index.hour / 24)
|
||
|
||
return df
|
||
|
||
|
||
def get_feature_cols(df: pd.DataFrame):
|
||
"""排除原始价格与各通道上下轨/中轨原始列,只保留衍生特征"""
|
||
exclude = {
|
||
'ts', 'open', 'high', 'low', 'close', 'label',
|
||
'bb_upper', 'bb_lower', 'bb_mid', 'atr_14',
|
||
'keltner_upper', 'keltner_lower', 'keltner_mid',
|
||
'donchian_upper', 'donchian_lower', 'donchian_mid',
|
||
'atr_band_upper', 'atr_band_lower', 'atr_band_mid',
|
||
'lr_upper', 'lr_lower', 'lr_mid',
|
||
'median_band_upper', 'median_band_lower', 'median_band_mid',
|
||
'chandelier_upper', 'chandelier_lower', 'chandelier_mid',
|
||
'std_band_upper', 'std_band_lower', 'std_band_mid',
|
||
'ema_5', 'ema_8', 'ema_13', 'ema_21', 'ema_50', 'ema_120',
|
||
}
|
||
return [c for c in df.columns if c not in exclude
|
||
and df[c].dtype in ('float64', 'float32', 'int64', 'int32')]
|
||
|
||
|
||
# ==================== 标签 ====================
|
||
def add_labels(df: pd.DataFrame, forward_bars: int = 10, threshold: float = 0.002) -> pd.DataFrame:
|
||
"""未来 N 根收益率 > threshold → 1(多), < -threshold → -1(空), 否则 0"""
|
||
future_ret = df['close'].shift(-forward_bars) / df['close'] - 1
|
||
df = df.copy()
|
||
df['label'] = 0
|
||
df.loc[future_ret > threshold, 'label'] = 1
|
||
df.loc[future_ret < -threshold, 'label'] = -1
|
||
return df
|
||
|
||
|
||
# ==================== 滚动训练 ====================
|
||
def train_predict_walkforward(
|
||
df: pd.DataFrame,
|
||
feature_cols: list,
|
||
train_months: int = 3,
|
||
lgb_rounds: int = 250,
|
||
):
|
||
"""滚动:用过去 train_months 月训练,预测下一个月"""
|
||
df = df.copy()
|
||
df['month'] = df.index.to_period('M')
|
||
months = sorted(df['month'].unique())
|
||
all_proba_long = pd.Series(0.0, index=df.index)
|
||
all_proba_short = pd.Series(0.0, index=df.index)
|
||
last_model = None
|
||
|
||
for i in range(train_months, len(months)):
|
||
test_month = months[i]
|
||
train_start = months[i - train_months]
|
||
train_mask = (df['month'] >= train_start) & (df['month'] < test_month)
|
||
test_mask = df['month'] == test_month
|
||
train_df = df[train_mask].dropna(subset=feature_cols + ['label'])
|
||
test_df = df[test_mask].dropna(subset=feature_cols)
|
||
if len(train_df) < 1000 or len(test_df) < 100:
|
||
continue
|
||
X_train = train_df[feature_cols].values
|
||
y_train = (train_df['label'].values + 1).astype(int) # -1,0,1 -> 0,1,2
|
||
X_test = test_df[feature_cols].values
|
||
params = {
|
||
'objective': 'multiclass',
|
||
'num_class': 3,
|
||
'metric': 'multi_logloss',
|
||
'learning_rate': 0.05,
|
||
'num_leaves': 31,
|
||
'max_depth': 6,
|
||
'min_child_samples': 50,
|
||
'subsample': 0.8,
|
||
'colsample_bytree': 0.8,
|
||
'reg_alpha': 0.1,
|
||
'reg_lambda': 0.1,
|
||
'verbose': -1,
|
||
'n_jobs': -1,
|
||
'seed': 42,
|
||
}
|
||
dtrain = lgb.Dataset(X_train, label=y_train)
|
||
model = lgb.train(params, dtrain, num_boost_round=lgb_rounds)
|
||
last_model = model
|
||
proba = model.predict(X_test) # (n, 3) -> [P(short), P(neutral), P(long)]
|
||
test_idx = test_df.index
|
||
all_proba_short.loc[test_idx] = proba[:, 0]
|
||
all_proba_long.loc[test_idx] = proba[:, 2]
|
||
|
||
return all_proba_long, all_proba_short, last_model
|
||
|
||
|
||
def train_on_period_predict_on_other(
|
||
df_train: pd.DataFrame,
|
||
df_test: pd.DataFrame,
|
||
feature_cols: list,
|
||
forward_bars: int,
|
||
label_threshold: float,
|
||
lgb_rounds: int = 250,
|
||
):
|
||
"""
|
||
在 df_train 上训练一个模型,在 df_test 上预测(严格样本外)。
|
||
df_train 需含 label 列;df_test 只需含 feature_cols。
|
||
返回 (proba_long_series, proba_short_series, model)
|
||
"""
|
||
train_df = df_train.dropna(subset=feature_cols + ['label'])
|
||
if len(train_df) < 2000:
|
||
return None, None, None
|
||
X_train = train_df[feature_cols].values
|
||
y_train = (train_df['label'].values + 1).astype(int)
|
||
test_df = df_test.dropna(subset=feature_cols)
|
||
if len(test_df) < 100:
|
||
return None, None, None
|
||
X_test = test_df[feature_cols].values
|
||
params = {
|
||
'objective': 'multiclass',
|
||
'num_class': 3,
|
||
'metric': 'multi_logloss',
|
||
'learning_rate': 0.05,
|
||
'num_leaves': 31,
|
||
'max_depth': 6,
|
||
'min_child_samples': 50,
|
||
'subsample': 0.8,
|
||
'colsample_bytree': 0.8,
|
||
'reg_alpha': 0.1,
|
||
'reg_lambda': 0.1,
|
||
'verbose': -1,
|
||
'n_jobs': -1,
|
||
'seed': 42,
|
||
}
|
||
dtrain = lgb.Dataset(X_train, label=y_train)
|
||
model = lgb.train(params, dtrain, num_boost_round=lgb_rounds)
|
||
proba = model.predict(X_test)
|
||
all_proba_long = pd.Series(0.0, index=df_test.index)
|
||
all_proba_short = pd.Series(0.0, index=df_test.index)
|
||
all_proba_short.loc[test_df.index] = proba[:, 0]
|
||
all_proba_long.loc[test_df.index] = proba[:, 2]
|
||
return all_proba_long, all_proba_short, model
|
||
|
||
|
||
# ==================== 回测(单仓位、100U×100倍) ====================
|
||
def backtest(
|
||
df: pd.DataFrame,
|
||
proba_long: pd.Series,
|
||
proba_short: pd.Series,
|
||
notional: float = NOTIONAL_PER_TRADE,
|
||
prob_threshold: float = 0.45,
|
||
min_hold_seconds: int = 180,
|
||
max_hold_seconds: int = 1800,
|
||
sl_pct: float = 0.004,
|
||
tp_pct: float = 0.006,
|
||
) -> list:
|
||
"""同一时间仅 1 仓,开仓即 10000U 名义(100U×100倍)"""
|
||
pos = 0
|
||
open_price = 0.0
|
||
open_time = None
|
||
trades = []
|
||
|
||
for i in range(len(df)):
|
||
dt = df.index[i]
|
||
price = df['close'].iloc[i]
|
||
pl = proba_long.iloc[i]
|
||
ps = proba_short.iloc[i]
|
||
|
||
if pos != 0 and open_time is not None:
|
||
pnl_pct = (price - open_price) / open_price if pos == 1 else (open_price - price) / open_price
|
||
hold_sec = (dt - open_time).total_seconds()
|
||
|
||
# 硬止损
|
||
if -pnl_pct >= sl_pct * 1.5:
|
||
pnl_usdt = notional * pnl_pct
|
||
trades.append((pos, open_price, price, pnl_usdt, hold_sec, '硬止损', open_time, dt))
|
||
pos = 0
|
||
continue
|
||
if hold_sec >= min_hold_seconds:
|
||
if -pnl_pct >= sl_pct:
|
||
pnl_usdt = notional * pnl_pct
|
||
trades.append((pos, open_price, price, pnl_usdt, hold_sec, '止损', open_time, dt))
|
||
pos = 0
|
||
continue
|
||
if pnl_pct >= tp_pct:
|
||
pnl_usdt = notional * pnl_pct
|
||
trades.append((pos, open_price, price, pnl_usdt, hold_sec, '止盈', open_time, dt))
|
||
pos = 0
|
||
continue
|
||
if hold_sec >= max_hold_seconds:
|
||
pnl_usdt = notional * pnl_pct
|
||
trades.append((pos, open_price, price, pnl_usdt, hold_sec, '超时', open_time, dt))
|
||
pos = 0
|
||
continue
|
||
# AI 反向信号平仓
|
||
if pos == 1 and ps > prob_threshold + 0.05:
|
||
pnl_usdt = notional * pnl_pct
|
||
trades.append((pos, open_price, price, pnl_usdt, hold_sec, 'AI反转', open_time, dt))
|
||
pos = 0
|
||
elif pos == -1 and pl > prob_threshold + 0.05:
|
||
pnl_usdt = notional * pnl_pct
|
||
trades.append((pos, open_price, price, pnl_usdt, hold_sec, 'AI反转', open_time, dt))
|
||
pos = 0
|
||
|
||
if pos == 0:
|
||
if pl > prob_threshold and pl > ps:
|
||
pos = 1
|
||
open_price = price
|
||
open_time = dt
|
||
elif ps > prob_threshold and ps > pl:
|
||
pos = -1
|
||
open_price = price
|
||
open_time = dt
|
||
|
||
if pos != 0:
|
||
price = df['close'].iloc[-1]
|
||
dt = df.index[-1]
|
||
pnl_pct = (price - open_price) / open_price if pos == 1 else (open_price - price) / open_price
|
||
hold_sec = (dt - open_time).total_seconds()
|
||
trades.append((pos, open_price, price, notional * pnl_pct, hold_sec, '结束', open_time, dt))
|
||
|
||
return trades
|
||
|
||
|
||
# ==================== 结果分析 ====================
|
||
def analyze_trades(trades: list, notional: float = NOTIONAL_PER_TRADE, initial_balance: float = INITIAL_BALANCE) -> dict:
|
||
"""
|
||
统计净利、胜率、回撤、月均净利等。
|
||
回撤按「可用资金」计算:手续费当日扣全额,返佣次日 8 点才到账,故 dd_available 更保守。
|
||
"""
|
||
if not trades:
|
||
return {'n': 0, 'net': 0.0, 'wr': 0.0, 'dd': 0.0, 'dd_available': 0.0, 'total_pnl': 0.0, 'monthly_net': 0.0, 'months': 0}
|
||
n = len(trades)
|
||
total_pnl = sum(t[3] for t in trades)
|
||
net = total_pnl - NET_FEE_PER_TRADE * n
|
||
wins = len([t for t in trades if t[3] > 0])
|
||
wr = wins / n * 100
|
||
# 简单回撤(净手续费已扣,不含返佣延迟)
|
||
cum = 0.0
|
||
peak = 0.0
|
||
dd = 0.0
|
||
for t in trades:
|
||
cum += t[3] - NET_FEE_PER_TRADE
|
||
if cum > peak:
|
||
peak = cum
|
||
if peak - cum > dd:
|
||
dd = peak - cum
|
||
# 可用资金回撤:平仓时扣全额手续费,返佣次日 8 点才到账(简化为 +1 天)
|
||
events = [] # (datetime, delta_balance)
|
||
for t in trades:
|
||
close_time = t[7]
|
||
events.append((close_time, t[3] - FULL_FEE_PER_TRADE))
|
||
try:
|
||
next_day = close_time + datetime.timedelta(days=1)
|
||
except Exception:
|
||
next_day = close_time
|
||
events.append((next_day, REBATE_PER_TRADE))
|
||
events.sort(key=lambda x: x[0])
|
||
balance = initial_balance
|
||
peak_bal = balance
|
||
dd_available = 0.0
|
||
for _, delta in events:
|
||
balance += delta
|
||
if balance > peak_bal:
|
||
peak_bal = balance
|
||
if peak_bal - balance > dd_available:
|
||
dd_available = peak_bal - balance
|
||
# 按平仓时间所在月汇总净利
|
||
monthly_net = defaultdict(float)
|
||
for t in trades:
|
||
close_time = t[7]
|
||
month_key = close_time.strftime('%Y-%m') if hasattr(close_time, 'strftime') else str(close_time)[:7]
|
||
monthly_net[month_key] += t[3] - NET_FEE_PER_TRADE
|
||
num_months = len(monthly_net) or 1
|
||
avg_monthly_net = sum(monthly_net.values()) / num_months
|
||
return {
|
||
'n': n,
|
||
'net': net,
|
||
'wr': wr,
|
||
'dd': dd,
|
||
'dd_available': dd_available,
|
||
'total_pnl': total_pnl,
|
||
'avg_pnl': net / n,
|
||
'monthly_net': avg_monthly_net,
|
||
'months': num_months,
|
||
'monthly_detail': dict(monthly_net),
|
||
}
|
||
|
||
|
||
TARGET_MONTHLY_NET = 1000.0 # 目标月均净利 (USDT)
|
||
|
||
def print_report(trades: list, label: str = ""):
|
||
"""打印回测报告:本金、可用资金最大回撤(返佣次日到账)、月均净利及双目标"""
|
||
if not trades:
|
||
print(f" [{label}] 无交易", flush=True)
|
||
return
|
||
r = analyze_trades(trades)
|
||
reasons = defaultdict(int)
|
||
for t in trades:
|
||
reasons[t[5]] += 1
|
||
dd_ok = r['dd_available'] <= MAX_DD_TARGET
|
||
monthly_ok = r['monthly_net'] >= TARGET_MONTHLY_NET
|
||
print(f"\n === {label} ===", flush=True)
|
||
print(f" 本金: {INITIAL_BALANCE:.0f} USDT | 交易: {r['n']} 笔 | 总净利: {r['net']:+.2f} USDT | 胜率: {r['wr']:.1f}%", flush=True)
|
||
print(f" 可用资金最大回撤: {r['dd_available']:.2f} USDT (返佣次日到账) | 目标 ≤{MAX_DD_TARGET:.0f}U: {'达标' if dd_ok else '未达标'}", flush=True)
|
||
print(f" 月均净利: {r['monthly_net']:+.2f} USDT (共 {r['months']} 个月) | 目标 ≥{TARGET_MONTHLY_NET:.0f}U/月: {'达标' if monthly_ok else '未达标'}", flush=True)
|
||
print(f" 平仓原因: {dict(reasons)}", flush=True)
|
||
|
||
|
||
# ==================== 模型保存 ====================
|
||
def save_model_and_params(model, feature_cols: list, params: dict, path: Path = None):
|
||
"""保存 LightGBM 模型和策略参数"""
|
||
if path is None:
|
||
path = Path(__file__).parent / 'models' / 'eth_ai_strategy'
|
||
path = Path(path)
|
||
path.mkdir(parents=True, exist_ok=True)
|
||
model.save_model(str(path / 'model.txt'))
|
||
with open(path / 'feature_cols.json', 'w', encoding='utf-8') as f:
|
||
json.dump(feature_cols, f, ensure_ascii=False, indent=2)
|
||
with open(path / 'strategy_params.json', 'w', encoding='utf-8') as f:
|
||
json.dump(params, f, ensure_ascii=False, indent=2)
|
||
print(f" 模型与参数已保存至: {path}", flush=True)
|
||
|
||
|
||
# ==================== 主流程:2020 训练 + 2021 真实回测 ====================
|
||
def run_single(
|
||
df_train: pd.DataFrame,
|
||
df_test: pd.DataFrame,
|
||
ind_params: dict,
|
||
forward_bars: int,
|
||
label_threshold: float,
|
||
prob_threshold: float,
|
||
sl_pct: float,
|
||
tp_pct: float,
|
||
min_hold: int,
|
||
max_hold: int,
|
||
):
|
||
"""
|
||
用 df_train(2020)训练模型,在 df_test(2021)上预测并回测,严格样本外。
|
||
返回 (trades_2021, model, feature_cols)。
|
||
"""
|
||
df_tr = add_features(df_train.copy(), ind=ind_params)
|
||
df_te = add_features(df_test.copy(), ind=ind_params)
|
||
feature_cols = get_feature_cols(df_tr)
|
||
df_tr = add_labels(df_tr, forward_bars=forward_bars, threshold=label_threshold)
|
||
proba_long, proba_short, model = train_on_period_predict_on_other(
|
||
df_tr, df_te, feature_cols, forward_bars, label_threshold,
|
||
)
|
||
if model is None:
|
||
return [], None, None
|
||
trades = backtest(
|
||
df_te, proba_long, proba_short,
|
||
notional=NOTIONAL_PER_TRADE,
|
||
prob_threshold=prob_threshold,
|
||
min_hold_seconds=min_hold,
|
||
max_hold_seconds=max_hold,
|
||
sl_pct=sl_pct,
|
||
tp_pct=tp_pct,
|
||
)
|
||
return trades, model, feature_cols
|
||
|
||
|
||
def _grid_indicator_and_strategy(max_configs: int = 220, seed: int = 42):
|
||
"""随机采样「各类均值回归/带型指标参数 + 策略参数」组合,各种组合、各种指标参数一起调"""
|
||
rng = np.random.default_rng(seed)
|
||
ind_base = default_indicator_params()
|
||
# 布林类
|
||
bb_opts = [(20, 2.0), (20, 2.5), (30, 2.0), (15, 2.0), (14, 2.0)]
|
||
keltner_opts = [(20, 2.0), (20, 2.5), (14, 2.0), (22, 2.5)]
|
||
donchian_opts = [14, 20, 30]
|
||
stoch_opts = [(14, 3), (10, 3), (20, 5)]
|
||
cci_opts = [14, 20]
|
||
rsi_opts = [7, 14, 21]
|
||
atr_band_opts = [(20, 2.0), (14, 2.0), (22, 2.5)]
|
||
# 新增均值回归类
|
||
zscore_opts = [14, 20, 30]
|
||
lr_opts = [(20, 2.0), (14, 2.0), (30, 2.5)]
|
||
median_band_opts = [14, 20, 30]
|
||
pct_rank_opts = [14, 20, 30]
|
||
chandelier_opts = [(22, 3.0), (14, 2.5), (30, 3.5)]
|
||
std_band_opts = [(14, 2.0), (20, 2.0), (14, 2.5)]
|
||
elder_opts = [10, 13, 20]
|
||
# 策略:加入更保守选项以压低回撤(紧止损、高置信)
|
||
label_opts = [(8, 0.0015), (10, 0.002), (10, 0.003), (15, 0.002), (20, 0.003)]
|
||
sl_tp_opts = [
|
||
(0.002, 0.004), (0.0025, 0.005), (0.003, 0.005), (0.003, 0.006),
|
||
(0.004, 0.006), (0.004, 0.008), (0.005, 0.008), (0.005, 0.010),
|
||
]
|
||
prob_opts = [0.42, 0.45, 0.48, 0.50]
|
||
out = []
|
||
for _ in range(max_configs):
|
||
bb_p, bb_s = bb_opts[rng.integers(0, len(bb_opts))]
|
||
kp, ka = keltner_opts[rng.integers(0, len(keltner_opts))]
|
||
dc = donchian_opts[rng.integers(0, len(donchian_opts))]
|
||
sk, sd = stoch_opts[rng.integers(0, len(stoch_opts))]
|
||
cci_p = cci_opts[rng.integers(0, len(cci_opts))]
|
||
rsi_p = rsi_opts[rng.integers(0, len(rsi_opts))]
|
||
abp, abm = atr_band_opts[rng.integers(0, len(atr_band_opts))]
|
||
zp = zscore_opts[rng.integers(0, len(zscore_opts))]
|
||
lrp, lrm = lr_opts[rng.integers(0, len(lr_opts))]
|
||
mdp = median_band_opts[rng.integers(0, len(median_band_opts))]
|
||
prp = pct_rank_opts[rng.integers(0, len(pct_rank_opts))]
|
||
cep, cem = chandelier_opts[rng.integers(0, len(chandelier_opts))]
|
||
sbp, sbm = std_band_opts[rng.integers(0, len(std_band_opts))]
|
||
ep = elder_opts[rng.integers(0, len(elder_opts))]
|
||
fb, th = label_opts[rng.integers(0, len(label_opts))]
|
||
sl, tp = sl_tp_opts[rng.integers(0, len(sl_tp_opts))]
|
||
prob = prob_opts[rng.integers(0, len(prob_opts))]
|
||
ind = dict(ind_base,
|
||
bb_period=bb_p, bb_std=bb_s,
|
||
keltner_period=kp, keltner_atr_mult=ka,
|
||
donchian_period=dc,
|
||
stoch_k_period=sk, stoch_d_period=sd,
|
||
cci_period=cci_p, rsi_period=rsi_p,
|
||
atr_band_period=abp, atr_band_mult=abm,
|
||
zscore_period=zp,
|
||
lr_period=lrp, lr_std_mult=lrm,
|
||
median_band_period=mdp,
|
||
pct_rank_period=prp,
|
||
chandelier_period=cep, chandelier_mult=cem,
|
||
std_band_period=sbp, std_band_mult=sbm,
|
||
elder_period=ep,
|
||
)
|
||
out.append({
|
||
'ind': ind,
|
||
'forward_bars': fb,
|
||
'label_threshold': th,
|
||
'sl_pct': sl,
|
||
'tp_pct': tp,
|
||
'prob_threshold': prob,
|
||
'min_hold': 180,
|
||
'max_hold': 1800,
|
||
})
|
||
return out
|
||
|
||
|
||
def run_cycle_compare(periods: list, do_save_best: bool = True):
|
||
"""
|
||
多周期对比:对每个周期用同一套默认参数做 2020 训练 + 2021 回测,比较 2021 年月均净利等。
|
||
periods: 如 [5, 15] 表示 5m、15m;或 [1, 3, 5, 15, 30, 60] 表示全部周期。
|
||
"""
|
||
t0 = _time.time()
|
||
period_labels = [_normalize_period(p) for p in periods]
|
||
print("=" * 60, flush=True)
|
||
print(" ETH 合约 — 多周期回测对比 (2020 训练 / 2021 回测)", flush=True)
|
||
print(f" 参与周期: {', '.join(period_labels)}", flush=True)
|
||
print("=" * 60, flush=True)
|
||
ind = default_indicator_params()
|
||
results = []
|
||
best_net = -1e9
|
||
best_row = None
|
||
for i, period in enumerate(period_labels):
|
||
print(f"\n [{i+1}/{len(period_labels)}] 周期 {period} ...", flush=True)
|
||
try:
|
||
df_2020 = load_klines(TRAIN_START, TRAIN_END_EXCLUSIVE, period=period)
|
||
df_2021 = load_klines(TEST_START, TEST_END_EXCLUSIVE, period=period)
|
||
except Exception as e:
|
||
print(f" 跳过 {period}: {e}", flush=True)
|
||
results.append({'period': period, 'ok': False, 'error': str(e)})
|
||
continue
|
||
print(f" 数据: 2020 {len(df_2020):,} 根 | 2021 {len(df_2021):,} 根", flush=True)
|
||
trades, model, feature_cols = run_single(
|
||
df_2020, df_2021, ind,
|
||
forward_bars=10, label_threshold=0.002,
|
||
prob_threshold=0.45, sl_pct=0.004, tp_pct=0.006,
|
||
min_hold=180, max_hold=1800,
|
||
)
|
||
if not trades or model is None:
|
||
print(f" {period}: 无有效交易", flush=True)
|
||
results.append({'period': period, 'ok': False})
|
||
continue
|
||
r = analyze_trades(trades)
|
||
results.append({
|
||
'period': period,
|
||
'ok': True,
|
||
'n': r['n'],
|
||
'net': r['net'],
|
||
'monthly_net': r['monthly_net'],
|
||
'wr': r['wr'],
|
||
'dd': r['dd'],
|
||
'dd_available': r['dd_available'],
|
||
'trades': trades,
|
||
'model': model,
|
||
'feature_cols': feature_cols,
|
||
})
|
||
print(f" {period}: 交易 {r['n']} 笔 | 月均 {r['monthly_net']:+.2f} USDT | 可用回撤 {r['dd_available']:.2f} USDT | 胜率 {r['wr']:.1f}%", flush=True)
|
||
if r['monthly_net'] > best_net:
|
||
best_net = r['monthly_net']
|
||
best_row = results[-1]
|
||
# 打印对比表
|
||
ok_results = [x for x in results if x.get('ok')]
|
||
print("\n" + "=" * 60, flush=True)
|
||
print(" 多周期对比结果 (2021 年样本外)", flush=True)
|
||
print("=" * 60, flush=True)
|
||
if not ok_results:
|
||
print(" 无有效回测结果。", flush=True)
|
||
return
|
||
print(f" {'周期':<6} {'交易数':>8} {'总净利':>12} {'月均净利':>12} {'可用回撤':>10} {'胜率':>8}", flush=True)
|
||
print(" " + "-" * 62, flush=True)
|
||
for x in ok_results:
|
||
print(f" {x['period']:<6} {x['n']:>8} {x['net']:>+12.2f} {x['monthly_net']:>+12.2f} {x.get('dd_available', x['dd']):>10.2f} {x['wr']:>7.1f}%", flush=True)
|
||
# 优选:回撤≤500 且 月盈≥1000;否则月均最高
|
||
def _cmp_cycle(a):
|
||
dd_ok = a.get('dd_available', a['dd']) <= MAX_DD_TARGET
|
||
mo_ok = a['monthly_net'] >= TARGET_MONTHLY_NET
|
||
return (dd_ok, mo_ok, a['monthly_net'], -a.get('dd_available', a['dd']))
|
||
best = max(ok_results, key=_cmp_cycle)
|
||
best_row = best
|
||
print(" " + "-" * 62, flush=True)
|
||
b_dd = best.get('dd_available', best['dd'])
|
||
print(f" 最佳周期: {best['period']} (月均 {best['monthly_net']:+.2f} USDT, 可用回撤 {b_dd:.2f} USDT)", flush=True)
|
||
if b_dd <= MAX_DD_TARGET and best['monthly_net'] >= TARGET_MONTHLY_NET:
|
||
print(f" 双目标达标: 回撤≤{MAX_DD_TARGET:.0f}U 且 月盈≥{TARGET_MONTHLY_NET:.0f}U", flush=True)
|
||
if do_save_best and best_row and best_row.get('model') is not None:
|
||
flat_params = {
|
||
**ind,
|
||
'forward_bars': 10, 'label_threshold': 0.002,
|
||
'prob_threshold': 0.45, 'sl_pct': 0.004, 'tp_pct': 0.006,
|
||
'min_hold_seconds': 180, 'max_hold_seconds': 1800,
|
||
'margin_per_trade': MARGIN_PER_TRADE, 'leverage': LEVERAGE,
|
||
'notional': NOTIONAL_PER_TRADE, 'rebate_rate': REBATE_RATE,
|
||
'train_period': f'{TRAIN_START} ~ {TRAIN_END_EXCLUSIVE}',
|
||
'test_period': f'{TEST_START} ~ {TEST_END_EXCLUSIVE}',
|
||
'kline_period': best['period'],
|
||
'initial_balance': INITIAL_BALANCE,
|
||
'max_dd_target': MAX_DD_TARGET,
|
||
}
|
||
save_model_and_params(best_row['model'], best_row.get('feature_cols', []), flat_params)
|
||
print(f" 已保存最佳周期 {best['period']} 的模型与参数。", flush=True)
|
||
print(f"\n 总耗时: {_time.time() - t0:.1f}s", flush=True)
|
||
print("=" * 60, flush=True)
|
||
|
||
|
||
def main(do_grid_search: bool = True, period: str = '1m'):
|
||
"""
|
||
2020 年全年训练,2021 年全年真实回测(严格样本外)。
|
||
period: K 线周期 '1m'|'5m'|'15m' 等。
|
||
参数搜索时在 2021 年月均净利上选最佳并保存。
|
||
"""
|
||
period = _normalize_period(period)
|
||
t0 = _time.time()
|
||
print("=" * 60, flush=True)
|
||
print(f" ETH 合约 AI 策略 — 2020 训练 / 2021 真实回测 | K线周期 {period} | 单仓 100U×100倍 | 90% 返佣", flush=True)
|
||
print(f" 本金: {INITIAL_BALANCE:.0f} USDT | 返佣次日 8 点到账,回撤按可用资金计算", flush=True)
|
||
print(f" 目标: 可用资金最大回撤 ≤ {MAX_DD_TARGET:.0f} USDT,月均净利 ≥ {TARGET_MONTHLY_NET:.0f} USDT", flush=True)
|
||
print("=" * 60, flush=True)
|
||
|
||
print(f"\n[1/4] 加载 K 线 (周期 {period}, 训练 2020 / 回测 2021)...", flush=True)
|
||
df_2020 = load_klines(TRAIN_START, TRAIN_END_EXCLUSIVE, period=period)
|
||
df_2021 = load_klines(TEST_START, TEST_END_EXCLUSIVE, period=period)
|
||
print(f" 2020 训练: {len(df_2020):,} 根 | 2021 回测: {len(df_2021):,} 根", flush=True)
|
||
if len(df_2020) < 10000:
|
||
print(" 警告: 2020 年数据不足 10000 根,请先运行 抓取多周期K线.py 拉取 2020 年数据。", flush=True)
|
||
if len(df_2021) < 1000:
|
||
print(" 警告: 2021 年数据不足,回测结果可能不可靠。", flush=True)
|
||
|
||
# 优选:回撤≤500 且 月盈≥1000;其次回撤≤500;再其次月盈≥1000;否则取综合最优
|
||
def _score(r):
|
||
dd_ok = r['dd_available'] <= MAX_DD_TARGET
|
||
monthly_ok = r['monthly_net'] >= TARGET_MONTHLY_NET
|
||
return (dd_ok, monthly_ok, r['monthly_net'], -r['dd_available'])
|
||
|
||
best_score = (-1, -1, -1e9, 1e9)
|
||
best_result = None
|
||
|
||
if do_grid_search:
|
||
configs = list(_grid_indicator_and_strategy())
|
||
print(f"\n[2/4] 参数搜索 (共 {len(configs)} 组): 目标 回撤≤{MAX_DD_TARGET:.0f}U 且 月盈≥{TARGET_MONTHLY_NET:.0f}U ...", flush=True)
|
||
for i, cfg in enumerate(configs):
|
||
_t_start = _time.time()
|
||
trades, model, feature_cols = run_single(
|
||
df_2020,
|
||
df_2021,
|
||
cfg['ind'],
|
||
cfg['forward_bars'],
|
||
cfg['label_threshold'],
|
||
cfg['prob_threshold'],
|
||
cfg['sl_pct'],
|
||
cfg['tp_pct'],
|
||
cfg['min_hold'],
|
||
cfg['max_hold'],
|
||
)
|
||
_elapsed = _time.time() - _t_start
|
||
if not trades or model is None:
|
||
if (i + 1) % 10 == 0 or i == 0:
|
||
print(f" 组 {i+1}/{len(configs)} 完成 (本组 {_elapsed:.0f}s,无有效交易)", flush=True)
|
||
continue
|
||
r = analyze_trades(trades)
|
||
score = _score(r)
|
||
if score > best_score:
|
||
best_score = score
|
||
best_result = (trades, model, feature_cols, {
|
||
'indicator_params': cfg['ind'],
|
||
'forward_bars': cfg['forward_bars'],
|
||
'label_threshold': cfg['label_threshold'],
|
||
'prob_threshold': cfg['prob_threshold'],
|
||
'sl_pct': cfg['sl_pct'],
|
||
'tp_pct': cfg['tp_pct'],
|
||
'min_hold_seconds': cfg['min_hold'],
|
||
'max_hold_seconds': cfg['max_hold'],
|
||
'margin_per_trade': MARGIN_PER_TRADE,
|
||
'leverage': LEVERAGE,
|
||
'notional': NOTIONAL_PER_TRADE,
|
||
'rebate_rate': REBATE_RATE,
|
||
'train_period': f'{TRAIN_START} ~ {TRAIN_END_EXCLUSIVE}',
|
||
'test_period': f'{TEST_START} ~ {TEST_END_EXCLUSIVE}',
|
||
'kline_period': period,
|
||
'initial_balance': INITIAL_BALANCE,
|
||
'max_dd_target': MAX_DD_TARGET,
|
||
})
|
||
_total_so_far = _time.time() - t0
|
||
_avg_per = _total_so_far / (i + 1)
|
||
_left = _avg_per * (len(configs) - i - 1)
|
||
if (i + 1) % 10 == 0 or i == 0:
|
||
print(f" 组 {i+1}/{len(configs)} 完成 | 本组 {_elapsed:.0f}s | 回撤 {r['dd_available']:.0f} 月均 {r['monthly_net']:+.0f} | 预计剩余 ~{_left/60:.0f}min", flush=True)
|
||
else:
|
||
print("\n[2/4] 使用默认参数: 2020 训练 -> 2021 回测...", flush=True)
|
||
ind = default_indicator_params()
|
||
trades, model, feature_cols = run_single(
|
||
df_2020, df_2021, ind,
|
||
forward_bars=10, label_threshold=0.002,
|
||
prob_threshold=0.45, sl_pct=0.004, tp_pct=0.006,
|
||
min_hold=180, max_hold=1800,
|
||
)
|
||
if trades and model is not None:
|
||
r = analyze_trades(trades)
|
||
best_monthly = r['monthly_net']
|
||
reached_target = r['monthly_net'] >= TARGET_MONTHLY_NET
|
||
best_result = (trades, model, feature_cols, {
|
||
'indicator_params': ind,
|
||
'forward_bars': 10,
|
||
'label_threshold': 0.002,
|
||
'prob_threshold': 0.45,
|
||
'sl_pct': 0.004,
|
||
'tp_pct': 0.006,
|
||
'min_hold_seconds': 180,
|
||
'max_hold_seconds': 1800,
|
||
'margin_per_trade': MARGIN_PER_TRADE,
|
||
'leverage': LEVERAGE,
|
||
'notional': NOTIONAL_PER_TRADE,
|
||
'rebate_rate': REBATE_RATE,
|
||
'train_period': f'{TRAIN_START} ~ {TRAIN_END_EXCLUSIVE}',
|
||
'test_period': f'{TEST_START} ~ {TEST_END_EXCLUSIVE}',
|
||
'kline_period': period,
|
||
'initial_balance': INITIAL_BALANCE,
|
||
'max_dd_target': MAX_DD_TARGET,
|
||
})
|
||
|
||
if best_result is None:
|
||
print(" 未得到有效模型/交易(可能 2020 或 2021 数据不足)。", flush=True)
|
||
return
|
||
trades, model, feature_cols, full_params = best_result
|
||
print("\n[3/4] 2021 年真实回测结果 (最佳参数)...", flush=True)
|
||
print_report(trades, "2021 年样本外回测")
|
||
flat_params = {**full_params.get('indicator_params', {}), **{k: v for k, v in full_params.items() if k != 'indicator_params'}}
|
||
save_model_and_params(model, feature_cols, flat_params)
|
||
r = analyze_trades(trades)
|
||
elapsed = _time.time() - t0
|
||
dd_ok = r['dd_available'] <= MAX_DD_TARGET
|
||
monthly_ok = r['monthly_net'] >= TARGET_MONTHLY_NET
|
||
print(f"\n[4/4] 总耗时: {elapsed:.1f}s | 可用资金回撤: {r['dd_available']:.2f} USDT (目标≤{MAX_DD_TARGET:.0f}): {'达标' if dd_ok else '未达标'} | 月均净利: {r['monthly_net']:+.2f} USDT (目标≥{TARGET_MONTHLY_NET:.0f}): {'达标' if monthly_ok else '未达标'}", flush=True)
|
||
print("=" * 60, flush=True)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
import sys
|
||
argv = sys.argv
|
||
# 多周期对比:默认只比 5m 和 15m;加 --all 则比 1m,3m,5m,15m,30m,1h
|
||
if '--compare' in argv or '-c' in argv:
|
||
periods = [5, 15] if '--all' not in argv else [1, 3, 5, 15, 30, 60]
|
||
run_cycle_compare(periods, do_save_best=True)
|
||
else:
|
||
# 单周期:--period 5m 或 --period 15 指定周期,否则默认 1m
|
||
period = '1m'
|
||
for i, a in enumerate(argv):
|
||
if a in ('--period', '-p') and i + 1 < len(argv):
|
||
period = argv[i + 1]
|
||
break
|
||
do_search = '--no-search' not in argv
|
||
main(do_grid_search=do_search, period=period)
|