第一版策略

This commit is contained in:
ddrwode
2026-02-23 04:09:34 +08:00
parent e83d15f127
commit d504f720d5
10 changed files with 5246 additions and 0 deletions

298
strategy/backtest_engine.py Normal file
View File

@@ -0,0 +1,298 @@
"""
回测引擎 - 完整模拟手续费、返佣延迟到账、每日回撤限制、持仓时间约束
支持同时持有多单并发,严格控制每日最大回撤
"""
import datetime
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class Trade:
entry_time: pd.Timestamp
exit_time: Optional[pd.Timestamp] = None
direction: int = 0
entry_price: float = 0.0
exit_price: float = 0.0
pnl: float = 0.0
fee: float = 0.0
rebate: float = 0.0
holding_bars: int = 0
@dataclass
class OpenPosition:
direction: int = 0
entry_price: float = 0.0
entry_time: pd.Timestamp = None
hold_bars: int = 0
class BacktestEngine:
def __init__(
self,
initial_capital: float = 1000.0,
margin_per_trade: float = 25.0,
leverage: int = 50,
fee_rate: float = 0.0005,
rebate_ratio: float = 0.70,
max_daily_drawdown: float = 50.0,
min_hold_bars: int = 1,
stop_loss_pct: float = 0.005,
take_profit_pct: float = 0.01,
max_positions: int = 3,
):
self.initial_capital = initial_capital
self.margin = margin_per_trade
self.leverage = leverage
self.notional = margin_per_trade * leverage
self.fee_rate = fee_rate
self.rebate_ratio = rebate_ratio
self.max_daily_dd = max_daily_drawdown
self.min_hold_bars = min_hold_bars
self.sl_pct = stop_loss_pct
self.tp_pct = take_profit_pct
self.max_positions = max_positions
def _close_position(self, pos, exit_price, t, today, trades, pending_rebates):
"""平仓一个持仓,返回 net_pnl"""
qty = self.notional / pos.entry_price
if pos.direction == 1:
raw_pnl = qty * (exit_price - pos.entry_price)
else:
raw_pnl = qty * (pos.entry_price - exit_price)
close_fee = self.notional * self.fee_rate
net_pnl = raw_pnl - close_fee
total_fee = self.notional * self.fee_rate * 2
rebate = total_fee * self.rebate_ratio
rebate_date = today + datetime.timedelta(days=1)
pending_rebates.append((rebate_date, rebate))
trades.append(Trade(
entry_time=pos.entry_time, exit_time=t,
direction=pos.direction, entry_price=pos.entry_price,
exit_price=exit_price, pnl=net_pnl, fee=total_fee,
rebate=rebate, holding_bars=pos.hold_bars,
))
return net_pnl
def _worst_unrealized(self, positions, h, lo):
"""计算所有持仓在本K线内的最坏浮动亏损用 high/low"""
worst = 0.0
for pos in positions:
qty = self.notional / pos.entry_price
if pos.direction == 1:
# 多单最坏情况: 价格跌到 low
worst += qty * (lo - pos.entry_price)
else:
# 空单最坏情况: 价格涨到 high
worst += qty * (pos.entry_price - h)
return worst
def run(self, df: pd.DataFrame, score: pd.Series, open_threshold: float = 0.3) -> dict:
capital = self.initial_capital
trades: List[Trade] = []
daily_pnl = {}
pending_rebates = []
positions: List[OpenPosition] = []
used_margin = 0.0
current_date = None
day_pnl = 0.0
day_stopped = False
close_arr = df['close'].values
high_arr = df['high'].values
low_arr = df['low'].values
times = df.index
scores = score.values
for i in range(len(df)):
t = times[i]
c = close_arr[i]
h = high_arr[i]
lo = low_arr[i]
s = scores[i]
today = t.date()
# --- 日切换 ---
if today != current_date:
if current_date is not None:
daily_pnl[current_date] = day_pnl
current_date = today
day_pnl = 0.0
day_stopped = False
arrived = []
remaining = []
for rd, ra in pending_rebates:
if today >= rd:
arrived.append(ra)
else:
remaining.append((rd, ra))
if arrived:
capital += sum(arrived)
pending_rebates = remaining
if day_stopped:
for pos in positions:
pos.hold_bars += 1
continue
# --- 正常止损止盈逻辑 ---
closed_indices = []
for pi, pos in enumerate(positions):
pos.hold_bars += 1
qty = self.notional / pos.entry_price
if pos.direction == 1:
sl_price = pos.entry_price * (1 - self.sl_pct)
tp_price = pos.entry_price * (1 + self.tp_pct)
hit_sl = lo <= sl_price
hit_tp = h >= tp_price
else:
sl_price = pos.entry_price * (1 + self.sl_pct)
tp_price = pos.entry_price * (1 - self.tp_pct)
hit_sl = h >= sl_price
hit_tp = lo <= tp_price
should_close = False
exit_price = c
# 止损始终生效(不受持仓时间限制)
if hit_sl:
should_close = True
exit_price = sl_price
elif pos.hold_bars >= self.min_hold_bars:
# 止盈和信号反转需要满足最小持仓时间
if hit_tp:
should_close = True
exit_price = tp_price
elif (pos.direction == 1 and s < -open_threshold) or \
(pos.direction == -1 and s > open_threshold):
should_close = True
exit_price = c
if should_close:
net = self._close_position(pos, exit_price, t, today, trades, pending_rebates)
capital += net
used_margin -= self.margin
day_pnl += net
closed_indices.append(pi)
# 每笔平仓后立即检查日回撤
if day_pnl < -self.max_daily_dd:
# 熔断剩余持仓
for pj, pos2 in enumerate(positions):
if pj not in closed_indices:
pos2.hold_bars += 1
net2 = self._close_position(pos2, c, t, today, trades, pending_rebates)
capital += net2
used_margin -= self.margin
day_pnl += net2
closed_indices.append(pj)
day_stopped = True
break
for pi in sorted(set(closed_indices), reverse=True):
positions.pop(pi)
if day_stopped:
continue
# --- 开仓 ---
if len(positions) < self.max_positions:
if np.isnan(s):
continue
# 开仓前检查: 当前所有持仓 + 新仓同时止损的最大亏损
n_after = len(positions) + 1
worst_total_sl = n_after * (self.notional * self.sl_pct + self.notional * self.fee_rate * 2)
if day_pnl - worst_total_sl < -self.max_daily_dd:
continue # 风险敞口太大
open_fee = self.notional * self.fee_rate
if capital - used_margin < self.margin + open_fee:
continue
new_dir = 0
if s > open_threshold:
new_dir = 1
elif s < -open_threshold:
new_dir = -1
if new_dir != 0:
positions.append(OpenPosition(
direction=new_dir, entry_price=c,
entry_time=t, hold_bars=0,
))
capital -= open_fee
used_margin += self.margin
day_pnl -= open_fee
# 最后一天
if current_date is not None:
daily_pnl[current_date] = day_pnl
# 强制平仓
if positions and len(df) > 0:
last_close = close_arr[-1]
for pos in positions:
qty = self.notional / pos.entry_price
if pos.direction == 1:
raw_pnl = qty * (last_close - pos.entry_price)
else:
raw_pnl = qty * (pos.entry_price - last_close)
fee = self.notional * self.fee_rate
net_pnl = raw_pnl - fee
capital += net_pnl
trades.append(Trade(
entry_time=pos.entry_time, exit_time=times[-1],
direction=pos.direction, entry_price=pos.entry_price,
exit_price=last_close, pnl=net_pnl,
fee=self.notional * self.fee_rate * 2,
rebate=0, holding_bars=pos.hold_bars,
))
remaining_rebate = sum(amt for _, amt in pending_rebates)
capital += remaining_rebate
return self._build_result(trades, daily_pnl, capital)
def _build_result(self, trades, daily_pnl, final_capital):
if not trades:
return {
'total_pnl': 0, 'final_capital': final_capital,
'num_trades': 0, 'win_rate': 0, 'avg_pnl': 0,
'max_daily_dd': 0, 'avg_daily_pnl': 0,
'profit_factor': 0, 'trades': [], 'daily_pnl': daily_pnl,
'total_fee': 0, 'total_rebate': 0,
}
pnls = [t.pnl for t in trades]
wins = [p for p in pnls if p > 0]
losses = [p for p in pnls if p <= 0]
daily_vals = list(daily_pnl.values())
total_fee = sum(t.fee for t in trades)
total_rebate = sum(t.rebate for t in trades)
gross_profit = sum(wins) if wins else 0
gross_loss = abs(sum(losses)) if losses else 1e-10
return {
'total_pnl': sum(pnls) + total_rebate,
'final_capital': final_capital,
'num_trades': len(trades),
'win_rate': len(wins) / len(trades) if trades else 0,
'avg_pnl': np.mean(pnls),
'max_daily_dd': min(daily_vals) if daily_vals else 0,
'avg_daily_pnl': np.mean(daily_vals) if daily_vals else 0,
'profit_factor': gross_profit / gross_loss,
'total_fee': total_fee,
'total_rebate': total_rebate,
'trades': trades,
'daily_pnl': daily_pnl,
}