Files
jyx_code4/strategy/backtest.py

299 lines
12 KiB
Python
Raw Normal View History

2026-02-20 20:57:25 +08:00
"""
回测引擎 多空双向手续费滑点绩效统计
每笔固定名义 100U100 倍杠杆同一时间仅一个仓位最大回撤 300U 硬约束手续费 90% 返佣
"""
import pandas as pd
import numpy as np
from loguru import logger
from .config import TRADE_CONFIG as TC, SIGNAL_MAP
class BacktestEngine:
"""回测引擎:固定每笔 100U 名义、100x单仓位最大回撤 300U 内实付手续费90% 返佣)"""
def __init__(self, commission: float = None, slippage: float = None,
initial_capital: float = None, position_size: float = None,
position_notional_usd: float = None, max_drawdown_limit: float = None,
commission_rebate: float = None):
raw_commission = commission if commission is not None else TC['commission']
rebate = commission_rebate if commission_rebate is not None else TC.get('commission_rebate', 0)
self.commission = raw_commission * (1 - rebate) # 实付手续费90% 返佣)
self.slippage = slippage or TC['slippage']
self.initial_capital = initial_capital or TC['initial_capital']
self.position_size = position_size or TC['position_size']
self.position_notional_usd = position_notional_usd if position_notional_usd is not None else TC.get('position_notional_usd', self.initial_capital * self.position_size)
self.max_drawdown_limit = max_drawdown_limit if max_drawdown_limit is not None else TC.get('max_drawdown_limit', float('inf'))
def run(self, prices: pd.Series, signals: pd.Series) -> dict:
"""
执行回测
:param prices: 收盘价 Series
:param signals: 信号 Series: 0=观望, 1=做多, 2=做空
:return: 回测结果字典
"""
df = pd.DataFrame({'price': prices, 'signal': signals}).dropna()
if df.empty:
logger.warning("回测数据为空")
return self._empty_result()
n = len(df)
capital = self.initial_capital
position = 0 # 持仓数量(正=多头,负=空头)
entry_price = 0.0
direction = 0 # 当前方向: 0=空仓, 1=多, 2=空
trades = []
equity_curve = np.zeros(n)
peak_equity = self.initial_capital # 用于 300U 最大回撤约束
for i in range(n):
price = df.iloc[i]['price']
signal = int(df.iloc[i]['signal'])
# 计算当前权益
if position > 0:
equity_curve[i] = capital + position * (price - entry_price)
elif position < 0:
equity_curve[i] = capital + position * (price - entry_price)
else:
equity_curve[i] = capital
current_equity = equity_curve[i]
peak_equity = max(peak_equity, current_equity)
# 最大回撤硬约束:从峰值回落超过 300U 则不再开新仓(只允许平仓)
drawdown_usd = peak_equity - current_equity
can_open = drawdown_usd < self.max_drawdown_limit
# 强制止损持仓浮亏超过初始资金的20%时强制平仓
if position != 0:
unrealized = position * (price - entry_price)
if unrealized < -self.initial_capital * 0.20:
capital, trade = self._close_position(
capital, position, entry_price, price, df.index[i])
trades.append(trade)
position = 0
direction = 0
continue
# 资金不足时不开新仓;同一时间仅一个仓位(已有持仓则只能先平再开)
min_capital = self.initial_capital * 0.05
can_trade = capital > min_capital and can_open
# 跳过:信号与当前持仓方向相同
if signal == direction:
continue
# 每笔固定名义 100Uqty = position_notional_usd / price
notional = self.position_notional_usd
qty = notional / price if price > 0 else 0
# 需要换方向或平仓
if signal == 1 and direction != 1:
# 先平仓
if position != 0:
capital, trade = self._close_position(
capital, position, entry_price, price, df.index[i])
trades.append(trade)
position = 0
if not can_trade or qty <= 0:
direction = 0
continue
# 开多:固定 100U 名义,实付手续费
cost = notional * (self.commission + self.slippage)
capital -= cost
position = qty
entry_price = price
direction = 1
elif signal == 2 and direction != 2:
# 先平仓
if position != 0:
capital, trade = self._close_position(
capital, position, entry_price, price, df.index[i])
trades.append(trade)
position = 0
if not can_trade or qty <= 0:
direction = 0
continue
# 开空:固定 100U 名义
cost = notional * (self.commission + self.slippage)
capital -= cost
position = -qty
entry_price = price
direction = 2
elif signal == 0 and position != 0:
# 平仓
capital, trade = self._close_position(
capital, position, entry_price, price, df.index[i])
trades.append(trade)
position = 0
direction = 0
# 最终平仓
if position != 0:
price = df.iloc[-1]['price']
capital, trade = self._close_position(
capital, position, entry_price, price, df.index[-1])
trades.append(trade)
equity = pd.Series(equity_curve, index=df.index)
trades_df = pd.DataFrame(trades) if trades else pd.DataFrame()
metrics = self._calc_metrics(equity, trades_df)
monthly_pnl = self._monthly_pnl(equity)
return {
'equity_curve': equity,
'trades': trades_df,
'metrics': metrics,
'final_capital': capital,
'monthly_pnl': monthly_pnl,
}
def _close_position(self, capital, position, entry_price, exit_price, time):
"""平仓并返回更新后的capital和交易记录"""
if position > 0:
pnl = position * (exit_price - entry_price)
cost = position * exit_price * (self.commission + self.slippage)
trade_type = '平多'
else:
pnl = -position * (entry_price - exit_price)
cost = abs(position) * exit_price * (self.commission + self.slippage)
trade_type = '平空'
capital += pnl - cost
trade = {
'type': trade_type,
'entry': entry_price,
'exit': exit_price,
'pnl': pnl - cost,
'return_pct': (exit_price / entry_price - 1) * (1 if position > 0 else -1),
'time': time,
}
return capital, trade
def _calc_metrics(self, equity: pd.Series, trades: pd.DataFrame) -> dict:
"""计算绩效指标"""
if equity.empty:
return self._empty_metrics()
total_return = (equity.iloc[-1] / self.initial_capital) - 1
n_bars = len(equity)
if n_bars > 1:
# 按日聚合收益率计算夏普
daily_equity = equity.resample('D').last().dropna()
if len(daily_equity) > 1:
daily_returns = daily_equity.pct_change().dropna()
n_days = len(daily_returns)
annualized_return = (1 + total_return) ** (365 / max(n_days, 1)) - 1 if total_return > -1 else -1.0
sharpe = (daily_returns.mean() / daily_returns.std() * np.sqrt(365)
if daily_returns.std() > 0 else 0)
else:
annualized_return = 0
sharpe = 0
else:
annualized_return = 0
sharpe = 0
# 最大回撤(比例与绝对 USDT
cummax = equity.cummax()
drawdown = (equity - cummax) / cummax.replace(0, np.nan)
max_drawdown = drawdown.min() if not drawdown.empty else 0
max_drawdown_usd = (cummax - equity).max() if not equity.empty else 0
# 按月收益(自然月):用于目标「每月盈利 ≥ 1000U」
monthly_pnl_usd = None
months_above_1000 = 0
if not equity.empty and hasattr(equity.index, 'to_period'):
try:
monthly = equity.resample('ME').last().dropna()
if len(monthly) > 0:
monthly_pnl = monthly.diff()
monthly_pnl.iloc[0] = monthly.iloc[0] - self.initial_capital
monthly_pnl_usd = monthly_pnl
months_above_1000 = (monthly_pnl >= 1000).sum()
except Exception:
pass
# 交易统计
n_trades = len(trades)
if n_trades > 0:
wins = trades[trades['pnl'] > 0]
losses = trades[trades['pnl'] <= 0]
win_rate = len(wins) / n_trades
avg_win = wins['pnl'].mean() if len(wins) > 0 else 0
avg_loss = abs(losses['pnl'].mean()) if len(losses) > 0 else 0
profit_factor = (wins['pnl'].sum() / abs(losses['pnl'].sum())
if len(losses) > 0 and losses['pnl'].sum() != 0 else float('inf'))
else:
win_rate = 0
avg_win = 0
avg_loss = 0
profit_factor = 0
out = {
'总收益率': f'{total_return:.2%}',
'年化收益率': f'{annualized_return:.2%}',
'最大回撤': f'{max_drawdown:.2%}',
'最大回撤(U)': f'{max_drawdown_usd:.2f}',
'夏普比率': f'{sharpe:.2f}',
'总交易次数': n_trades,
'胜率': f'{win_rate:.2%}',
'平均盈利': f'{avg_win:.2f}',
'平均亏损': f'{avg_loss:.2f}',
'盈亏比': f'{profit_factor:.2f}',
'最终资金': f'{equity.iloc[-1]:.2f}',
}
out['月盈利≥1000U的月数'] = months_above_1000 if monthly_pnl_usd is not None else 0
if monthly_pnl_usd is not None and len(monthly_pnl_usd) > 0:
out['月均盈利(U)'] = f'{monthly_pnl_usd.mean():.2f}'
else:
out['月均盈利(U)'] = '0.00'
return out
def _monthly_pnl(self, equity: pd.Series):
"""按自然月汇总收益USDT首月为当月权益 - 初始资金"""
if equity.empty or not hasattr(equity.index, 'to_period'):
return None
try:
monthly = equity.resample('ME').last().dropna()
if len(monthly) == 0:
return None
pnl = monthly.diff()
pnl.iloc[0] = monthly.iloc[0] - self.initial_capital
return pnl
except Exception:
return None
def _empty_result(self):
return {
'equity_curve': pd.Series(dtype=float),
'trades': pd.DataFrame(),
'metrics': self._empty_metrics(),
'final_capital': self.initial_capital,
'monthly_pnl': None,
}
def _empty_metrics(self):
return {
'总收益率': '0.00%', '年化收益率': '0.00%', '最大回撤': '0.00%',
'最大回撤(U)': '0.00', '夏普比率': '0.00', '总交易次数': 0, '胜率': '0.00%',
'平均盈利': '0.00', '平均亏损': '0.00', '盈亏比': '0.00',
'最终资金': f'{self.initial_capital:.2f}', '月盈利≥1000U的月数': 0,
'月均盈利(U)': '0.00',
}
def print_metrics(metrics: dict, title: str = "回测结果"):
"""打印绩效指标"""
logger.info(f"\n{'='*50}")
logger.info(f" {title}")
logger.info(f"{'='*50}")
for k, v in metrics.items():
logger.info(f" {k:>12}: {v}")
logger.info(f"{'='*50}")