Files
codex_jxs_code/strategy/backtest_engine.py
2026-02-23 04:09:34 +08:00

299 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
回测引擎 - 完整模拟手续费、返佣延迟到账、每日回撤限制、持仓时间约束
支持同时持有多单并发,严格控制每日最大回撤
"""
import datetime
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class Trade:
entry_time: pd.Timestamp
exit_time: Optional[pd.Timestamp] = None
direction: int = 0
entry_price: float = 0.0
exit_price: float = 0.0
pnl: float = 0.0
fee: float = 0.0
rebate: float = 0.0
holding_bars: int = 0
@dataclass
class OpenPosition:
direction: int = 0
entry_price: float = 0.0
entry_time: pd.Timestamp = None
hold_bars: int = 0
class BacktestEngine:
def __init__(
self,
initial_capital: float = 1000.0,
margin_per_trade: float = 25.0,
leverage: int = 50,
fee_rate: float = 0.0005,
rebate_ratio: float = 0.70,
max_daily_drawdown: float = 50.0,
min_hold_bars: int = 1,
stop_loss_pct: float = 0.005,
take_profit_pct: float = 0.01,
max_positions: int = 3,
):
self.initial_capital = initial_capital
self.margin = margin_per_trade
self.leverage = leverage
self.notional = margin_per_trade * leverage
self.fee_rate = fee_rate
self.rebate_ratio = rebate_ratio
self.max_daily_dd = max_daily_drawdown
self.min_hold_bars = min_hold_bars
self.sl_pct = stop_loss_pct
self.tp_pct = take_profit_pct
self.max_positions = max_positions
def _close_position(self, pos, exit_price, t, today, trades, pending_rebates):
"""平仓一个持仓,返回 net_pnl"""
qty = self.notional / pos.entry_price
if pos.direction == 1:
raw_pnl = qty * (exit_price - pos.entry_price)
else:
raw_pnl = qty * (pos.entry_price - exit_price)
close_fee = self.notional * self.fee_rate
net_pnl = raw_pnl - close_fee
total_fee = self.notional * self.fee_rate * 2
rebate = total_fee * self.rebate_ratio
rebate_date = today + datetime.timedelta(days=1)
pending_rebates.append((rebate_date, rebate))
trades.append(Trade(
entry_time=pos.entry_time, exit_time=t,
direction=pos.direction, entry_price=pos.entry_price,
exit_price=exit_price, pnl=net_pnl, fee=total_fee,
rebate=rebate, holding_bars=pos.hold_bars,
))
return net_pnl
def _worst_unrealized(self, positions, h, lo):
"""计算所有持仓在本K线内的最坏浮动亏损用 high/low"""
worst = 0.0
for pos in positions:
qty = self.notional / pos.entry_price
if pos.direction == 1:
# 多单最坏情况: 价格跌到 low
worst += qty * (lo - pos.entry_price)
else:
# 空单最坏情况: 价格涨到 high
worst += qty * (pos.entry_price - h)
return worst
def run(self, df: pd.DataFrame, score: pd.Series, open_threshold: float = 0.3) -> dict:
capital = self.initial_capital
trades: List[Trade] = []
daily_pnl = {}
pending_rebates = []
positions: List[OpenPosition] = []
used_margin = 0.0
current_date = None
day_pnl = 0.0
day_stopped = False
close_arr = df['close'].values
high_arr = df['high'].values
low_arr = df['low'].values
times = df.index
scores = score.values
for i in range(len(df)):
t = times[i]
c = close_arr[i]
h = high_arr[i]
lo = low_arr[i]
s = scores[i]
today = t.date()
# --- 日切换 ---
if today != current_date:
if current_date is not None:
daily_pnl[current_date] = day_pnl
current_date = today
day_pnl = 0.0
day_stopped = False
arrived = []
remaining = []
for rd, ra in pending_rebates:
if today >= rd:
arrived.append(ra)
else:
remaining.append((rd, ra))
if arrived:
capital += sum(arrived)
pending_rebates = remaining
if day_stopped:
for pos in positions:
pos.hold_bars += 1
continue
# --- 正常止损止盈逻辑 ---
closed_indices = []
for pi, pos in enumerate(positions):
pos.hold_bars += 1
qty = self.notional / pos.entry_price
if pos.direction == 1:
sl_price = pos.entry_price * (1 - self.sl_pct)
tp_price = pos.entry_price * (1 + self.tp_pct)
hit_sl = lo <= sl_price
hit_tp = h >= tp_price
else:
sl_price = pos.entry_price * (1 + self.sl_pct)
tp_price = pos.entry_price * (1 - self.tp_pct)
hit_sl = h >= sl_price
hit_tp = lo <= tp_price
should_close = False
exit_price = c
# 止损始终生效(不受持仓时间限制)
if hit_sl:
should_close = True
exit_price = sl_price
elif pos.hold_bars >= self.min_hold_bars:
# 止盈和信号反转需要满足最小持仓时间
if hit_tp:
should_close = True
exit_price = tp_price
elif (pos.direction == 1 and s < -open_threshold) or \
(pos.direction == -1 and s > open_threshold):
should_close = True
exit_price = c
if should_close:
net = self._close_position(pos, exit_price, t, today, trades, pending_rebates)
capital += net
used_margin -= self.margin
day_pnl += net
closed_indices.append(pi)
# 每笔平仓后立即检查日回撤
if day_pnl < -self.max_daily_dd:
# 熔断剩余持仓
for pj, pos2 in enumerate(positions):
if pj not in closed_indices:
pos2.hold_bars += 1
net2 = self._close_position(pos2, c, t, today, trades, pending_rebates)
capital += net2
used_margin -= self.margin
day_pnl += net2
closed_indices.append(pj)
day_stopped = True
break
for pi in sorted(set(closed_indices), reverse=True):
positions.pop(pi)
if day_stopped:
continue
# --- 开仓 ---
if len(positions) < self.max_positions:
if np.isnan(s):
continue
# 开仓前检查: 当前所有持仓 + 新仓同时止损的最大亏损
n_after = len(positions) + 1
worst_total_sl = n_after * (self.notional * self.sl_pct + self.notional * self.fee_rate * 2)
if day_pnl - worst_total_sl < -self.max_daily_dd:
continue # 风险敞口太大
open_fee = self.notional * self.fee_rate
if capital - used_margin < self.margin + open_fee:
continue
new_dir = 0
if s > open_threshold:
new_dir = 1
elif s < -open_threshold:
new_dir = -1
if new_dir != 0:
positions.append(OpenPosition(
direction=new_dir, entry_price=c,
entry_time=t, hold_bars=0,
))
capital -= open_fee
used_margin += self.margin
day_pnl -= open_fee
# 最后一天
if current_date is not None:
daily_pnl[current_date] = day_pnl
# 强制平仓
if positions and len(df) > 0:
last_close = close_arr[-1]
for pos in positions:
qty = self.notional / pos.entry_price
if pos.direction == 1:
raw_pnl = qty * (last_close - pos.entry_price)
else:
raw_pnl = qty * (pos.entry_price - last_close)
fee = self.notional * self.fee_rate
net_pnl = raw_pnl - fee
capital += net_pnl
trades.append(Trade(
entry_time=pos.entry_time, exit_time=times[-1],
direction=pos.direction, entry_price=pos.entry_price,
exit_price=last_close, pnl=net_pnl,
fee=self.notional * self.fee_rate * 2,
rebate=0, holding_bars=pos.hold_bars,
))
remaining_rebate = sum(amt for _, amt in pending_rebates)
capital += remaining_rebate
return self._build_result(trades, daily_pnl, capital)
def _build_result(self, trades, daily_pnl, final_capital):
if not trades:
return {
'total_pnl': 0, 'final_capital': final_capital,
'num_trades': 0, 'win_rate': 0, 'avg_pnl': 0,
'max_daily_dd': 0, 'avg_daily_pnl': 0,
'profit_factor': 0, 'trades': [], 'daily_pnl': daily_pnl,
'total_fee': 0, 'total_rebate': 0,
}
pnls = [t.pnl for t in trades]
wins = [p for p in pnls if p > 0]
losses = [p for p in pnls if p <= 0]
daily_vals = list(daily_pnl.values())
total_fee = sum(t.fee for t in trades)
total_rebate = sum(t.rebate for t in trades)
gross_profit = sum(wins) if wins else 0
gross_loss = abs(sum(losses)) if losses else 1e-10
return {
'total_pnl': sum(pnls) + total_rebate,
'final_capital': final_capital,
'num_trades': len(trades),
'win_rate': len(wins) / len(trades) if trades else 0,
'avg_pnl': np.mean(pnls),
'max_daily_dd': min(daily_vals) if daily_vals else 0,
'avg_daily_pnl': np.mean(daily_vals) if daily_vals else 0,
'profit_factor': gross_profit / gross_loss,
'total_fee': total_fee,
'total_rebate': total_rebate,
'trades': trades,
'daily_pnl': daily_pnl,
}