Files
lm_code/adaptive_third_strategy/backtest.py
ddrwode 970080a2e6 hahaa
2026-01-31 10:35:25 +08:00

465 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
自适应三分位趋势策略 - 回测引擎
使用 Bitmart 方式获取数据API 或 CSV含硬止损、分批止盈、移动止损、时间止损
"""
import os
import sys
import csv
import datetime
from typing import List, Dict, Optional, Tuple
# 使用 UTC 时区,避免 utcfromtimestamp 弃用警告
def _utc_dt(ts):
if hasattr(datetime, "timezone") and hasattr(datetime.timezone, "utc"):
return datetime.datetime.fromtimestamp(ts, tz=datetime.timezone.utc)
return datetime.datetime.utcfromtimestamp(ts)
try:
from loguru import logger
except ImportError:
import logging
logger = logging.getLogger(__name__)
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT_DIR not in sys.path:
sys.path.insert(0, ROOT_DIR)
from adaptive_third_strategy.config import (
STEP_5M,
STEP_15M,
STEP_60M,
ATR_PERIOD,
EMA_SHORT,
EMA_MID_FAST,
EMA_MID_SLOW,
EMA_LONG_FAST,
EMA_LONG_SLOW,
STOP_LOSS_ATR_MULT,
TIME_STOP_BARS,
TRAIL_START_ATR,
TRAIL_ATR_MULT,
TP1_ATR,
TP2_ATR,
TP3_ATR,
TP1_RATIO,
TP2_RATIO,
TP3_RATIO,
MIN_BARS_SINCE_ENTRY,
SAME_KLINE_NO_REVERSE,
REVERSE_BREAK_MULT,
REVERSE_LOSS_ATR,
MAX_POSITION_PERCENT,
BASE_POSITION_PERCENT,
CONTRACT_SIZE,
FEE_RATE,
FEE_FIXED,
FEE_FIXED_BACKTEST,
MIN_BARS_BETWEEN_TRADES,
SLIPPAGE_POINTS,
)
from adaptive_third_strategy.indicators import get_ema_atr_from_klines, align_higher_tf_ema
from adaptive_third_strategy.strategy_core import (
check_trigger,
get_body_size,
build_volume_ma,
)
from adaptive_third_strategy.data_fetcher import (
fetch_multi_timeframe,
load_klines_csv,
save_klines_csv,
)
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
def ensure_data(
start_time: int,
end_time: int,
data_dir: str,
use_api: bool = True,
api_key: str = "",
secret_key: str = "",
) -> Tuple[List[Dict], List[Dict], List[Dict]]:
"""
确保有 5/15/60 分钟数据。若 data_dir 下已有 CSV 则加载,否则用 API 拉取并保存。
use_api=False 时仅从 CSV 加载(需事先抓取)。
"""
data_dir = data_dir or os.path.join(SCRIPT_DIR, "data")
os.makedirs(data_dir, exist_ok=True)
paths = {
5: os.path.join(data_dir, "kline_5m.csv"),
15: os.path.join(data_dir, "kline_15m.csv"),
60: os.path.join(data_dir, "kline_60m.csv"),
}
k5 = load_klines_csv(paths[5])
k15 = load_klines_csv(paths[15])
k60 = load_klines_csv(paths[60])
need_fetch = not k5 or not k15 or not k60
if need_fetch and use_api:
from adaptive_third_strategy.data_fetcher import DEFAULT_API_KEY, DEFAULT_SECRET_KEY, DEFAULT_MEMO
key = api_key or DEFAULT_API_KEY
sec = secret_key or DEFAULT_SECRET_KEY
data = fetch_multi_timeframe(start_time, end_time, [5, 15, 60], key, sec, DEFAULT_MEMO)
k5, k15, k60 = data[5], data[15], data[60]
for step, kl in [(5, k5), (15, k15), (60, k60)]:
save_klines_csv(kl, paths[step])
# 按时间范围过滤end_time 若为次日 0 点则用 < 以只含 end 日及之前)
k5 = [x for x in k5 if start_time <= x["id"] < end_time]
k15 = [x for x in k15 if start_time <= x["id"] < end_time]
k60 = [x for x in k60 if start_time <= x["id"] < end_time]
k5.sort(key=lambda x: x["id"])
k15.sort(key=lambda x: x["id"])
k60.sort(key=lambda x: x["id"])
return k5, k15, k60
def run_backtest(
klines_5m: List[Dict],
klines_15m: List[Dict],
klines_60m: List[Dict],
initial_capital: float = 10000.0,
use_stop_loss: bool = True,
use_take_profit: bool = True,
use_trailing_stop: bool = True,
use_time_stop: bool = True,
fixed_margin_usdt: Optional[float] = None,
leverage: Optional[float] = None,
deduct_fee: bool = True,
) -> Tuple[List[Dict], Dict, List[Dict]]:
"""
回测主循环。
- 若指定 fixed_margin_usdt 与 leverage则每笔开仓名义价值 = fixed_margin_usdt * leverage如 50U 一百倍 = 5000U
- deduct_fee=False 时不扣手续费,总盈利 = 所有 money_pnl 之和。
"""
if len(klines_5m) < ATR_PERIOD + 50:
logger.warning("5分钟K线数量不足")
return [], {}, klines_5m
ema_5m, atr_5m = get_ema_atr_from_klines(klines_5m, EMA_SHORT, ATR_PERIOD)
ema_15m_align = align_higher_tf_ema(klines_5m, klines_15m, EMA_MID_FAST, EMA_MID_SLOW)
ema_60m_align = align_higher_tf_ema(klines_5m, klines_60m, EMA_LONG_FAST, EMA_LONG_SLOW)
volume_ma = build_volume_ma(klines_5m)
trades: List[Dict] = []
position: Optional[Dict] = None
last_trade_bar_idx: Optional[int] = None
last_close_bar_idx: Optional[int] = None
equity_curve: List[float] = [initial_capital]
capital = initial_capital
# 每笔固定名义价值USDT50U 一百倍 = 5000
fixed_notional = (fixed_margin_usdt * leverage) if (fixed_margin_usdt is not None and leverage is not None) else None
use_fee = deduct_fee
fee_fixed = 0 if not use_fee else (FEE_FIXED_BACKTEST if FEE_FIXED_BACKTEST is not None else FEE_FIXED)
def _size_usdt(cap: float) -> float:
if fixed_notional is not None:
return fixed_notional
return min(cap * MAX_POSITION_PERCENT, cap * BASE_POSITION_PERCENT)
def _fee(sz: float) -> float:
return 0 if not use_fee else (fee_fixed + sz * FEE_RATE * 2)
for idx in range(ATR_PERIOD, len(klines_5m)):
just_reversed = False
curr = klines_5m[idx]
bar_id = curr["id"]
high, low, close = float(curr["high"]), float(curr["low"]), float(curr["close"])
atr_val = atr_5m[idx]
if atr_val is None or atr_val <= 0:
equity_curve.append(capital)
continue
# ---------- 持仓管理:止损 / 止盈 / 移动止损 / 时间止损 ----------
if position is not None:
pos_dir = position["direction"]
entry_price = position["entry_price"]
entry_idx = position["entry_bar_idx"]
entry_atr = position["entry_atr"]
stop_price = position.get("stop_price")
trail_activated = position.get("trail_activated", False)
exit_reason = None
exit_price = close
if pos_dir == "long":
# 硬止损
if use_stop_loss and stop_price is not None and low <= stop_price:
exit_price = min(stop_price, high)
exit_reason = "stop_loss"
# 止盈(简化:首次触及任一目标即全平)
elif use_take_profit:
tp1 = entry_price + entry_atr * TP1_ATR
tp2 = entry_price + entry_atr * TP2_ATR
tp3 = entry_price + entry_atr * TP3_ATR
if high >= tp3:
exit_price = tp3
exit_reason = "tp3"
elif high >= tp2:
exit_price = tp2
exit_reason = "tp2"
elif high >= tp1:
exit_price = tp1
exit_reason = "tp1"
# 移动止损
if use_trailing_stop and not exit_reason:
if close >= entry_price + entry_atr * TRAIL_START_ATR:
trail_activated = True
position["trail_activated"] = True
trail_stop = close - entry_atr * TRAIL_ATR_MULT
if low <= trail_stop:
exit_price = trail_stop
exit_reason = "trail_stop"
# 时间止损
if use_time_stop and not exit_reason and (idx - entry_idx) >= TIME_STOP_BARS:
if close <= entry_price:
exit_price = close
exit_reason = "time_stop"
else:
if use_stop_loss and stop_price is not None and high >= stop_price:
exit_price = max(stop_price, low)
exit_reason = "stop_loss"
elif use_take_profit:
tp1 = entry_price - entry_atr * TP1_ATR
tp2 = entry_price - entry_atr * TP2_ATR
tp3 = entry_price - entry_atr * TP3_ATR
if low <= tp3:
exit_price = tp3
exit_reason = "tp3"
elif low <= tp2:
exit_price = tp2
exit_reason = "tp2"
elif low <= tp1:
exit_price = tp1
exit_reason = "tp1"
if use_trailing_stop and not exit_reason:
if close <= entry_price - entry_atr * TRAIL_START_ATR:
trail_activated = True
position["trail_activated"] = True
trail_stop = close + entry_atr * TRAIL_ATR_MULT
if high >= trail_stop:
exit_price = trail_stop
exit_reason = "trail_stop"
if use_time_stop and not exit_reason and (idx - entry_idx) >= TIME_STOP_BARS:
if close >= entry_price:
exit_price = close
exit_reason = "time_stop"
if exit_reason:
# 平仓
if pos_dir == "long":
point_pnl = exit_price - entry_price
else:
point_pnl = entry_price - exit_price
size_usdt = position.get("size_usdt", _size_usdt(capital))
contract_val = CONTRACT_SIZE / entry_price
money_pnl = point_pnl / entry_price * size_usdt
fee = _fee(size_usdt)
net = money_pnl - fee
capital += net
trades.append({
"direction": "做多" if pos_dir == "long" else "做空",
"entry_time": _utc_dt(position["entry_time"]),
"exit_time": _utc_dt(bar_id),
"entry_price": entry_price,
"exit_price": exit_price,
"point_pnl": point_pnl,
"money_pnl": money_pnl,
"fee": fee,
"net_profit": net,
"exit_reason": exit_reason,
"hold_bars": idx - entry_idx,
})
position = None
last_close_bar_idx = idx
equity_curve.append(capital)
continue
# ---------- 信号检测 ----------
direction, trigger_price, valid_prev_idx, valid_prev = check_trigger(
klines_5m, idx, atr_5m, ema_5m, ema_15m_align, ema_60m_align, volume_ma, use_confirm=True
)
if direction is None:
equity_curve.append(capital)
continue
if SAME_KLINE_NO_REVERSE and last_trade_bar_idx == idx:
equity_curve.append(capital)
continue
if position is not None:
if direction == position["direction"]:
equity_curve.append(capital)
continue
# 反手条件
bars_since = idx - position["entry_bar_idx"]
if bars_since < MIN_BARS_SINCE_ENTRY:
equity_curve.append(capital)
continue
entry_atr_pos = position.get("entry_atr") or atr_val
pos_loss_atr = (position["entry_price"] - close) / entry_atr_pos if position["direction"] == "long" else (close - position["entry_price"]) / entry_atr_pos
if pos_loss_atr < REVERSE_LOSS_ATR:
# 可选:反向突破幅度 > 实体/2 才反手
equity_curve.append(capital)
continue
# 先平仓再开仓(下面统一开仓)
# 简化:这里直接平仓记一笔,再开新仓
exit_price = close
if position["direction"] == "long":
point_pnl = exit_price - position["entry_price"]
else:
point_pnl = position["entry_price"] - exit_price
size_usdt = position.get("size_usdt", _size_usdt(capital))
money_pnl = point_pnl / position["entry_price"] * size_usdt
fee = _fee(size_usdt)
net = money_pnl - fee
capital += net
trades.append({
"direction": "做多" if position["direction"] == "long" else "做空",
"entry_time": _utc_dt(position["entry_time"]),
"exit_time": _utc_dt(bar_id),
"entry_price": position["entry_price"],
"exit_price": exit_price,
"point_pnl": point_pnl,
"money_pnl": money_pnl,
"fee": fee,
"net_profit": net,
"exit_reason": "reverse",
"hold_bars": idx - position["entry_bar_idx"],
})
position = None
last_close_bar_idx = idx
just_reversed = True
# ---------- 开仓 ----------
# 反手后本 K 线允许开仓;否则需间隔 MIN_BARS_BETWEEN_TRADES 根
if not just_reversed and last_close_bar_idx is not None and (idx - last_close_bar_idx) < MIN_BARS_BETWEEN_TRADES:
equity_curve.append(capital)
continue
just_reversed = False
size_usdt = _size_usdt(capital)
if size_usdt <= 0:
equity_curve.append(capital)
continue
stop_price = None
if direction == "long":
stop_price = trigger_price - atr_val * STOP_LOSS_ATR_MULT
else:
stop_price = trigger_price + atr_val * STOP_LOSS_ATR_MULT
position = {
"direction": direction,
"entry_price": trigger_price,
"entry_time": bar_id,
"entry_bar_idx": idx,
"entry_atr": atr_val,
"stop_price": stop_price,
"size_usdt": size_usdt,
"closed_ratio": 0,
"trail_activated": False,
}
last_trade_bar_idx = idx
equity_curve.append(capital)
# 尾仓
if position is not None:
last_bar = klines_5m[-1]
exit_price = float(last_bar["close"])
pos_dir = position["direction"]
entry_price = position["entry_price"]
if pos_dir == "long":
point_pnl = exit_price - entry_price
else:
point_pnl = entry_price - exit_price
size_usdt = position.get("size_usdt", _size_usdt(capital))
money_pnl = point_pnl / entry_price * size_usdt
fee = _fee(size_usdt)
net = money_pnl - fee
capital += net
trades.append({
"direction": "做多" if pos_dir == "long" else "做空",
"entry_time": _utc_dt(position["entry_time"]),
"exit_time": _utc_dt(last_bar["id"]),
"entry_price": entry_price,
"exit_price": exit_price,
"point_pnl": point_pnl,
"money_pnl": money_pnl,
"fee": fee,
"net_profit": net,
"exit_reason": "tail",
"hold_bars": len(klines_5m) - 1 - position["entry_bar_idx"],
})
# 统计
total_net = sum(t["net_profit"] for t in trades)
total_gross = sum(t["money_pnl"] for t in trades)
total_fee = sum(t["fee"] for t in trades)
win_count = len([t for t in trades if t["net_profit"] > 0])
stats = {
"total_trades": len(trades),
"win_count": win_count,
"win_rate": (win_count / len(trades) * 100) if trades else 0,
"total_gross_profit": total_gross,
"total_fee": total_fee,
"total_net_profit": total_net,
"final_capital": capital,
"max_drawdown": 0,
"max_drawdown_pct": 0,
}
peak = initial_capital
for eq in equity_curve:
peak = max(peak, eq)
dd = peak - eq
if peak > 0:
stats["max_drawdown"] = max(stats["max_drawdown"], dd)
stats["max_drawdown_pct"] = max(stats["max_drawdown_pct"], dd / peak * 100)
return trades, stats, klines_5m
def main():
import argparse
parser = argparse.ArgumentParser(description="自适应三分位趋势策略回测")
parser.add_argument("--start", default="2025-01-01", help="开始日期 YYYY-MM-DD")
parser.add_argument("--end", default="2025-12-31", help="结束日期,默认 2025-12-31")
parser.add_argument("--data-dir", default=None, help="数据目录,默认 adaptive_third_strategy/data")
parser.add_argument("--no-api", action="store_true", help="不从 API 拉取,仅用本地 CSV")
parser.add_argument("--capital", type=float, default=10000, help="初始资金(按比例开仓时用)")
parser.add_argument("--fixed-margin", type=float, default=None, help="每笔固定保证金 USDT如 50")
parser.add_argument("--leverage", type=float, default=None, help="杠杆倍数,如 100")
parser.add_argument("--no-fee", action="store_true", help="不扣手续费,只算总盈利")
args = parser.parse_args()
start_dt = datetime.datetime.strptime(args.start, "%Y-%m-%d")
if args.end:
end_dt = datetime.datetime.strptime(args.end, "%Y-%m-%d")
# 包含 end 日全天:取次日 0 点前一刻,这样 id < end_ts 的 K 线都含在内
end_ts = int((end_dt + datetime.timedelta(days=1)).timestamp())
else:
end_ts = int(datetime.datetime.utcnow().timestamp())
start_ts = int(start_dt.timestamp())
data_dir = args.data_dir or os.path.join(SCRIPT_DIR, "data")
k5, k15, k60 = ensure_data(start_ts, end_ts, data_dir, use_api=not args.no_api)
if not k5:
logger.error("无 5 分钟数据,请先抓取或开启 --api")
return
logger.info(f"5m={len(k5)} 15m={len(k15)} 60m={len(k60)}")
trades, stats, _ = run_backtest(
k5, k15, k60,
initial_capital=args.capital,
fixed_margin_usdt=args.fixed_margin,
leverage=args.leverage,
deduct_fee=not args.no_fee,
)
logger.info(f"交易笔数: {stats['total_trades']} 胜率: {stats['win_rate']:.2f}% "
f"总盈利(未扣费): {stats['total_gross_profit']:.2f} USDT "
f"总手续费: {stats['total_fee']:.2f} 总净利润: {stats['total_net_profit']:.2f} "
f"最大回撤: {stats['max_drawdown']:.2f} ({stats['max_drawdown_pct']:.2f}%)")
out_csv = os.path.join(SCRIPT_DIR, "backtest_trades.csv")
if trades:
with open(out_csv, "w", newline="", encoding="utf-8") as f:
w = csv.DictWriter(f, fieldnames=["direction", "entry_time", "exit_time", "entry_price", "exit_price", "point_pnl", "money_pnl", "fee", "net_profit", "exit_reason", "hold_bars"])
w.writeheader()
for t in trades:
w.writerow({k: str(v) if isinstance(v, datetime.datetime) else v for k, v in t.items()})
logger.info(f"交易记录已保存: {out_csv}")
if __name__ == "__main__":
main()