Files
codex_jxs_code/run_bb_full_grid_search.py
2026-02-28 13:27:54 +08:00

447 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
布林带中轨策略 - 全参数网格搜索 (2020-2025)
策略逻辑:
- 阳线 + 碰到布林带均线(先涨碰到) → 开多,碰上轨止盈
- 阴线 + 碰到布林带均线(先跌碰到) → 平多开空,碰下轨止盈
- 使用 1m 线判断当前K线是先跌碰均线还是先涨碰均线
- 每根K线只能操作一次
参数范围: period 1~1000, std 0.5~1000按 (0.5,0.5),(0.5,1)...(0.5,1000),(1,0.5),(1,1)...(1000,1000) 顺序遍历
回测设置: 200U本金 | 全仓 | 1%权益/单 | 万五手续费 | 90%返佣次日8点到账 | 100x杠杆
数据来源: 抓取多周期K线.py 抓取并存入 models/database.db 的 bitmart_eth_5m / bitmart_eth_1m
"""
from __future__ import annotations
import argparse
import hashlib
import json
import math
import os
import tempfile
import time
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
import numpy as np
import pandas as pd
from strategy.bb_midline_backtest import BBMidlineConfig, run_bb_midline_backtest
from strategy.data_loader import get_1m_touch_direction, load_klines
from strategy.indicators import bollinger
def frange(start: float, end: float, step: float) -> list[float]:
out: list[float] = []
x = float(start)
while x <= end + 1e-9:
out.append(round(x, 6))
x += step
return out
def build_full_grid(
p_start: float = 1,
p_end: float = 1000,
p_step: float = 1,
s_start: float = 0.5,
s_end: float = 1000,
s_step: float = 0.5,
) -> list[tuple[int, float]]:
"""构建完整参数网格,按 (0.5,0.5),(0.5,1)...(0.5,1000),(1,0.5)... 顺序"""
periods = sorted({max(1, min(1000, int(round(v)))) for v in frange(p_start, p_end, p_step)})
stds = sorted({round(max(0.5, min(1000.0, v)), 2) for v in frange(s_start, s_end, s_step)})
out = [(p, s) for p in periods for s in stds]
return sorted(set(out))
def score_stable(ret_pct: float, sharpe: float, dd_pct: float, n_trades: int) -> float:
"""收益稳定性评分:收益+夏普加分,回撤惩罚,交易过少惩罚"""
sparse_penalty = -5.0 if n_trades < 200 else 0.0
return ret_pct + sharpe * 12.0 - dd_pct * 0.8 + sparse_penalty
def _checkpoint_meta(
period: str,
start: str,
end: str,
p_step: float,
s_step: float,
sample: bool,
focus: bool,
fine: bool,
) -> dict:
return {
"period": period,
"start": start,
"end": end,
"p_step": p_step,
"s_step": s_step,
"sample": sample,
"focus": focus,
"fine": fine,
}
def _checkpoint_path(out_dir: Path, meta: dict) -> tuple[Path, Path]:
"""返回 checkpoint 数据文件和 meta 文件路径"""
h = hashlib.md5(json.dumps(meta, sort_keys=True).encode()).hexdigest()[:12]
return (
out_dir / f"bb_full_grid_{meta['period']}_resume_{h}.csv",
out_dir / f"bb_full_grid_{meta['period']}_resume_{h}.meta.json",
)
def load_checkpoint(
ckpt_path: Path,
meta_path: Path,
meta: dict,
) -> tuple[pd.DataFrame, set[tuple[int, float]]]:
"""
加载断点数据。若文件存在且 meta 一致,返回 (已完成结果df, 已完成的(period,std)集合)。
否则返回 (空df, 空集合)。
"""
if not ckpt_path.exists() or not meta_path.exists():
return pd.DataFrame(), set()
try:
with open(meta_path, "r", encoding="utf-8") as f:
saved = json.load(f)
if saved != meta:
return pd.DataFrame(), set()
except (json.JSONDecodeError, OSError):
return pd.DataFrame(), set()
try:
df = pd.read_csv(ckpt_path)
if "period" not in df.columns or "std" not in df.columns:
return pd.DataFrame(), set()
done = {(int(r["period"]), round(float(r["std"]), 2)) for _, r in df.iterrows()}
return df, done
except Exception:
return pd.DataFrame(), set()
def save_checkpoint(ckpt_path: Path, meta_path: Path, meta: dict, rows: list[dict]) -> None:
"""追加/覆盖保存断点"""
ckpt_path.parent.mkdir(parents=True, exist_ok=True)
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f, indent=2)
df = pd.DataFrame(rows)
df.to_csv(ckpt_path, index=False)
def _init_worker(df_path: str, df_1m_path: str | None, use_1m: bool, step_min: int):
global G_DF, G_DF_1M, G_USE_1M, G_STEP_MIN
G_DF = pd.read_pickle(df_path)
G_DF_1M = pd.read_pickle(df_1m_path) if (use_1m and df_1m_path) else None
G_USE_1M = bool(use_1m)
G_STEP_MIN = int(step_min)
def _eval_period_task(args: tuple[int, list[float]]) -> list[dict]:
period, std_list = args
assert G_DF is not None
arr_touch_dir = None
if G_USE_1M and G_DF_1M is not None:
close = G_DF["close"].astype(float)
bb_mid, _, _, _ = bollinger(close, period, 1.0)
arr_touch_dir = get_1m_touch_direction(G_DF, G_DF_1M, bb_mid.values, kline_step_min=G_STEP_MIN)
rows: list[dict] = []
for std in std_list:
cfg = BBMidlineConfig(
bb_period=period,
bb_std=float(std),
initial_capital=200.0,
margin_pct=0.01,
leverage=100.0,
cross_margin=True,
fee_rate=0.0005,
rebate_pct=0.90,
rebate_hour_utc=0,
fill_at_close=True,
use_1m_touch_filter=G_USE_1M,
kline_step_min=G_STEP_MIN,
)
result = run_bb_midline_backtest(
G_DF,
cfg,
df_1m=G_DF_1M if G_USE_1M else None,
arr_touch_dir_override=arr_touch_dir,
)
eq = result.equity_curve["equity"].dropna()
if len(eq) == 0:
final_eq = 0.0
ret_pct = -100.0
dd_u = -200.0
dd_pct = 100.0
else:
final_eq = float(eq.iloc[-1])
ret_pct = (final_eq - cfg.initial_capital) / cfg.initial_capital * 100.0
dd_u = float((eq.astype(float) - eq.astype(float).cummax()).min())
dd_pct = abs(dd_u) / cfg.initial_capital * 100.0
n_trades = len(result.trades)
win_rate = (
sum(1 for t in result.trades if t.net_pnl > 0) / n_trades * 100.0
if n_trades > 0
else 0.0
)
pnl = result.daily_stats["pnl"].astype(float)
sharpe = float(pnl.mean() / pnl.std()) * math.sqrt(365.0) if pnl.std() > 0 else 0.0
stable_score = score_stable(ret_pct, sharpe, dd_pct, n_trades)
rows.append({
"period": period,
"std": round(float(std), 2),
"final_eq": final_eq,
"ret_pct": ret_pct,
"n_trades": n_trades,
"win_rate": win_rate,
"sharpe": sharpe,
"max_dd_u": dd_u,
"max_dd_pct": dd_pct,
"stable_score": stable_score,
})
return rows
def run_grid_search(
params: list[tuple[int, float]],
*,
workers: int,
df_path: str,
df_1m_path: str | None,
use_1m: bool,
step_min: int,
existing_rows: list[dict] | None = None,
ckpt_path: Path | None = None,
meta_path: Path | None = None,
meta: dict | None = None,
checkpoint_interval: int = 5,
) -> pd.DataFrame:
by_period: dict[int, set[float]] = defaultdict(set)
for p, s in params:
by_period[int(p)].add(round(float(s), 2))
tasks = [(p, sorted(stds)) for p, stds in sorted(by_period.items())]
total_periods = len(tasks)
total_combos = sum(len(stds) for _, stds in tasks)
rows: list[dict] = list(existing_rows) if existing_rows else []
print(f"待运行: {total_combos} 组合 ({total_periods} period组), workers={workers}" + (
f", 断点续跑 (已有 {len(rows)} 条)" if rows else ""
))
start = time.time()
last_save = 0
def maybe_save(n_done_periods: int):
nonlocal last_save
if ckpt_path and meta_path and meta and n_done_periods > 0:
if n_done_periods - last_save >= checkpoint_interval:
save_checkpoint(ckpt_path, meta_path, meta, rows)
last_save = n_done_periods
if workers <= 1:
_init_worker(df_path, df_1m_path, use_1m, step_min)
done_periods = 0
done_combos = 0
for task in tasks:
res = _eval_period_task(task)
period = task[0]
# 打印该period的所有结果
for row in res:
print(f"✓ period={int(row['period']):4d}, std={float(row['std']):7.2f} | "
f"收益: {row['ret_pct']:+7.2f}% | 回撤: {row['max_dd_pct']:6.2f}% | "
f"夏普: {row['sharpe']:7.3f} | 交易: {int(row['n_trades']):6d} | "
f"评分: {row['stable_score']:7.1f}")
rows.extend(res)
done_periods += 1
done_combos += len(task[1])
maybe_save(done_periods)
if done_periods % max(1, total_periods // 20) == 0 or done_periods == total_periods:
elapsed = time.time() - start
print(f"进度 {done_combos}/{total_combos} ({done_periods}/{total_periods} periods), 用时 {elapsed:.0f}s")
else:
done_periods = 0
done_combos = 0
try:
with ProcessPoolExecutor(
max_workers=workers,
initializer=_init_worker,
initargs=(df_path, df_1m_path, use_1m, step_min),
) as ex:
future_map = {ex.submit(_eval_period_task, task): task for task in tasks}
for fut in as_completed(future_map):
period, stds = future_map[fut]
res = fut.result()
# 打印该period的所有结果
for row in res:
print(f"✓ period={int(row['period']):4d}, std={float(row['std']):7.2f} | "
f"收益: {row['ret_pct']:+7.2f}% | 回撤: {row['max_dd_pct']:6.2f}% | "
f"夏普: {row['sharpe']:7.3f} | 交易: {int(row['n_trades']):6d} | "
f"评分: {row['stable_score']:7.1f}")
rows.extend(res)
done_periods += 1
done_combos += len(stds)
maybe_save(done_periods)
if done_periods % max(1, total_periods // 20) == 0 or done_periods == total_periods:
elapsed = time.time() - start
print(f"进度 {done_combos}/{total_combos} ({done_periods}/{total_periods} periods), 用时 {elapsed:.0f}s")
except (PermissionError, OSError) as e:
print(f"多进程不可用 ({e}),改用单进程...")
_init_worker(df_path, df_1m_path, use_1m, step_min)
done_periods = 0
done_combos = 0
for task in tasks:
res = _eval_period_task(task)
# 打印该period的所有结果
for row in res:
print(f"✓ period={int(row['period']):4d}, std={float(row['std']):7.2f} | "
f"收益: {row['ret_pct']:+7.2f}% | 回撤: {row['max_dd_pct']:6.2f}% | "
f"夏普: {row['sharpe']:7.3f} | 交易: {int(row['n_trades']):6d} | "
f"评分: {row['stable_score']:7.1f}")
rows.extend(res)
done_periods += 1
done_combos += len(task[1])
maybe_save(done_periods)
if done_periods % max(1, total_periods // 20) == 0 or done_periods == total_periods:
elapsed = time.time() - start
print(f"进度 {done_combos}/{total_combos} ({done_periods}/{total_periods} periods), 用时 {elapsed:.0f}s")
if ckpt_path and meta_path and meta and rows:
save_checkpoint(ckpt_path, meta_path, meta, rows)
df = pd.DataFrame(rows)
print(f"完成, 总用时 {time.time() - start:.1f}s")
return df
def main():
parser = argparse.ArgumentParser(description="布林带全参数网格搜索 (1-1000, 0.5-1000)")
parser.add_argument("-p", "--period", default="5m", choices=["5m", "15m", "30m"])
parser.add_argument("--start", default="2020-01-01")
parser.add_argument("--end", default="2026-01-01")
parser.add_argument("-j", "--workers", type=int, default=max(1, (os.cpu_count() or 4) - 1))
parser.add_argument("--no-1m", action="store_true", help="禁用 1m 触及方向过滤")
parser.add_argument("--p-step", type=float, default=5, help="period 步长 (默认5, 全量用1)")
parser.add_argument("--s-step", type=float, default=5, help="std 步长 (默认5, 全量用0.5或1)")
parser.add_argument("--quick", action="store_true", help="快速模式: p-step=20, s-step=20")
parser.add_argument("--sample", action="store_true", help="采样模式: 仅用2022-2024两年加速")
parser.add_argument("--focus", action="store_true", help="聚焦模式: 仅在period 50-400, std 100-800 细搜")
parser.add_argument("--fine", action="store_true", help="精细模式: 在period 280-310, std 450-550 细搜")
parser.add_argument("--no-resume", action="store_true", help="禁用断点续跑,重新开始")
parser.add_argument("--checkpoint-interval", type=int, default=10,
help="每完成 N 个 period 组保存一次断点 (默认 10)")
args = parser.parse_args()
use_1m = not args.no_1m
step_min = int(args.period.replace("m", ""))
if args.sample:
args.start, args.end = "2022-01-01", "2024-01-01"
print("采样模式: 使用 2022-2024 数据加速")
if args.quick:
p_step, s_step = 20.0, 20.0
else:
p_step, s_step = args.p_step, args.s_step
if args.fine:
params = build_full_grid(p_start=280, p_end=300, p_step=2, s_start=480, s_end=510, s_step=2)
print("精细模式: period 280-300 step=2, std 480-510 step=2 (~176组合)")
elif args.focus:
params = build_full_grid(p_start=50, p_end=400, p_step=25, s_start=100, s_end=800, s_step=50)
print("聚焦模式: period 50-400 step=25, std 100-800 step=50 (~225组合)")
else:
params = build_full_grid(p_step=p_step, s_step=s_step)
print(f"网格参数: period 1-1000 step={p_step}, std 0.5-1000 step={s_step}{len(params)} 组合")
out_dir = Path(__file__).resolve().parent / "strategy" / "results"
out_dir.mkdir(parents=True, exist_ok=True)
meta = _checkpoint_meta(
args.period, args.start, args.end, p_step, s_step,
args.sample, args.focus, args.fine,
)
ckpt_path, meta_path = _checkpoint_path(out_dir, meta)
existing_rows: list[dict] = []
params_to_run = params
if not args.no_resume:
ckpt_df, done_set = load_checkpoint(ckpt_path, meta_path, meta)
if len(done_set) > 0:
existing_rows = ckpt_df.to_dict("records")
params_to_run = [(p, s) for p, s in params if (int(p), round(float(s), 2)) not in done_set]
print(f"断点续跑: 已完成 {len(done_set)} 组合,剩余 {len(params_to_run)} 组合")
if not params_to_run:
print("无待运行组合,直接使用断点结果")
all_df = pd.DataFrame(existing_rows)
else:
print(f"\n加载数据: {args.period} + 1m, {args.start} ~ {args.end}")
t0 = time.time()
df = load_klines(args.period, args.start, args.end)
df_1m = load_klines("1m", args.start, args.end) if use_1m else None
print(f" {args.period}: {len(df):,}" + (f", 1m: {len(df_1m):,}" if df_1m is not None else "") + f", {time.time()-t0:.1f}s\n")
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as f_df:
df.to_pickle(f_df.name)
df_path = f_df.name
df_1m_path = None
if df_1m is not None:
with tempfile.NamedTemporaryFile(suffix=".pkl", delete=False) as f_1m:
df_1m.to_pickle(f_1m.name)
df_1m_path = f_1m.name
try:
all_df = run_grid_search(
params_to_run,
workers=args.workers,
df_path=df_path,
df_1m_path=df_1m_path,
use_1m=use_1m,
step_min=step_min,
existing_rows=existing_rows,
ckpt_path=ckpt_path,
meta_path=meta_path,
meta=meta,
checkpoint_interval=args.checkpoint_interval,
)
finally:
Path(df_path).unlink(missing_ok=True)
if df_1m_path:
Path(df_1m_path).unlink(missing_ok=True)
all_df = all_df.drop_duplicates(subset=["period", "std"], keep="last")
best_stable = all_df.sort_values("stable_score", ascending=False).iloc[0]
best_return = all_df.sort_values("ret_pct", ascending=False).iloc[0]
stamp = time.strftime("%Y%m%d_%H%M%S")
out_path = out_dir / f"bb_full_grid_{args.period}_{stamp}.csv"
all_df.sort_values("stable_score", ascending=False).to_csv(out_path, index=False)
print("\n" + "=" * 80)
print("全参数网格搜索完成")
print(f"最佳稳定参数: period={int(best_stable['period'])}, std={float(best_stable['std']):.2f}")
print(f" 最终权益: {best_stable['final_eq']:.2f} U | 收益: {best_stable['ret_pct']:+.2f}%")
print(f" 最大回撤: {best_stable['max_dd_pct']:.2f}% | 夏普: {best_stable['sharpe']:.3f} | 交易数: {int(best_stable['n_trades'])}")
print()
print(f"最高收益参数: period={int(best_return['period'])}, std={float(best_return['std']):.2f}")
print(f" 最终权益: {best_return['final_eq']:.2f} U | 收益: {best_return['ret_pct']:+.2f}%")
print(f" 最大回撤: {best_return['max_dd_pct']:.2f}% | 夏普: {best_return['sharpe']:.3f} | 交易数: {int(best_return['n_trades'])}")
print(f"\n结果已保存: {out_path}")
print(f"断点文件: {ckpt_path.name} (可用 --no-resume 重新开始)")
print("=" * 80)
if __name__ == "__main__":
main()