306 lines
10 KiB
Python
306 lines
10 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
自适应三分位趋势策略 - 核心逻辑
|
||
趋势过滤、动态阈值、信号确认、市场状态
|
||
"""
|
||
|
||
from typing import List, Dict, Optional, Tuple
|
||
import sys
|
||
import os
|
||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
if ROOT_DIR not in sys.path:
|
||
sys.path.insert(0, ROOT_DIR)
|
||
|
||
from adaptive_third_strategy.config import (
|
||
MIN_BODY_ATR_RATIO,
|
||
MIN_VOLATILITY_PERCENT,
|
||
EMA_SHORT,
|
||
EMA_LONG_FAST,
|
||
EMA_LONG_SLOW,
|
||
EMA_MID_FAST,
|
||
EMA_MID_SLOW,
|
||
ATR_PERIOD,
|
||
VOLATILITY_COEF_CLAMP,
|
||
BASE_COEF,
|
||
TREND_FAVOR_COEF,
|
||
TREND_AGAINST_COEF,
|
||
TREND_MODE,
|
||
CONFIRM_REQUIRED,
|
||
VOLUME_MA_PERIOD,
|
||
VOLUME_RATIO_THRESHOLD,
|
||
REVERSE_BREAK_MULT,
|
||
MIN_BARS_SINCE_ENTRY,
|
||
FORBIDDEN_PERIODS,
|
||
ATR_PAUSE_MULT,
|
||
STRONG_TREND_COEF,
|
||
RANGE_COEF,
|
||
HIGH_VOL_EXTRA_CONFIRM,
|
||
)
|
||
from adaptive_third_strategy.indicators import (
|
||
get_ema_atr_from_klines,
|
||
align_higher_tf_ema,
|
||
ema,
|
||
)
|
||
|
||
|
||
def get_body_size(candle: Dict) -> float:
|
||
return abs(float(candle["open"]) - float(candle["close"]))
|
||
|
||
|
||
def is_bullish(candle: Dict) -> bool:
|
||
return float(candle["close"]) > float(candle["open"])
|
||
|
||
|
||
def get_min_body_threshold(price: float, atr_value: Optional[float]) -> float:
|
||
"""有效K线最小实体 = max(ATR*0.1, 价格*0.05%)"""
|
||
min_vol = price * MIN_VOLATILITY_PERCENT
|
||
if atr_value is not None and atr_value > 0:
|
||
min_vol = max(min_vol, atr_value * MIN_BODY_ATR_RATIO)
|
||
return min_vol
|
||
|
||
|
||
def find_valid_prev_bar(
|
||
all_data: List[Dict],
|
||
current_idx: int,
|
||
atr_series: List[Optional[float]],
|
||
min_body_override: Optional[float] = None,
|
||
) -> Tuple[Optional[int], Optional[Dict]]:
|
||
"""从当前索引往前找实体>=阈值的K线。阈值 = max(ATR*0.1, 价格*0.05%)"""
|
||
if current_idx <= 0:
|
||
return None, None
|
||
for i in range(current_idx - 1, -1, -1):
|
||
prev = all_data[i]
|
||
body = get_body_size(prev)
|
||
price = float(prev["close"])
|
||
atr_val = atr_series[i] if i < len(atr_series) else None
|
||
th = min_body_override if min_body_override is not None else get_min_body_threshold(price, atr_val)
|
||
if body >= th:
|
||
return i, prev
|
||
return None, None
|
||
|
||
|
||
def get_trend(
|
||
klines_5m: List[Dict],
|
||
idx_5m: int,
|
||
ema_5m: List[Optional[float]],
|
||
ema_15m_align: List[Dict],
|
||
ema_60m_align: List[Dict],
|
||
) -> str:
|
||
"""
|
||
多时间框架趋势。返回 "long" / "short" / "neutral"。
|
||
长期:1h EMA50 vs EMA200;中期:15m EMA20 vs EMA50;短期:5m close vs EMA9。
|
||
ema_*_align: 与 5m 对齐的列表,每项 {"ema_fast", "ema_slow"}。
|
||
"""
|
||
if idx_5m >= len(klines_5m):
|
||
return "neutral"
|
||
curr = klines_5m[idx_5m]
|
||
close_5 = float(curr["close"])
|
||
# 短期
|
||
ema9 = ema_5m[idx_5m] if idx_5m < len(ema_5m) else None
|
||
short_bull = (ema9 is not None and close_5 > ema9)
|
||
# 中期
|
||
mid = ema_15m_align[idx_5m] if idx_5m < len(ema_15m_align) else {}
|
||
mid_bull = (mid.get("ema_fast") is not None and mid.get("ema_slow") is not None
|
||
and mid["ema_fast"] > mid["ema_slow"])
|
||
# 长期
|
||
long_d = ema_60m_align[idx_5m] if idx_5m < len(ema_60m_align) else {}
|
||
long_bull = (long_d.get("ema_fast") is not None and long_d.get("ema_slow") is not None
|
||
and long_d["ema_fast"] > long_d["ema_slow"])
|
||
|
||
if TREND_MODE == "aggressive":
|
||
return "long" if short_bull else "short"
|
||
if TREND_MODE == "conservative":
|
||
if mid_bull and short_bull:
|
||
return "long"
|
||
if not mid_bull and not short_bull:
|
||
return "short"
|
||
return "neutral"
|
||
# strict
|
||
if long_bull and mid_bull and short_bull:
|
||
return "long"
|
||
if not long_bull and not mid_bull and not short_bull:
|
||
return "short"
|
||
return "neutral"
|
||
|
||
|
||
def get_dynamic_trigger_levels(
|
||
prev: Dict,
|
||
atr_value: float,
|
||
trend: str,
|
||
market_state: str = "normal",
|
||
) -> Tuple[Optional[float], Optional[float]]:
|
||
"""
|
||
动态三分位触发价。
|
||
波动率系数 = clamp(实体/ATR, 0.3, 3.0),调整系数 = 0.33 * 波动率系数。
|
||
顺势方向 ×0.8,逆势 ×1.2。
|
||
"""
|
||
body = get_body_size(prev)
|
||
if body < 1e-6 or atr_value <= 0:
|
||
return None, None
|
||
vol_coef = body / atr_value
|
||
vol_coef = max(VOLATILITY_COEF_CLAMP[0], min(VOLATILITY_COEF_CLAMP[1], vol_coef))
|
||
adj = BASE_COEF * vol_coef
|
||
if market_state == "strong_trend":
|
||
adj = STRONG_TREND_COEF
|
||
elif market_state == "range":
|
||
adj = RANGE_COEF
|
||
p_close = float(prev["close"])
|
||
if trend == "long":
|
||
long_adj = adj * TREND_FAVOR_COEF
|
||
short_adj = adj * TREND_AGAINST_COEF
|
||
elif trend == "short":
|
||
long_adj = adj * TREND_AGAINST_COEF
|
||
short_adj = adj * TREND_FAVOR_COEF
|
||
else:
|
||
long_adj = short_adj = adj
|
||
long_trigger = p_close + body * long_adj
|
||
short_trigger = p_close - body * short_adj
|
||
return long_trigger, short_trigger
|
||
|
||
|
||
def check_signal_confirm(
|
||
curr: Dict,
|
||
direction: str,
|
||
trigger_price: float,
|
||
all_data: List[Dict],
|
||
current_idx: int,
|
||
volume_ma: Optional[float],
|
||
required: int = CONFIRM_REQUIRED,
|
||
) -> int:
|
||
"""
|
||
确认条件计数:收盘价确认、成交量确认、动量确认。
|
||
返回满足的个数。
|
||
"""
|
||
count = 0
|
||
c_close = float(curr["close"])
|
||
c_volume = float(curr.get("volume", 0))
|
||
# 1. 收盘价确认
|
||
if direction == "long" and c_close >= trigger_price:
|
||
count += 1
|
||
elif direction == "short" and c_close <= trigger_price:
|
||
count += 1
|
||
# 2. 成交量确认
|
||
if volume_ma is not None and volume_ma > 0 and c_volume >= volume_ma * VOLUME_RATIO_THRESHOLD:
|
||
count += 1
|
||
# 3. 动量确认:当前K线实体方向与信号一致
|
||
if direction == "long" and is_bullish(curr):
|
||
count += 1
|
||
elif direction == "short" and not is_bullish(curr):
|
||
count += 1
|
||
return count
|
||
|
||
|
||
def in_forbidden_period(ts_sec: int) -> bool:
|
||
"""是否在禁止交易时段(按 UTC+8 小时:分)"""
|
||
from datetime import datetime, timezone
|
||
try:
|
||
dt = datetime.fromtimestamp(ts_sec, tz=timezone.utc)
|
||
except Exception:
|
||
dt = datetime.utcfromtimestamp(ts_sec)
|
||
# 转 UTC+8
|
||
hour = (dt.hour + 8) % 24
|
||
minute = dt.minute
|
||
for h1, m1, h2, m2 in FORBIDDEN_PERIODS:
|
||
t1 = h1 * 60 + m1
|
||
t2 = h2 * 60 + m2
|
||
t = hour * 60 + minute
|
||
if t1 <= t < t2:
|
||
return True
|
||
return False
|
||
|
||
|
||
def get_market_state(
|
||
atr_value: float,
|
||
atr_avg: Optional[float],
|
||
trend: str,
|
||
) -> str:
|
||
"""normal / strong_trend / range / high_vol"""
|
||
if atr_avg is not None and atr_avg > 0 and atr_value >= atr_avg * ATR_PAUSE_MULT:
|
||
return "high_vol"
|
||
if trend in ("long", "short") and atr_avg is not None and atr_value > atr_avg * 1.2:
|
||
return "strong_trend"
|
||
if trend == "neutral":
|
||
return "range"
|
||
return "normal"
|
||
|
||
|
||
def check_trigger(
|
||
all_data: List[Dict],
|
||
current_idx: int,
|
||
atr_series: List[Optional[float]],
|
||
ema_5m: List[Optional[float]],
|
||
ema_15m_align: List[Dict],
|
||
ema_60m_align: List[Dict],
|
||
volume_ma_list: Optional[List[Optional[float]]] = None,
|
||
use_confirm: bool = True,
|
||
) -> Tuple[Optional[str], Optional[float], Optional[int], Optional[Dict]]:
|
||
"""
|
||
检查当前K线是否产生有效信号(含趋势过滤与确认)。
|
||
返回 (方向, 触发价, 有效前一根索引, 有效前一根K线) 或 (None, None, None, None)。
|
||
"""
|
||
if current_idx <= 0 or current_idx >= len(all_data):
|
||
return None, None, None, None
|
||
curr = all_data[current_idx]
|
||
valid_prev_idx, prev = find_valid_prev_bar(all_data, current_idx, atr_series)
|
||
if prev is None:
|
||
return None, None, None, None
|
||
atr_val = atr_series[current_idx] if current_idx < len(atr_series) else None
|
||
if atr_val is None or atr_val <= 0:
|
||
return None, None, None, None
|
||
trend = get_trend(
|
||
all_data, current_idx, ema_5m, ema_15m_align, ema_60m_align,
|
||
)
|
||
atr_avg = None
|
||
if atr_series:
|
||
valid_atr = [x for x in atr_series[: current_idx + 1] if x is not None and x > 0]
|
||
if len(valid_atr) >= ATR_PERIOD:
|
||
atr_avg = sum(valid_atr) / len(valid_atr)
|
||
market_state = get_market_state(atr_val, atr_avg, trend)
|
||
if market_state == "high_vol":
|
||
return None, None, None, None
|
||
long_trigger, short_trigger = get_dynamic_trigger_levels(prev, atr_val, trend, market_state)
|
||
if long_trigger is None:
|
||
return None, None, None, None
|
||
c_high = float(curr["high"])
|
||
c_low = float(curr["low"])
|
||
long_triggered = c_high >= long_trigger
|
||
short_triggered = c_low <= short_trigger
|
||
direction = None
|
||
trigger_price = None
|
||
if long_triggered and short_triggered:
|
||
c_open = float(curr["open"])
|
||
if abs(c_open - short_trigger) <= abs(c_open - long_trigger):
|
||
direction, trigger_price = "short", short_trigger
|
||
else:
|
||
direction, trigger_price = "long", long_trigger
|
||
elif short_triggered:
|
||
direction, trigger_price = "short", short_trigger
|
||
elif long_triggered:
|
||
direction, trigger_price = "long", long_trigger
|
||
if direction is None:
|
||
return None, None, None, None
|
||
# 趋势过滤:逆势不交易(可选,这里做过滤)
|
||
if trend == "long" and direction == "short":
|
||
return None, None, None, None
|
||
if trend == "short" and direction == "long":
|
||
return None, None, None, None
|
||
# 禁止时段
|
||
if in_forbidden_period(curr["id"]):
|
||
return None, None, None, None
|
||
# 信号确认
|
||
if use_confirm and CONFIRM_REQUIRED > 0:
|
||
vol_ma = volume_ma_list[current_idx] if volume_ma_list and current_idx < len(volume_ma_list) else None
|
||
n = check_signal_confirm(curr, direction, trigger_price, all_data, current_idx, vol_ma, CONFIRM_REQUIRED)
|
||
if n < CONFIRM_REQUIRED:
|
||
return None, None, None, None
|
||
return direction, trigger_price, valid_prev_idx, prev
|
||
|
||
|
||
def build_volume_ma(klines: List[Dict], period: int = VOLUME_MA_PERIOD) -> List[Optional[float]]:
|
||
"""前 period-1 为 None,之后为 volume 的 SMA"""
|
||
vol = [float(k.get("volume", 0)) for k in klines]
|
||
out: List[Optional[float]] = [None] * (period - 1)
|
||
for i in range(period - 1, len(vol)):
|
||
out.append(sum(vol[i - period + 1 : i + 1]) / period)
|
||
return out
|