299 lines
11 KiB
Python
299 lines
11 KiB
Python
"""
|
||
回测引擎 - 完整模拟手续费、返佣延迟到账、每日回撤限制、持仓时间约束
|
||
支持同时持有多单并发,严格控制每日最大回撤
|
||
"""
|
||
import datetime
|
||
import numpy as np
|
||
import pandas as pd
|
||
from dataclasses import dataclass
|
||
from typing import List, Optional
|
||
|
||
|
||
@dataclass
|
||
class Trade:
|
||
entry_time: pd.Timestamp
|
||
exit_time: Optional[pd.Timestamp] = None
|
||
direction: int = 0
|
||
entry_price: float = 0.0
|
||
exit_price: float = 0.0
|
||
pnl: float = 0.0
|
||
fee: float = 0.0
|
||
rebate: float = 0.0
|
||
holding_bars: int = 0
|
||
|
||
|
||
@dataclass
|
||
class OpenPosition:
|
||
direction: int = 0
|
||
entry_price: float = 0.0
|
||
entry_time: pd.Timestamp = None
|
||
hold_bars: int = 0
|
||
|
||
|
||
class BacktestEngine:
|
||
def __init__(
|
||
self,
|
||
initial_capital: float = 1000.0,
|
||
margin_per_trade: float = 25.0,
|
||
leverage: int = 50,
|
||
fee_rate: float = 0.0005,
|
||
rebate_ratio: float = 0.70,
|
||
max_daily_drawdown: float = 50.0,
|
||
min_hold_bars: int = 1,
|
||
stop_loss_pct: float = 0.005,
|
||
take_profit_pct: float = 0.01,
|
||
max_positions: int = 3,
|
||
):
|
||
self.initial_capital = initial_capital
|
||
self.margin = margin_per_trade
|
||
self.leverage = leverage
|
||
self.notional = margin_per_trade * leverage
|
||
self.fee_rate = fee_rate
|
||
self.rebate_ratio = rebate_ratio
|
||
self.max_daily_dd = max_daily_drawdown
|
||
self.min_hold_bars = min_hold_bars
|
||
self.sl_pct = stop_loss_pct
|
||
self.tp_pct = take_profit_pct
|
||
self.max_positions = max_positions
|
||
|
||
def _close_position(self, pos, exit_price, t, today, trades, pending_rebates):
|
||
"""平仓一个持仓,返回 net_pnl"""
|
||
qty = self.notional / pos.entry_price
|
||
if pos.direction == 1:
|
||
raw_pnl = qty * (exit_price - pos.entry_price)
|
||
else:
|
||
raw_pnl = qty * (pos.entry_price - exit_price)
|
||
|
||
close_fee = self.notional * self.fee_rate
|
||
net_pnl = raw_pnl - close_fee
|
||
total_fee = self.notional * self.fee_rate * 2
|
||
rebate = total_fee * self.rebate_ratio
|
||
|
||
rebate_date = today + datetime.timedelta(days=1)
|
||
pending_rebates.append((rebate_date, rebate))
|
||
|
||
trades.append(Trade(
|
||
entry_time=pos.entry_time, exit_time=t,
|
||
direction=pos.direction, entry_price=pos.entry_price,
|
||
exit_price=exit_price, pnl=net_pnl, fee=total_fee,
|
||
rebate=rebate, holding_bars=pos.hold_bars,
|
||
))
|
||
return net_pnl
|
||
|
||
def _worst_unrealized(self, positions, h, lo):
|
||
"""计算所有持仓在本K线内的最坏浮动亏损(用 high/low)"""
|
||
worst = 0.0
|
||
for pos in positions:
|
||
qty = self.notional / pos.entry_price
|
||
if pos.direction == 1:
|
||
# 多单最坏情况: 价格跌到 low
|
||
worst += qty * (lo - pos.entry_price)
|
||
else:
|
||
# 空单最坏情况: 价格涨到 high
|
||
worst += qty * (pos.entry_price - h)
|
||
return worst
|
||
|
||
def run(self, df: pd.DataFrame, score: pd.Series, open_threshold: float = 0.3) -> dict:
|
||
capital = self.initial_capital
|
||
trades: List[Trade] = []
|
||
daily_pnl = {}
|
||
pending_rebates = []
|
||
positions: List[OpenPosition] = []
|
||
used_margin = 0.0
|
||
|
||
current_date = None
|
||
day_pnl = 0.0
|
||
day_stopped = False
|
||
|
||
close_arr = df['close'].values
|
||
high_arr = df['high'].values
|
||
low_arr = df['low'].values
|
||
times = df.index
|
||
scores = score.values
|
||
|
||
for i in range(len(df)):
|
||
t = times[i]
|
||
c = close_arr[i]
|
||
h = high_arr[i]
|
||
lo = low_arr[i]
|
||
s = scores[i]
|
||
today = t.date()
|
||
|
||
# --- 日切换 ---
|
||
if today != current_date:
|
||
if current_date is not None:
|
||
daily_pnl[current_date] = day_pnl
|
||
current_date = today
|
||
day_pnl = 0.0
|
||
day_stopped = False
|
||
|
||
arrived = []
|
||
remaining = []
|
||
for rd, ra in pending_rebates:
|
||
if today >= rd:
|
||
arrived.append(ra)
|
||
else:
|
||
remaining.append((rd, ra))
|
||
if arrived:
|
||
capital += sum(arrived)
|
||
pending_rebates = remaining
|
||
|
||
if day_stopped:
|
||
for pos in positions:
|
||
pos.hold_bars += 1
|
||
continue
|
||
|
||
# --- 正常止损止盈逻辑 ---
|
||
closed_indices = []
|
||
for pi, pos in enumerate(positions):
|
||
pos.hold_bars += 1
|
||
qty = self.notional / pos.entry_price
|
||
|
||
if pos.direction == 1:
|
||
sl_price = pos.entry_price * (1 - self.sl_pct)
|
||
tp_price = pos.entry_price * (1 + self.tp_pct)
|
||
hit_sl = lo <= sl_price
|
||
hit_tp = h >= tp_price
|
||
else:
|
||
sl_price = pos.entry_price * (1 + self.sl_pct)
|
||
tp_price = pos.entry_price * (1 - self.tp_pct)
|
||
hit_sl = h >= sl_price
|
||
hit_tp = lo <= tp_price
|
||
|
||
should_close = False
|
||
exit_price = c
|
||
|
||
# 止损始终生效(不受持仓时间限制)
|
||
if hit_sl:
|
||
should_close = True
|
||
exit_price = sl_price
|
||
elif pos.hold_bars >= self.min_hold_bars:
|
||
# 止盈和信号反转需要满足最小持仓时间
|
||
if hit_tp:
|
||
should_close = True
|
||
exit_price = tp_price
|
||
elif (pos.direction == 1 and s < -open_threshold) or \
|
||
(pos.direction == -1 and s > open_threshold):
|
||
should_close = True
|
||
exit_price = c
|
||
|
||
if should_close:
|
||
net = self._close_position(pos, exit_price, t, today, trades, pending_rebates)
|
||
capital += net
|
||
used_margin -= self.margin
|
||
day_pnl += net
|
||
closed_indices.append(pi)
|
||
|
||
# 每笔平仓后立即检查日回撤
|
||
if day_pnl < -self.max_daily_dd:
|
||
# 熔断剩余持仓
|
||
for pj, pos2 in enumerate(positions):
|
||
if pj not in closed_indices:
|
||
pos2.hold_bars += 1
|
||
net2 = self._close_position(pos2, c, t, today, trades, pending_rebates)
|
||
capital += net2
|
||
used_margin -= self.margin
|
||
day_pnl += net2
|
||
closed_indices.append(pj)
|
||
day_stopped = True
|
||
break
|
||
|
||
for pi in sorted(set(closed_indices), reverse=True):
|
||
positions.pop(pi)
|
||
|
||
if day_stopped:
|
||
continue
|
||
|
||
# --- 开仓 ---
|
||
if len(positions) < self.max_positions:
|
||
if np.isnan(s):
|
||
continue
|
||
|
||
# 开仓前检查: 当前所有持仓 + 新仓同时止损的最大亏损
|
||
n_after = len(positions) + 1
|
||
worst_total_sl = n_after * (self.notional * self.sl_pct + self.notional * self.fee_rate * 2)
|
||
if day_pnl - worst_total_sl < -self.max_daily_dd:
|
||
continue # 风险敞口太大
|
||
|
||
open_fee = self.notional * self.fee_rate
|
||
if capital - used_margin < self.margin + open_fee:
|
||
continue
|
||
|
||
new_dir = 0
|
||
if s > open_threshold:
|
||
new_dir = 1
|
||
elif s < -open_threshold:
|
||
new_dir = -1
|
||
|
||
if new_dir != 0:
|
||
positions.append(OpenPosition(
|
||
direction=new_dir, entry_price=c,
|
||
entry_time=t, hold_bars=0,
|
||
))
|
||
capital -= open_fee
|
||
used_margin += self.margin
|
||
day_pnl -= open_fee
|
||
|
||
# 最后一天
|
||
if current_date is not None:
|
||
daily_pnl[current_date] = day_pnl
|
||
|
||
# 强制平仓
|
||
if positions and len(df) > 0:
|
||
last_close = close_arr[-1]
|
||
for pos in positions:
|
||
qty = self.notional / pos.entry_price
|
||
if pos.direction == 1:
|
||
raw_pnl = qty * (last_close - pos.entry_price)
|
||
else:
|
||
raw_pnl = qty * (pos.entry_price - last_close)
|
||
fee = self.notional * self.fee_rate
|
||
net_pnl = raw_pnl - fee
|
||
capital += net_pnl
|
||
trades.append(Trade(
|
||
entry_time=pos.entry_time, exit_time=times[-1],
|
||
direction=pos.direction, entry_price=pos.entry_price,
|
||
exit_price=last_close, pnl=net_pnl,
|
||
fee=self.notional * self.fee_rate * 2,
|
||
rebate=0, holding_bars=pos.hold_bars,
|
||
))
|
||
|
||
remaining_rebate = sum(amt for _, amt in pending_rebates)
|
||
capital += remaining_rebate
|
||
|
||
return self._build_result(trades, daily_pnl, capital)
|
||
|
||
def _build_result(self, trades, daily_pnl, final_capital):
|
||
if not trades:
|
||
return {
|
||
'total_pnl': 0, 'final_capital': final_capital,
|
||
'num_trades': 0, 'win_rate': 0, 'avg_pnl': 0,
|
||
'max_daily_dd': 0, 'avg_daily_pnl': 0,
|
||
'profit_factor': 0, 'trades': [], 'daily_pnl': daily_pnl,
|
||
'total_fee': 0, 'total_rebate': 0,
|
||
}
|
||
|
||
pnls = [t.pnl for t in trades]
|
||
wins = [p for p in pnls if p > 0]
|
||
losses = [p for p in pnls if p <= 0]
|
||
daily_vals = list(daily_pnl.values())
|
||
total_fee = sum(t.fee for t in trades)
|
||
total_rebate = sum(t.rebate for t in trades)
|
||
gross_profit = sum(wins) if wins else 0
|
||
gross_loss = abs(sum(losses)) if losses else 1e-10
|
||
|
||
return {
|
||
'total_pnl': sum(pnls) + total_rebate,
|
||
'final_capital': final_capital,
|
||
'num_trades': len(trades),
|
||
'win_rate': len(wins) / len(trades) if trades else 0,
|
||
'avg_pnl': np.mean(pnls),
|
||
'max_daily_dd': min(daily_vals) if daily_vals else 0,
|
||
'avg_daily_pnl': np.mean(daily_vals) if daily_vals else 0,
|
||
'profit_factor': gross_profit / gross_loss,
|
||
'total_fee': total_fee,
|
||
'total_rebate': total_rebate,
|
||
'trades': trades,
|
||
'daily_pnl': daily_pnl,
|
||
}
|