From 1d68fcd925517fb8c7a3e6e11619cd4002ed11e6 Mon Sep 17 00:00:00 2001 From: ddrwode <34234@3来 34> Date: Sat, 28 Feb 2026 16:48:40 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AC=AC=E4=B8=80=E7=89=88=E7=AD=96=E7=95=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- run_bb_full_grid_search.py | 227 ++++++++++++++++++------------------- 1 file changed, 111 insertions(+), 116 deletions(-) diff --git a/run_bb_full_grid_search.py b/run_bb_full_grid_search.py index d7fd280..7267733 100644 --- a/run_bb_full_grid_search.py +++ b/run_bb_full_grid_search.py @@ -145,9 +145,9 @@ def _init_worker(df_path: str, df_1m_path: str | None, use_1m: bool, step_min: i G_STEP_MIN = int(step_min) -def _eval_single_task(args: tuple[int, float]) -> dict: - """单个 (period, std) 组合的回测任务""" - period, std = args +def _eval_period_task(args: tuple[int, list[float]]) -> list[dict]: + """一个 period 组的批量回测(同一 period 只算一次布林带 + 1m touch)""" + period, std_list = args assert G_DF is not None arr_touch_dir = None @@ -156,67 +156,64 @@ def _eval_single_task(args: tuple[int, float]) -> dict: bb_mid, _, _, _ = bollinger(close, period, 1.0) arr_touch_dir = get_1m_touch_direction(G_DF, G_DF_1M, bb_mid.values, kline_step_min=G_STEP_MIN) - cfg = BBMidlineConfig( - bb_period=period, - bb_std=float(std), - initial_capital=200.0, - margin_pct=0.01, - leverage=100.0, - cross_margin=True, - fee_rate=0.0005, - rebate_pct=0.90, - rebate_hour_utc=0, - fill_at_close=True, - use_1m_touch_filter=G_USE_1M, - kline_step_min=G_STEP_MIN, - ) - result = run_bb_midline_backtest( - G_DF, - cfg, - df_1m=G_DF_1M if G_USE_1M else None, - arr_touch_dir_override=arr_touch_dir, - ) + rows: list[dict] = [] + for std in std_list: + cfg = BBMidlineConfig( + bb_period=period, + bb_std=float(std), + initial_capital=200.0, + margin_pct=0.01, + leverage=100.0, + cross_margin=True, + fee_rate=0.0005, + rebate_pct=0.90, + rebate_hour_utc=0, + fill_at_close=True, + use_1m_touch_filter=G_USE_1M, + kline_step_min=G_STEP_MIN, + ) + result = run_bb_midline_backtest( + G_DF, + cfg, + df_1m=G_DF_1M if G_USE_1M else None, + arr_touch_dir_override=arr_touch_dir, + ) - eq = result.equity_curve["equity"].dropna() - if len(eq) == 0: - final_eq = 0.0 - ret_pct = -100.0 - dd_u = -200.0 - dd_pct = 100.0 - else: - final_eq = float(eq.iloc[-1]) - ret_pct = (final_eq - cfg.initial_capital) / cfg.initial_capital * 100.0 - dd_u = float((eq.astype(float) - eq.astype(float).cummax()).min()) - dd_pct = abs(dd_u) / cfg.initial_capital * 100.0 + eq = result.equity_curve["equity"].dropna() + if len(eq) == 0: + final_eq = 0.0 + ret_pct = -100.0 + dd_u = -200.0 + dd_pct = 100.0 + else: + final_eq = float(eq.iloc[-1]) + ret_pct = (final_eq - cfg.initial_capital) / cfg.initial_capital * 100.0 + dd_u = float((eq.astype(float) - eq.astype(float).cummax()).min()) + dd_pct = abs(dd_u) / cfg.initial_capital * 100.0 - n_trades = len(result.trades) - win_rate = ( - sum(1 for t in result.trades if t.net_pnl > 0) / n_trades * 100.0 - if n_trades > 0 - else 0.0 - ) - pnl = result.daily_stats["pnl"].astype(float) - sharpe = float(pnl.mean() / pnl.std()) * math.sqrt(365.0) if pnl.std() > 0 else 0.0 - stable_score = score_stable(ret_pct, sharpe, dd_pct, n_trades) + n_trades = len(result.trades) + win_rate = ( + sum(1 for t in result.trades if t.net_pnl > 0) / n_trades * 100.0 + if n_trades > 0 + else 0.0 + ) + pnl = result.daily_stats["pnl"].astype(float) + sharpe = float(pnl.mean() / pnl.std()) * math.sqrt(365.0) if pnl.std() > 0 else 0.0 + stable_score = score_stable(ret_pct, sharpe, dd_pct, n_trades) - return { - "period": period, - "std": round(float(std), 2), - "final_eq": final_eq, - "ret_pct": ret_pct, - "n_trades": n_trades, - "win_rate": win_rate, - "sharpe": sharpe, - "max_dd_u": dd_u, - "max_dd_pct": dd_pct, - "stable_score": stable_score, - } - - -def _eval_period_task(args: tuple[int, list[float]]) -> list[dict]: - """兼容旧接口:一个 period 组的批量回测""" - period, std_list = args - return [_eval_single_task((period, s)) for s in std_list] + rows.append({ + "period": period, + "std": round(float(std), 2), + "final_eq": final_eq, + "ret_pct": ret_pct, + "n_trades": n_trades, + "win_rate": win_rate, + "sharpe": sharpe, + "max_dd_u": dd_u, + "max_dd_pct": dd_pct, + "stable_score": stable_score, + }) + return rows @@ -262,86 +259,83 @@ def run_grid_search( meta: dict | None = None, checkpoint_interval: int = 5, ) -> pd.DataFrame: - total_combos = len(params) + from collections import defaultdict + + by_period: dict[int, set[float]] = defaultdict(set) + for p, s in params: + by_period[int(p)].add(round(float(s), 2)) + + tasks = [(p, sorted(stds)) for p, stds in sorted(by_period.items())] + total_periods = len(tasks) + total_combos = sum(len(stds) for _, stds in tasks) rows: list[dict] = list(existing_rows) if existing_rows else [] - print(f"待运行: {total_combos} 组合, workers={workers}" + ( + print(f"待运行: {total_combos} 组合 ({total_periods} period组), workers={workers}" + ( f", 断点续跑 (已有 {len(rows)} 条)" if rows else "" )) t_start = time.time() + done_periods = 0 done_combos = 0 - last_save_count = 0 + last_save_periods = 0 last_top_time = t_start - # checkpoint 按组合数保存,默认每 500 个组合保存一次 - ckpt_combo_interval = max(100, checkpoint_interval * 50) - # Top N 排行榜刷新间隔(秒) - top_n_interval = 30.0 - def maybe_save(): - nonlocal last_save_count - if ckpt_path and meta_path and meta and done_combos > 0: - if done_combos - last_save_count >= ckpt_combo_interval: + def on_period_done(res: list[dict], period: int, n_stds: int): + nonlocal done_periods, done_combos, last_save_periods, last_top_time + # 逐条打印该 period 组的结果 + for row in res: + rows.append(row) + done_combos += 1 + elapsed = time.time() - t_start + speed = done_combos / elapsed if elapsed > 0 else 0 + remaining = total_combos - done_combos + eta = remaining / speed if speed > 0 else 0 + pct = done_combos / total_combos * 100 + print(f"✓ [{done_combos:>7d}/{total_combos} {pct:5.1f}% ETA {_format_eta(eta)}] " + f"p={int(row['period']):4d} s={row['std']:7.2f} | " + f"收益:{row['ret_pct']:+8.2f}% 回撤:{row['max_dd_pct']:6.2f}% " + f"夏普:{row['sharpe']:7.3f} 交易:{int(row['n_trades']):5d} " + f"评分:{row['stable_score']:8.1f}", flush=True) + + done_periods += 1 + + # 每完成 checkpoint_interval 个 period 组保存一次 + if ckpt_path and meta_path and meta: + if done_periods - last_save_periods >= checkpoint_interval: save_checkpoint(ckpt_path, meta_path, meta, rows) - last_save_count = done_combos + last_save_periods = done_periods + print(f" 💾 断点已保存 ({done_combos} 条)", flush=True) - def maybe_print_top(): - nonlocal last_top_time + # 每 60 秒打印一次 Top 10 now = time.time() - if now - last_top_time >= top_n_interval: + if now - last_top_time >= 60.0: _print_top_n(rows) last_top_time = now - def on_result(row: dict): - nonlocal done_combos - rows.append(row) - done_combos += 1 - elapsed = time.time() - t_start - speed = done_combos / elapsed if elapsed > 0 else 0 - remaining = total_combos - done_combos - eta = remaining / speed if speed > 0 else 0 - pct = done_combos / total_combos * 100 - - # 每个结果都实时打印 - print(f"✓ [{done_combos:>7d}/{total_combos} {pct:5.1f}% ETA {_format_eta(eta)}] " - f"p={int(row['period']):4d} s={row['std']:7.2f} | " - f"收益:{row['ret_pct']:+8.2f}% 回撤:{row['max_dd_pct']:6.2f}% " - f"夏普:{row['sharpe']:7.3f} 交易:{int(row['n_trades']):5d} " - f"评分:{row['stable_score']:8.1f}", flush=True) - - maybe_save() - maybe_print_top() + def _run_sequential(): + _init_worker(df_path, df_1m_path, use_1m, step_min) + for task in tasks: + res = _eval_period_task(task) + on_period_done(res, task[0], len(task[1])) if workers <= 1: - _init_worker(df_path, df_1m_path, use_1m, step_min) - for p, s in params: - row = _eval_single_task((p, s)) - on_result(row) + _run_sequential() else: - # 多进程:逐个 (period, std) 提交,实时返回 - # 为避免提交 200 万个 future 占用过多内存,分批提交 - batch_size = workers * 20 # 每批提交的任务数 try: with ProcessPoolExecutor( max_workers=workers, initializer=_init_worker, initargs=(df_path, df_1m_path, use_1m, step_min), ) as ex: - idx = 0 - while idx < total_combos: - batch_end = min(idx + batch_size, total_combos) - batch = params[idx:batch_end] - future_map = {ex.submit(_eval_single_task, (p, s)): (p, s) for p, s in batch} - for fut in as_completed(future_map): - row = fut.result() - on_result(row) - idx = batch_end + future_map = {ex.submit(_eval_period_task, task): task for task in tasks} + for fut in as_completed(future_map): + period, stds = future_map[fut] + res = fut.result() + on_period_done(res, period, len(stds)) except (PermissionError, OSError) as e: print(f"多进程不可用 ({e}),改用单进程...") - _init_worker(df_path, df_1m_path, use_1m, step_min) - for p, s in params[done_combos:]: - row = _eval_single_task((p, s)) - on_result(row) + _run_sequential() + # 最终保存 if ckpt_path and meta_path and meta and rows: save_checkpoint(ckpt_path, meta_path, meta, rows) @@ -349,7 +343,8 @@ def run_grid_search( _print_top_n(rows, n=20, label="最终 Top 20") df = pd.DataFrame(rows) - print(f"完成, 总用时 {time.time() - t_start:.1f}s, 平均 {total_combos / (time.time() - t_start):.1f} 组合/秒") + elapsed_total = time.time() - t_start + print(f"完成, 总用时 {elapsed_total:.1f}s, 平均 {done_combos / elapsed_total:.1f} 组合/秒") return df