Files
codex_jxs_code/run_bb_midline_hier_search.py
2026-02-26 19:05:17 +08:00

423 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
布林带中轨策略参数分层搜索2020-2025
说明:
- 全区间覆盖: period 1~1000, std 0.5~1000
- 分层搜索: 先粗扫全区间,再在候选周围细化,最终细化到 std=0.5 步长
- 使用 1m 触及方向过滤(先涨/先跌)时,按 period 复用触及方向以提速
"""
from __future__ import annotations
import argparse
import math
import os
import tempfile
import time
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
import numpy as np
import pandas as pd
from strategy.bb_midline_backtest import BBMidlineConfig, run_bb_midline_backtest
from strategy.data_loader import get_1m_touch_direction, load_klines
from strategy.indicators import bollinger
G_DF: pd.DataFrame | None = None
G_DF_1M: pd.DataFrame | None = None
G_USE_1M: bool = True
G_STEP_MIN: int = 5
def frange(start: float, end: float, step: float) -> list[float]:
out: list[float] = []
x = float(start)
while x <= end + 1e-9:
out.append(round(x, 6))
x += step
return out
def build_grid(
p_start: float,
p_end: float,
p_step: float,
s_start: float,
s_end: float,
s_step: float,
) -> list[tuple[int, float]]:
periods = sorted({max(1, min(1000, int(round(v)))) for v in frange(p_start, p_end, p_step)})
stds = sorted({round(max(0.5, min(1000.0, v)), 2) for v in frange(s_start, s_end, s_step)})
if 1000 not in periods:
periods.append(1000)
if 1000.0 not in stds:
stds.append(1000.0)
out = [(p, s) for p in periods for s in stds]
return sorted(set(out))
def build_local_grid(
centers: pd.DataFrame,
p_window: int,
p_step: int,
s_window: float,
s_step: float,
) -> list[tuple[int, float]]:
out: set[tuple[int, float]] = set()
for _, row in centers.iterrows():
p0 = int(row["period"])
s0 = float(row["std"])
p_min = max(1, p0 - p_window)
p_max = min(1000, p0 + p_window)
s_min = max(0.5, s0 - s_window)
s_max = min(1000.0, s0 + s_window)
periods = sorted({max(1, min(1000, int(round(v)))) for v in frange(p_min, p_max, p_step)})
stds = sorted({round(max(0.5, min(1000.0, v)), 2) for v in frange(s_min, s_max, s_step)})
for p in periods:
for s in stds:
out.add((p, s))
return sorted(out)
def score_row(ret_pct: float, sharpe: float, dd_pct: float, n_trades: int) -> float:
# 偏向“收益稳定”: 收益和夏普加分,回撤和极少交易惩罚
sparse_penalty = -5.0 if n_trades < 200 else 0.0
return ret_pct + sharpe * 12.0 - dd_pct * 0.8 + sparse_penalty
def _init_worker(df_path: str, df_1m_path: str | None, use_1m: bool, step_min: int):
global G_DF, G_DF_1M, G_USE_1M, G_STEP_MIN
G_DF = pd.read_pickle(df_path)
G_DF_1M = pd.read_pickle(df_1m_path) if (use_1m and df_1m_path) else None
G_USE_1M = bool(use_1m)
G_STEP_MIN = int(step_min)
def _eval_period_task(args: tuple[int, list[float]]) -> list[dict]:
period, std_list = args
assert G_DF is not None
arr_touch_dir = None
if G_USE_1M and G_DF_1M is not None:
close = G_DF["close"].astype(float)
bb_mid, _, _, _ = bollinger(close, period, 1.0)
arr_touch_dir = get_1m_touch_direction(G_DF, G_DF_1M, bb_mid.values, kline_step_min=G_STEP_MIN)
rows: list[dict] = []
for std in std_list:
cfg = BBMidlineConfig(
bb_period=period,
bb_std=float(std),
initial_capital=200.0,
margin_pct=0.01,
leverage=100.0,
cross_margin=True,
fee_rate=0.0005,
rebate_pct=0.90,
rebate_hour_utc=0,
fill_at_close=True,
use_1m_touch_filter=G_USE_1M,
kline_step_min=G_STEP_MIN,
)
result = run_bb_midline_backtest(
G_DF,
cfg,
df_1m=G_DF_1M if G_USE_1M else None,
arr_touch_dir_override=arr_touch_dir,
)
eq = result.equity_curve["equity"].dropna()
if len(eq) == 0:
final_eq = 0.0
ret_pct = -100.0
dd_u = -200.0
dd_pct = 100.0
else:
final_eq = float(eq.iloc[-1])
ret_pct = (final_eq - cfg.initial_capital) / cfg.initial_capital * 100.0
dd_u = float((eq.astype(float) - eq.astype(float).cummax()).min())
dd_pct = abs(dd_u) / cfg.initial_capital * 100.0
n_trades = len(result.trades)
win_rate = (
sum(1 for t in result.trades if t.net_pnl > 0) / n_trades * 100.0
if n_trades > 0
else 0.0
)
pnl = result.daily_stats["pnl"].astype(float)
sharpe = float(pnl.mean() / pnl.std()) * math.sqrt(365.0) if pnl.std() > 0 else 0.0
stable_score = score_row(ret_pct, sharpe, dd_pct, n_trades)
rows.append(
{
"period": period,
"std": round(float(std), 2),
"final_eq": final_eq,
"ret_pct": ret_pct,
"n_trades": n_trades,
"win_rate": win_rate,
"sharpe": sharpe,
"max_dd_u": dd_u,
"max_dd_pct": dd_pct,
"stable_score": stable_score,
"use_1m_filter": int(G_USE_1M),
}
)
return rows
def evaluate_grid(
params: list[tuple[int, float]],
*,
workers: int,
df_path: str,
df_1m_path: str | None,
use_1m: bool,
step_min: int,
label: str,
) -> pd.DataFrame:
by_period: dict[int, set[float]] = defaultdict(set)
for p, s in params:
by_period[int(p)].add(round(float(s), 2))
tasks = [(p, sorted(stds)) for p, stds in sorted(by_period.items())]
total_periods = len(tasks)
total_combos = sum(len(stds) for _, stds in tasks)
if total_combos == 0:
return pd.DataFrame()
print(f"[{label}] period组数={total_periods}, 参数组合={total_combos}, workers={workers}")
start = time.time()
rows: list[dict] = []
done_periods = 0
done_combos = 0
with ProcessPoolExecutor(
max_workers=workers,
initializer=_init_worker,
initargs=(df_path, df_1m_path, use_1m, step_min),
) as ex:
future_map = {ex.submit(_eval_period_task, task): task for task in tasks}
for fut in as_completed(future_map):
period, stds = future_map[fut]
res = fut.result()
rows.extend(res)
done_periods += 1
done_combos += len(stds)
if (
done_periods == total_periods
or done_periods % max(1, total_periods // 10) == 0
):
elapsed = time.time() - start
print(
f"[{label}] 进度 {done_combos}/{total_combos} 组合 "
f"({done_periods}/{total_periods} periods), {elapsed:.0f}s"
)
df = pd.DataFrame(rows)
print(f"[{label}] 完成, 用时 {time.time() - start:.1f}s")
return df
def summarize_yearly(eq: pd.Series, initial_capital: float = 200.0) -> pd.DataFrame:
s = eq.dropna().copy()
s.index = pd.to_datetime(s.index)
out_rows: list[dict] = []
prev = initial_capital
for year in range(2020, 2026):
sub = s[s.index.year == year]
if len(sub) == 0:
continue
ye = float(sub.iloc[-1])
ret = (ye - prev) / prev * 100.0 if prev > 0 else 0.0
out_rows.append({"year": year, "year_end_equity": ye, "year_return_pct": ret})
prev = ye
return pd.DataFrame(out_rows)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--period", default="5m", choices=["5m", "15m", "30m"])
parser.add_argument("--start", default="2020-01-01")
parser.add_argument("--end", default="2026-01-01")
parser.add_argument("-j", "--workers", type=int, default=max(1, (os.cpu_count() or 4) - 1))
parser.add_argument("--no-1m", action="store_true", help="禁用 1m 方向过滤")
args = parser.parse_args()
use_1m = not args.no_1m
step_min = int(args.period.replace("m", ""))
out_dir = Path(__file__).resolve().parent / "strategy" / "results"
out_dir.mkdir(parents=True, exist_ok=True)
print(f"加载数据: {args.period} {args.start}~{args.end}")
t0 = time.time()
df = load_klines(args.period, args.start, args.end)
df_1m = load_klines("1m", args.start, args.end) if use_1m else None
print(
f" {args.period}: {len(df):,}"
+ (f", 1m: {len(df_1m):,}" if df_1m is not None else "")
+ f", {time.time()-t0:.1f}s\n"
)
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as f_df:
df.to_pickle(f_df.name)
df_path = f_df.name
df_1m_path = None
if df_1m is not None:
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as f_1m:
df_1m.to_pickle(f_1m.name)
df_1m_path = f_1m.name
try:
evaluated: set[tuple[int, float]] = set()
all_parts: list[pd.DataFrame] = []
# Stage 1: 全区间粗扫
stage1 = build_grid(1, 1000, 50, 0.5, 1000, 50)
stage1 = [x for x in stage1 if x not in evaluated]
df1 = evaluate_grid(
stage1,
workers=args.workers,
df_path=df_path,
df_1m_path=df_1m_path,
use_1m=use_1m,
step_min=step_min,
label="stage1-global",
)
if not df1.empty:
all_parts.append(df1)
evaluated.update((int(r["period"]), float(r["std"])) for _, r in df1.iterrows())
seed1 = (
df1.sort_values("stable_score", ascending=False).head(6)
if not df1.empty
else pd.DataFrame(columns=["period", "std"])
)
# Stage 2: 候选周围中等步长细化
stage2 = build_local_grid(seed1, p_window=25, p_step=5, s_window=50, s_step=10)
stage2 = [x for x in stage2 if x not in evaluated]
df2 = evaluate_grid(
stage2,
workers=args.workers,
df_path=df_path,
df_1m_path=df_1m_path,
use_1m=use_1m,
step_min=step_min,
label="stage2-local",
)
if not df2.empty:
all_parts.append(df2)
evaluated.update((int(r["period"]), float(r["std"])) for _, r in df2.iterrows())
pool2 = pd.concat([d for d in [df1, df2] if not d.empty], ignore_index=True)
seed2 = (
pool2.sort_values("stable_score", ascending=False).head(4)
if len(pool2) > 0
else pd.DataFrame(columns=["period", "std"])
)
# Stage 3: 候选周围更细化
stage3 = build_local_grid(seed2, p_window=8, p_step=1, s_window=10, s_step=1)
stage3 = [x for x in stage3 if x not in evaluated]
df3 = evaluate_grid(
stage3,
workers=args.workers,
df_path=df_path,
df_1m_path=df_1m_path,
use_1m=use_1m,
step_min=step_min,
label="stage3-fine",
)
if not df3.empty:
all_parts.append(df3)
evaluated.update((int(r["period"]), float(r["std"])) for _, r in df3.iterrows())
pool3 = pd.concat([d for d in [df1, df2, df3] if not d.empty], ignore_index=True)
seed3 = (
pool3.sort_values("stable_score", ascending=False).head(2)
if len(pool3) > 0
else pd.DataFrame(columns=["period", "std"])
)
# Stage 4: 最终细化std 0.5 步长)
stage4 = build_local_grid(seed3, p_window=3, p_step=1, s_window=4, s_step=0.5)
stage4 = [x for x in stage4 if x not in evaluated]
df4 = evaluate_grid(
stage4,
workers=args.workers,
df_path=df_path,
df_1m_path=df_1m_path,
use_1m=use_1m,
step_min=step_min,
label="stage4-final",
)
if not df4.empty:
all_parts.append(df4)
evaluated.update((int(r["period"]), float(r["std"])) for _, r in df4.iterrows())
if not all_parts:
raise RuntimeError("未得到任何评估结果")
all_df = pd.concat(all_parts, ignore_index=True)
all_df = all_df.drop_duplicates(subset=["period", "std"], keep="last")
best_stable = all_df.sort_values("stable_score", ascending=False).iloc[0]
best_return = all_df.sort_values("ret_pct", ascending=False).iloc[0]
# 对最佳稳定参数再跑一次,导出逐年收益
cfg = BBMidlineConfig(
bb_period=int(best_stable["period"]),
bb_std=float(best_stable["std"]),
initial_capital=200.0,
margin_pct=0.01,
leverage=100.0,
cross_margin=True,
fee_rate=0.0005,
rebate_pct=0.90,
rebate_hour_utc=0,
fill_at_close=True,
use_1m_touch_filter=use_1m,
kline_step_min=step_min,
)
final_res = run_bb_midline_backtest(df, cfg, df_1m=df_1m if use_1m else None)
final_eq = final_res.equity_curve["equity"].dropna()
yearly = summarize_yearly(final_eq, initial_capital=200.0)
stamp = time.strftime("%Y%m%d_%H%M%S")
all_path = out_dir / f"bb_midline_hier_search_{args.period}_{stamp}.csv"
yearly_path = out_dir / f"bb_midline_hier_search_{args.period}_{stamp}_yearly.csv"
all_df.sort_values("stable_score", ascending=False).to_csv(all_path, index=False)
yearly.to_csv(yearly_path, index=False)
print("\n" + "=" * 96)
print("分层搜索完成")
print(
f"最佳稳定参数: period={int(best_stable['period'])}, std={float(best_stable['std']):.2f} | "
f"final={best_stable['final_eq']:.4f}U | ret={best_stable['ret_pct']:+.2f}% | "
f"dd={best_stable['max_dd_pct']:.2f}% | sharpe={best_stable['sharpe']:.3f} | "
f"trades={int(best_stable['n_trades'])}"
)
print(
f"最高收益参数: period={int(best_return['period'])}, std={float(best_return['std']):.2f} | "
f"final={best_return['final_eq']:.4f}U | ret={best_return['ret_pct']:+.2f}% | "
f"dd={best_return['max_dd_pct']:.2f}% | sharpe={best_return['sharpe']:.3f} | "
f"trades={int(best_return['n_trades'])}"
)
print(f"结果文件: {all_path}")
print(f"逐年文件: {yearly_path}")
print("=" * 96)
finally:
Path(df_path).unlink(missing_ok=True)
if df_1m_path:
Path(df_1m_path).unlink(missing_ok=True)
if __name__ == "__main__":
main()