Files
codex_jxs_code/strategy/bb_midline_backtest.py
2026-02-26 19:05:17 +08:00

295 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""布林带均线策略回测(优化版)
策略逻辑:
- 阳线 + 碰到布林带均线 → 开多(可选 1m 线过滤:先涨碰到才开)
- 持多: 碰到上轨 → 止盈(无下轨止损)
- 阴线 + 碰到布林带均线 → 平多开空(可选 1m先跌碰到才开
- 持空: 碰到下轨 → 止盈(无上轨止损)
全仓模式 | 200U | 1% 权益/单 | 万五手续费 | 90%返佣次日8点到账 | 100x杠杆
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional
import numpy as np
import pandas as pd
from .indicators import bollinger
@dataclass
class BBMidlineConfig:
bb_period: int = 20
bb_std: float = 2.0
initial_capital: float = 200.0
margin_pct: float = 0.01
leverage: float = 100.0
cross_margin: bool = True
fee_rate: float = 0.0005
rebate_pct: float = 0.90
rebate_hour_utc: int = 0
slippage_pct: float = 0.0
fill_at_close: bool = True
# 是否用 1m 线判断「先涨碰到」/「先跌碰到」均线
use_1m_touch_filter: bool = True
# 主K线周期(分钟),用于 1m 触及方向时的桶对齐5/15/30
kline_step_min: int = 5
@dataclass
class BBTrade:
side: str
entry_price: float
exit_price: float
entry_time: object
exit_time: object
margin: float
leverage: float
qty: float
gross_pnl: float
fee: float
net_pnl: float
exit_reason: str
@dataclass
class BBMidlineResult:
equity_curve: pd.DataFrame
trades: List[BBTrade]
daily_stats: pd.DataFrame
total_fee: float
total_rebate: float
config: BBMidlineConfig
def run_bb_midline_backtest(
df: pd.DataFrame,
cfg: BBMidlineConfig,
df_1m: Optional[pd.DataFrame] = None,
arr_touch_dir_override: Optional[np.ndarray] = None,
) -> BBMidlineResult:
close = df["close"].astype(float)
high = df["high"].astype(float)
low = df["low"].astype(float)
open_ = df["open"].astype(float)
n = len(df)
bb_mid, bb_upper, bb_lower, _ = bollinger(close, cfg.bb_period, cfg.bb_std)
arr_mid = bb_mid.values
# 1m 触及方向1=先涨碰到, -1=先跌碰到, 0=未碰到
arr_touch_dir = None
if arr_touch_dir_override is not None:
arr_touch_dir = np.asarray(arr_touch_dir_override, dtype=np.int32)
if len(arr_touch_dir) != n:
raise ValueError(f"arr_touch_dir_override 长度不匹配: {len(arr_touch_dir)} != {n}")
elif cfg.use_1m_touch_filter and df_1m is not None and len(df_1m) > 0:
from .data_loader import get_1m_touch_direction
arr_touch_dir = get_1m_touch_direction(df, df_1m, arr_mid, kline_step_min=cfg.kline_step_min)
arr_close = close.values
arr_high = high.values
arr_low = low.values
arr_open = open_.values
arr_upper = bb_upper.values
arr_lower = bb_lower.values
ts_index = df.index
balance = cfg.initial_capital
position = 0
entry_price = 0.0
entry_time = None
entry_margin = 0.0
entry_qty = 0.0
trades: List[BBTrade] = []
total_fee = 0.0
total_rebate = 0.0
day_pnl = 0.0
current_day = None
today_fees = 0.0
pending_rebate = 0.0
rebate_applied_today = False
out_equity = np.full(n, np.nan)
out_balance = np.full(n, np.nan)
out_position = np.zeros(n)
def unrealised(price):
if position == 0:
return 0.0
if position == 1:
return entry_qty * (price - entry_price)
return entry_qty * (entry_price - price)
def close_position(exit_price, exit_idx, reason: str):
nonlocal balance, position, entry_price, entry_time, entry_margin, entry_qty
nonlocal total_fee, total_rebate, day_pnl, today_fees
if position == 0:
return
if position == 1:
exit_price = exit_price * (1 - cfg.slippage_pct)
else:
exit_price = exit_price * (1 + cfg.slippage_pct)
if position == 1:
gross = entry_qty * (exit_price - entry_price)
else:
gross = entry_qty * (entry_price - exit_price)
exit_notional = entry_qty * exit_price
fee = exit_notional * cfg.fee_rate
net = gross - fee
trades.append(BBTrade(
side="long" if position == 1 else "short",
entry_price=entry_price,
exit_price=exit_price,
entry_time=entry_time,
exit_time=ts_index[exit_idx],
margin=entry_margin,
leverage=cfg.leverage,
qty=entry_qty,
gross_pnl=gross,
fee=fee,
net_pnl=net,
exit_reason=reason,
))
balance += net
total_fee += fee
today_fees += fee
day_pnl += net
position = 0
entry_price = 0.0
entry_time = None
entry_margin = 0.0
entry_qty = 0.0
def open_position(side, price, idx):
nonlocal position, entry_price, entry_time, entry_margin, entry_qty
nonlocal balance, total_fee, day_pnl, today_fees
if side == "long":
price = price * (1 + cfg.slippage_pct)
else:
price = price * (1 - cfg.slippage_pct)
equity = balance + unrealised(price) if position != 0 else balance
margin = equity * cfg.margin_pct
margin = min(margin, balance * 0.95)
if margin <= 0:
return False
notional = margin * cfg.leverage
qty = notional / price
fee = notional * cfg.fee_rate
balance -= fee
total_fee += fee
today_fees += fee
day_pnl -= fee
position = 1 if side == "long" else -1
entry_price = price
entry_time = ts_index[idx]
entry_margin = margin
entry_qty = qty
return True
for i in range(n):
bar_day = ts_index[i].date() if hasattr(ts_index[i], 'date') else None
bar_hour = ts_index[i].hour if hasattr(ts_index[i], 'hour') else 0
if bar_day is not None and bar_day != current_day:
pending_rebate += today_fees * cfg.rebate_pct
today_fees = 0.0
rebate_applied_today = False
day_pnl = 0.0
current_day = bar_day
if cfg.rebate_pct > 0 and not rebate_applied_today and bar_hour >= cfg.rebate_hour_utc and pending_rebate > 0:
balance += pending_rebate
total_rebate += pending_rebate
pending_rebate = 0.0
rebate_applied_today = True
if np.isnan(arr_upper[i]) or np.isnan(arr_lower[i]) or np.isnan(arr_mid[i]):
out_equity[i] = balance + unrealised(arr_close[i])
out_balance[i] = balance
out_position[i] = position
continue
fill_price = arr_close[i] if cfg.fill_at_close else None
bullish = arr_close[i] > arr_open[i]
bearish = arr_close[i] < arr_open[i]
# 碰到均线K 线贯穿或触及 mid
touched_mid = arr_low[i] <= arr_mid[i] <= arr_high[i]
touched_upper = arr_high[i] >= arr_upper[i]
touched_lower = arr_low[i] <= arr_lower[i]
exec_upper = fill_price if fill_price is not None else arr_upper[i]
exec_lower = fill_price if fill_price is not None else arr_lower[i]
# 1m 过滤:开多需先涨碰到,开空需先跌碰到
touch_up_ok = True if arr_touch_dir is None else (arr_touch_dir[i] == 1)
touch_down_ok = True if arr_touch_dir is None else (arr_touch_dir[i] == -1)
# 单根 K 线只允许一次操作
if position == 1 and touched_upper:
# 持多止盈
close_position(exec_upper, i, "tp_upper")
elif position == -1 and touched_lower:
# 持空止盈
close_position(exec_lower, i, "tp_lower")
elif position == 1 and bearish and touched_mid and touch_down_ok:
# 阴线触中轨: 平多并反手开空
close_position(arr_close[i], i, "flip_to_short")
if balance > 0:
open_position("short", arr_close[i], i)
elif position == -1 and bullish and touched_mid and touch_up_ok:
# 阳线触中轨: 平空并反手开多
close_position(arr_close[i], i, "flip_to_long")
if balance > 0:
open_position("long", arr_close[i], i)
elif position == 0 and bullish and touched_mid and touch_up_ok:
# 空仓开多
open_position("long", arr_close[i], i)
elif position == 0 and bearish and touched_mid and touch_down_ok:
# 空仓开空
open_position("short", arr_close[i], i)
out_equity[i] = balance + unrealised(arr_close[i])
out_balance[i] = balance
out_position[i] = position
if position != 0:
close_position(arr_close[n - 1], n - 1, "end")
out_equity[n - 1] = balance
out_balance[n - 1] = balance
out_position[n - 1] = 0
eq_df = pd.DataFrame({
"equity": out_equity,
"balance": out_balance,
"price": arr_close,
"position": out_position,
}, index=ts_index)
daily_eq = eq_df["equity"].resample("1D").last().dropna().to_frame("equity")
daily_eq["pnl"] = daily_eq["equity"].diff().fillna(0.0)
return BBMidlineResult(
equity_curve=eq_df,
trades=trades,
daily_stats=daily_eq,
total_fee=total_fee,
total_rebate=total_rebate,
config=cfg,
)