Files
lm_code/交易/bitmart-参数优化.py
Your Name b5af5b07f3 哈哈
2026-02-15 02:16:45 +08:00

377 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
EMA趋势策略参数优化扫描
基于前一轮回测发现 EMA-Trend 方向盈利 +327 USDT仅差 101 USDT 即可盈利。
核心优化方向:减少交易次数(降低费用),同时保持方向盈利。
扫描参数:
- fast_ema: [8, 13, 15, 20]
- slow_ema: [21, 34, 40, 55]
- big_ema: [120, 200, 300]
- atr_min_pct: [0.0003, 0.0005, 0.0008, 0.0012]
- stop_loss: [0.003, 0.004, 0.005, 0.006]
- max_hold: [900, 1200, 1800, 2700, 3600]
"""
import time
import datetime
import sqlite3
import itertools
from pathlib import Path
from dataclasses import dataclass
from typing import List
# ========================= EMA =========================
class EMA:
def __init__(self, period):
self.k = 2.0 / (period + 1)
self.value = None
def update(self, price):
if self.value is None:
self.value = price
else:
self.value = price * self.k + self.value * (1 - self.k)
return self.value
# ========================= 数据加载 =========================
def load_data(start_date='2025-01-01', end_date='2025-12-31'):
db_path = Path(__file__).parent.parent / 'models' / 'database.db'
start_ms = int(datetime.datetime.strptime(start_date, '%Y-%m-%d').timestamp()) * 1000
end_ms = int((datetime.datetime.strptime(end_date, '%Y-%m-%d') + datetime.timedelta(days=1)).timestamp()) * 1000
conn = sqlite3.connect(str(db_path))
rows = conn.cursor().execute(
"SELECT id, open, high, low, close FROM bitmart_eth_1m WHERE id >= ? AND id < ? ORDER BY id",
(start_ms, end_ms)
).fetchall()
conn.close()
data = []
for r in rows:
data.append({
'datetime': datetime.datetime.fromtimestamp(r[0] / 1000.0),
'open': r[1], 'high': r[2], 'low': r[3], 'close': r[4],
})
return data
# ========================= 快速回测引擎 =========================
def run_ema_backtest(data, fast_p, slow_p, big_p, atr_period, atr_min,
stop_loss_pct, hard_sl_pct, max_hold_sec,
initial_balance=1000.0, leverage=50, risk_pct=0.005,
taker_fee=0.0006, rebate_rate=0.90, min_hold_sec=200):
"""
快速回测 EMA 趋势策略,返回关键指标字典
"""
balance = initial_balance
position = 0 # -1/0/1
open_price = 0.0
open_time = None
pos_size = 0.0
pending = None
ema_f = EMA(fast_p)
ema_s = EMA(slow_p)
ema_b = EMA(big_p)
highs = []
lows = []
closes = []
prev_fast = None
prev_slow = None
trade_count = 0
win_count = 0
total_dir_pnl = 0.0
total_fee = 0.0
total_rebate = 0.0
def calc_atr():
if len(highs) < atr_period + 1:
return None
trs = []
for i in range(-atr_period, 0):
tr = max(highs[i] - lows[i],
abs(highs[i] - closes[i-1]),
abs(lows[i] - closes[i-1]))
trs.append(tr)
return sum(trs) / len(trs)
def do_open(direction, price, dt_):
nonlocal balance, position, open_price, open_time, pos_size, total_fee
ps = balance * risk_pct * leverage
if ps < 1:
return
fee_ = ps * taker_fee
balance -= fee_
position = 1 if direction == 'long' else -1
open_price = price
open_time = dt_
pos_size = ps
def do_close(price, dt_):
nonlocal balance, position, open_price, open_time, pos_size
nonlocal trade_count, win_count, total_dir_pnl, total_fee, total_rebate, pending
if position == 0:
return
if position == 1:
pp = (price - open_price) / open_price
else:
pp = (open_price - price) / open_price
pnl_ = pos_size * pp
cv = pos_size * (1 + pp)
cf = cv * taker_fee
of = pos_size * taker_fee
tf = of + cf
rb = tf * rebate_rate
balance += pnl_ - cf + rb
total_dir_pnl += pnl_
total_fee += tf
total_rebate += rb
trade_count += 1
if pnl_ > 0:
win_count += 1
position = 0
open_price = 0.0
open_time = None
pos_size = 0.0
pending = None
for bar in data:
price = bar['close']
dt = bar['datetime']
highs.append(bar['high'])
lows.append(bar['low'])
closes.append(price)
fast = ema_f.update(price)
slow = ema_s.update(price)
big = ema_b.update(price)
atr = calc_atr()
atr_pct = atr / price if atr and price > 0 else 0
cross_up = (prev_fast is not None and prev_fast <= prev_slow and fast > slow)
cross_down = (prev_fast is not None and prev_fast >= prev_slow and fast < slow)
prev_fast = fast
prev_slow = slow
# === 有持仓 ===
if position != 0 and open_time:
if position == 1:
p = (price - open_price) / open_price
else:
p = (open_price - price) / open_price
hs = (dt - open_time).total_seconds()
# 硬止损
if -p >= hard_sl_pct:
do_close(price, dt)
continue
can_close_ = hs >= min_hold_sec
if can_close_:
# 止损
if -p >= stop_loss_pct:
do_close(price, dt)
continue
# 超时
if hs >= max_hold_sec:
do_close(price, dt)
continue
# 反手
if position == 1 and cross_down:
do_close(price, dt)
if price < big and atr_pct >= atr_min:
do_open('short', price, dt)
continue
if position == -1 and cross_up:
do_close(price, dt)
if price > big and atr_pct >= atr_min:
do_open('long', price, dt)
continue
# 延迟信号
if pending == 'close_long' and position == 1:
do_close(price, dt)
if fast < slow and price < big and atr_pct >= atr_min:
do_open('short', price, dt)
continue
if pending == 'close_short' and position == -1:
do_close(price, dt)
if fast > slow and price > big and atr_pct >= atr_min:
do_open('long', price, dt)
continue
else:
if position == 1 and cross_down:
pending = 'close_long'
elif position == -1 and cross_up:
pending = 'close_short'
# === 无持仓 ===
if position == 0 and atr_pct >= atr_min:
if cross_up and price > big:
do_open('long', price, dt)
elif cross_down and price < big:
do_open('short', price, dt)
# 强制平仓
if position != 0:
do_close(data[-1]['close'], data[-1]['datetime'])
net = balance - initial_balance
net_fee_cost = total_fee - total_rebate
vol = total_fee / taker_fee if taker_fee > 0 else 0
wr = win_count / trade_count * 100 if trade_count > 0 else 0
avg_dir = total_dir_pnl / trade_count if trade_count > 0 else 0
return {
'balance': balance,
'net': net,
'net_pct': net / initial_balance * 100,
'trades': trade_count,
'win_rate': wr,
'dir_pnl': total_dir_pnl,
'total_fee': total_fee,
'rebate': total_rebate,
'net_fee': net_fee_cost,
'volume': vol,
'avg_dir_pnl': avg_dir,
}
# ========================= 参数扫描 =========================
def main():
print("Loading data...")
data = load_data('2025-01-01', '2025-12-31')
print(f"Loaded {len(data)} bars\n")
# 参数组合
param_grid = {
'fast_p': [8, 13, 15, 20],
'slow_p': [21, 34, 40, 55],
'big_p': [120, 200, 300],
'atr_min': [0.0003, 0.0005, 0.0008, 0.0012],
'stop_loss_pct':[0.003, 0.004, 0.005, 0.008],
'max_hold_sec': [900, 1200, 1800, 3600],
}
# 过滤无效组合 (fast >= slow)
combos = []
for fp, sp, bp, am, sl, mh in itertools.product(
param_grid['fast_p'], param_grid['slow_p'], param_grid['big_p'],
param_grid['atr_min'], param_grid['stop_loss_pct'], param_grid['max_hold_sec']
):
if fp >= sp:
continue
combos.append((fp, sp, bp, am, sl, mh))
print(f"Total parameter combinations: {len(combos)}")
print(f"Estimated time: ~{len(combos) * 2 / 60:.0f} minutes\n")
# 只跑最有潜力的子集(减少扫描时间)
# 基于前次回测聚焦在更长周期EMA减少交易+ 更高ATR过滤质量过滤
focused_combos = []
for fp, sp, bp, am, sl, mh in combos:
# 过滤:聚焦减少交易次数的参数
if fp < 8:
continue
if sp < 21:
continue
focused_combos.append((fp, sp, bp, am, sl, mh))
print(f"Focused combinations: {len(focused_combos)}")
# 如果组合太多,进一步采样
if len(focused_combos) > 500:
# 分两轮:先粗扫,再精调
print("Phase 1: Coarse scan with subset...")
coarse_combos = []
for fp, sp, bp, am, sl, mh in focused_combos:
if am in [0.0003, 0.0008] and sl in [0.004, 0.006] and mh in [1200, 1800]:
coarse_combos.append((fp, sp, bp, am, sl, mh))
elif am in [0.0005, 0.0012] and sl in [0.003, 0.005, 0.008] and mh in [900, 3600]:
coarse_combos.append((fp, sp, bp, am, sl, mh))
focused_combos = coarse_combos[:600] # cap
print(f" Reduced to {len(focused_combos)} combos")
results = []
t0 = time.time()
for idx, (fp, sp, bp, am, sl, mh) in enumerate(focused_combos):
r = run_ema_backtest(
data, fast_p=fp, slow_p=sp, big_p=bp,
atr_period=14, atr_min=am,
stop_loss_pct=sl, hard_sl_pct=sl * 1.5,
max_hold_sec=mh,
)
r['params'] = f"EMA({fp}/{sp}/{bp}) ATR>{am*100:.2f}% SL={sl*100:.1f}% MaxH={mh}s"
r['fp'] = fp
r['sp'] = sp
r['bp'] = bp
r['am'] = am
r['sl'] = sl
r['mh'] = mh
results.append(r)
if (idx + 1) % 50 == 0:
elapsed = time.time() - t0
eta = elapsed / (idx + 1) * (len(focused_combos) - idx - 1)
print(f" [{idx+1}/{len(focused_combos)}] elapsed={elapsed:.0f}s eta={eta:.0f}s")
total_time = time.time() - t0
print(f"\nScan complete! {len(results)} combos in {total_time:.1f}s")
# 按净收益排序
results.sort(key=lambda x: x['net'], reverse=True)
# === 打印 Top 20 ===
print(f"\n{'='*120}")
print(f" TOP 20 PARAMETER COMBINATIONS (by Net P&L)")
print(f"{'='*120}")
print(f" {'#':>3} {'Params':<52} {'Net%':>7} {'Net$':>9} {'Trades':>7} {'WR':>6} {'DirPnL':>9} {'Rebate':>9} {'NetFee':>8}")
print(f" {'-'*116}")
for i, r in enumerate(results[:20]):
print(f" {i+1:>3} {r['params']:<52} {r['net_pct']:>+6.2f}% {r['net']:>+8.2f} {r['trades']:>7} {r['win_rate']:>5.1f}% {r['dir_pnl']:>+8.2f} {r['rebate']:>8.2f} {r['net_fee']:>8.2f}")
# === 打印 Bottom 5 (最差) ===
print(f"\n BOTTOM 5:")
print(f" {'-'*116}")
for i, r in enumerate(results[-5:]):
print(f" {len(results)-4+i:>3} {r['params']:<52} {r['net_pct']:>+6.2f}% {r['net']:>+8.2f} {r['trades']:>7} {r['win_rate']:>5.1f}% {r['dir_pnl']:>+8.2f} {r['rebate']:>8.2f} {r['net_fee']:>8.2f}")
print(f"{'='*120}")
# === 盈利的参数组合统计 ===
profitable = [r for r in results if r['net'] > 0]
print(f"\nProfitable combinations: {len(profitable)} / {len(results)} ({len(profitable)/len(results)*100:.1f}%)")
if profitable:
print(f"\nAll profitable combinations:")
print(f" {'#':>3} {'Params':<52} {'Net%':>7} {'Net$':>9} {'Trades':>7} {'WR':>6} {'DirPnL':>9} {'Rebate':>9}")
print(f" {'-'*106}")
for i, r in enumerate(profitable):
print(f" {i+1:>3} {r['params']:<52} {r['net_pct']:>+6.2f}% {r['net']:>+8.2f} {r['trades']:>7} {r['win_rate']:>5.1f}% {r['dir_pnl']:>+8.2f} {r['rebate']:>8.2f}")
# 保存全部结果到CSV
csv_path = Path(__file__).parent.parent / 'param_scan_results.csv'
with open(csv_path, 'w', encoding='utf-8-sig') as f:
f.write("fast,slow,big,atr_min,stop_loss,max_hold,net_pct,net_usd,trades,win_rate,dir_pnl,rebate,net_fee,volume\n")
for r in results:
f.write(f"{r['fp']},{r['sp']},{r['bp']},{r['am']},{r['sl']},{r['mh']},"
f"{r['net_pct']:.4f},{r['net']:.4f},{r['trades']},{r['win_rate']:.2f},"
f"{r['dir_pnl']:.4f},{r['rebate']:.4f},{r['net_fee']:.4f},{r['volume']:.0f}\n")
print(f"\nFull results saved: {csv_path}")
if __name__ == '__main__':
main()