465 lines
19 KiB
Python
465 lines
19 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
自适应三分位趋势策略 - 回测引擎
|
||
使用 Bitmart 方式获取数据(API 或 CSV),含硬止损、分批止盈、移动止损、时间止损
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import csv
|
||
import datetime
|
||
from typing import List, Dict, Optional, Tuple
|
||
|
||
# 使用 UTC 时区,避免 utcfromtimestamp 弃用警告
|
||
def _utc_dt(ts):
|
||
if hasattr(datetime, "timezone") and hasattr(datetime.timezone, "utc"):
|
||
return datetime.datetime.fromtimestamp(ts, tz=datetime.timezone.utc)
|
||
return datetime.datetime.utcfromtimestamp(ts)
|
||
|
||
try:
|
||
from loguru import logger
|
||
except ImportError:
|
||
import logging
|
||
logger = logging.getLogger(__name__)
|
||
|
||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
if ROOT_DIR not in sys.path:
|
||
sys.path.insert(0, ROOT_DIR)
|
||
|
||
from adaptive_third_strategy.config import (
|
||
STEP_5M,
|
||
STEP_15M,
|
||
STEP_60M,
|
||
ATR_PERIOD,
|
||
EMA_SHORT,
|
||
EMA_MID_FAST,
|
||
EMA_MID_SLOW,
|
||
EMA_LONG_FAST,
|
||
EMA_LONG_SLOW,
|
||
STOP_LOSS_ATR_MULT,
|
||
TIME_STOP_BARS,
|
||
TRAIL_START_ATR,
|
||
TRAIL_ATR_MULT,
|
||
TP1_ATR,
|
||
TP2_ATR,
|
||
TP3_ATR,
|
||
TP1_RATIO,
|
||
TP2_RATIO,
|
||
TP3_RATIO,
|
||
MIN_BARS_SINCE_ENTRY,
|
||
SAME_KLINE_NO_REVERSE,
|
||
REVERSE_BREAK_MULT,
|
||
REVERSE_LOSS_ATR,
|
||
MAX_POSITION_PERCENT,
|
||
BASE_POSITION_PERCENT,
|
||
CONTRACT_SIZE,
|
||
FEE_RATE,
|
||
FEE_FIXED,
|
||
FEE_FIXED_BACKTEST,
|
||
MIN_BARS_BETWEEN_TRADES,
|
||
SLIPPAGE_POINTS,
|
||
)
|
||
from adaptive_third_strategy.indicators import get_ema_atr_from_klines, align_higher_tf_ema
|
||
from adaptive_third_strategy.strategy_core import (
|
||
check_trigger,
|
||
get_body_size,
|
||
build_volume_ma,
|
||
)
|
||
from adaptive_third_strategy.data_fetcher import (
|
||
fetch_multi_timeframe,
|
||
load_klines_csv,
|
||
save_klines_csv,
|
||
)
|
||
|
||
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
|
||
|
||
def ensure_data(
|
||
start_time: int,
|
||
end_time: int,
|
||
data_dir: str,
|
||
use_api: bool = True,
|
||
api_key: str = "",
|
||
secret_key: str = "",
|
||
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
||
"""
|
||
确保有 5/15/60 分钟数据。若 data_dir 下已有 CSV 则加载,否则用 API 拉取并保存。
|
||
use_api=False 时仅从 CSV 加载(需事先抓取)。
|
||
"""
|
||
data_dir = data_dir or os.path.join(SCRIPT_DIR, "data")
|
||
os.makedirs(data_dir, exist_ok=True)
|
||
paths = {
|
||
5: os.path.join(data_dir, "kline_5m.csv"),
|
||
15: os.path.join(data_dir, "kline_15m.csv"),
|
||
60: os.path.join(data_dir, "kline_60m.csv"),
|
||
}
|
||
k5 = load_klines_csv(paths[5])
|
||
k15 = load_klines_csv(paths[15])
|
||
k60 = load_klines_csv(paths[60])
|
||
need_fetch = not k5 or not k15 or not k60
|
||
if need_fetch and use_api:
|
||
from adaptive_third_strategy.data_fetcher import DEFAULT_API_KEY, DEFAULT_SECRET_KEY, DEFAULT_MEMO
|
||
key = api_key or DEFAULT_API_KEY
|
||
sec = secret_key or DEFAULT_SECRET_KEY
|
||
data = fetch_multi_timeframe(start_time, end_time, [5, 15, 60], key, sec, DEFAULT_MEMO)
|
||
k5, k15, k60 = data[5], data[15], data[60]
|
||
for step, kl in [(5, k5), (15, k15), (60, k60)]:
|
||
save_klines_csv(kl, paths[step])
|
||
# 按时间范围过滤(end_time 若为次日 0 点则用 < 以只含 end 日及之前)
|
||
k5 = [x for x in k5 if start_time <= x["id"] < end_time]
|
||
k15 = [x for x in k15 if start_time <= x["id"] < end_time]
|
||
k60 = [x for x in k60 if start_time <= x["id"] < end_time]
|
||
k5.sort(key=lambda x: x["id"])
|
||
k15.sort(key=lambda x: x["id"])
|
||
k60.sort(key=lambda x: x["id"])
|
||
return k5, k15, k60
|
||
|
||
|
||
def run_backtest(
|
||
klines_5m: List[Dict],
|
||
klines_15m: List[Dict],
|
||
klines_60m: List[Dict],
|
||
initial_capital: float = 10000.0,
|
||
use_stop_loss: bool = True,
|
||
use_take_profit: bool = True,
|
||
use_trailing_stop: bool = True,
|
||
use_time_stop: bool = True,
|
||
fixed_margin_usdt: Optional[float] = None,
|
||
leverage: Optional[float] = None,
|
||
deduct_fee: bool = True,
|
||
) -> Tuple[List[Dict], Dict, List[Dict]]:
|
||
"""
|
||
回测主循环。
|
||
- 若指定 fixed_margin_usdt 与 leverage,则每笔开仓名义价值 = fixed_margin_usdt * leverage(如 50U 一百倍 = 5000U)。
|
||
- deduct_fee=False 时不扣手续费,总盈利 = 所有 money_pnl 之和。
|
||
"""
|
||
if len(klines_5m) < ATR_PERIOD + 50:
|
||
logger.warning("5分钟K线数量不足")
|
||
return [], {}, klines_5m
|
||
|
||
ema_5m, atr_5m = get_ema_atr_from_klines(klines_5m, EMA_SHORT, ATR_PERIOD)
|
||
ema_15m_align = align_higher_tf_ema(klines_5m, klines_15m, EMA_MID_FAST, EMA_MID_SLOW)
|
||
ema_60m_align = align_higher_tf_ema(klines_5m, klines_60m, EMA_LONG_FAST, EMA_LONG_SLOW)
|
||
volume_ma = build_volume_ma(klines_5m)
|
||
|
||
trades: List[Dict] = []
|
||
position: Optional[Dict] = None
|
||
last_trade_bar_idx: Optional[int] = None
|
||
last_close_bar_idx: Optional[int] = None
|
||
equity_curve: List[float] = [initial_capital]
|
||
capital = initial_capital
|
||
# 每笔固定名义价值(USDT):50U 一百倍 = 5000
|
||
fixed_notional = (fixed_margin_usdt * leverage) if (fixed_margin_usdt is not None and leverage is not None) else None
|
||
use_fee = deduct_fee
|
||
fee_fixed = 0 if not use_fee else (FEE_FIXED_BACKTEST if FEE_FIXED_BACKTEST is not None else FEE_FIXED)
|
||
|
||
def _size_usdt(cap: float) -> float:
|
||
if fixed_notional is not None:
|
||
return fixed_notional
|
||
return min(cap * MAX_POSITION_PERCENT, cap * BASE_POSITION_PERCENT)
|
||
|
||
def _fee(sz: float) -> float:
|
||
return 0 if not use_fee else (fee_fixed + sz * FEE_RATE * 2)
|
||
|
||
for idx in range(ATR_PERIOD, len(klines_5m)):
|
||
just_reversed = False
|
||
curr = klines_5m[idx]
|
||
bar_id = curr["id"]
|
||
high, low, close = float(curr["high"]), float(curr["low"]), float(curr["close"])
|
||
atr_val = atr_5m[idx]
|
||
if atr_val is None or atr_val <= 0:
|
||
equity_curve.append(capital)
|
||
continue
|
||
|
||
# ---------- 持仓管理:止损 / 止盈 / 移动止损 / 时间止损 ----------
|
||
if position is not None:
|
||
pos_dir = position["direction"]
|
||
entry_price = position["entry_price"]
|
||
entry_idx = position["entry_bar_idx"]
|
||
entry_atr = position["entry_atr"]
|
||
stop_price = position.get("stop_price")
|
||
trail_activated = position.get("trail_activated", False)
|
||
exit_reason = None
|
||
exit_price = close
|
||
|
||
if pos_dir == "long":
|
||
# 硬止损
|
||
if use_stop_loss and stop_price is not None and low <= stop_price:
|
||
exit_price = min(stop_price, high)
|
||
exit_reason = "stop_loss"
|
||
# 止盈(简化:首次触及任一目标即全平)
|
||
elif use_take_profit:
|
||
tp1 = entry_price + entry_atr * TP1_ATR
|
||
tp2 = entry_price + entry_atr * TP2_ATR
|
||
tp3 = entry_price + entry_atr * TP3_ATR
|
||
if high >= tp3:
|
||
exit_price = tp3
|
||
exit_reason = "tp3"
|
||
elif high >= tp2:
|
||
exit_price = tp2
|
||
exit_reason = "tp2"
|
||
elif high >= tp1:
|
||
exit_price = tp1
|
||
exit_reason = "tp1"
|
||
# 移动止损
|
||
if use_trailing_stop and not exit_reason:
|
||
if close >= entry_price + entry_atr * TRAIL_START_ATR:
|
||
trail_activated = True
|
||
position["trail_activated"] = True
|
||
trail_stop = close - entry_atr * TRAIL_ATR_MULT
|
||
if low <= trail_stop:
|
||
exit_price = trail_stop
|
||
exit_reason = "trail_stop"
|
||
# 时间止损
|
||
if use_time_stop and not exit_reason and (idx - entry_idx) >= TIME_STOP_BARS:
|
||
if close <= entry_price:
|
||
exit_price = close
|
||
exit_reason = "time_stop"
|
||
else:
|
||
if use_stop_loss and stop_price is not None and high >= stop_price:
|
||
exit_price = max(stop_price, low)
|
||
exit_reason = "stop_loss"
|
||
elif use_take_profit:
|
||
tp1 = entry_price - entry_atr * TP1_ATR
|
||
tp2 = entry_price - entry_atr * TP2_ATR
|
||
tp3 = entry_price - entry_atr * TP3_ATR
|
||
if low <= tp3:
|
||
exit_price = tp3
|
||
exit_reason = "tp3"
|
||
elif low <= tp2:
|
||
exit_price = tp2
|
||
exit_reason = "tp2"
|
||
elif low <= tp1:
|
||
exit_price = tp1
|
||
exit_reason = "tp1"
|
||
if use_trailing_stop and not exit_reason:
|
||
if close <= entry_price - entry_atr * TRAIL_START_ATR:
|
||
trail_activated = True
|
||
position["trail_activated"] = True
|
||
trail_stop = close + entry_atr * TRAIL_ATR_MULT
|
||
if high >= trail_stop:
|
||
exit_price = trail_stop
|
||
exit_reason = "trail_stop"
|
||
if use_time_stop and not exit_reason and (idx - entry_idx) >= TIME_STOP_BARS:
|
||
if close >= entry_price:
|
||
exit_price = close
|
||
exit_reason = "time_stop"
|
||
|
||
if exit_reason:
|
||
# 平仓
|
||
if pos_dir == "long":
|
||
point_pnl = exit_price - entry_price
|
||
else:
|
||
point_pnl = entry_price - exit_price
|
||
size_usdt = position.get("size_usdt", _size_usdt(capital))
|
||
contract_val = CONTRACT_SIZE / entry_price
|
||
money_pnl = point_pnl / entry_price * size_usdt
|
||
fee = _fee(size_usdt)
|
||
net = money_pnl - fee
|
||
capital += net
|
||
trades.append({
|
||
"direction": "做多" if pos_dir == "long" else "做空",
|
||
"entry_time": _utc_dt(position["entry_time"]),
|
||
"exit_time": _utc_dt(bar_id),
|
||
"entry_price": entry_price,
|
||
"exit_price": exit_price,
|
||
"point_pnl": point_pnl,
|
||
"money_pnl": money_pnl,
|
||
"fee": fee,
|
||
"net_profit": net,
|
||
"exit_reason": exit_reason,
|
||
"hold_bars": idx - entry_idx,
|
||
})
|
||
position = None
|
||
last_close_bar_idx = idx
|
||
equity_curve.append(capital)
|
||
continue
|
||
|
||
# ---------- 信号检测 ----------
|
||
direction, trigger_price, valid_prev_idx, valid_prev = check_trigger(
|
||
klines_5m, idx, atr_5m, ema_5m, ema_15m_align, ema_60m_align, volume_ma, use_confirm=True
|
||
)
|
||
if direction is None:
|
||
equity_curve.append(capital)
|
||
continue
|
||
if SAME_KLINE_NO_REVERSE and last_trade_bar_idx == idx:
|
||
equity_curve.append(capital)
|
||
continue
|
||
if position is not None:
|
||
if direction == position["direction"]:
|
||
equity_curve.append(capital)
|
||
continue
|
||
# 反手条件
|
||
bars_since = idx - position["entry_bar_idx"]
|
||
if bars_since < MIN_BARS_SINCE_ENTRY:
|
||
equity_curve.append(capital)
|
||
continue
|
||
entry_atr_pos = position.get("entry_atr") or atr_val
|
||
pos_loss_atr = (position["entry_price"] - close) / entry_atr_pos if position["direction"] == "long" else (close - position["entry_price"]) / entry_atr_pos
|
||
if pos_loss_atr < REVERSE_LOSS_ATR:
|
||
# 可选:反向突破幅度 > 实体/2 才反手
|
||
equity_curve.append(capital)
|
||
continue
|
||
# 先平仓再开仓(下面统一开仓)
|
||
# 简化:这里直接平仓记一笔,再开新仓
|
||
exit_price = close
|
||
if position["direction"] == "long":
|
||
point_pnl = exit_price - position["entry_price"]
|
||
else:
|
||
point_pnl = position["entry_price"] - exit_price
|
||
size_usdt = position.get("size_usdt", _size_usdt(capital))
|
||
money_pnl = point_pnl / position["entry_price"] * size_usdt
|
||
fee = _fee(size_usdt)
|
||
net = money_pnl - fee
|
||
capital += net
|
||
trades.append({
|
||
"direction": "做多" if position["direction"] == "long" else "做空",
|
||
"entry_time": _utc_dt(position["entry_time"]),
|
||
"exit_time": _utc_dt(bar_id),
|
||
"entry_price": position["entry_price"],
|
||
"exit_price": exit_price,
|
||
"point_pnl": point_pnl,
|
||
"money_pnl": money_pnl,
|
||
"fee": fee,
|
||
"net_profit": net,
|
||
"exit_reason": "reverse",
|
||
"hold_bars": idx - position["entry_bar_idx"],
|
||
})
|
||
position = None
|
||
last_close_bar_idx = idx
|
||
just_reversed = True
|
||
|
||
# ---------- 开仓 ----------
|
||
# 反手后本 K 线允许开仓;否则需间隔 MIN_BARS_BETWEEN_TRADES 根
|
||
if not just_reversed and last_close_bar_idx is not None and (idx - last_close_bar_idx) < MIN_BARS_BETWEEN_TRADES:
|
||
equity_curve.append(capital)
|
||
continue
|
||
just_reversed = False
|
||
size_usdt = _size_usdt(capital)
|
||
if size_usdt <= 0:
|
||
equity_curve.append(capital)
|
||
continue
|
||
stop_price = None
|
||
if direction == "long":
|
||
stop_price = trigger_price - atr_val * STOP_LOSS_ATR_MULT
|
||
else:
|
||
stop_price = trigger_price + atr_val * STOP_LOSS_ATR_MULT
|
||
position = {
|
||
"direction": direction,
|
||
"entry_price": trigger_price,
|
||
"entry_time": bar_id,
|
||
"entry_bar_idx": idx,
|
||
"entry_atr": atr_val,
|
||
"stop_price": stop_price,
|
||
"size_usdt": size_usdt,
|
||
"closed_ratio": 0,
|
||
"trail_activated": False,
|
||
}
|
||
last_trade_bar_idx = idx
|
||
equity_curve.append(capital)
|
||
|
||
# 尾仓
|
||
if position is not None:
|
||
last_bar = klines_5m[-1]
|
||
exit_price = float(last_bar["close"])
|
||
pos_dir = position["direction"]
|
||
entry_price = position["entry_price"]
|
||
if pos_dir == "long":
|
||
point_pnl = exit_price - entry_price
|
||
else:
|
||
point_pnl = entry_price - exit_price
|
||
size_usdt = position.get("size_usdt", _size_usdt(capital))
|
||
money_pnl = point_pnl / entry_price * size_usdt
|
||
fee = _fee(size_usdt)
|
||
net = money_pnl - fee
|
||
capital += net
|
||
trades.append({
|
||
"direction": "做多" if pos_dir == "long" else "做空",
|
||
"entry_time": _utc_dt(position["entry_time"]),
|
||
"exit_time": _utc_dt(last_bar["id"]),
|
||
"entry_price": entry_price,
|
||
"exit_price": exit_price,
|
||
"point_pnl": point_pnl,
|
||
"money_pnl": money_pnl,
|
||
"fee": fee,
|
||
"net_profit": net,
|
||
"exit_reason": "tail",
|
||
"hold_bars": len(klines_5m) - 1 - position["entry_bar_idx"],
|
||
})
|
||
|
||
# 统计
|
||
total_net = sum(t["net_profit"] for t in trades)
|
||
total_gross = sum(t["money_pnl"] for t in trades)
|
||
total_fee = sum(t["fee"] for t in trades)
|
||
win_count = len([t for t in trades if t["net_profit"] > 0])
|
||
stats = {
|
||
"total_trades": len(trades),
|
||
"win_count": win_count,
|
||
"win_rate": (win_count / len(trades) * 100) if trades else 0,
|
||
"total_gross_profit": total_gross,
|
||
"total_fee": total_fee,
|
||
"total_net_profit": total_net,
|
||
"final_capital": capital,
|
||
"max_drawdown": 0,
|
||
"max_drawdown_pct": 0,
|
||
}
|
||
peak = initial_capital
|
||
for eq in equity_curve:
|
||
peak = max(peak, eq)
|
||
dd = peak - eq
|
||
if peak > 0:
|
||
stats["max_drawdown"] = max(stats["max_drawdown"], dd)
|
||
stats["max_drawdown_pct"] = max(stats["max_drawdown_pct"], dd / peak * 100)
|
||
return trades, stats, klines_5m
|
||
|
||
|
||
def main():
|
||
import argparse
|
||
parser = argparse.ArgumentParser(description="自适应三分位趋势策略回测")
|
||
parser.add_argument("--start", default="2025-01-01", help="开始日期 YYYY-MM-DD")
|
||
parser.add_argument("--end", default="2025-12-31", help="结束日期,默认 2025-12-31")
|
||
parser.add_argument("--data-dir", default=None, help="数据目录,默认 adaptive_third_strategy/data")
|
||
parser.add_argument("--no-api", action="store_true", help="不从 API 拉取,仅用本地 CSV")
|
||
parser.add_argument("--capital", type=float, default=10000, help="初始资金(按比例开仓时用)")
|
||
parser.add_argument("--fixed-margin", type=float, default=None, help="每笔固定保证金 USDT,如 50")
|
||
parser.add_argument("--leverage", type=float, default=None, help="杠杆倍数,如 100")
|
||
parser.add_argument("--no-fee", action="store_true", help="不扣手续费,只算总盈利")
|
||
args = parser.parse_args()
|
||
start_dt = datetime.datetime.strptime(args.start, "%Y-%m-%d")
|
||
if args.end:
|
||
end_dt = datetime.datetime.strptime(args.end, "%Y-%m-%d")
|
||
# 包含 end 日全天:取次日 0 点前一刻,这样 id < end_ts 的 K 线都含在内
|
||
end_ts = int((end_dt + datetime.timedelta(days=1)).timestamp())
|
||
else:
|
||
end_ts = int(datetime.datetime.utcnow().timestamp())
|
||
start_ts = int(start_dt.timestamp())
|
||
data_dir = args.data_dir or os.path.join(SCRIPT_DIR, "data")
|
||
k5, k15, k60 = ensure_data(start_ts, end_ts, data_dir, use_api=not args.no_api)
|
||
if not k5:
|
||
logger.error("无 5 分钟数据,请先抓取或开启 --api")
|
||
return
|
||
logger.info(f"5m={len(k5)} 15m={len(k15)} 60m={len(k60)}")
|
||
trades, stats, _ = run_backtest(
|
||
k5, k15, k60,
|
||
initial_capital=args.capital,
|
||
fixed_margin_usdt=args.fixed_margin,
|
||
leverage=args.leverage,
|
||
deduct_fee=not args.no_fee,
|
||
)
|
||
logger.info(f"交易笔数: {stats['total_trades']} 胜率: {stats['win_rate']:.2f}% "
|
||
f"总盈利(未扣费): {stats['total_gross_profit']:.2f} USDT "
|
||
f"总手续费: {stats['total_fee']:.2f} 总净利润: {stats['total_net_profit']:.2f} "
|
||
f"最大回撤: {stats['max_drawdown']:.2f} ({stats['max_drawdown_pct']:.2f}%)")
|
||
out_csv = os.path.join(SCRIPT_DIR, "backtest_trades.csv")
|
||
if trades:
|
||
with open(out_csv, "w", newline="", encoding="utf-8") as f:
|
||
w = csv.DictWriter(f, fieldnames=["direction", "entry_time", "exit_time", "entry_price", "exit_price", "point_pnl", "money_pnl", "fee", "net_profit", "exit_reason", "hold_bars"])
|
||
w.writeheader()
|
||
for t in trades:
|
||
w.writerow({k: str(v) if isinstance(v, datetime.datetime) else v for k, v in t.items()})
|
||
logger.info(f"交易记录已保存: {out_csv}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|