Files
lm_code/adaptive_third_strategy/strategy_core.py
ddrwode 970080a2e6 hahaa
2026-01-31 10:35:25 +08:00

306 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
自适应三分位趋势策略 - 核心逻辑
趋势过滤、动态阈值、信号确认、市场状态
"""
from typing import List, Dict, Optional, Tuple
import sys
import os
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if ROOT_DIR not in sys.path:
sys.path.insert(0, ROOT_DIR)
from adaptive_third_strategy.config import (
MIN_BODY_ATR_RATIO,
MIN_VOLATILITY_PERCENT,
EMA_SHORT,
EMA_LONG_FAST,
EMA_LONG_SLOW,
EMA_MID_FAST,
EMA_MID_SLOW,
ATR_PERIOD,
VOLATILITY_COEF_CLAMP,
BASE_COEF,
TREND_FAVOR_COEF,
TREND_AGAINST_COEF,
TREND_MODE,
CONFIRM_REQUIRED,
VOLUME_MA_PERIOD,
VOLUME_RATIO_THRESHOLD,
REVERSE_BREAK_MULT,
MIN_BARS_SINCE_ENTRY,
FORBIDDEN_PERIODS,
ATR_PAUSE_MULT,
STRONG_TREND_COEF,
RANGE_COEF,
HIGH_VOL_EXTRA_CONFIRM,
)
from adaptive_third_strategy.indicators import (
get_ema_atr_from_klines,
align_higher_tf_ema,
ema,
)
def get_body_size(candle: Dict) -> float:
return abs(float(candle["open"]) - float(candle["close"]))
def is_bullish(candle: Dict) -> bool:
return float(candle["close"]) > float(candle["open"])
def get_min_body_threshold(price: float, atr_value: Optional[float]) -> float:
"""有效K线最小实体 = max(ATR*0.1, 价格*0.05%)"""
min_vol = price * MIN_VOLATILITY_PERCENT
if atr_value is not None and atr_value > 0:
min_vol = max(min_vol, atr_value * MIN_BODY_ATR_RATIO)
return min_vol
def find_valid_prev_bar(
all_data: List[Dict],
current_idx: int,
atr_series: List[Optional[float]],
min_body_override: Optional[float] = None,
) -> Tuple[Optional[int], Optional[Dict]]:
"""从当前索引往前找实体>=阈值的K线。阈值 = max(ATR*0.1, 价格*0.05%)"""
if current_idx <= 0:
return None, None
for i in range(current_idx - 1, -1, -1):
prev = all_data[i]
body = get_body_size(prev)
price = float(prev["close"])
atr_val = atr_series[i] if i < len(atr_series) else None
th = min_body_override if min_body_override is not None else get_min_body_threshold(price, atr_val)
if body >= th:
return i, prev
return None, None
def get_trend(
klines_5m: List[Dict],
idx_5m: int,
ema_5m: List[Optional[float]],
ema_15m_align: List[Dict],
ema_60m_align: List[Dict],
) -> str:
"""
多时间框架趋势。返回 "long" / "short" / "neutral"
长期1h EMA50 vs EMA200中期15m EMA20 vs EMA50短期5m close vs EMA9。
ema_*_align: 与 5m 对齐的列表,每项 {"ema_fast", "ema_slow"}。
"""
if idx_5m >= len(klines_5m):
return "neutral"
curr = klines_5m[idx_5m]
close_5 = float(curr["close"])
# 短期
ema9 = ema_5m[idx_5m] if idx_5m < len(ema_5m) else None
short_bull = (ema9 is not None and close_5 > ema9)
# 中期
mid = ema_15m_align[idx_5m] if idx_5m < len(ema_15m_align) else {}
mid_bull = (mid.get("ema_fast") is not None and mid.get("ema_slow") is not None
and mid["ema_fast"] > mid["ema_slow"])
# 长期
long_d = ema_60m_align[idx_5m] if idx_5m < len(ema_60m_align) else {}
long_bull = (long_d.get("ema_fast") is not None and long_d.get("ema_slow") is not None
and long_d["ema_fast"] > long_d["ema_slow"])
if TREND_MODE == "aggressive":
return "long" if short_bull else "short"
if TREND_MODE == "conservative":
if mid_bull and short_bull:
return "long"
if not mid_bull and not short_bull:
return "short"
return "neutral"
# strict
if long_bull and mid_bull and short_bull:
return "long"
if not long_bull and not mid_bull and not short_bull:
return "short"
return "neutral"
def get_dynamic_trigger_levels(
prev: Dict,
atr_value: float,
trend: str,
market_state: str = "normal",
) -> Tuple[Optional[float], Optional[float]]:
"""
动态三分位触发价。
波动率系数 = clamp(实体/ATR, 0.3, 3.0),调整系数 = 0.33 * 波动率系数。
顺势方向 ×0.8,逆势 ×1.2。
"""
body = get_body_size(prev)
if body < 1e-6 or atr_value <= 0:
return None, None
vol_coef = body / atr_value
vol_coef = max(VOLATILITY_COEF_CLAMP[0], min(VOLATILITY_COEF_CLAMP[1], vol_coef))
adj = BASE_COEF * vol_coef
if market_state == "strong_trend":
adj = STRONG_TREND_COEF
elif market_state == "range":
adj = RANGE_COEF
p_close = float(prev["close"])
if trend == "long":
long_adj = adj * TREND_FAVOR_COEF
short_adj = adj * TREND_AGAINST_COEF
elif trend == "short":
long_adj = adj * TREND_AGAINST_COEF
short_adj = adj * TREND_FAVOR_COEF
else:
long_adj = short_adj = adj
long_trigger = p_close + body * long_adj
short_trigger = p_close - body * short_adj
return long_trigger, short_trigger
def check_signal_confirm(
curr: Dict,
direction: str,
trigger_price: float,
all_data: List[Dict],
current_idx: int,
volume_ma: Optional[float],
required: int = CONFIRM_REQUIRED,
) -> int:
"""
确认条件计数:收盘价确认、成交量确认、动量确认。
返回满足的个数。
"""
count = 0
c_close = float(curr["close"])
c_volume = float(curr.get("volume", 0))
# 1. 收盘价确认
if direction == "long" and c_close >= trigger_price:
count += 1
elif direction == "short" and c_close <= trigger_price:
count += 1
# 2. 成交量确认
if volume_ma is not None and volume_ma > 0 and c_volume >= volume_ma * VOLUME_RATIO_THRESHOLD:
count += 1
# 3. 动量确认当前K线实体方向与信号一致
if direction == "long" and is_bullish(curr):
count += 1
elif direction == "short" and not is_bullish(curr):
count += 1
return count
def in_forbidden_period(ts_sec: int) -> bool:
"""是否在禁止交易时段(按 UTC+8 小时:分)"""
from datetime import datetime, timezone
try:
dt = datetime.fromtimestamp(ts_sec, tz=timezone.utc)
except Exception:
dt = datetime.utcfromtimestamp(ts_sec)
# 转 UTC+8
hour = (dt.hour + 8) % 24
minute = dt.minute
for h1, m1, h2, m2 in FORBIDDEN_PERIODS:
t1 = h1 * 60 + m1
t2 = h2 * 60 + m2
t = hour * 60 + minute
if t1 <= t < t2:
return True
return False
def get_market_state(
atr_value: float,
atr_avg: Optional[float],
trend: str,
) -> str:
"""normal / strong_trend / range / high_vol"""
if atr_avg is not None and atr_avg > 0 and atr_value >= atr_avg * ATR_PAUSE_MULT:
return "high_vol"
if trend in ("long", "short") and atr_avg is not None and atr_value > atr_avg * 1.2:
return "strong_trend"
if trend == "neutral":
return "range"
return "normal"
def check_trigger(
all_data: List[Dict],
current_idx: int,
atr_series: List[Optional[float]],
ema_5m: List[Optional[float]],
ema_15m_align: List[Dict],
ema_60m_align: List[Dict],
volume_ma_list: Optional[List[Optional[float]]] = None,
use_confirm: bool = True,
) -> Tuple[Optional[str], Optional[float], Optional[int], Optional[Dict]]:
"""
检查当前K线是否产生有效信号含趋势过滤与确认
返回 (方向, 触发价, 有效前一根索引, 有效前一根K线) 或 (None, None, None, None)。
"""
if current_idx <= 0 or current_idx >= len(all_data):
return None, None, None, None
curr = all_data[current_idx]
valid_prev_idx, prev = find_valid_prev_bar(all_data, current_idx, atr_series)
if prev is None:
return None, None, None, None
atr_val = atr_series[current_idx] if current_idx < len(atr_series) else None
if atr_val is None or atr_val <= 0:
return None, None, None, None
trend = get_trend(
all_data, current_idx, ema_5m, ema_15m_align, ema_60m_align,
)
atr_avg = None
if atr_series:
valid_atr = [x for x in atr_series[: current_idx + 1] if x is not None and x > 0]
if len(valid_atr) >= ATR_PERIOD:
atr_avg = sum(valid_atr) / len(valid_atr)
market_state = get_market_state(atr_val, atr_avg, trend)
if market_state == "high_vol":
return None, None, None, None
long_trigger, short_trigger = get_dynamic_trigger_levels(prev, atr_val, trend, market_state)
if long_trigger is None:
return None, None, None, None
c_high = float(curr["high"])
c_low = float(curr["low"])
long_triggered = c_high >= long_trigger
short_triggered = c_low <= short_trigger
direction = None
trigger_price = None
if long_triggered and short_triggered:
c_open = float(curr["open"])
if abs(c_open - short_trigger) <= abs(c_open - long_trigger):
direction, trigger_price = "short", short_trigger
else:
direction, trigger_price = "long", long_trigger
elif short_triggered:
direction, trigger_price = "short", short_trigger
elif long_triggered:
direction, trigger_price = "long", long_trigger
if direction is None:
return None, None, None, None
# 趋势过滤:逆势不交易(可选,这里做过滤)
if trend == "long" and direction == "short":
return None, None, None, None
if trend == "short" and direction == "long":
return None, None, None, None
# 禁止时段
if in_forbidden_period(curr["id"]):
return None, None, None, None
# 信号确认
if use_confirm and CONFIRM_REQUIRED > 0:
vol_ma = volume_ma_list[current_idx] if volume_ma_list and current_idx < len(volume_ma_list) else None
n = check_signal_confirm(curr, direction, trigger_price, all_data, current_idx, vol_ma, CONFIRM_REQUIRED)
if n < CONFIRM_REQUIRED:
return None, None, None, None
return direction, trigger_price, valid_prev_idx, prev
def build_volume_ma(klines: List[Dict], period: int = VOLUME_MA_PERIOD) -> List[Optional[float]]:
"""前 period-1 为 None之后为 volume 的 SMA"""
vol = [float(k.get("volume", 0)) for k in klines]
out: List[Optional[float]] = [None] * (period - 1)
for i in range(period - 1, len(vol)):
out.append(sum(vol[i - period + 1 : i + 1]) / period)
return out