import argparse import csv import itertools import json from dataclasses import asdict, dataclass from pathlib import Path @dataclass class Bar: ts: int open: float high: float low: float close: float @dataclass class StrategyParams: min_prev_entity_pct: float = 0.1 breakout_buffer_pct: float = 0.03 shadow_threshold_pct: float = 0.15 stop_loss_pct: float = 0.35 take_profit_pct: float = 0.8 trailing_start_pct: float = 0.5 trailing_backoff_pct: float = 0.25 use_atr_dynamic_threshold: bool = True atr_length: int = 14 breakout_buffer_atr_mult: float = 0.12 shadow_threshold_atr_mult: float = 0.18 stop_loss_atr_mult: float = 0.45 take_profit_atr_mult: float = 0.95 trailing_start_atr_mult: float = 0.6 trailing_backoff_atr_mult: float = 0.3 @dataclass class BacktestResult: score: float total_return_pct: float max_drawdown_pct: float win_rate_pct: float profit_factor: float trades: int wins: int params: StrategyParams @dataclass class RobustEvalResult: robust_score: float consistency_score: float overfit_gap: float train_risk_adj: float valid_risk_adj: float train: BacktestResult valid: BacktestResult full: BacktestResult params: StrategyParams def load_csv_bars(path: Path) -> list[Bar]: bars: list[Bar] = [] with path.open("r", encoding="utf-8") as f: reader = csv.DictReader(f) required = {"id", "open", "high", "low", "close"} if not required.issubset(reader.fieldnames or set()): raise ValueError(f"CSV missing columns: {required}") for row in reader: try: bars.append( Bar( ts=int(float(row["id"])), open=float(row["open"]), high=float(row["high"]), low=float(row["low"]), close=float(row["close"]), ) ) except (ValueError, TypeError): continue bars.sort(key=lambda x: x.ts) return bars def resample_to_minutes(bars: list[Bar], minutes: int) -> list[Bar]: if not bars: return [] bucket_sec = minutes * 60 grouped: list[Bar] = [] cur_bucket = None cur_open = cur_high = cur_low = cur_close = None for bar in bars: b = bar.ts // bucket_sec if cur_bucket is None or b != cur_bucket: if cur_bucket is not None: grouped.append( Bar( ts=cur_bucket * bucket_sec, open=cur_open, high=cur_high, low=cur_low, close=cur_close, ) ) cur_bucket = b cur_open = bar.open cur_high = bar.high cur_low = bar.low cur_close = bar.close else: cur_high = max(cur_high, bar.high) cur_low = min(cur_low, bar.low) cur_close = bar.close if cur_bucket is not None: grouped.append( Bar( ts=cur_bucket * bucket_sec, open=cur_open, high=cur_high, low=cur_low, close=cur_close, ) ) return grouped def compute_atr_series(bars: list[Bar], length: int) -> list[float | None]: atr: list[float | None] = [None] * len(bars) if len(bars) < length + 1: return atr tr_list: list[float] = [0.0] * len(bars) for i in range(1, len(bars)): high = bars[i].high low = bars[i].low prev_close = bars[i - 1].close tr = max(high - low, abs(high - prev_close), abs(low - prev_close)) tr_list[i] = tr for i in range(length, len(bars)): window = tr_list[i - length + 1:i + 1] atr[i] = sum(window) / length return atr def resolve_dynamic_distance( base_price: float, fixed_pct: float, atr_value: float | None, atr_mult: float, use_atr_dynamic_threshold: bool, ) -> float: fixed_distance = base_price * fixed_pct / 100 if use_atr_dynamic_threshold and atr_value and atr_value > 0: return max(fixed_distance, atr_value * atr_mult) return fixed_distance def kline_entity_abs(bar: Bar) -> float: return abs(bar.close - bar.open) def kline_entity_edges(bar: Bar) -> tuple[float, float]: return max(bar.open, bar.close), min(bar.open, bar.close) def upper_shadow_abs(bar: Bar) -> float: return max(0.0, bar.high - max(bar.open, bar.close)) def lower_shadow_abs(bar: Bar) -> float: return max(0.0, min(bar.open, bar.close) - bar.low) def close_position( equity: float, side: int, entry_price: float, exit_price: float, fee_rate: float, ) -> tuple[float, float]: net_ret = side * (exit_price - entry_price) / entry_price - 2 * fee_rate equity *= (1 + net_ret) return equity, net_ret def backtest_strategy( bars: list[Bar], params: StrategyParams, fee_rate: float = 0.0004, min_trades: int = 20, ) -> BacktestResult: atr_series = compute_atr_series(bars, params.atr_length) position = 0 entry_price = None max_favorable_price = None min_favorable_price = None equity = 1.0 peak_equity = 1.0 max_drawdown = 0.0 trades = 0 wins = 0 gross_profit = 0.0 gross_loss = 0.0 for i in range(1, len(bars)): current = bars[i] current_price = current.close atr_value = atr_series[i - 1] prev_idx = None for j in range(i - 1, -1, -1): prev_bar = bars[j] entity = kline_entity_abs(prev_bar) entity_pct = (entity / prev_bar.open * 100) if prev_bar.open else 0 if entity_pct > params.min_prev_entity_pct: prev_idx = j break if prev_idx is None: continue prev = bars[prev_idx] prev_entity = kline_entity_abs(prev) prev_entity_upper, prev_entity_lower = kline_entity_edges(prev) prev_is_bullish_for_calc = prev.close > prev.open prev_is_bearish_for_calc = prev.close < prev.open current_open_above_prev_close = current.open > prev.close current_open_below_prev_close = current.open < prev.close use_current_open_as_base = ( (prev_is_bullish_for_calc and current_open_above_prev_close) or (prev_is_bearish_for_calc and current_open_below_prev_close) ) if use_current_open_as_base: calc_lower = current.open calc_upper = current.open long_trigger = calc_lower + prev_entity / 3 short_trigger = calc_upper - prev_entity / 3 long_breakout = calc_upper + prev_entity / 3 short_breakout = calc_lower - prev_entity / 3 else: long_trigger = prev_entity_lower + prev_entity / 3 short_trigger = prev_entity_upper - prev_entity / 3 long_breakout = prev_entity_upper + prev_entity / 3 short_breakout = prev_entity_lower - prev_entity / 3 breakout_buffer = max( prev_entity * 0.1, current_price * params.breakout_buffer_pct / 100, (atr_value * params.breakout_buffer_atr_mult) if (params.use_atr_dynamic_threshold and atr_value and atr_value > 0) else 0, ) long_breakout_effective = long_breakout + breakout_buffer short_breakout_effective = short_breakout - breakout_buffer prev_is_bearish = prev.close < prev.open current_is_bullish = current.close > current.open skip_short_by_upper_third = prev_is_bearish and current_is_bullish prev_is_bullish = prev.close > prev.open current_is_bearish = current.close < current.open skip_long_by_lower_third = prev_is_bullish and current_is_bearish if position != 0 and entry_price is not None: sl_distance = resolve_dynamic_distance( base_price=entry_price, fixed_pct=params.stop_loss_pct, atr_value=atr_value, atr_mult=params.stop_loss_atr_mult, use_atr_dynamic_threshold=params.use_atr_dynamic_threshold, ) tp_distance = resolve_dynamic_distance( base_price=entry_price, fixed_pct=params.take_profit_pct, atr_value=atr_value, atr_mult=params.take_profit_atr_mult, use_atr_dynamic_threshold=params.use_atr_dynamic_threshold, ) trail_start_distance = resolve_dynamic_distance( base_price=entry_price, fixed_pct=params.trailing_start_pct, atr_value=atr_value, atr_mult=params.trailing_start_atr_mult, use_atr_dynamic_threshold=params.use_atr_dynamic_threshold, ) trail_backoff_distance = resolve_dynamic_distance( base_price=entry_price, fixed_pct=params.trailing_backoff_pct, atr_value=atr_value, atr_mult=params.trailing_backoff_atr_mult, use_atr_dynamic_threshold=params.use_atr_dynamic_threshold, ) should_close = False if position == 1: max_favorable_price = max(max_favorable_price or entry_price, current_price) profit_distance = current_price - entry_price loss_distance = entry_price - current_price if loss_distance >= sl_distance: should_close = True elif profit_distance >= tp_distance: should_close = True elif ( profit_distance >= trail_start_distance and (max_favorable_price - current_price) >= trail_backoff_distance ): should_close = True else: min_favorable_price = min(min_favorable_price or entry_price, current_price) profit_distance = entry_price - current_price loss_distance = current_price - entry_price if loss_distance >= sl_distance: should_close = True elif profit_distance >= tp_distance: should_close = True elif ( profit_distance >= trail_start_distance and (current_price - min_favorable_price) >= trail_backoff_distance ): should_close = True if should_close: equity, net_ret = close_position( equity=equity, side=position, entry_price=entry_price, exit_price=current_price, fee_rate=fee_rate, ) trades += 1 if net_ret > 0: wins += 1 gross_profit += net_ret else: gross_loss += net_ret position = 0 entry_price = None max_favorable_price = None min_favorable_price = None peak_equity = max(peak_equity, equity) max_drawdown = max(max_drawdown, (peak_equity - equity) / peak_equity) continue signal = None if position == 0: if current_price >= long_breakout_effective and not skip_long_by_lower_third: signal = "long" elif current_price <= short_breakout_effective and not skip_short_by_upper_third: signal = "short" elif position == 1: if current_price <= short_trigger and not skip_short_by_upper_third: signal = "reverse_short" else: upper_abs = upper_shadow_abs(prev) upper_thr = resolve_dynamic_distance( base_price=max(prev.open, prev.close), fixed_pct=params.shadow_threshold_pct, atr_value=atr_value, atr_mult=params.shadow_threshold_atr_mult, use_atr_dynamic_threshold=params.use_atr_dynamic_threshold, ) if upper_abs > upper_thr and current_price <= prev_entity_lower: signal = "reverse_short" elif position == -1: if current_price >= long_trigger and not skip_long_by_lower_third: signal = "reverse_long" else: lower_abs = lower_shadow_abs(prev) lower_thr = resolve_dynamic_distance( base_price=min(prev.open, prev.close), fixed_pct=params.shadow_threshold_pct, atr_value=atr_value, atr_mult=params.shadow_threshold_atr_mult, use_atr_dynamic_threshold=params.use_atr_dynamic_threshold, ) if lower_abs > lower_thr and current_price >= prev_entity_upper: signal = "reverse_long" if signal == "long": position = 1 entry_price = current_price max_favorable_price = current_price min_favorable_price = None elif signal == "short": position = -1 entry_price = current_price min_favorable_price = current_price max_favorable_price = None elif signal == "reverse_long": equity, net_ret = close_position( equity=equity, side=position, entry_price=entry_price, exit_price=current_price, fee_rate=fee_rate, ) trades += 1 if net_ret > 0: wins += 1 gross_profit += net_ret else: gross_loss += net_ret position = 1 entry_price = current_price max_favorable_price = current_price min_favorable_price = None peak_equity = max(peak_equity, equity) max_drawdown = max(max_drawdown, (peak_equity - equity) / peak_equity) elif signal == "reverse_short": equity, net_ret = close_position( equity=equity, side=position, entry_price=entry_price, exit_price=current_price, fee_rate=fee_rate, ) trades += 1 if net_ret > 0: wins += 1 gross_profit += net_ret else: gross_loss += net_ret position = -1 entry_price = current_price min_favorable_price = current_price max_favorable_price = None peak_equity = max(peak_equity, equity) max_drawdown = max(max_drawdown, (peak_equity - equity) / peak_equity) if position != 0 and entry_price is not None: equity, net_ret = close_position( equity=equity, side=position, entry_price=entry_price, exit_price=bars[-1].close, fee_rate=fee_rate, ) trades += 1 if net_ret > 0: wins += 1 gross_profit += net_ret else: gross_loss += net_ret peak_equity = max(peak_equity, equity) max_drawdown = max(max_drawdown, (peak_equity - equity) / peak_equity) total_return_pct = (equity - 1) * 100 win_rate_pct = (wins / trades * 100) if trades else 0.0 loss_abs = abs(gross_loss) profit_factor = (gross_profit / loss_abs) if loss_abs > 1e-12 else 999.0 score = total_return_pct - max_drawdown * 100 * 0.5 if trades < min_trades: score -= (min_trades - trades) * 0.8 return BacktestResult( score=score, total_return_pct=total_return_pct, max_drawdown_pct=max_drawdown * 100, win_rate_pct=win_rate_pct, profit_factor=profit_factor, trades=trades, wins=wins, params=params, ) def split_train_valid( bars: list[Bar], train_ratio: float = 0.7, gap_bars: int = 0, ) -> tuple[list[Bar], list[Bar]]: """ 时间序列分离:前段训练、后段验证。 gap_bars 用于在训练与验证之间留空,降低相邻样本泄漏。 """ if not bars: return [], [] train_ratio = min(max(train_ratio, 0.5), 0.9) split_idx = int(len(bars) * train_ratio) split_idx = max(1, min(split_idx, len(bars) - 1)) valid_start = min(len(bars), split_idx + max(0, gap_bars)) train_bars = bars[:split_idx] valid_bars = bars[valid_start:] return train_bars, valid_bars def risk_adjusted_return(result: BacktestResult) -> float: """简单风险调整收益:收益 - 0.6 * 回撤。""" return result.total_return_pct - 0.6 * result.max_drawdown_pct def compute_robust_score( train_result: BacktestResult, valid_result: BacktestResult, min_train_trades: int = 20, min_valid_trades: int = 10, ) -> tuple[float, float, float, float, float]: """ 稳健性分数(越高越稳健): - 以验证集风险调整收益为主 - 奖励训练/验证一致性 - 惩罚过拟合(训练好、验证差) - 惩罚验证成交次数过少 """ train_ra = risk_adjusted_return(train_result) valid_ra = risk_adjusted_return(valid_result) overfit_gap = abs(train_ra - valid_ra) denom = abs(train_ra) + abs(valid_ra) + 1e-9 consistency = max(0.0, 1.0 - overfit_gap / denom) train_trade_penalty = max(0, min_train_trades - train_result.trades) * 0.4 valid_trade_penalty = max(0, min_valid_trades - valid_result.trades) * 1.2 pf_bonus = min(valid_result.profit_factor, 3.0) * 2.0 win_bonus = max(0.0, (valid_result.win_rate_pct - 50.0) * 0.08) direction_penalty = 0.0 if train_ra > 0 and valid_ra <= 0: direction_penalty += 12.0 if train_result.total_return_pct > 0 and valid_result.total_return_pct < 0: direction_penalty += 8.0 robust_score = ( 0.75 * valid_ra + 0.25 * train_ra + 10.0 * consistency + pf_bonus + win_bonus - 0.2 * overfit_gap - train_trade_penalty - valid_trade_penalty - direction_penalty ) return robust_score, consistency, overfit_gap, train_ra, valid_ra def evaluate_param_set( train_bars: list[Bar], valid_bars: list[Bar], full_bars: list[Bar], params: StrategyParams, fee_rate: float, min_train_trades: int, min_valid_trades: int, ) -> RobustEvalResult: train_result = backtest_strategy( bars=train_bars, params=params, fee_rate=fee_rate, min_trades=0, ) valid_result = backtest_strategy( bars=valid_bars, params=params, fee_rate=fee_rate, min_trades=0, ) full_result = backtest_strategy( bars=full_bars, params=params, fee_rate=fee_rate, min_trades=0, ) robust_score, consistency, overfit_gap, train_ra, valid_ra = compute_robust_score( train_result=train_result, valid_result=valid_result, min_train_trades=min_train_trades, min_valid_trades=min_valid_trades, ) return RobustEvalResult( robust_score=robust_score, consistency_score=consistency, overfit_gap=overfit_gap, train_risk_adj=train_ra, valid_risk_adj=valid_ra, train=train_result, valid=valid_result, full=full_result, params=params, ) def quick_grid() -> dict[str, list[float | int]]: return { "atr_length": [10, 14, 20], "breakout_buffer_atr_mult": [0.08, 0.12], "shadow_threshold_atr_mult": [0.14, 0.2], "stop_loss_atr_mult": [0.35, 0.5], "take_profit_atr_mult": [0.8, 1.0], "trailing_start_atr_mult": [0.45, 0.65], "trailing_backoff_atr_mult": [0.2, 0.3], } def full_grid() -> dict[str, list[float | int]]: return { "atr_length": [10, 14, 20], "breakout_buffer_atr_mult": [0.08, 0.12, 0.16], "shadow_threshold_atr_mult": [0.12, 0.18, 0.24], "stop_loss_atr_mult": [0.35, 0.5, 0.65], "take_profit_atr_mult": [0.8, 1.0, 1.2], "trailing_start_atr_mult": [0.45, 0.65, 0.85], "trailing_backoff_atr_mult": [0.2, 0.3, 0.4], } def iter_param_combos(base: StrategyParams, grid: dict[str, list[float | int]]): keys = list(grid.keys()) values = [grid[k] for k in keys] for combo in itertools.product(*values): data = asdict(base) for k, v in zip(keys, combo): data[k] = v yield StrategyParams(**data) def result_metrics_dict(result: BacktestResult) -> dict[str, float | int]: return { "score": result.score, "total_return_pct": result.total_return_pct, "max_drawdown_pct": result.max_drawdown_pct, "win_rate_pct": result.win_rate_pct, "profit_factor": result.profit_factor, "trades": result.trades, "wins": result.wins, } def save_best_params( best: RobustEvalResult, out_path: Path, train_ratio: float, split_gap_bars: int, ): payload = { "apply_live": False, "selection_metric": "robust_score", "robust_score": best.robust_score, "consistency_score": best.consistency_score, "overfit_gap": best.overfit_gap, "train_risk_adj": best.train_risk_adj, "valid_risk_adj": best.valid_risk_adj, "split": { "train_ratio": train_ratio, "split_gap_bars": split_gap_bars, }, "train_metrics": result_metrics_dict(best.train), "valid_metrics": result_metrics_dict(best.valid), "full_metrics": result_metrics_dict(best.full), "params_for_trade_py": asdict(best.params), } out_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") def main(): parser = argparse.ArgumentParser(description="ATR dynamic parameter optimizer for bitmart strategy") parser.add_argument( "--csv", default=str(Path(__file__).resolve().parent / "数据" / "kline_1.csv"), help="path to csv with id/open/high/low/close", ) parser.add_argument("--interval-min", type=int, default=5, help="resample interval minutes") parser.add_argument("--mode", choices=["quick", "full"], default="quick", help="grid size") parser.add_argument("--limit-bars", type=int, default=30000, help="use latest N bars after resample") parser.add_argument("--train-ratio", type=float, default=0.7, help="train split ratio in time order") parser.add_argument("--split-gap-bars", type=int, default=0, help="gap bars between train and valid") parser.add_argument("--min-trades", type=int, default=20, help="deprecated: kept for compatibility") parser.add_argument("--min-train-trades", type=int, default=20, help="min train trades for robustness") parser.add_argument("--min-valid-trades", type=int, default=10, help="min valid trades for robustness") parser.add_argument("--fee-rate", type=float, default=0.0004, help="round-trip half fee per side") parser.add_argument("--top-n", type=int, default=10, help="print top N results") parser.add_argument( "--out-json", default=str(Path(__file__).resolve().parent / "atr_best_params.json"), help="best result json output path", ) args = parser.parse_args() csv_path = Path(args.csv).resolve() if not csv_path.exists(): raise FileNotFoundError(f"CSV not found: {csv_path}") bars = load_csv_bars(csv_path) bars = resample_to_minutes(bars, args.interval_min) if args.limit_bars and args.limit_bars > 0: bars = bars[-args.limit_bars:] if len(bars) < 400: raise ValueError("not enough bars for optimization") train_bars, valid_bars = split_train_valid( bars=bars, train_ratio=args.train_ratio, gap_bars=args.split_gap_bars, ) if len(train_bars) < 200 or len(valid_bars) < 100: raise ValueError( f"insufficient split bars: train={len(train_bars)} valid={len(valid_bars)}; " "consider increasing --limit-bars or adjusting --train-ratio" ) base = StrategyParams() grid = quick_grid() if args.mode == "quick" else full_grid() combos = list(iter_param_combos(base, grid)) print( f"bars={len(bars)} train={len(train_bars)} valid={len(valid_bars)} " f"| combos={len(combos)} | mode={args.mode} | train_ratio={args.train_ratio:.2f} gap={args.split_gap_bars}" ) results: list[RobustEvalResult] = [] for idx, params in enumerate(combos, 1): result = evaluate_param_set( train_bars=train_bars, valid_bars=valid_bars, full_bars=bars, params=params, fee_rate=args.fee_rate, min_train_trades=args.min_train_trades, min_valid_trades=args.min_valid_trades, ) results.append(result) if idx % 50 == 0 or idx == len(combos): print(f"progress {idx}/{len(combos)}") results.sort(key=lambda x: x.robust_score, reverse=True) top_n = max(1, args.top_n) top = results[:top_n] for i, r in enumerate(top, 1): p = r.params print( f"[{i}] robust={r.robust_score:.2f} consistency={r.consistency_score:.3f} gap={r.overfit_gap:.2f} | " f"train(ret={r.train.total_return_pct:.2f}% dd={r.train.max_drawdown_pct:.2f}% trades={r.train.trades}) | " f"valid(ret={r.valid.total_return_pct:.2f}% dd={r.valid.max_drawdown_pct:.2f}% trades={r.valid.trades}) | " f"full(ret={r.full.total_return_pct:.2f}% dd={r.full.max_drawdown_pct:.2f}% trades={r.full.trades}) | " f"atr={p.atr_length} brk={p.breakout_buffer_atr_mult:.2f} shadow={p.shadow_threshold_atr_mult:.2f} " f"sl={p.stop_loss_atr_mult:.2f} tp={p.take_profit_atr_mult:.2f} " f"ts={p.trailing_start_atr_mult:.2f} tb={p.trailing_backoff_atr_mult:.2f}" ) out_path = Path(args.out_json).resolve() save_best_params( top[0], out_path, train_ratio=args.train_ratio, split_gap_bars=args.split_gap_bars, ) print(f"saved best params -> {out_path}") print("note: set apply_live=true in json when you want 交易.py to auto-load it") if __name__ == "__main__": main()