Files
lm_code/交易/bitmart-新策略回测.py
Your Name b5af5b07f3 哈哈
2026-02-15 02:16:45 +08:00

674 lines
24 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
BitMart 返佣策略回测 — 双策略对比
策略A: 网格交易 (Grid Trading)
- 围绕EMA中轨设定网格价格触及网格线时开仓
- 固定止盈(1格)、固定止损(3格)
- 趋势过滤:只在趋势方向开仓
策略B: EMA趋势跟随 (EMA Trend Following)
- 快慢EMA金叉做多、死叉做空
- 始终持仓,信号反转时反手
- 大级别趋势过滤避免逆势
两个策略都:
- 严格执行 >3分钟最低持仓
- 计算90%返佣收益
- 输出详细对比报告
"""
import time
import datetime
import statistics
import sqlite3
from pathlib import Path
from dataclasses import dataclass
from typing import List
# ========================= 简易 Logger =========================
class _L:
@staticmethod
def info(m): print(f"[INFO] {m}")
@staticmethod
def ok(m): print(f"[ OK ] {m}")
@staticmethod
def warn(m): print(f"[WARN] {m}")
@staticmethod
def err(m): print(f"[ERR ] {m}")
log = _L()
# ========================= 交易记录 =========================
@dataclass
class Trade:
open_time: datetime.datetime
close_time: datetime.datetime
direction: str
open_price: float
close_price: float
size: float
pnl: float
pnl_pct: float
fee: float
rebate: float
hold_seconds: float
close_reason: str
# ========================= 数据加载 =========================
def load_1m_klines(start_date='2025-01-01', end_date='2025-12-31'):
db_path = Path(__file__).parent.parent / 'models' / 'database.db'
start_dt = datetime.datetime.strptime(start_date, '%Y-%m-%d')
end_dt = datetime.datetime.strptime(end_date, '%Y-%m-%d') + datetime.timedelta(days=1)
start_ms = int(start_dt.timestamp()) * 1000
end_ms = int(end_dt.timestamp()) * 1000
conn = sqlite3.connect(str(db_path))
cursor = conn.cursor()
cursor.execute(
"SELECT id, open, high, low, close FROM bitmart_eth_1m "
"WHERE id >= ? AND id < ? ORDER BY id",
(start_ms, end_ms)
)
rows = cursor.fetchall()
conn.close()
data = []
for r in rows:
dt = datetime.datetime.fromtimestamp(r[0] / 1000.0)
data.append({
'datetime': dt,
'open': r[1], 'high': r[2], 'low': r[3], 'close': r[4],
})
log.info(f"Loaded {len(data)} bars ({start_date} ~ {end_date})")
return data
# ========================= EMA 工具 =========================
class EMA:
def __init__(self, period):
self.period = period
self.k = 2.0 / (period + 1)
self.value = None
def update(self, price):
if self.value is None:
self.value = price
else:
self.value = price * self.k + self.value * (1 - self.k)
return self.value
# ========================= 基础回测引擎 =========================
class BaseBacktest:
def __init__(self, name, initial_balance=1000.0, leverage=50,
risk_pct=0.005, taker_fee=0.0006, rebate_rate=0.90,
min_hold_sec=200, max_hold_sec=1800):
self.name = name
self.initial_balance = initial_balance
self.balance = initial_balance
self.leverage = leverage
self.risk_pct = risk_pct
self.taker_fee = taker_fee
self.rebate_rate = rebate_rate
self.min_hold_sec = min_hold_sec
self.max_hold_sec = max_hold_sec
self.position = 0
self.open_price = 0.0
self.open_time = None
self.pos_size = 0.0
self.trades: List[Trade] = []
self.equity_curve = []
self.peak_equity = initial_balance
self.max_dd_pct = 0.0
def _open(self, direction, price, dt):
self.pos_size = self.balance * self.risk_pct * self.leverage
if self.pos_size < 1:
return False
fee = self.pos_size * self.taker_fee
self.balance -= fee
self.position = 1 if direction == 'long' else -1
self.open_price = price
self.open_time = dt
return True
def _close(self, price, dt, reason):
if self.position == 0:
return None
if self.position == 1:
pnl_pct = (price - self.open_price) / self.open_price
else:
pnl_pct = (self.open_price - price) / self.open_price
pnl = self.pos_size * pnl_pct
close_val = self.pos_size * (1 + pnl_pct)
close_fee = close_val * self.taker_fee
open_fee = self.pos_size * self.taker_fee
total_fee = open_fee + close_fee
rebate = total_fee * self.rebate_rate
self.balance += pnl - close_fee + rebate
hold_sec = (dt - self.open_time).total_seconds()
trade = Trade(
open_time=self.open_time, close_time=dt,
direction='long' if self.position == 1 else 'short',
open_price=self.open_price, close_price=price,
size=self.pos_size, pnl=pnl, pnl_pct=pnl_pct,
fee=total_fee, rebate=rebate,
hold_seconds=hold_sec, close_reason=reason,
)
self.trades.append(trade)
self.position = 0
self.open_price = 0.0
self.open_time = None
self.pos_size = 0.0
return trade
def hold_seconds(self, dt):
if self.open_time is None:
return 0
return (dt - self.open_time).total_seconds()
def can_close(self, dt):
return self.hold_seconds(dt) >= self.min_hold_sec
def cur_pnl_pct(self, price):
if self.position == 1:
return (price - self.open_price) / self.open_price
elif self.position == -1:
return (self.open_price - price) / self.open_price
return 0
def track_equity(self, dt, price, every_n=60, bar_idx=0):
if bar_idx % every_n != 0:
return
eq = self.balance
if self.position != 0 and self.open_price > 0:
eq += self.pos_size * self.cur_pnl_pct(price)
self.equity_curve.append({'datetime': dt, 'equity': eq})
if eq > self.peak_equity:
self.peak_equity = eq
dd = (self.peak_equity - eq) / self.peak_equity if self.peak_equity > 0 else 0
if dd > self.max_dd_pct:
self.max_dd_pct = dd
# ========================= 策略A: 网格交易 =========================
class GridStrategy(BaseBacktest):
"""
网格交易 + 趋势过滤
- 用 EMA(120) 判断趋势方向
- 网格间距 = grid_pct (如 0.20%)
- 顺势开仓:上涨趋势中价格回落到网格线做多,下跌趋势中价格反弹到网格线做空
- TP: tp_grids 格 (如 1格 = 0.20%)
- SL: sl_grids 格 (如 3格 = 0.60%)
- 最低持仓 200 秒
"""
def __init__(self, grid_pct=0.0020, tp_grids=1, sl_grids=3,
trend_ema_period=120, **kwargs):
super().__init__(name="Grid+Trend", **kwargs)
self.grid_pct = grid_pct
self.tp_pct = grid_pct * tp_grids
self.sl_pct = grid_pct * sl_grids
self.hard_sl_pct = grid_pct * (sl_grids + 1)
self.trend_ema = EMA(trend_ema_period)
self.last_grid_cross = None # 上一次穿越的网格线价格
self.cooldown_until = None # 冷却期
def get_grid_level(self, price, direction='below'):
"""获取价格最近的网格线"""
grid_size = price * self.grid_pct
if grid_size == 0:
return price
if direction == 'below':
return price - (price % grid_size)
else:
return price - (price % grid_size) + grid_size
def run(self, data):
log.info(f"[{self.name}] Starting... grid={self.grid_pct*100:.2f}% TP={self.tp_pct*100:.2f}% SL={self.sl_pct*100:.2f}%")
t0 = time.time()
prev_close = None
for i, bar in enumerate(data):
price = bar['close']
high = bar['high']
low = bar['low']
dt = bar['datetime']
ema_val = self.trend_ema.update(price)
# 冷却期检查
if self.cooldown_until and dt < self.cooldown_until:
self.track_equity(dt, price, bar_idx=i)
prev_close = price
continue
self.cooldown_until = None
# === 有持仓:检查平仓 ===
if self.position != 0:
p = self.cur_pnl_pct(price)
hs = self.hold_seconds(dt)
# 硬止损(不受时间限制)
if -p >= self.hard_sl_pct:
self._close(price, dt, f"hard_SL ({p*100:+.3f}%)")
self.cooldown_until = dt + datetime.timedelta(seconds=120)
self.track_equity(dt, price, bar_idx=i)
prev_close = price
continue
# 满足最低持仓后
if self.can_close(dt):
# 止盈
if p >= self.tp_pct:
self._close(price, dt, f"TP ({p*100:+.3f}%)")
prev_close = price
self.track_equity(dt, price, bar_idx=i)
continue
# 止损
if -p >= self.sl_pct:
self._close(price, dt, f"SL ({p*100:+.3f}%)")
self.cooldown_until = dt + datetime.timedelta(seconds=120)
prev_close = price
self.track_equity(dt, price, bar_idx=i)
continue
# 超时
if hs >= self.max_hold_sec:
self._close(price, dt, f"timeout ({hs:.0f}s)")
prev_close = price
self.track_equity(dt, price, bar_idx=i)
continue
# === 无持仓:检查开仓 ===
if self.position == 0 and prev_close is not None:
grid_below = self.get_grid_level(prev_close, 'below')
grid_above = self.get_grid_level(prev_close, 'above')
# 上涨趋势:价格回落到下方网格线 → 做多
if price > ema_val and low <= grid_below and prev_close > grid_below:
self._open('long', price, dt)
# 下跌趋势:价格反弹到上方网格线 → 做空
elif price < ema_val and high >= grid_above and prev_close < grid_above:
self._open('short', price, dt)
self.track_equity(dt, price, bar_idx=i)
prev_close = price
if i > 0 and i % (len(data) // 10) == 0:
log.info(f" [{self.name}] {i/len(data)*100:.0f}% | bal={self.balance:.2f} | trades={len(self.trades)}")
# 强制平仓
if self.position != 0:
self._close(data[-1]['close'], data[-1]['datetime'], "backtest_end")
log.ok(f"[{self.name}] Done in {time.time()-t0:.1f}s | {len(self.trades)} trades")
return self.trades
# ========================= 策略B: EMA趋势跟随 =========================
class EMATrendStrategy(BaseBacktest):
"""
EMA 趋势跟随
- 快线 EMA(8),慢线 EMA(21)
- 大级别过滤 EMA(120)
- 金叉且价格在大EMA上方 → 做多
- 死叉且价格在大EMA下方 → 做空
- 反向交叉时反手(满足持仓时间后)
- 加入 ATR 波动率过滤,低波动时不交易
"""
def __init__(self, fast_period=8, slow_period=21, big_period=120,
atr_period=14, atr_min_pct=0.0003, **kwargs):
super().__init__(name="EMA-Trend", **kwargs)
self.ema_fast = EMA(fast_period)
self.ema_slow = EMA(slow_period)
self.ema_big = EMA(big_period)
self.atr_period = atr_period
self.atr_min_pct = atr_min_pct # 最低波动率过滤
self.highs = []
self.lows = []
self.closes = []
self.prev_fast = None
self.prev_slow = None
self.pending_signal = None # 等待持仓时间满足后执行的信号
def calc_atr(self):
if len(self.highs) < self.atr_period + 1:
return None
trs = []
for i in range(-self.atr_period, 0):
h = self.highs[i]
l = self.lows[i]
pc = self.closes[i - 1]
tr = max(h - l, abs(h - pc), abs(l - pc))
trs.append(tr)
return sum(trs) / len(trs)
def run(self, data):
log.info(f"[{self.name}] Starting... fast=EMA8 slow=EMA21 big=EMA120")
t0 = time.time()
stop_loss_pct = 0.004 # 0.4% 止损
hard_sl_pct = 0.006 # 0.6% 硬止损
for i, bar in enumerate(data):
price = bar['close']
dt = bar['datetime']
self.highs.append(bar['high'])
self.lows.append(bar['low'])
self.closes.append(price)
fast = self.ema_fast.update(price)
slow = self.ema_slow.update(price)
big = self.ema_big.update(price)
# ATR 波动率过滤
atr = self.calc_atr()
if atr is not None and price > 0:
atr_pct = atr / price
else:
atr_pct = 0
# 检测交叉
cross_up = (self.prev_fast is not None and
self.prev_fast <= self.prev_slow and fast > slow)
cross_down = (self.prev_fast is not None and
self.prev_fast >= self.prev_slow and fast < slow)
self.prev_fast = fast
self.prev_slow = slow
# === 有持仓 ===
if self.position != 0:
p = self.cur_pnl_pct(price)
# 硬止损
if -p >= hard_sl_pct:
self._close(price, dt, f"hard_SL ({p*100:+.3f}%)")
self.track_equity(dt, price, bar_idx=i)
continue
if self.can_close(dt):
# 止损
if -p >= stop_loss_pct:
self._close(price, dt, f"SL ({p*100:+.3f}%)")
self.track_equity(dt, price, bar_idx=i)
continue
# 超时
hs = self.hold_seconds(dt)
if hs >= self.max_hold_sec:
self._close(price, dt, f"timeout ({hs:.0f}s)")
self.track_equity(dt, price, bar_idx=i)
continue
# 反手信号:持多遇到死叉 → 平多
if self.position == 1 and cross_down:
self._close(price, dt, "cross_reverse")
if price < big and atr_pct >= self.atr_min_pct:
self._open('short', price, dt)
self.track_equity(dt, price, bar_idx=i)
continue
# 反手信号:持空遇到金叉 → 平空
if self.position == -1 and cross_up:
self._close(price, dt, "cross_reverse")
if price > big and atr_pct >= self.atr_min_pct:
self._open('long', price, dt)
self.track_equity(dt, price, bar_idx=i)
continue
else:
# 未满最低持仓时间,记录待处理信号
if self.position == 1 and cross_down:
self.pending_signal = 'close_long'
elif self.position == -1 and cross_up:
self.pending_signal = 'close_short'
# 处理待处理信号(持仓时间刚好满足)
if self.pending_signal and self.can_close(dt):
if self.pending_signal == 'close_long' and self.position == 1:
self._close(price, dt, "delayed_cross")
if fast < slow and price < big and atr_pct >= self.atr_min_pct:
self._open('short', price, dt)
elif self.pending_signal == 'close_short' and self.position == -1:
self._close(price, dt, "delayed_cross")
if fast > slow and price > big and atr_pct >= self.atr_min_pct:
self._open('long', price, dt)
self.pending_signal = None
# === 无持仓:检查开仓 ===
if self.position == 0 and atr_pct >= self.atr_min_pct:
if cross_up and price > big:
self._open('long', price, dt)
elif cross_down and price < big:
self._open('short', price, dt)
self.track_equity(dt, price, bar_idx=i)
if i > 0 and i % (len(data) // 10) == 0:
log.info(f" [{self.name}] {i/len(data)*100:.0f}% | bal={self.balance:.2f} | trades={len(self.trades)}")
if self.position != 0:
self._close(data[-1]['close'], data[-1]['datetime'], "backtest_end")
log.ok(f"[{self.name}] Done in {time.time()-t0:.1f}s | {len(self.trades)} trades")
return self.trades
# ========================= 报告生成 =========================
def print_report(strategy: BaseBacktest):
trades = strategy.trades
if not trades:
print(f"\n[{strategy.name}] No trades.")
return
n = len(trades)
wins = [t for t in trades if t.pnl > 0]
losses = [t for t in trades if t.pnl <= 0]
wr = len(wins) / n * 100
total_pnl = sum(t.pnl for t in trades)
total_fee = sum(t.fee for t in trades)
total_rebate = sum(t.rebate for t in trades)
net = strategy.balance - strategy.initial_balance
total_vol = sum(t.size for t in trades) * 2
avg_pnl = total_pnl / n
avg_win = statistics.mean([t.pnl for t in wins]) if wins else 0
avg_loss = statistics.mean([t.pnl for t in losses]) if losses else 0
avg_hold = statistics.mean([t.hold_seconds for t in trades])
pf_num = sum(t.pnl for t in wins)
pf_den = abs(sum(t.pnl for t in losses))
pf = pf_num / pf_den if pf_den > 0 else float('inf')
# 连续亏损
max_streak = 0
cur = 0
for t in trades:
if t.pnl <= 0:
cur += 1
max_streak = max(max_streak, cur)
else:
cur = 0
long_t = [t for t in trades if t.direction == 'long']
short_t = [t for t in trades if t.direction == 'short']
long_wr = len([t for t in long_t if t.pnl > 0]) / len(long_t) * 100 if long_t else 0
short_wr = len([t for t in short_t if t.pnl > 0]) / len(short_t) * 100 if short_t else 0
# 平仓原因
reasons = {}
for t in trades:
r = t.close_reason.split(' (')[0]
reasons[r] = reasons.get(r, 0) + 1
under_3m = len([t for t in trades if t.hold_seconds < 180])
w = 65
print(f"\n{'='*w}")
print(f" [{strategy.name}] Backtest Report")
print(f"{'='*w}")
print(f"\n--- Account ---")
print(f" Initial: {strategy.initial_balance:>12.2f} USDT")
print(f" Final: {strategy.balance:>12.2f} USDT")
print(f" Net P&L: {net:>+12.2f} USDT ({net/strategy.initial_balance*100:+.2f}%)")
print(f" Max Drawdown: {strategy.max_dd_pct*100:>11.2f}%")
print(f"\n--- Trades ---")
print(f" Total: {n:>8}")
print(f" Wins: {len(wins):>8} ({wr:.1f}%)")
print(f" Losses: {len(losses):>8} ({100-wr:.1f}%)")
print(f" Long: {len(long_t):>8} (WR {long_wr:.1f}%)")
print(f" Short: {len(short_t):>8} (WR {short_wr:.1f}%)")
print(f" Profit Factor: {pf:>8.2f}")
print(f" Max Loss Streak:{max_streak:>8}")
print(f"\n--- P&L ---")
print(f" Direction P&L: {total_pnl:>+12.4f} USDT")
print(f" Avg per trade: {avg_pnl:>+12.4f} USDT")
print(f" Avg win: {avg_win:>+12.4f} USDT")
print(f" Avg loss: {avg_loss:>+12.4f} USDT")
print(f" Best trade: {max(t.pnl for t in trades):>+12.4f} USDT")
print(f" Worst trade: {min(t.pnl for t in trades):>+12.4f} USDT")
print(f"\n--- Fees & Rebate ---")
print(f" Volume: {total_vol:>12.2f} USDT")
print(f" Total Fees: {total_fee:>12.4f} USDT")
print(f" Rebate (90%): {total_rebate:>+12.4f} USDT")
print(f" Net Fee Cost: {total_fee - total_rebate:>12.4f} USDT")
print(f"\n--- Hold Time ---")
print(f" Average: {avg_hold:>8.0f}s ({avg_hold/60:.1f}min)")
print(f" Shortest: {min(t.hold_seconds for t in trades):>8.0f}s")
print(f" Longest: {max(t.hold_seconds for t in trades):>8.0f}s")
print(f" Under 3min: {under_3m:>8} (hard SL only)")
print(f"\n--- Close Reasons ---")
for r, c in sorted(reasons.items(), key=lambda x: -x[1]):
print(f" {r:<22} {c:>6} ({c/n*100:.1f}%)")
# 月度统计
print(f"\n--- Monthly ---")
print(f" {'Month':<10} {'Trades':>6} {'Dir PnL':>10} {'Rebate':>10} {'Net':>10} {'WR':>6}")
print(f" {'-'*54}")
monthly = {}
for t in trades:
k = t.close_time.strftime('%Y-%m')
if k not in monthly:
monthly[k] = {'n': 0, 'pnl': 0, 'rebate': 0, 'fee': 0, 'wins': 0}
monthly[k]['n'] += 1
monthly[k]['pnl'] += t.pnl
monthly[k]['rebate'] += t.rebate
monthly[k]['fee'] += t.fee
if t.pnl > 0:
monthly[k]['wins'] += 1
for month in sorted(monthly.keys()):
m = monthly[month]
net_m = m['pnl'] - m['fee'] + m['rebate'] # 正确的月度净收益
wr_m = m['wins'] / m['n'] * 100 if m['n'] > 0 else 0
print(f" {month:<10} {m['n']:>6} {m['pnl']:>+10.2f} {m['rebate']:>10.2f} {net_m:>+10.2f} {wr_m:>5.1f}%")
print(f"{'='*w}")
# 保存CSV
csv_path = Path(__file__).parent.parent / f'{strategy.name}_trades.csv'
with open(csv_path, 'w', encoding='utf-8-sig') as f:
f.write("open_time,close_time,dir,open_px,close_px,size,pnl,pnl_pct,fee,rebate,hold_sec,reason\n")
for t in trades:
f.write(f"{t.open_time},{t.close_time},{t.direction},"
f"{t.open_price:.2f},{t.close_price:.2f},{t.size:.2f},"
f"{t.pnl:.4f},{t.pnl_pct*100:.4f}%,{t.fee:.4f},{t.rebate:.4f},"
f"{t.hold_seconds:.0f},{t.close_reason}\n")
log.ok(f"Trades saved: {csv_path}")
# ========================= 主函数 =========================
def main():
data = load_1m_klines('2025-01-01', '2025-12-31')
if not data:
log.err("No data!")
return
common = dict(
initial_balance=1000.0,
leverage=50,
risk_pct=0.005,
taker_fee=0.0006,
rebate_rate=0.90,
min_hold_sec=200,
max_hold_sec=1800,
)
# === 策略A: 网格交易 ===
grid = GridStrategy(
grid_pct=0.0020, # 0.20% 网格间距
tp_grids=1, # 止盈1格 (0.20%)
sl_grids=3, # 止损3格 (0.60%)
trend_ema_period=120, # 2小时EMA趋势过滤
**common,
)
grid.run(data)
print_report(grid)
# === 策略B: EMA趋势跟随 ===
ema = EMATrendStrategy(
fast_period=8,
slow_period=21,
big_period=120,
atr_period=14,
atr_min_pct=0.0003, # 最低波动率过滤
**common,
)
ema.run(data)
print_report(ema)
# === 对比摘要 ===
print(f"\n{'='*65}")
print(f" COMPARISON SUMMARY")
print(f"{'='*65}")
print(f" {'Metric':<25} {'Grid+Trend':>18} {'EMA-Trend':>18}")
print(f" {'-'*61}")
for s in [grid, ema]:
s._net = s.balance - s.initial_balance
s._trades_n = len(s.trades)
s._wr = len([t for t in s.trades if t.pnl > 0]) / len(s.trades) * 100 if s.trades else 0
s._dir_pnl = sum(t.pnl for t in s.trades)
s._rebate = sum(t.rebate for t in s.trades)
s._fee = sum(t.fee for t in s.trades)
s._vol = sum(t.size for t in s.trades) * 2
rows = [
("Net P&L (USDT)", f"{grid._net:+.2f}", f"{ema._net:+.2f}"),
("Net P&L (%)", f"{grid._net/grid.initial_balance*100:+.2f}%", f"{ema._net/ema.initial_balance*100:+.2f}%"),
("Max Drawdown", f"{grid.max_dd_pct*100:.2f}%", f"{ema.max_dd_pct*100:.2f}%"),
("Total Trades", f"{grid._trades_n}", f"{ema._trades_n}"),
("Win Rate", f"{grid._wr:.1f}%", f"{ema._wr:.1f}%"),
("Direction P&L", f"{grid._dir_pnl:+.2f}", f"{ema._dir_pnl:+.2f}"),
("Total Volume", f"{grid._vol:,.0f}", f"{ema._vol:,.0f}"),
("Total Fees", f"{grid._fee:.2f}", f"{ema._fee:.2f}"),
("Rebate Income", f"{grid._rebate:+.2f}", f"{ema._rebate:+.2f}"),
]
for label, v1, v2 in rows:
print(f" {label:<25} {v1:>18} {v2:>18}")
print(f"{'='*65}")
if __name__ == '__main__':
main()