Files
codex_jxs_code/strategy/train.py
2026-02-23 04:09:34 +08:00

227 lines
8.4 KiB
Python

"""
Optuna 训练入口 - 在 2020-2022 数据上搜索最优参数
"""
import json
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import optuna
from optuna.samplers import TPESampler
import numpy as np
from strategy.data_loader import load_klines
from strategy.indicators import compute_all_indicators
from strategy.strategy_signal import (
generate_indicator_signals, compute_composite_score,
apply_htf_filter, WEIGHT_KEYS,
)
from strategy.backtest_engine import BacktestEngine
# ============================================================
# 全局加载数据 (只加载一次)
# ============================================================
print("正在加载 2020-2022 训练数据...")
DF_5M = load_klines('5m', '2020-01-01', '2023-01-01')
DF_1H = load_klines('1h', '2020-01-01', '2023-01-01')
print(f" 5m: {len(DF_5M)} 条, 1h: {len(DF_1H)}")
print("数据加载完成。\n")
def build_params(trial: optuna.Trial) -> dict:
"""从 Optuna trial 构建完整参数字典"""
p = {}
# --- 指标参数 ---
p['bb_period'] = trial.suggest_int('bb_period', 10, 50)
p['bb_std'] = trial.suggest_float('bb_std', 1.0, 3.5, step=0.1)
p['kc_period'] = trial.suggest_int('kc_period', 10, 50)
p['kc_mult'] = trial.suggest_float('kc_mult', 0.5, 3.0, step=0.1)
p['dc_period'] = trial.suggest_int('dc_period', 10, 50)
p['ema_fast'] = trial.suggest_int('ema_fast', 3, 20)
p['ema_slow'] = trial.suggest_int('ema_slow', 15, 60)
p['macd_fast'] = trial.suggest_int('macd_fast', 6, 20)
p['macd_slow'] = trial.suggest_int('macd_slow', 18, 40)
p['macd_signal'] = trial.suggest_int('macd_signal', 5, 15)
p['adx_period'] = trial.suggest_int('adx_period', 7, 30)
p['st_period'] = trial.suggest_int('st_period', 5, 20)
p['st_mult'] = trial.suggest_float('st_mult', 1.0, 5.0, step=0.1)
p['rsi_period'] = trial.suggest_int('rsi_period', 7, 28)
p['stoch_k'] = trial.suggest_int('stoch_k', 5, 21)
p['stoch_d'] = trial.suggest_int('stoch_d', 2, 7)
p['stoch_smooth'] = trial.suggest_int('stoch_smooth', 2, 7)
p['cci_period'] = trial.suggest_int('cci_period', 10, 40)
p['wr_period'] = trial.suggest_int('wr_period', 7, 28)
p['wma_period'] = trial.suggest_int('wma_period', 10, 50)
# --- 信号阈值参数 ---
p['bb_oversold'] = trial.suggest_float('bb_oversold', -0.3, 0.3, step=0.05)
p['bb_overbought'] = trial.suggest_float('bb_overbought', 0.7, 1.3, step=0.05)
p['kc_oversold'] = trial.suggest_float('kc_oversold', -0.3, 0.3, step=0.05)
p['kc_overbought'] = trial.suggest_float('kc_overbought', 0.7, 1.3, step=0.05)
p['dc_oversold'] = trial.suggest_float('dc_oversold', 0.0, 0.3, step=0.05)
p['dc_overbought'] = trial.suggest_float('dc_overbought', 0.7, 1.0, step=0.05)
p['adx_threshold'] = trial.suggest_float('adx_threshold', 15, 35, step=1)
p['rsi_overbought'] = trial.suggest_float('rsi_overbought', 60, 85, step=1)
p['rsi_oversold'] = trial.suggest_float('rsi_oversold', 15, 40, step=1)
p['stoch_overbought'] = trial.suggest_float('stoch_overbought', 70, 90, step=1)
p['stoch_oversold'] = trial.suggest_float('stoch_oversold', 10, 30, step=1)
p['cci_overbought'] = trial.suggest_float('cci_overbought', 80, 200, step=5)
p['cci_oversold'] = trial.suggest_float('cci_oversold', -200, -80, step=5)
p['wr_overbought'] = trial.suggest_float('wr_overbought', -30, -10, step=1)
p['wr_oversold'] = trial.suggest_float('wr_oversold', -90, -70, step=1)
# --- 权重 ---
for wk in WEIGHT_KEYS:
p[wk] = trial.suggest_float(wk, 0.0, 1.0, step=0.05)
# --- 回测参数 ---
p['open_threshold'] = trial.suggest_float('open_threshold', 0.1, 0.6, step=0.02)
p['max_positions'] = trial.suggest_int('max_positions', 1, 3)
p['take_profit_pct'] = trial.suggest_float('take_profit_pct', 0.003, 0.025, step=0.001)
# 止损约束: N单同时止损 + 手续费 <= 50U
# N * 1250 * sl_pct + N * 1.25 <= 50
# sl_pct <= (50 - N*1.25) / (N*1250)
n = p['max_positions']
max_sl = (50.0 - n * 1.25) / (n * 1250.0)
max_sl = round(max(max_sl, 0.002), 3) # 至少 0.2%
p['stop_loss_pct'] = trial.suggest_float('stop_loss_pct', 0.002, max_sl, step=0.001)
return p
def objective(trial: optuna.Trial) -> float:
params = build_params(trial)
# 确保 ema_slow > ema_fast, macd_slow > macd_fast
if params['ema_slow'] <= params['ema_fast']:
return -1e6
if params['macd_slow'] <= params['macd_fast']:
return -1e6
try:
# 计算指标
df_5m = compute_all_indicators(DF_5M, params)
df_1h = compute_all_indicators(DF_1H, params)
# 生成信号
df_5m = generate_indicator_signals(df_5m, params)
df_1h = generate_indicator_signals(df_1h, params)
# 综合得分
score = compute_composite_score(df_5m, params)
# 高时间框架过滤
score = apply_htf_filter(score, df_1h, params)
# 回测
engine = BacktestEngine(
initial_capital=1000.0,
margin_per_trade=25.0,
leverage=50,
fee_rate=0.0005,
rebate_ratio=0.70,
max_daily_drawdown=50.0,
min_hold_bars=1,
stop_loss_pct=params['stop_loss_pct'],
take_profit_pct=params['take_profit_pct'],
max_positions=params['max_positions'],
)
result = engine.run(df_5m, score, open_threshold=params['open_threshold'])
num_trades = result['num_trades']
if num_trades < 50:
return -1e6 # 交易次数太少,不可靠
total_pnl = result['total_pnl']
max_dd = result['max_daily_dd'] # 负数 (引擎已保证 >= -50)
avg_daily = result['avg_daily_pnl']
# 引擎内部已经有每日 50U 回撤熔断,这里不再硬约束
# 目标: 最大化总收益
score_val = total_pnl
# 奖励日均收益高的方案
if avg_daily >= 50:
score_val *= 1.3
elif avg_daily >= 30:
score_val *= 1.15
trial.set_user_attr('total_pnl', total_pnl)
trial.set_user_attr('num_trades', num_trades)
trial.set_user_attr('win_rate', result['win_rate'])
trial.set_user_attr('max_daily_dd', max_dd)
trial.set_user_attr('avg_daily_pnl', avg_daily)
trial.set_user_attr('profit_factor', result['profit_factor'])
return score_val
except Exception as e:
print(f"Trial {trial.number} 异常: {e}")
return -1e6
def main():
study = optuna.create_study(
direction='maximize',
sampler=TPESampler(seed=42, n_startup_trials=30),
study_name='eth_strategy_v1',
)
# 设置日志级别
optuna.logging.set_verbosity(optuna.logging.WARNING)
n_trials = 1000
print(f"开始 Optuna 优化, 共 {n_trials} 次试验 (多单并发版)...")
print("=" * 60)
def callback(study, trial):
if trial.number % 10 == 0:
best = study.best_trial
print(f"[Trial {trial.number:>4d}] "
f"当前值={trial.value:.2f} | "
f"最佳值={best.value:.2f} | "
f"PnL={best.user_attrs.get('total_pnl', 0):.1f}U | "
f"胜率={best.user_attrs.get('win_rate', 0):.1%} | "
f"日均={best.user_attrs.get('avg_daily_pnl', 0):.1f}U | "
f"最大日回撤={best.user_attrs.get('max_daily_dd', 0):.1f}U")
study.optimize(objective, n_trials=n_trials, callbacks=[callback], show_progress_bar=True)
# 输出最佳结果
best = study.best_trial
print("\n" + "=" * 60)
print("训练完成!最佳参数:")
print("=" * 60)
print(f" 目标值: {best.value:.4f}")
print(f" 总收益: {best.user_attrs.get('total_pnl', 0):.2f}U")
print(f" 交易次数: {best.user_attrs.get('num_trades', 0)}")
print(f" 胜率: {best.user_attrs.get('win_rate', 0):.2%}")
print(f" 日均收益: {best.user_attrs.get('avg_daily_pnl', 0):.2f}U")
print(f" 最大日回撤: {best.user_attrs.get('max_daily_dd', 0):.2f}U")
print(f" 盈亏比: {best.user_attrs.get('profit_factor', 0):.2f}")
# 保存最佳参数
output_path = os.path.join(os.path.dirname(__file__), 'best_params_2020_2022.json')
with open(output_path, 'w') as f:
json.dump(best.params, f, indent=2, ensure_ascii=False)
print(f"\n最佳参数已保存到: {output_path}")
if __name__ == '__main__':
main()