""" BitMart 返佣策略回测 — 双策略对比 策略A: 网格交易 (Grid Trading) - 围绕EMA中轨设定网格,价格触及网格线时开仓 - 固定止盈(1格)、固定止损(3格) - 趋势过滤:只在趋势方向开仓 策略B: EMA趋势跟随 (EMA Trend Following) - 快慢EMA金叉做多、死叉做空 - 始终持仓,信号反转时反手 - 大级别趋势过滤避免逆势 两个策略都: - 严格执行 >3分钟最低持仓 - 计算90%返佣收益 - 输出详细对比报告 """ import time import datetime import statistics import sqlite3 from pathlib import Path from dataclasses import dataclass from typing import List # ========================= 简易 Logger ========================= class _L: @staticmethod def info(m): print(f"[INFO] {m}") @staticmethod def ok(m): print(f"[ OK ] {m}") @staticmethod def warn(m): print(f"[WARN] {m}") @staticmethod def err(m): print(f"[ERR ] {m}") log = _L() # ========================= 交易记录 ========================= @dataclass class Trade: open_time: datetime.datetime close_time: datetime.datetime direction: str open_price: float close_price: float size: float pnl: float pnl_pct: float fee: float rebate: float hold_seconds: float close_reason: str # ========================= 数据加载 ========================= def load_1m_klines(start_date='2025-01-01', end_date='2025-12-31'): db_path = Path(__file__).parent.parent / 'models' / 'database.db' start_dt = datetime.datetime.strptime(start_date, '%Y-%m-%d') end_dt = datetime.datetime.strptime(end_date, '%Y-%m-%d') + datetime.timedelta(days=1) start_ms = int(start_dt.timestamp()) * 1000 end_ms = int(end_dt.timestamp()) * 1000 conn = sqlite3.connect(str(db_path)) cursor = conn.cursor() cursor.execute( "SELECT id, open, high, low, close FROM bitmart_eth_1m " "WHERE id >= ? AND id < ? ORDER BY id", (start_ms, end_ms) ) rows = cursor.fetchall() conn.close() data = [] for r in rows: dt = datetime.datetime.fromtimestamp(r[0] / 1000.0) data.append({ 'datetime': dt, 'open': r[1], 'high': r[2], 'low': r[3], 'close': r[4], }) log.info(f"Loaded {len(data)} bars ({start_date} ~ {end_date})") return data # ========================= EMA 工具 ========================= class EMA: def __init__(self, period): self.period = period self.k = 2.0 / (period + 1) self.value = None def update(self, price): if self.value is None: self.value = price else: self.value = price * self.k + self.value * (1 - self.k) return self.value # ========================= 基础回测引擎 ========================= class BaseBacktest: def __init__(self, name, initial_balance=1000.0, leverage=50, risk_pct=0.005, taker_fee=0.0006, rebate_rate=0.90, min_hold_sec=200, max_hold_sec=1800): self.name = name self.initial_balance = initial_balance self.balance = initial_balance self.leverage = leverage self.risk_pct = risk_pct self.taker_fee = taker_fee self.rebate_rate = rebate_rate self.min_hold_sec = min_hold_sec self.max_hold_sec = max_hold_sec self.position = 0 self.open_price = 0.0 self.open_time = None self.pos_size = 0.0 self.trades: List[Trade] = [] self.equity_curve = [] self.peak_equity = initial_balance self.max_dd_pct = 0.0 def _open(self, direction, price, dt): self.pos_size = self.balance * self.risk_pct * self.leverage if self.pos_size < 1: return False fee = self.pos_size * self.taker_fee self.balance -= fee self.position = 1 if direction == 'long' else -1 self.open_price = price self.open_time = dt return True def _close(self, price, dt, reason): if self.position == 0: return None if self.position == 1: pnl_pct = (price - self.open_price) / self.open_price else: pnl_pct = (self.open_price - price) / self.open_price pnl = self.pos_size * pnl_pct close_val = self.pos_size * (1 + pnl_pct) close_fee = close_val * self.taker_fee open_fee = self.pos_size * self.taker_fee total_fee = open_fee + close_fee rebate = total_fee * self.rebate_rate self.balance += pnl - close_fee + rebate hold_sec = (dt - self.open_time).total_seconds() trade = Trade( open_time=self.open_time, close_time=dt, direction='long' if self.position == 1 else 'short', open_price=self.open_price, close_price=price, size=self.pos_size, pnl=pnl, pnl_pct=pnl_pct, fee=total_fee, rebate=rebate, hold_seconds=hold_sec, close_reason=reason, ) self.trades.append(trade) self.position = 0 self.open_price = 0.0 self.open_time = None self.pos_size = 0.0 return trade def hold_seconds(self, dt): if self.open_time is None: return 0 return (dt - self.open_time).total_seconds() def can_close(self, dt): return self.hold_seconds(dt) >= self.min_hold_sec def cur_pnl_pct(self, price): if self.position == 1: return (price - self.open_price) / self.open_price elif self.position == -1: return (self.open_price - price) / self.open_price return 0 def track_equity(self, dt, price, every_n=60, bar_idx=0): if bar_idx % every_n != 0: return eq = self.balance if self.position != 0 and self.open_price > 0: eq += self.pos_size * self.cur_pnl_pct(price) self.equity_curve.append({'datetime': dt, 'equity': eq}) if eq > self.peak_equity: self.peak_equity = eq dd = (self.peak_equity - eq) / self.peak_equity if self.peak_equity > 0 else 0 if dd > self.max_dd_pct: self.max_dd_pct = dd # ========================= 策略A: 网格交易 ========================= class GridStrategy(BaseBacktest): """ 网格交易 + 趋势过滤 - 用 EMA(120) 判断趋势方向 - 网格间距 = grid_pct (如 0.20%) - 顺势开仓:上涨趋势中价格回落到网格线做多,下跌趋势中价格反弹到网格线做空 - TP: tp_grids 格 (如 1格 = 0.20%) - SL: sl_grids 格 (如 3格 = 0.60%) - 最低持仓 200 秒 """ def __init__(self, grid_pct=0.0020, tp_grids=1, sl_grids=3, trend_ema_period=120, **kwargs): super().__init__(name="Grid+Trend", **kwargs) self.grid_pct = grid_pct self.tp_pct = grid_pct * tp_grids self.sl_pct = grid_pct * sl_grids self.hard_sl_pct = grid_pct * (sl_grids + 1) self.trend_ema = EMA(trend_ema_period) self.last_grid_cross = None # 上一次穿越的网格线价格 self.cooldown_until = None # 冷却期 def get_grid_level(self, price, direction='below'): """获取价格最近的网格线""" grid_size = price * self.grid_pct if grid_size == 0: return price if direction == 'below': return price - (price % grid_size) else: return price - (price % grid_size) + grid_size def run(self, data): log.info(f"[{self.name}] Starting... grid={self.grid_pct*100:.2f}% TP={self.tp_pct*100:.2f}% SL={self.sl_pct*100:.2f}%") t0 = time.time() prev_close = None for i, bar in enumerate(data): price = bar['close'] high = bar['high'] low = bar['low'] dt = bar['datetime'] ema_val = self.trend_ema.update(price) # 冷却期检查 if self.cooldown_until and dt < self.cooldown_until: self.track_equity(dt, price, bar_idx=i) prev_close = price continue self.cooldown_until = None # === 有持仓:检查平仓 === if self.position != 0: p = self.cur_pnl_pct(price) hs = self.hold_seconds(dt) # 硬止损(不受时间限制) if -p >= self.hard_sl_pct: self._close(price, dt, f"hard_SL ({p*100:+.3f}%)") self.cooldown_until = dt + datetime.timedelta(seconds=120) self.track_equity(dt, price, bar_idx=i) prev_close = price continue # 满足最低持仓后 if self.can_close(dt): # 止盈 if p >= self.tp_pct: self._close(price, dt, f"TP ({p*100:+.3f}%)") prev_close = price self.track_equity(dt, price, bar_idx=i) continue # 止损 if -p >= self.sl_pct: self._close(price, dt, f"SL ({p*100:+.3f}%)") self.cooldown_until = dt + datetime.timedelta(seconds=120) prev_close = price self.track_equity(dt, price, bar_idx=i) continue # 超时 if hs >= self.max_hold_sec: self._close(price, dt, f"timeout ({hs:.0f}s)") prev_close = price self.track_equity(dt, price, bar_idx=i) continue # === 无持仓:检查开仓 === if self.position == 0 and prev_close is not None: grid_below = self.get_grid_level(prev_close, 'below') grid_above = self.get_grid_level(prev_close, 'above') # 上涨趋势:价格回落到下方网格线 → 做多 if price > ema_val and low <= grid_below and prev_close > grid_below: self._open('long', price, dt) # 下跌趋势:价格反弹到上方网格线 → 做空 elif price < ema_val and high >= grid_above and prev_close < grid_above: self._open('short', price, dt) self.track_equity(dt, price, bar_idx=i) prev_close = price if i > 0 and i % (len(data) // 10) == 0: log.info(f" [{self.name}] {i/len(data)*100:.0f}% | bal={self.balance:.2f} | trades={len(self.trades)}") # 强制平仓 if self.position != 0: self._close(data[-1]['close'], data[-1]['datetime'], "backtest_end") log.ok(f"[{self.name}] Done in {time.time()-t0:.1f}s | {len(self.trades)} trades") return self.trades # ========================= 策略B: EMA趋势跟随 ========================= class EMATrendStrategy(BaseBacktest): """ EMA 趋势跟随 - 快线 EMA(8),慢线 EMA(21) - 大级别过滤 EMA(120) - 金叉且价格在大EMA上方 → 做多 - 死叉且价格在大EMA下方 → 做空 - 反向交叉时反手(满足持仓时间后) - 加入 ATR 波动率过滤,低波动时不交易 """ def __init__(self, fast_period=8, slow_period=21, big_period=120, atr_period=14, atr_min_pct=0.0003, **kwargs): super().__init__(name="EMA-Trend", **kwargs) self.ema_fast = EMA(fast_period) self.ema_slow = EMA(slow_period) self.ema_big = EMA(big_period) self.atr_period = atr_period self.atr_min_pct = atr_min_pct # 最低波动率过滤 self.highs = [] self.lows = [] self.closes = [] self.prev_fast = None self.prev_slow = None self.pending_signal = None # 等待持仓时间满足后执行的信号 def calc_atr(self): if len(self.highs) < self.atr_period + 1: return None trs = [] for i in range(-self.atr_period, 0): h = self.highs[i] l = self.lows[i] pc = self.closes[i - 1] tr = max(h - l, abs(h - pc), abs(l - pc)) trs.append(tr) return sum(trs) / len(trs) def run(self, data): log.info(f"[{self.name}] Starting... fast=EMA8 slow=EMA21 big=EMA120") t0 = time.time() stop_loss_pct = 0.004 # 0.4% 止损 hard_sl_pct = 0.006 # 0.6% 硬止损 for i, bar in enumerate(data): price = bar['close'] dt = bar['datetime'] self.highs.append(bar['high']) self.lows.append(bar['low']) self.closes.append(price) fast = self.ema_fast.update(price) slow = self.ema_slow.update(price) big = self.ema_big.update(price) # ATR 波动率过滤 atr = self.calc_atr() if atr is not None and price > 0: atr_pct = atr / price else: atr_pct = 0 # 检测交叉 cross_up = (self.prev_fast is not None and self.prev_fast <= self.prev_slow and fast > slow) cross_down = (self.prev_fast is not None and self.prev_fast >= self.prev_slow and fast < slow) self.prev_fast = fast self.prev_slow = slow # === 有持仓 === if self.position != 0: p = self.cur_pnl_pct(price) # 硬止损 if -p >= hard_sl_pct: self._close(price, dt, f"hard_SL ({p*100:+.3f}%)") self.track_equity(dt, price, bar_idx=i) continue if self.can_close(dt): # 止损 if -p >= stop_loss_pct: self._close(price, dt, f"SL ({p*100:+.3f}%)") self.track_equity(dt, price, bar_idx=i) continue # 超时 hs = self.hold_seconds(dt) if hs >= self.max_hold_sec: self._close(price, dt, f"timeout ({hs:.0f}s)") self.track_equity(dt, price, bar_idx=i) continue # 反手信号:持多遇到死叉 → 平多 if self.position == 1 and cross_down: self._close(price, dt, "cross_reverse") if price < big and atr_pct >= self.atr_min_pct: self._open('short', price, dt) self.track_equity(dt, price, bar_idx=i) continue # 反手信号:持空遇到金叉 → 平空 if self.position == -1 and cross_up: self._close(price, dt, "cross_reverse") if price > big and atr_pct >= self.atr_min_pct: self._open('long', price, dt) self.track_equity(dt, price, bar_idx=i) continue else: # 未满最低持仓时间,记录待处理信号 if self.position == 1 and cross_down: self.pending_signal = 'close_long' elif self.position == -1 and cross_up: self.pending_signal = 'close_short' # 处理待处理信号(持仓时间刚好满足) if self.pending_signal and self.can_close(dt): if self.pending_signal == 'close_long' and self.position == 1: self._close(price, dt, "delayed_cross") if fast < slow and price < big and atr_pct >= self.atr_min_pct: self._open('short', price, dt) elif self.pending_signal == 'close_short' and self.position == -1: self._close(price, dt, "delayed_cross") if fast > slow and price > big and atr_pct >= self.atr_min_pct: self._open('long', price, dt) self.pending_signal = None # === 无持仓:检查开仓 === if self.position == 0 and atr_pct >= self.atr_min_pct: if cross_up and price > big: self._open('long', price, dt) elif cross_down and price < big: self._open('short', price, dt) self.track_equity(dt, price, bar_idx=i) if i > 0 and i % (len(data) // 10) == 0: log.info(f" [{self.name}] {i/len(data)*100:.0f}% | bal={self.balance:.2f} | trades={len(self.trades)}") if self.position != 0: self._close(data[-1]['close'], data[-1]['datetime'], "backtest_end") log.ok(f"[{self.name}] Done in {time.time()-t0:.1f}s | {len(self.trades)} trades") return self.trades # ========================= 报告生成 ========================= def print_report(strategy: BaseBacktest): trades = strategy.trades if not trades: print(f"\n[{strategy.name}] No trades.") return n = len(trades) wins = [t for t in trades if t.pnl > 0] losses = [t for t in trades if t.pnl <= 0] wr = len(wins) / n * 100 total_pnl = sum(t.pnl for t in trades) total_fee = sum(t.fee for t in trades) total_rebate = sum(t.rebate for t in trades) net = strategy.balance - strategy.initial_balance total_vol = sum(t.size for t in trades) * 2 avg_pnl = total_pnl / n avg_win = statistics.mean([t.pnl for t in wins]) if wins else 0 avg_loss = statistics.mean([t.pnl for t in losses]) if losses else 0 avg_hold = statistics.mean([t.hold_seconds for t in trades]) pf_num = sum(t.pnl for t in wins) pf_den = abs(sum(t.pnl for t in losses)) pf = pf_num / pf_den if pf_den > 0 else float('inf') # 连续亏损 max_streak = 0 cur = 0 for t in trades: if t.pnl <= 0: cur += 1 max_streak = max(max_streak, cur) else: cur = 0 long_t = [t for t in trades if t.direction == 'long'] short_t = [t for t in trades if t.direction == 'short'] long_wr = len([t for t in long_t if t.pnl > 0]) / len(long_t) * 100 if long_t else 0 short_wr = len([t for t in short_t if t.pnl > 0]) / len(short_t) * 100 if short_t else 0 # 平仓原因 reasons = {} for t in trades: r = t.close_reason.split(' (')[0] reasons[r] = reasons.get(r, 0) + 1 under_3m = len([t for t in trades if t.hold_seconds < 180]) w = 65 print(f"\n{'='*w}") print(f" [{strategy.name}] Backtest Report") print(f"{'='*w}") print(f"\n--- Account ---") print(f" Initial: {strategy.initial_balance:>12.2f} USDT") print(f" Final: {strategy.balance:>12.2f} USDT") print(f" Net P&L: {net:>+12.2f} USDT ({net/strategy.initial_balance*100:+.2f}%)") print(f" Max Drawdown: {strategy.max_dd_pct*100:>11.2f}%") print(f"\n--- Trades ---") print(f" Total: {n:>8}") print(f" Wins: {len(wins):>8} ({wr:.1f}%)") print(f" Losses: {len(losses):>8} ({100-wr:.1f}%)") print(f" Long: {len(long_t):>8} (WR {long_wr:.1f}%)") print(f" Short: {len(short_t):>8} (WR {short_wr:.1f}%)") print(f" Profit Factor: {pf:>8.2f}") print(f" Max Loss Streak:{max_streak:>8}") print(f"\n--- P&L ---") print(f" Direction P&L: {total_pnl:>+12.4f} USDT") print(f" Avg per trade: {avg_pnl:>+12.4f} USDT") print(f" Avg win: {avg_win:>+12.4f} USDT") print(f" Avg loss: {avg_loss:>+12.4f} USDT") print(f" Best trade: {max(t.pnl for t in trades):>+12.4f} USDT") print(f" Worst trade: {min(t.pnl for t in trades):>+12.4f} USDT") print(f"\n--- Fees & Rebate ---") print(f" Volume: {total_vol:>12.2f} USDT") print(f" Total Fees: {total_fee:>12.4f} USDT") print(f" Rebate (90%): {total_rebate:>+12.4f} USDT") print(f" Net Fee Cost: {total_fee - total_rebate:>12.4f} USDT") print(f"\n--- Hold Time ---") print(f" Average: {avg_hold:>8.0f}s ({avg_hold/60:.1f}min)") print(f" Shortest: {min(t.hold_seconds for t in trades):>8.0f}s") print(f" Longest: {max(t.hold_seconds for t in trades):>8.0f}s") print(f" Under 3min: {under_3m:>8} (hard SL only)") print(f"\n--- Close Reasons ---") for r, c in sorted(reasons.items(), key=lambda x: -x[1]): print(f" {r:<22} {c:>6} ({c/n*100:.1f}%)") # 月度统计 print(f"\n--- Monthly ---") print(f" {'Month':<10} {'Trades':>6} {'Dir PnL':>10} {'Rebate':>10} {'Net':>10} {'WR':>6}") print(f" {'-'*54}") monthly = {} for t in trades: k = t.close_time.strftime('%Y-%m') if k not in monthly: monthly[k] = {'n': 0, 'pnl': 0, 'rebate': 0, 'fee': 0, 'wins': 0} monthly[k]['n'] += 1 monthly[k]['pnl'] += t.pnl monthly[k]['rebate'] += t.rebate monthly[k]['fee'] += t.fee if t.pnl > 0: monthly[k]['wins'] += 1 for month in sorted(monthly.keys()): m = monthly[month] net_m = m['pnl'] - m['fee'] + m['rebate'] # 正确的月度净收益 wr_m = m['wins'] / m['n'] * 100 if m['n'] > 0 else 0 print(f" {month:<10} {m['n']:>6} {m['pnl']:>+10.2f} {m['rebate']:>10.2f} {net_m:>+10.2f} {wr_m:>5.1f}%") print(f"{'='*w}") # 保存CSV csv_path = Path(__file__).parent.parent / f'{strategy.name}_trades.csv' with open(csv_path, 'w', encoding='utf-8-sig') as f: f.write("open_time,close_time,dir,open_px,close_px,size,pnl,pnl_pct,fee,rebate,hold_sec,reason\n") for t in trades: f.write(f"{t.open_time},{t.close_time},{t.direction}," f"{t.open_price:.2f},{t.close_price:.2f},{t.size:.2f}," f"{t.pnl:.4f},{t.pnl_pct*100:.4f}%,{t.fee:.4f},{t.rebate:.4f}," f"{t.hold_seconds:.0f},{t.close_reason}\n") log.ok(f"Trades saved: {csv_path}") # ========================= 主函数 ========================= def main(): data = load_1m_klines('2025-01-01', '2025-12-31') if not data: log.err("No data!") return common = dict( initial_balance=1000.0, leverage=50, risk_pct=0.005, taker_fee=0.0006, rebate_rate=0.90, min_hold_sec=200, max_hold_sec=1800, ) # === 策略A: 网格交易 === grid = GridStrategy( grid_pct=0.0020, # 0.20% 网格间距 tp_grids=1, # 止盈1格 (0.20%) sl_grids=3, # 止损3格 (0.60%) trend_ema_period=120, # 2小时EMA趋势过滤 **common, ) grid.run(data) print_report(grid) # === 策略B: EMA趋势跟随 === ema = EMATrendStrategy( fast_period=8, slow_period=21, big_period=120, atr_period=14, atr_min_pct=0.0003, # 最低波动率过滤 **common, ) ema.run(data) print_report(ema) # === 对比摘要 === print(f"\n{'='*65}") print(f" COMPARISON SUMMARY") print(f"{'='*65}") print(f" {'Metric':<25} {'Grid+Trend':>18} {'EMA-Trend':>18}") print(f" {'-'*61}") for s in [grid, ema]: s._net = s.balance - s.initial_balance s._trades_n = len(s.trades) s._wr = len([t for t in s.trades if t.pnl > 0]) / len(s.trades) * 100 if s.trades else 0 s._dir_pnl = sum(t.pnl for t in s.trades) s._rebate = sum(t.rebate for t in s.trades) s._fee = sum(t.fee for t in s.trades) s._vol = sum(t.size for t in s.trades) * 2 rows = [ ("Net P&L (USDT)", f"{grid._net:+.2f}", f"{ema._net:+.2f}"), ("Net P&L (%)", f"{grid._net/grid.initial_balance*100:+.2f}%", f"{ema._net/ema.initial_balance*100:+.2f}%"), ("Max Drawdown", f"{grid.max_dd_pct*100:.2f}%", f"{ema.max_dd_pct*100:.2f}%"), ("Total Trades", f"{grid._trades_n}", f"{ema._trades_n}"), ("Win Rate", f"{grid._wr:.1f}%", f"{ema._wr:.1f}%"), ("Direction P&L", f"{grid._dir_pnl:+.2f}", f"{ema._dir_pnl:+.2f}"), ("Total Volume", f"{grid._vol:,.0f}", f"{ema._vol:,.0f}"), ("Total Fees", f"{grid._fee:.2f}", f"{ema._fee:.2f}"), ("Rebate Income", f"{grid._rebate:+.2f}", f"{ema._rebate:+.2f}"), ] for label, v1, v2 in rows: print(f" {label:<25} {v1:>18} {v2:>18}") print(f"{'='*65}") if __name__ == '__main__': main()