""" 回测引擎 — 多空双向、手续费、滑点、绩效统计 每笔固定名义 100U、100 倍杠杆;同一时间仅一个仓位;最大回撤 300U 硬约束;手续费 90% 返佣 """ import pandas as pd import numpy as np from loguru import logger from .config import TRADE_CONFIG as TC, SIGNAL_MAP class BacktestEngine: """回测引擎:固定每笔 100U 名义、100x,单仓位,最大回撤 300U 内,实付手续费(90% 返佣)""" def __init__(self, commission: float = None, slippage: float = None, initial_capital: float = None, position_size: float = None, position_notional_usd: float = None, max_drawdown_limit: float = None, commission_rebate: float = None): raw_commission = commission if commission is not None else TC['commission'] rebate = commission_rebate if commission_rebate is not None else TC.get('commission_rebate', 0) self.commission = raw_commission * (1 - rebate) # 实付手续费(90% 返佣) self.slippage = slippage or TC['slippage'] self.initial_capital = initial_capital or TC['initial_capital'] self.position_size = position_size or TC['position_size'] self.position_notional_usd = position_notional_usd if position_notional_usd is not None else TC.get('position_notional_usd', self.initial_capital * self.position_size) self.max_drawdown_limit = max_drawdown_limit if max_drawdown_limit is not None else TC.get('max_drawdown_limit', float('inf')) def run(self, prices: pd.Series, signals: pd.Series) -> dict: """ 执行回测 :param prices: 收盘价 Series :param signals: 信号 Series,值: 0=观望, 1=做多, 2=做空 :return: 回测结果字典 """ df = pd.DataFrame({'price': prices, 'signal': signals}).dropna() if df.empty: logger.warning("回测数据为空") return self._empty_result() n = len(df) capital = self.initial_capital position = 0 # 持仓数量(正=多头,负=空头) entry_price = 0.0 direction = 0 # 当前方向: 0=空仓, 1=多, 2=空 trades = [] equity_curve = np.zeros(n) peak_equity = self.initial_capital # 用于 300U 最大回撤约束 for i in range(n): price = df.iloc[i]['price'] signal = int(df.iloc[i]['signal']) # 计算当前权益 if position > 0: equity_curve[i] = capital + position * (price - entry_price) elif position < 0: equity_curve[i] = capital + position * (price - entry_price) else: equity_curve[i] = capital current_equity = equity_curve[i] peak_equity = max(peak_equity, current_equity) # 最大回撤硬约束:从峰值回落超过 300U 则不再开新仓(只允许平仓) drawdown_usd = peak_equity - current_equity can_open = drawdown_usd < self.max_drawdown_limit # 强制止损:持仓浮亏超过初始资金的20%时强制平仓 if position != 0: unrealized = position * (price - entry_price) if unrealized < -self.initial_capital * 0.20: capital, trade = self._close_position( capital, position, entry_price, price, df.index[i]) trades.append(trade) position = 0 direction = 0 continue # 资金不足时不开新仓;同一时间仅一个仓位(已有持仓则只能先平再开) min_capital = self.initial_capital * 0.05 can_trade = capital > min_capital and can_open # 跳过:信号与当前持仓方向相同 if signal == direction: continue # 每笔固定名义 100U:qty = position_notional_usd / price notional = self.position_notional_usd qty = notional / price if price > 0 else 0 # 需要换方向或平仓 if signal == 1 and direction != 1: # 先平仓 if position != 0: capital, trade = self._close_position( capital, position, entry_price, price, df.index[i]) trades.append(trade) position = 0 if not can_trade or qty <= 0: direction = 0 continue # 开多:固定 100U 名义,实付手续费 cost = notional * (self.commission + self.slippage) capital -= cost position = qty entry_price = price direction = 1 elif signal == 2 and direction != 2: # 先平仓 if position != 0: capital, trade = self._close_position( capital, position, entry_price, price, df.index[i]) trades.append(trade) position = 0 if not can_trade or qty <= 0: direction = 0 continue # 开空:固定 100U 名义 cost = notional * (self.commission + self.slippage) capital -= cost position = -qty entry_price = price direction = 2 elif signal == 0 and position != 0: # 平仓 capital, trade = self._close_position( capital, position, entry_price, price, df.index[i]) trades.append(trade) position = 0 direction = 0 # 最终平仓 if position != 0: price = df.iloc[-1]['price'] capital, trade = self._close_position( capital, position, entry_price, price, df.index[-1]) trades.append(trade) equity = pd.Series(equity_curve, index=df.index) trades_df = pd.DataFrame(trades) if trades else pd.DataFrame() metrics = self._calc_metrics(equity, trades_df) monthly_pnl = self._monthly_pnl(equity) return { 'equity_curve': equity, 'trades': trades_df, 'metrics': metrics, 'final_capital': capital, 'monthly_pnl': monthly_pnl, } def _close_position(self, capital, position, entry_price, exit_price, time): """平仓并返回更新后的capital和交易记录""" if position > 0: pnl = position * (exit_price - entry_price) cost = position * exit_price * (self.commission + self.slippage) trade_type = '平多' else: pnl = -position * (entry_price - exit_price) cost = abs(position) * exit_price * (self.commission + self.slippage) trade_type = '平空' capital += pnl - cost trade = { 'type': trade_type, 'entry': entry_price, 'exit': exit_price, 'pnl': pnl - cost, 'return_pct': (exit_price / entry_price - 1) * (1 if position > 0 else -1), 'time': time, } return capital, trade def _calc_metrics(self, equity: pd.Series, trades: pd.DataFrame) -> dict: """计算绩效指标""" if equity.empty: return self._empty_metrics() total_return = (equity.iloc[-1] / self.initial_capital) - 1 n_bars = len(equity) if n_bars > 1: # 按日聚合收益率计算夏普 daily_equity = equity.resample('D').last().dropna() if len(daily_equity) > 1: daily_returns = daily_equity.pct_change().dropna() n_days = len(daily_returns) annualized_return = (1 + total_return) ** (365 / max(n_days, 1)) - 1 if total_return > -1 else -1.0 sharpe = (daily_returns.mean() / daily_returns.std() * np.sqrt(365) if daily_returns.std() > 0 else 0) else: annualized_return = 0 sharpe = 0 else: annualized_return = 0 sharpe = 0 # 最大回撤(比例与绝对 USDT) cummax = equity.cummax() drawdown = (equity - cummax) / cummax.replace(0, np.nan) max_drawdown = drawdown.min() if not drawdown.empty else 0 max_drawdown_usd = (cummax - equity).max() if not equity.empty else 0 # 按月收益(自然月):用于目标「每月盈利 ≥ 1000U」 monthly_pnl_usd = None months_above_1000 = 0 if not equity.empty and hasattr(equity.index, 'to_period'): try: monthly = equity.resample('ME').last().dropna() if len(monthly) > 0: monthly_pnl = monthly.diff() monthly_pnl.iloc[0] = monthly.iloc[0] - self.initial_capital monthly_pnl_usd = monthly_pnl months_above_1000 = (monthly_pnl >= 1000).sum() except Exception: pass # 交易统计 n_trades = len(trades) if n_trades > 0: wins = trades[trades['pnl'] > 0] losses = trades[trades['pnl'] <= 0] win_rate = len(wins) / n_trades avg_win = wins['pnl'].mean() if len(wins) > 0 else 0 avg_loss = abs(losses['pnl'].mean()) if len(losses) > 0 else 0 profit_factor = (wins['pnl'].sum() / abs(losses['pnl'].sum()) if len(losses) > 0 and losses['pnl'].sum() != 0 else float('inf')) else: win_rate = 0 avg_win = 0 avg_loss = 0 profit_factor = 0 out = { '总收益率': f'{total_return:.2%}', '年化收益率': f'{annualized_return:.2%}', '最大回撤': f'{max_drawdown:.2%}', '最大回撤(U)': f'{max_drawdown_usd:.2f}', '夏普比率': f'{sharpe:.2f}', '总交易次数': n_trades, '胜率': f'{win_rate:.2%}', '平均盈利': f'{avg_win:.2f}', '平均亏损': f'{avg_loss:.2f}', '盈亏比': f'{profit_factor:.2f}', '最终资金': f'{equity.iloc[-1]:.2f}', } out['月盈利≥1000U的月数'] = months_above_1000 if monthly_pnl_usd is not None else 0 if monthly_pnl_usd is not None and len(monthly_pnl_usd) > 0: out['月均盈利(U)'] = f'{monthly_pnl_usd.mean():.2f}' else: out['月均盈利(U)'] = '0.00' return out def _monthly_pnl(self, equity: pd.Series): """按自然月汇总收益(USDT),首月为当月权益 - 初始资金""" if equity.empty or not hasattr(equity.index, 'to_period'): return None try: monthly = equity.resample('ME').last().dropna() if len(monthly) == 0: return None pnl = monthly.diff() pnl.iloc[0] = monthly.iloc[0] - self.initial_capital return pnl except Exception: return None def _empty_result(self): return { 'equity_curve': pd.Series(dtype=float), 'trades': pd.DataFrame(), 'metrics': self._empty_metrics(), 'final_capital': self.initial_capital, 'monthly_pnl': None, } def _empty_metrics(self): return { '总收益率': '0.00%', '年化收益率': '0.00%', '最大回撤': '0.00%', '最大回撤(U)': '0.00', '夏普比率': '0.00', '总交易次数': 0, '胜率': '0.00%', '平均盈利': '0.00', '平均亏损': '0.00', '盈亏比': '0.00', '最终资金': f'{self.initial_capital:.2f}', '月盈利≥1000U的月数': 0, '月均盈利(U)': '0.00', } def print_metrics(metrics: dict, title: str = "回测结果"): """打印绩效指标""" logger.info(f"\n{'='*50}") logger.info(f" {title}") logger.info(f"{'='*50}") for k, v in metrics.items(): logger.info(f" {k:>12}: {v}") logger.info(f"{'='*50}")