hahaa
This commit is contained in:
305
adaptive_third_strategy/strategy_core.py
Normal file
305
adaptive_third_strategy/strategy_core.py
Normal file
@@ -0,0 +1,305 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
自适应三分位趋势策略 - 核心逻辑
|
||||
趋势过滤、动态阈值、信号确认、市场状态
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
import sys
|
||||
import os
|
||||
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
if ROOT_DIR not in sys.path:
|
||||
sys.path.insert(0, ROOT_DIR)
|
||||
|
||||
from adaptive_third_strategy.config import (
|
||||
MIN_BODY_ATR_RATIO,
|
||||
MIN_VOLATILITY_PERCENT,
|
||||
EMA_SHORT,
|
||||
EMA_LONG_FAST,
|
||||
EMA_LONG_SLOW,
|
||||
EMA_MID_FAST,
|
||||
EMA_MID_SLOW,
|
||||
ATR_PERIOD,
|
||||
VOLATILITY_COEF_CLAMP,
|
||||
BASE_COEF,
|
||||
TREND_FAVOR_COEF,
|
||||
TREND_AGAINST_COEF,
|
||||
TREND_MODE,
|
||||
CONFIRM_REQUIRED,
|
||||
VOLUME_MA_PERIOD,
|
||||
VOLUME_RATIO_THRESHOLD,
|
||||
REVERSE_BREAK_MULT,
|
||||
MIN_BARS_SINCE_ENTRY,
|
||||
FORBIDDEN_PERIODS,
|
||||
ATR_PAUSE_MULT,
|
||||
STRONG_TREND_COEF,
|
||||
RANGE_COEF,
|
||||
HIGH_VOL_EXTRA_CONFIRM,
|
||||
)
|
||||
from adaptive_third_strategy.indicators import (
|
||||
get_ema_atr_from_klines,
|
||||
align_higher_tf_ema,
|
||||
ema,
|
||||
)
|
||||
|
||||
|
||||
def get_body_size(candle: Dict) -> float:
|
||||
return abs(float(candle["open"]) - float(candle["close"]))
|
||||
|
||||
|
||||
def is_bullish(candle: Dict) -> bool:
|
||||
return float(candle["close"]) > float(candle["open"])
|
||||
|
||||
|
||||
def get_min_body_threshold(price: float, atr_value: Optional[float]) -> float:
|
||||
"""有效K线最小实体 = max(ATR*0.1, 价格*0.05%)"""
|
||||
min_vol = price * MIN_VOLATILITY_PERCENT
|
||||
if atr_value is not None and atr_value > 0:
|
||||
min_vol = max(min_vol, atr_value * MIN_BODY_ATR_RATIO)
|
||||
return min_vol
|
||||
|
||||
|
||||
def find_valid_prev_bar(
|
||||
all_data: List[Dict],
|
||||
current_idx: int,
|
||||
atr_series: List[Optional[float]],
|
||||
min_body_override: Optional[float] = None,
|
||||
) -> Tuple[Optional[int], Optional[Dict]]:
|
||||
"""从当前索引往前找实体>=阈值的K线。阈值 = max(ATR*0.1, 价格*0.05%)"""
|
||||
if current_idx <= 0:
|
||||
return None, None
|
||||
for i in range(current_idx - 1, -1, -1):
|
||||
prev = all_data[i]
|
||||
body = get_body_size(prev)
|
||||
price = float(prev["close"])
|
||||
atr_val = atr_series[i] if i < len(atr_series) else None
|
||||
th = min_body_override if min_body_override is not None else get_min_body_threshold(price, atr_val)
|
||||
if body >= th:
|
||||
return i, prev
|
||||
return None, None
|
||||
|
||||
|
||||
def get_trend(
|
||||
klines_5m: List[Dict],
|
||||
idx_5m: int,
|
||||
ema_5m: List[Optional[float]],
|
||||
ema_15m_align: List[Dict],
|
||||
ema_60m_align: List[Dict],
|
||||
) -> str:
|
||||
"""
|
||||
多时间框架趋势。返回 "long" / "short" / "neutral"。
|
||||
长期:1h EMA50 vs EMA200;中期:15m EMA20 vs EMA50;短期:5m close vs EMA9。
|
||||
ema_*_align: 与 5m 对齐的列表,每项 {"ema_fast", "ema_slow"}。
|
||||
"""
|
||||
if idx_5m >= len(klines_5m):
|
||||
return "neutral"
|
||||
curr = klines_5m[idx_5m]
|
||||
close_5 = float(curr["close"])
|
||||
# 短期
|
||||
ema9 = ema_5m[idx_5m] if idx_5m < len(ema_5m) else None
|
||||
short_bull = (ema9 is not None and close_5 > ema9)
|
||||
# 中期
|
||||
mid = ema_15m_align[idx_5m] if idx_5m < len(ema_15m_align) else {}
|
||||
mid_bull = (mid.get("ema_fast") is not None and mid.get("ema_slow") is not None
|
||||
and mid["ema_fast"] > mid["ema_slow"])
|
||||
# 长期
|
||||
long_d = ema_60m_align[idx_5m] if idx_5m < len(ema_60m_align) else {}
|
||||
long_bull = (long_d.get("ema_fast") is not None and long_d.get("ema_slow") is not None
|
||||
and long_d["ema_fast"] > long_d["ema_slow"])
|
||||
|
||||
if TREND_MODE == "aggressive":
|
||||
return "long" if short_bull else "short"
|
||||
if TREND_MODE == "conservative":
|
||||
if mid_bull and short_bull:
|
||||
return "long"
|
||||
if not mid_bull and not short_bull:
|
||||
return "short"
|
||||
return "neutral"
|
||||
# strict
|
||||
if long_bull and mid_bull and short_bull:
|
||||
return "long"
|
||||
if not long_bull and not mid_bull and not short_bull:
|
||||
return "short"
|
||||
return "neutral"
|
||||
|
||||
|
||||
def get_dynamic_trigger_levels(
|
||||
prev: Dict,
|
||||
atr_value: float,
|
||||
trend: str,
|
||||
market_state: str = "normal",
|
||||
) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""
|
||||
动态三分位触发价。
|
||||
波动率系数 = clamp(实体/ATR, 0.3, 3.0),调整系数 = 0.33 * 波动率系数。
|
||||
顺势方向 ×0.8,逆势 ×1.2。
|
||||
"""
|
||||
body = get_body_size(prev)
|
||||
if body < 1e-6 or atr_value <= 0:
|
||||
return None, None
|
||||
vol_coef = body / atr_value
|
||||
vol_coef = max(VOLATILITY_COEF_CLAMP[0], min(VOLATILITY_COEF_CLAMP[1], vol_coef))
|
||||
adj = BASE_COEF * vol_coef
|
||||
if market_state == "strong_trend":
|
||||
adj = STRONG_TREND_COEF
|
||||
elif market_state == "range":
|
||||
adj = RANGE_COEF
|
||||
p_close = float(prev["close"])
|
||||
if trend == "long":
|
||||
long_adj = adj * TREND_FAVOR_COEF
|
||||
short_adj = adj * TREND_AGAINST_COEF
|
||||
elif trend == "short":
|
||||
long_adj = adj * TREND_AGAINST_COEF
|
||||
short_adj = adj * TREND_FAVOR_COEF
|
||||
else:
|
||||
long_adj = short_adj = adj
|
||||
long_trigger = p_close + body * long_adj
|
||||
short_trigger = p_close - body * short_adj
|
||||
return long_trigger, short_trigger
|
||||
|
||||
|
||||
def check_signal_confirm(
|
||||
curr: Dict,
|
||||
direction: str,
|
||||
trigger_price: float,
|
||||
all_data: List[Dict],
|
||||
current_idx: int,
|
||||
volume_ma: Optional[float],
|
||||
required: int = CONFIRM_REQUIRED,
|
||||
) -> int:
|
||||
"""
|
||||
确认条件计数:收盘价确认、成交量确认、动量确认。
|
||||
返回满足的个数。
|
||||
"""
|
||||
count = 0
|
||||
c_close = float(curr["close"])
|
||||
c_volume = float(curr.get("volume", 0))
|
||||
# 1. 收盘价确认
|
||||
if direction == "long" and c_close >= trigger_price:
|
||||
count += 1
|
||||
elif direction == "short" and c_close <= trigger_price:
|
||||
count += 1
|
||||
# 2. 成交量确认
|
||||
if volume_ma is not None and volume_ma > 0 and c_volume >= volume_ma * VOLUME_RATIO_THRESHOLD:
|
||||
count += 1
|
||||
# 3. 动量确认:当前K线实体方向与信号一致
|
||||
if direction == "long" and is_bullish(curr):
|
||||
count += 1
|
||||
elif direction == "short" and not is_bullish(curr):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
def in_forbidden_period(ts_sec: int) -> bool:
|
||||
"""是否在禁止交易时段(按 UTC+8 小时:分)"""
|
||||
from datetime import datetime, timezone
|
||||
try:
|
||||
dt = datetime.fromtimestamp(ts_sec, tz=timezone.utc)
|
||||
except Exception:
|
||||
dt = datetime.utcfromtimestamp(ts_sec)
|
||||
# 转 UTC+8
|
||||
hour = (dt.hour + 8) % 24
|
||||
minute = dt.minute
|
||||
for h1, m1, h2, m2 in FORBIDDEN_PERIODS:
|
||||
t1 = h1 * 60 + m1
|
||||
t2 = h2 * 60 + m2
|
||||
t = hour * 60 + minute
|
||||
if t1 <= t < t2:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_market_state(
|
||||
atr_value: float,
|
||||
atr_avg: Optional[float],
|
||||
trend: str,
|
||||
) -> str:
|
||||
"""normal / strong_trend / range / high_vol"""
|
||||
if atr_avg is not None and atr_avg > 0 and atr_value >= atr_avg * ATR_PAUSE_MULT:
|
||||
return "high_vol"
|
||||
if trend in ("long", "short") and atr_avg is not None and atr_value > atr_avg * 1.2:
|
||||
return "strong_trend"
|
||||
if trend == "neutral":
|
||||
return "range"
|
||||
return "normal"
|
||||
|
||||
|
||||
def check_trigger(
|
||||
all_data: List[Dict],
|
||||
current_idx: int,
|
||||
atr_series: List[Optional[float]],
|
||||
ema_5m: List[Optional[float]],
|
||||
ema_15m_align: List[Dict],
|
||||
ema_60m_align: List[Dict],
|
||||
volume_ma_list: Optional[List[Optional[float]]] = None,
|
||||
use_confirm: bool = True,
|
||||
) -> Tuple[Optional[str], Optional[float], Optional[int], Optional[Dict]]:
|
||||
"""
|
||||
检查当前K线是否产生有效信号(含趋势过滤与确认)。
|
||||
返回 (方向, 触发价, 有效前一根索引, 有效前一根K线) 或 (None, None, None, None)。
|
||||
"""
|
||||
if current_idx <= 0 or current_idx >= len(all_data):
|
||||
return None, None, None, None
|
||||
curr = all_data[current_idx]
|
||||
valid_prev_idx, prev = find_valid_prev_bar(all_data, current_idx, atr_series)
|
||||
if prev is None:
|
||||
return None, None, None, None
|
||||
atr_val = atr_series[current_idx] if current_idx < len(atr_series) else None
|
||||
if atr_val is None or atr_val <= 0:
|
||||
return None, None, None, None
|
||||
trend = get_trend(
|
||||
all_data, current_idx, ema_5m, ema_15m_align, ema_60m_align,
|
||||
)
|
||||
atr_avg = None
|
||||
if atr_series:
|
||||
valid_atr = [x for x in atr_series[: current_idx + 1] if x is not None and x > 0]
|
||||
if len(valid_atr) >= ATR_PERIOD:
|
||||
atr_avg = sum(valid_atr) / len(valid_atr)
|
||||
market_state = get_market_state(atr_val, atr_avg, trend)
|
||||
if market_state == "high_vol":
|
||||
return None, None, None, None
|
||||
long_trigger, short_trigger = get_dynamic_trigger_levels(prev, atr_val, trend, market_state)
|
||||
if long_trigger is None:
|
||||
return None, None, None, None
|
||||
c_high = float(curr["high"])
|
||||
c_low = float(curr["low"])
|
||||
long_triggered = c_high >= long_trigger
|
||||
short_triggered = c_low <= short_trigger
|
||||
direction = None
|
||||
trigger_price = None
|
||||
if long_triggered and short_triggered:
|
||||
c_open = float(curr["open"])
|
||||
if abs(c_open - short_trigger) <= abs(c_open - long_trigger):
|
||||
direction, trigger_price = "short", short_trigger
|
||||
else:
|
||||
direction, trigger_price = "long", long_trigger
|
||||
elif short_triggered:
|
||||
direction, trigger_price = "short", short_trigger
|
||||
elif long_triggered:
|
||||
direction, trigger_price = "long", long_trigger
|
||||
if direction is None:
|
||||
return None, None, None, None
|
||||
# 趋势过滤:逆势不交易(可选,这里做过滤)
|
||||
if trend == "long" and direction == "short":
|
||||
return None, None, None, None
|
||||
if trend == "short" and direction == "long":
|
||||
return None, None, None, None
|
||||
# 禁止时段
|
||||
if in_forbidden_period(curr["id"]):
|
||||
return None, None, None, None
|
||||
# 信号确认
|
||||
if use_confirm and CONFIRM_REQUIRED > 0:
|
||||
vol_ma = volume_ma_list[current_idx] if volume_ma_list and current_idx < len(volume_ma_list) else None
|
||||
n = check_signal_confirm(curr, direction, trigger_price, all_data, current_idx, vol_ma, CONFIRM_REQUIRED)
|
||||
if n < CONFIRM_REQUIRED:
|
||||
return None, None, None, None
|
||||
return direction, trigger_price, valid_prev_idx, prev
|
||||
|
||||
|
||||
def build_volume_ma(klines: List[Dict], period: int = VOLUME_MA_PERIOD) -> List[Optional[float]]:
|
||||
"""前 period-1 为 None,之后为 volume 的 SMA"""
|
||||
vol = [float(k.get("volume", 0)) for k in klines]
|
||||
out: List[Optional[float]] = [None] * (period - 1)
|
||||
for i in range(period - 1, len(vol)):
|
||||
out.append(sum(vol[i - period + 1 : i + 1]) / period)
|
||||
return out
|
||||
Reference in New Issue
Block a user