""" 回测引擎 - 完整模拟手续费、返佣延迟到账、每日回撤限制、持仓时间约束 支持同时持有多单并发,严格控制每日最大回撤 """ import datetime import numpy as np import pandas as pd from dataclasses import dataclass from typing import List, Optional @dataclass class Trade: entry_time: pd.Timestamp exit_time: Optional[pd.Timestamp] = None direction: int = 0 entry_price: float = 0.0 exit_price: float = 0.0 pnl: float = 0.0 fee: float = 0.0 rebate: float = 0.0 holding_bars: int = 0 @dataclass class OpenPosition: direction: int = 0 entry_price: float = 0.0 entry_time: pd.Timestamp = None hold_bars: int = 0 class BacktestEngine: def __init__( self, initial_capital: float = 1000.0, margin_per_trade: float = 25.0, leverage: int = 50, fee_rate: float = 0.0005, rebate_ratio: float = 0.70, max_daily_drawdown: float = 50.0, min_hold_bars: int = 1, stop_loss_pct: float = 0.005, take_profit_pct: float = 0.01, max_positions: int = 3, ): self.initial_capital = initial_capital self.margin = margin_per_trade self.leverage = leverage self.notional = margin_per_trade * leverage self.fee_rate = fee_rate self.rebate_ratio = rebate_ratio self.max_daily_dd = max_daily_drawdown self.min_hold_bars = min_hold_bars self.sl_pct = stop_loss_pct self.tp_pct = take_profit_pct self.max_positions = max_positions def _close_position(self, pos, exit_price, t, today, trades, pending_rebates): """平仓一个持仓,返回 net_pnl""" qty = self.notional / pos.entry_price if pos.direction == 1: raw_pnl = qty * (exit_price - pos.entry_price) else: raw_pnl = qty * (pos.entry_price - exit_price) close_fee = self.notional * self.fee_rate net_pnl = raw_pnl - close_fee total_fee = self.notional * self.fee_rate * 2 rebate = total_fee * self.rebate_ratio rebate_date = today + datetime.timedelta(days=1) pending_rebates.append((rebate_date, rebate)) trades.append(Trade( entry_time=pos.entry_time, exit_time=t, direction=pos.direction, entry_price=pos.entry_price, exit_price=exit_price, pnl=net_pnl, fee=total_fee, rebate=rebate, holding_bars=pos.hold_bars, )) return net_pnl def _worst_unrealized(self, positions, h, lo): """计算所有持仓在本K线内的最坏浮动亏损(用 high/low)""" worst = 0.0 for pos in positions: qty = self.notional / pos.entry_price if pos.direction == 1: # 多单最坏情况: 价格跌到 low worst += qty * (lo - pos.entry_price) else: # 空单最坏情况: 价格涨到 high worst += qty * (pos.entry_price - h) return worst def run(self, df: pd.DataFrame, score: pd.Series, open_threshold: float = 0.3) -> dict: capital = self.initial_capital trades: List[Trade] = [] daily_pnl = {} pending_rebates = [] positions: List[OpenPosition] = [] used_margin = 0.0 current_date = None day_pnl = 0.0 day_stopped = False close_arr = df['close'].values high_arr = df['high'].values low_arr = df['low'].values times = df.index scores = score.values for i in range(len(df)): t = times[i] c = close_arr[i] h = high_arr[i] lo = low_arr[i] s = scores[i] today = t.date() # --- 日切换 --- if today != current_date: if current_date is not None: daily_pnl[current_date] = day_pnl current_date = today day_pnl = 0.0 day_stopped = False arrived = [] remaining = [] for rd, ra in pending_rebates: if today >= rd: arrived.append(ra) else: remaining.append((rd, ra)) if arrived: capital += sum(arrived) pending_rebates = remaining if day_stopped: for pos in positions: pos.hold_bars += 1 continue # --- 正常止损止盈逻辑 --- closed_indices = [] for pi, pos in enumerate(positions): pos.hold_bars += 1 qty = self.notional / pos.entry_price if pos.direction == 1: sl_price = pos.entry_price * (1 - self.sl_pct) tp_price = pos.entry_price * (1 + self.tp_pct) hit_sl = lo <= sl_price hit_tp = h >= tp_price else: sl_price = pos.entry_price * (1 + self.sl_pct) tp_price = pos.entry_price * (1 - self.tp_pct) hit_sl = h >= sl_price hit_tp = lo <= tp_price should_close = False exit_price = c # 止损始终生效(不受持仓时间限制) if hit_sl: should_close = True exit_price = sl_price elif pos.hold_bars >= self.min_hold_bars: # 止盈和信号反转需要满足最小持仓时间 if hit_tp: should_close = True exit_price = tp_price elif (pos.direction == 1 and s < -open_threshold) or \ (pos.direction == -1 and s > open_threshold): should_close = True exit_price = c if should_close: net = self._close_position(pos, exit_price, t, today, trades, pending_rebates) capital += net used_margin -= self.margin day_pnl += net closed_indices.append(pi) # 每笔平仓后立即检查日回撤 if day_pnl < -self.max_daily_dd: # 熔断剩余持仓 for pj, pos2 in enumerate(positions): if pj not in closed_indices: pos2.hold_bars += 1 net2 = self._close_position(pos2, c, t, today, trades, pending_rebates) capital += net2 used_margin -= self.margin day_pnl += net2 closed_indices.append(pj) day_stopped = True break for pi in sorted(set(closed_indices), reverse=True): positions.pop(pi) if day_stopped: continue # --- 开仓 --- if len(positions) < self.max_positions: if np.isnan(s): continue # 开仓前检查: 当前所有持仓 + 新仓同时止损的最大亏损 n_after = len(positions) + 1 worst_total_sl = n_after * (self.notional * self.sl_pct + self.notional * self.fee_rate * 2) if day_pnl - worst_total_sl < -self.max_daily_dd: continue # 风险敞口太大 open_fee = self.notional * self.fee_rate if capital - used_margin < self.margin + open_fee: continue new_dir = 0 if s > open_threshold: new_dir = 1 elif s < -open_threshold: new_dir = -1 if new_dir != 0: positions.append(OpenPosition( direction=new_dir, entry_price=c, entry_time=t, hold_bars=0, )) capital -= open_fee used_margin += self.margin day_pnl -= open_fee # 最后一天 if current_date is not None: daily_pnl[current_date] = day_pnl # 强制平仓 if positions and len(df) > 0: last_close = close_arr[-1] for pos in positions: qty = self.notional / pos.entry_price if pos.direction == 1: raw_pnl = qty * (last_close - pos.entry_price) else: raw_pnl = qty * (pos.entry_price - last_close) fee = self.notional * self.fee_rate net_pnl = raw_pnl - fee capital += net_pnl trades.append(Trade( entry_time=pos.entry_time, exit_time=times[-1], direction=pos.direction, entry_price=pos.entry_price, exit_price=last_close, pnl=net_pnl, fee=self.notional * self.fee_rate * 2, rebate=0, holding_bars=pos.hold_bars, )) remaining_rebate = sum(amt for _, amt in pending_rebates) capital += remaining_rebate return self._build_result(trades, daily_pnl, capital) def _build_result(self, trades, daily_pnl, final_capital): if not trades: return { 'total_pnl': 0, 'final_capital': final_capital, 'num_trades': 0, 'win_rate': 0, 'avg_pnl': 0, 'max_daily_dd': 0, 'avg_daily_pnl': 0, 'profit_factor': 0, 'trades': [], 'daily_pnl': daily_pnl, 'total_fee': 0, 'total_rebate': 0, } pnls = [t.pnl for t in trades] wins = [p for p in pnls if p > 0] losses = [p for p in pnls if p <= 0] daily_vals = list(daily_pnl.values()) total_fee = sum(t.fee for t in trades) total_rebate = sum(t.rebate for t in trades) gross_profit = sum(wins) if wins else 0 gross_loss = abs(sum(losses)) if losses else 1e-10 return { 'total_pnl': sum(pnls) + total_rebate, 'final_capital': final_capital, 'num_trades': len(trades), 'win_rate': len(wins) / len(trades) if trades else 0, 'avg_pnl': np.mean(pnls), 'max_daily_dd': min(daily_vals) if daily_vals else 0, 'avg_daily_pnl': np.mean(daily_vals) if daily_vals else 0, 'profit_factor': gross_profit / gross_loss, 'total_fee': total_fee, 'total_rebate': total_rebate, 'trades': trades, 'daily_pnl': daily_pnl, }