Files
lm_code/bitmart/atr_param_optimizer.py

758 lines
26 KiB
Python
Raw Normal View History

2026-02-10 11:31:12 +08:00
import argparse
import csv
import itertools
import json
from dataclasses import asdict, dataclass
from pathlib import Path
@dataclass
class Bar:
ts: int
open: float
high: float
low: float
close: float
@dataclass
class StrategyParams:
min_prev_entity_pct: float = 0.1
breakout_buffer_pct: float = 0.03
shadow_threshold_pct: float = 0.15
stop_loss_pct: float = 0.35
take_profit_pct: float = 0.8
trailing_start_pct: float = 0.5
trailing_backoff_pct: float = 0.25
use_atr_dynamic_threshold: bool = True
atr_length: int = 14
breakout_buffer_atr_mult: float = 0.12
shadow_threshold_atr_mult: float = 0.18
stop_loss_atr_mult: float = 0.45
take_profit_atr_mult: float = 0.95
trailing_start_atr_mult: float = 0.6
trailing_backoff_atr_mult: float = 0.3
@dataclass
class BacktestResult:
score: float
total_return_pct: float
max_drawdown_pct: float
win_rate_pct: float
profit_factor: float
trades: int
wins: int
params: StrategyParams
@dataclass
class RobustEvalResult:
robust_score: float
consistency_score: float
overfit_gap: float
train_risk_adj: float
valid_risk_adj: float
train: BacktestResult
valid: BacktestResult
full: BacktestResult
params: StrategyParams
def load_csv_bars(path: Path) -> list[Bar]:
bars: list[Bar] = []
with path.open("r", encoding="utf-8") as f:
reader = csv.DictReader(f)
required = {"id", "open", "high", "low", "close"}
if not required.issubset(reader.fieldnames or set()):
raise ValueError(f"CSV missing columns: {required}")
for row in reader:
try:
bars.append(
Bar(
ts=int(float(row["id"])),
open=float(row["open"]),
high=float(row["high"]),
low=float(row["low"]),
close=float(row["close"]),
)
)
except (ValueError, TypeError):
continue
bars.sort(key=lambda x: x.ts)
return bars
def resample_to_minutes(bars: list[Bar], minutes: int) -> list[Bar]:
if not bars:
return []
bucket_sec = minutes * 60
grouped: list[Bar] = []
cur_bucket = None
cur_open = cur_high = cur_low = cur_close = None
for bar in bars:
b = bar.ts // bucket_sec
if cur_bucket is None or b != cur_bucket:
if cur_bucket is not None:
grouped.append(
Bar(
ts=cur_bucket * bucket_sec,
open=cur_open,
high=cur_high,
low=cur_low,
close=cur_close,
)
)
cur_bucket = b
cur_open = bar.open
cur_high = bar.high
cur_low = bar.low
cur_close = bar.close
else:
cur_high = max(cur_high, bar.high)
cur_low = min(cur_low, bar.low)
cur_close = bar.close
if cur_bucket is not None:
grouped.append(
Bar(
ts=cur_bucket * bucket_sec,
open=cur_open,
high=cur_high,
low=cur_low,
close=cur_close,
)
)
return grouped
def compute_atr_series(bars: list[Bar], length: int) -> list[float | None]:
atr: list[float | None] = [None] * len(bars)
if len(bars) < length + 1:
return atr
tr_list: list[float] = [0.0] * len(bars)
for i in range(1, len(bars)):
high = bars[i].high
low = bars[i].low
prev_close = bars[i - 1].close
tr = max(high - low, abs(high - prev_close), abs(low - prev_close))
tr_list[i] = tr
for i in range(length, len(bars)):
window = tr_list[i - length + 1:i + 1]
atr[i] = sum(window) / length
return atr
def resolve_dynamic_distance(
base_price: float,
fixed_pct: float,
atr_value: float | None,
atr_mult: float,
use_atr_dynamic_threshold: bool,
) -> float:
fixed_distance = base_price * fixed_pct / 100
if use_atr_dynamic_threshold and atr_value and atr_value > 0:
return max(fixed_distance, atr_value * atr_mult)
return fixed_distance
def kline_entity_abs(bar: Bar) -> float:
return abs(bar.close - bar.open)
def kline_entity_edges(bar: Bar) -> tuple[float, float]:
return max(bar.open, bar.close), min(bar.open, bar.close)
def upper_shadow_abs(bar: Bar) -> float:
return max(0.0, bar.high - max(bar.open, bar.close))
def lower_shadow_abs(bar: Bar) -> float:
return max(0.0, min(bar.open, bar.close) - bar.low)
def close_position(
equity: float,
side: int,
entry_price: float,
exit_price: float,
fee_rate: float,
) -> tuple[float, float]:
net_ret = side * (exit_price - entry_price) / entry_price - 2 * fee_rate
equity *= (1 + net_ret)
return equity, net_ret
def backtest_strategy(
bars: list[Bar],
params: StrategyParams,
fee_rate: float = 0.0004,
min_trades: int = 20,
) -> BacktestResult:
atr_series = compute_atr_series(bars, params.atr_length)
position = 0
entry_price = None
max_favorable_price = None
min_favorable_price = None
equity = 1.0
peak_equity = 1.0
max_drawdown = 0.0
trades = 0
wins = 0
gross_profit = 0.0
gross_loss = 0.0
for i in range(1, len(bars)):
current = bars[i]
current_price = current.close
atr_value = atr_series[i - 1]
prev_idx = None
for j in range(i - 1, -1, -1):
prev_bar = bars[j]
entity = kline_entity_abs(prev_bar)
entity_pct = (entity / prev_bar.open * 100) if prev_bar.open else 0
if entity_pct > params.min_prev_entity_pct:
prev_idx = j
break
if prev_idx is None:
continue
prev = bars[prev_idx]
prev_entity = kline_entity_abs(prev)
prev_entity_upper, prev_entity_lower = kline_entity_edges(prev)
prev_is_bullish_for_calc = prev.close > prev.open
prev_is_bearish_for_calc = prev.close < prev.open
current_open_above_prev_close = current.open > prev.close
current_open_below_prev_close = current.open < prev.close
use_current_open_as_base = (
(prev_is_bullish_for_calc and current_open_above_prev_close)
or (prev_is_bearish_for_calc and current_open_below_prev_close)
)
if use_current_open_as_base:
calc_lower = current.open
calc_upper = current.open
long_trigger = calc_lower + prev_entity / 3
short_trigger = calc_upper - prev_entity / 3
long_breakout = calc_upper + prev_entity / 3
short_breakout = calc_lower - prev_entity / 3
else:
long_trigger = prev_entity_lower + prev_entity / 3
short_trigger = prev_entity_upper - prev_entity / 3
long_breakout = prev_entity_upper + prev_entity / 3
short_breakout = prev_entity_lower - prev_entity / 3
breakout_buffer = max(
prev_entity * 0.1,
current_price * params.breakout_buffer_pct / 100,
(atr_value * params.breakout_buffer_atr_mult)
if (params.use_atr_dynamic_threshold and atr_value and atr_value > 0)
else 0,
)
long_breakout_effective = long_breakout + breakout_buffer
short_breakout_effective = short_breakout - breakout_buffer
prev_is_bearish = prev.close < prev.open
current_is_bullish = current.close > current.open
skip_short_by_upper_third = prev_is_bearish and current_is_bullish
prev_is_bullish = prev.close > prev.open
current_is_bearish = current.close < current.open
skip_long_by_lower_third = prev_is_bullish and current_is_bearish
if position != 0 and entry_price is not None:
sl_distance = resolve_dynamic_distance(
base_price=entry_price,
fixed_pct=params.stop_loss_pct,
atr_value=atr_value,
atr_mult=params.stop_loss_atr_mult,
use_atr_dynamic_threshold=params.use_atr_dynamic_threshold,
)
tp_distance = resolve_dynamic_distance(
base_price=entry_price,
fixed_pct=params.take_profit_pct,
atr_value=atr_value,
atr_mult=params.take_profit_atr_mult,
use_atr_dynamic_threshold=params.use_atr_dynamic_threshold,
)
trail_start_distance = resolve_dynamic_distance(
base_price=entry_price,
fixed_pct=params.trailing_start_pct,
atr_value=atr_value,
atr_mult=params.trailing_start_atr_mult,
use_atr_dynamic_threshold=params.use_atr_dynamic_threshold,
)
trail_backoff_distance = resolve_dynamic_distance(
base_price=entry_price,
fixed_pct=params.trailing_backoff_pct,
atr_value=atr_value,
atr_mult=params.trailing_backoff_atr_mult,
use_atr_dynamic_threshold=params.use_atr_dynamic_threshold,
)
should_close = False
if position == 1:
max_favorable_price = max(max_favorable_price or entry_price, current_price)
profit_distance = current_price - entry_price
loss_distance = entry_price - current_price
if loss_distance >= sl_distance:
should_close = True
elif profit_distance >= tp_distance:
should_close = True
elif (
profit_distance >= trail_start_distance
and (max_favorable_price - current_price) >= trail_backoff_distance
):
should_close = True
else:
min_favorable_price = min(min_favorable_price or entry_price, current_price)
profit_distance = entry_price - current_price
loss_distance = current_price - entry_price
if loss_distance >= sl_distance:
should_close = True
elif profit_distance >= tp_distance:
should_close = True
elif (
profit_distance >= trail_start_distance
and (current_price - min_favorable_price) >= trail_backoff_distance
):
should_close = True
if should_close:
equity, net_ret = close_position(
equity=equity,
side=position,
entry_price=entry_price,
exit_price=current_price,
fee_rate=fee_rate,
)
trades += 1
if net_ret > 0:
wins += 1
gross_profit += net_ret
else:
gross_loss += net_ret
position = 0
entry_price = None
max_favorable_price = None
min_favorable_price = None
peak_equity = max(peak_equity, equity)
max_drawdown = max(max_drawdown, (peak_equity - equity) / peak_equity)
continue
signal = None
if position == 0:
if current_price >= long_breakout_effective and not skip_long_by_lower_third:
signal = "long"
elif current_price <= short_breakout_effective and not skip_short_by_upper_third:
signal = "short"
elif position == 1:
if current_price <= short_trigger and not skip_short_by_upper_third:
signal = "reverse_short"
else:
upper_abs = upper_shadow_abs(prev)
upper_thr = resolve_dynamic_distance(
base_price=max(prev.open, prev.close),
fixed_pct=params.shadow_threshold_pct,
atr_value=atr_value,
atr_mult=params.shadow_threshold_atr_mult,
use_atr_dynamic_threshold=params.use_atr_dynamic_threshold,
)
if upper_abs > upper_thr and current_price <= prev_entity_lower:
signal = "reverse_short"
elif position == -1:
if current_price >= long_trigger and not skip_long_by_lower_third:
signal = "reverse_long"
else:
lower_abs = lower_shadow_abs(prev)
lower_thr = resolve_dynamic_distance(
base_price=min(prev.open, prev.close),
fixed_pct=params.shadow_threshold_pct,
atr_value=atr_value,
atr_mult=params.shadow_threshold_atr_mult,
use_atr_dynamic_threshold=params.use_atr_dynamic_threshold,
)
if lower_abs > lower_thr and current_price >= prev_entity_upper:
signal = "reverse_long"
if signal == "long":
position = 1
entry_price = current_price
max_favorable_price = current_price
min_favorable_price = None
elif signal == "short":
position = -1
entry_price = current_price
min_favorable_price = current_price
max_favorable_price = None
elif signal == "reverse_long":
equity, net_ret = close_position(
equity=equity,
side=position,
entry_price=entry_price,
exit_price=current_price,
fee_rate=fee_rate,
)
trades += 1
if net_ret > 0:
wins += 1
gross_profit += net_ret
else:
gross_loss += net_ret
position = 1
entry_price = current_price
max_favorable_price = current_price
min_favorable_price = None
peak_equity = max(peak_equity, equity)
max_drawdown = max(max_drawdown, (peak_equity - equity) / peak_equity)
elif signal == "reverse_short":
equity, net_ret = close_position(
equity=equity,
side=position,
entry_price=entry_price,
exit_price=current_price,
fee_rate=fee_rate,
)
trades += 1
if net_ret > 0:
wins += 1
gross_profit += net_ret
else:
gross_loss += net_ret
position = -1
entry_price = current_price
min_favorable_price = current_price
max_favorable_price = None
peak_equity = max(peak_equity, equity)
max_drawdown = max(max_drawdown, (peak_equity - equity) / peak_equity)
if position != 0 and entry_price is not None:
equity, net_ret = close_position(
equity=equity,
side=position,
entry_price=entry_price,
exit_price=bars[-1].close,
fee_rate=fee_rate,
)
trades += 1
if net_ret > 0:
wins += 1
gross_profit += net_ret
else:
gross_loss += net_ret
peak_equity = max(peak_equity, equity)
max_drawdown = max(max_drawdown, (peak_equity - equity) / peak_equity)
total_return_pct = (equity - 1) * 100
win_rate_pct = (wins / trades * 100) if trades else 0.0
loss_abs = abs(gross_loss)
profit_factor = (gross_profit / loss_abs) if loss_abs > 1e-12 else 999.0
score = total_return_pct - max_drawdown * 100 * 0.5
if trades < min_trades:
score -= (min_trades - trades) * 0.8
return BacktestResult(
score=score,
total_return_pct=total_return_pct,
max_drawdown_pct=max_drawdown * 100,
win_rate_pct=win_rate_pct,
profit_factor=profit_factor,
trades=trades,
wins=wins,
params=params,
)
def split_train_valid(
bars: list[Bar],
train_ratio: float = 0.7,
gap_bars: int = 0,
) -> tuple[list[Bar], list[Bar]]:
"""
时间序列分离前段训练后段验证
gap_bars 用于在训练与验证之间留空降低相邻样本泄漏
"""
if not bars:
return [], []
train_ratio = min(max(train_ratio, 0.5), 0.9)
split_idx = int(len(bars) * train_ratio)
split_idx = max(1, min(split_idx, len(bars) - 1))
valid_start = min(len(bars), split_idx + max(0, gap_bars))
train_bars = bars[:split_idx]
valid_bars = bars[valid_start:]
return train_bars, valid_bars
def risk_adjusted_return(result: BacktestResult) -> float:
"""简单风险调整收益:收益 - 0.6 * 回撤。"""
return result.total_return_pct - 0.6 * result.max_drawdown_pct
def compute_robust_score(
train_result: BacktestResult,
valid_result: BacktestResult,
min_train_trades: int = 20,
min_valid_trades: int = 10,
) -> tuple[float, float, float, float, float]:
"""
稳健性分数越高越稳健
- 以验证集风险调整收益为主
- 奖励训练/验证一致性
- 惩罚过拟合训练好验证差
- 惩罚验证成交次数过少
"""
train_ra = risk_adjusted_return(train_result)
valid_ra = risk_adjusted_return(valid_result)
overfit_gap = abs(train_ra - valid_ra)
denom = abs(train_ra) + abs(valid_ra) + 1e-9
consistency = max(0.0, 1.0 - overfit_gap / denom)
train_trade_penalty = max(0, min_train_trades - train_result.trades) * 0.4
valid_trade_penalty = max(0, min_valid_trades - valid_result.trades) * 1.2
pf_bonus = min(valid_result.profit_factor, 3.0) * 2.0
win_bonus = max(0.0, (valid_result.win_rate_pct - 50.0) * 0.08)
direction_penalty = 0.0
if train_ra > 0 and valid_ra <= 0:
direction_penalty += 12.0
if train_result.total_return_pct > 0 and valid_result.total_return_pct < 0:
direction_penalty += 8.0
robust_score = (
0.75 * valid_ra
+ 0.25 * train_ra
+ 10.0 * consistency
+ pf_bonus
+ win_bonus
- 0.2 * overfit_gap
- train_trade_penalty
- valid_trade_penalty
- direction_penalty
)
return robust_score, consistency, overfit_gap, train_ra, valid_ra
def evaluate_param_set(
train_bars: list[Bar],
valid_bars: list[Bar],
full_bars: list[Bar],
params: StrategyParams,
fee_rate: float,
min_train_trades: int,
min_valid_trades: int,
) -> RobustEvalResult:
train_result = backtest_strategy(
bars=train_bars,
params=params,
fee_rate=fee_rate,
min_trades=0,
)
valid_result = backtest_strategy(
bars=valid_bars,
params=params,
fee_rate=fee_rate,
min_trades=0,
)
full_result = backtest_strategy(
bars=full_bars,
params=params,
fee_rate=fee_rate,
min_trades=0,
)
robust_score, consistency, overfit_gap, train_ra, valid_ra = compute_robust_score(
train_result=train_result,
valid_result=valid_result,
min_train_trades=min_train_trades,
min_valid_trades=min_valid_trades,
)
return RobustEvalResult(
robust_score=robust_score,
consistency_score=consistency,
overfit_gap=overfit_gap,
train_risk_adj=train_ra,
valid_risk_adj=valid_ra,
train=train_result,
valid=valid_result,
full=full_result,
params=params,
)
def quick_grid() -> dict[str, list[float | int]]:
return {
"atr_length": [10, 14, 20],
"breakout_buffer_atr_mult": [0.08, 0.12],
"shadow_threshold_atr_mult": [0.14, 0.2],
"stop_loss_atr_mult": [0.35, 0.5],
"take_profit_atr_mult": [0.8, 1.0],
"trailing_start_atr_mult": [0.45, 0.65],
"trailing_backoff_atr_mult": [0.2, 0.3],
}
def full_grid() -> dict[str, list[float | int]]:
return {
"atr_length": [10, 14, 20],
"breakout_buffer_atr_mult": [0.08, 0.12, 0.16],
"shadow_threshold_atr_mult": [0.12, 0.18, 0.24],
"stop_loss_atr_mult": [0.35, 0.5, 0.65],
"take_profit_atr_mult": [0.8, 1.0, 1.2],
"trailing_start_atr_mult": [0.45, 0.65, 0.85],
"trailing_backoff_atr_mult": [0.2, 0.3, 0.4],
}
def iter_param_combos(base: StrategyParams, grid: dict[str, list[float | int]]):
keys = list(grid.keys())
values = [grid[k] for k in keys]
for combo in itertools.product(*values):
data = asdict(base)
for k, v in zip(keys, combo):
data[k] = v
yield StrategyParams(**data)
def result_metrics_dict(result: BacktestResult) -> dict[str, float | int]:
return {
"score": result.score,
"total_return_pct": result.total_return_pct,
"max_drawdown_pct": result.max_drawdown_pct,
"win_rate_pct": result.win_rate_pct,
"profit_factor": result.profit_factor,
"trades": result.trades,
"wins": result.wins,
}
def save_best_params(
best: RobustEvalResult,
out_path: Path,
train_ratio: float,
split_gap_bars: int,
):
payload = {
"apply_live": False,
"selection_metric": "robust_score",
"robust_score": best.robust_score,
"consistency_score": best.consistency_score,
"overfit_gap": best.overfit_gap,
"train_risk_adj": best.train_risk_adj,
"valid_risk_adj": best.valid_risk_adj,
"split": {
"train_ratio": train_ratio,
"split_gap_bars": split_gap_bars,
},
"train_metrics": result_metrics_dict(best.train),
"valid_metrics": result_metrics_dict(best.valid),
"full_metrics": result_metrics_dict(best.full),
"params_for_trade_py": asdict(best.params),
}
out_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
def main():
parser = argparse.ArgumentParser(description="ATR dynamic parameter optimizer for bitmart strategy")
parser.add_argument(
"--csv",
default=str(Path(__file__).resolve().parent / "数据" / "kline_1.csv"),
help="path to csv with id/open/high/low/close",
)
parser.add_argument("--interval-min", type=int, default=5, help="resample interval minutes")
parser.add_argument("--mode", choices=["quick", "full"], default="quick", help="grid size")
parser.add_argument("--limit-bars", type=int, default=30000, help="use latest N bars after resample")
parser.add_argument("--train-ratio", type=float, default=0.7, help="train split ratio in time order")
parser.add_argument("--split-gap-bars", type=int, default=0, help="gap bars between train and valid")
parser.add_argument("--min-trades", type=int, default=20, help="deprecated: kept for compatibility")
parser.add_argument("--min-train-trades", type=int, default=20, help="min train trades for robustness")
parser.add_argument("--min-valid-trades", type=int, default=10, help="min valid trades for robustness")
parser.add_argument("--fee-rate", type=float, default=0.0004, help="round-trip half fee per side")
parser.add_argument("--top-n", type=int, default=10, help="print top N results")
parser.add_argument(
"--out-json",
default=str(Path(__file__).resolve().parent / "atr_best_params.json"),
help="best result json output path",
)
args = parser.parse_args()
csv_path = Path(args.csv).resolve()
if not csv_path.exists():
raise FileNotFoundError(f"CSV not found: {csv_path}")
bars = load_csv_bars(csv_path)
bars = resample_to_minutes(bars, args.interval_min)
if args.limit_bars and args.limit_bars > 0:
bars = bars[-args.limit_bars:]
if len(bars) < 400:
raise ValueError("not enough bars for optimization")
train_bars, valid_bars = split_train_valid(
bars=bars,
train_ratio=args.train_ratio,
gap_bars=args.split_gap_bars,
)
if len(train_bars) < 200 or len(valid_bars) < 100:
raise ValueError(
f"insufficient split bars: train={len(train_bars)} valid={len(valid_bars)}; "
"consider increasing --limit-bars or adjusting --train-ratio"
)
base = StrategyParams()
grid = quick_grid() if args.mode == "quick" else full_grid()
combos = list(iter_param_combos(base, grid))
print(
f"bars={len(bars)} train={len(train_bars)} valid={len(valid_bars)} "
f"| combos={len(combos)} | mode={args.mode} | train_ratio={args.train_ratio:.2f} gap={args.split_gap_bars}"
)
results: list[RobustEvalResult] = []
for idx, params in enumerate(combos, 1):
result = evaluate_param_set(
train_bars=train_bars,
valid_bars=valid_bars,
full_bars=bars,
params=params,
fee_rate=args.fee_rate,
min_train_trades=args.min_train_trades,
min_valid_trades=args.min_valid_trades,
)
results.append(result)
if idx % 50 == 0 or idx == len(combos):
print(f"progress {idx}/{len(combos)}")
results.sort(key=lambda x: x.robust_score, reverse=True)
top_n = max(1, args.top_n)
top = results[:top_n]
for i, r in enumerate(top, 1):
p = r.params
print(
f"[{i}] robust={r.robust_score:.2f} consistency={r.consistency_score:.3f} gap={r.overfit_gap:.2f} | "
f"train(ret={r.train.total_return_pct:.2f}% dd={r.train.max_drawdown_pct:.2f}% trades={r.train.trades}) | "
f"valid(ret={r.valid.total_return_pct:.2f}% dd={r.valid.max_drawdown_pct:.2f}% trades={r.valid.trades}) | "
f"full(ret={r.full.total_return_pct:.2f}% dd={r.full.max_drawdown_pct:.2f}% trades={r.full.trades}) | "
f"atr={p.atr_length} brk={p.breakout_buffer_atr_mult:.2f} shadow={p.shadow_threshold_atr_mult:.2f} "
f"sl={p.stop_loss_atr_mult:.2f} tp={p.take_profit_atr_mult:.2f} "
f"ts={p.trailing_start_atr_mult:.2f} tb={p.trailing_backoff_atr_mult:.2f}"
)
out_path = Path(args.out_json).resolve()
save_best_params(
top[0],
out_path,
train_ratio=args.train_ratio,
split_gap_bars=args.split_gap_bars,
)
print(f"saved best params -> {out_path}")
print("note: set apply_live=true in json when you want 交易.py to auto-load it")
if __name__ == "__main__":
main()