257 lines
10 KiB
Python
257 lines
10 KiB
Python
|
|
"""
|
|||
|
|
方案A:统计筛选 + 规则组合策略
|
|||
|
|
1. 从52个指标中用统计方法筛选最有效的指标
|
|||
|
|
2. 用经典规则组合生成交易信号
|
|||
|
|
3. 网格搜索优化参数
|
|||
|
|
"""
|
|||
|
|
import pandas as pd
|
|||
|
|
import numpy as np
|
|||
|
|
from sklearn.ensemble import RandomForestClassifier
|
|||
|
|
from sklearn.feature_selection import mutual_info_classif
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
from .config import STAT_CONFIG as SC, PRIMARY_PERIOD, AUX_PERIODS
|
|||
|
|
from .feature_engine import prepare_dataset
|
|||
|
|
from .backtest import BacktestEngine, print_metrics
|
|||
|
|
|
|||
|
|
|
|||
|
|
class StatStrategy:
|
|||
|
|
"""统计筛选策略"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self.top_features = []
|
|||
|
|
self.feature_scores = {}
|
|||
|
|
self.best_params = {}
|
|||
|
|
|
|||
|
|
def select_features(self, X: pd.DataFrame, y: pd.Series) -> list:
|
|||
|
|
"""
|
|||
|
|
用多种统计方法筛选有效指标
|
|||
|
|
:return: Top N 特征名列表
|
|||
|
|
"""
|
|||
|
|
logger.info("=" * 50)
|
|||
|
|
logger.info("开始特征筛选...")
|
|||
|
|
scores = {}
|
|||
|
|
|
|||
|
|
# 1. 皮尔逊相关系数
|
|||
|
|
logger.info("计算皮尔逊相关系数...")
|
|||
|
|
corr_scores = X.corrwith(y).abs().fillna(0)
|
|||
|
|
for col in X.columns:
|
|||
|
|
scores[col] = scores.get(col, 0) + corr_scores.get(col, 0)
|
|||
|
|
|
|||
|
|
# 2. 互信息
|
|||
|
|
logger.info("计算互信息...")
|
|||
|
|
mi = mutual_info_classif(X.fillna(0), y, random_state=42)
|
|||
|
|
mi_series = pd.Series(mi, index=X.columns)
|
|||
|
|
mi_norm = mi_series / mi_series.max() if mi_series.max() > 0 else mi_series
|
|||
|
|
for col in X.columns:
|
|||
|
|
scores[col] = scores.get(col, 0) + mi_norm.get(col, 0)
|
|||
|
|
|
|||
|
|
# 3. 随机森林特征重要性
|
|||
|
|
logger.info("训练随机森林评估特征重要性...")
|
|||
|
|
rf = RandomForestClassifier(n_estimators=200, max_depth=8, random_state=42, n_jobs=-1)
|
|||
|
|
rf.fit(X.fillna(0), y)
|
|||
|
|
rf_imp = pd.Series(rf.feature_importances_, index=X.columns)
|
|||
|
|
rf_norm = rf_imp / rf_imp.max() if rf_imp.max() > 0 else rf_imp
|
|||
|
|
for col in X.columns:
|
|||
|
|
scores[col] = scores.get(col, 0) + rf_norm.get(col, 0)
|
|||
|
|
|
|||
|
|
# 综合排名
|
|||
|
|
score_series = pd.Series(scores).sort_values(ascending=False)
|
|||
|
|
self.feature_scores = score_series.to_dict()
|
|||
|
|
|
|||
|
|
# 去除高相关特征
|
|||
|
|
top_candidates = score_series.head(SC['top_n_features'] * 2).index.tolist()
|
|||
|
|
selected = self._remove_correlated(X[top_candidates], SC['correlation_threshold'])
|
|||
|
|
self.top_features = selected[:SC['top_n_features']]
|
|||
|
|
|
|||
|
|
logger.info(f"筛选出 Top {len(self.top_features)} 特征:")
|
|||
|
|
for i, feat in enumerate(self.top_features):
|
|||
|
|
logger.info(f" {i+1}. {feat} (综合得分: {score_series[feat]:.4f})")
|
|||
|
|
|
|||
|
|
return self.top_features
|
|||
|
|
|
|||
|
|
def _remove_correlated(self, X: pd.DataFrame, threshold: float) -> list:
|
|||
|
|
"""去除高度相关的冗余特征"""
|
|||
|
|
corr_matrix = X.corr().abs()
|
|||
|
|
selected = list(X.columns)
|
|||
|
|
to_remove = set()
|
|||
|
|
|
|||
|
|
for i in range(len(selected)):
|
|||
|
|
if selected[i] in to_remove:
|
|||
|
|
continue
|
|||
|
|
for j in range(i + 1, len(selected)):
|
|||
|
|
if selected[j] in to_remove:
|
|||
|
|
continue
|
|||
|
|
if corr_matrix.loc[selected[i], selected[j]] > threshold:
|
|||
|
|
to_remove.add(selected[j])
|
|||
|
|
|
|||
|
|
result = [c for c in selected if c not in to_remove]
|
|||
|
|
if to_remove:
|
|||
|
|
logger.info(f"移除 {len(to_remove)} 个高相关冗余特征")
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def generate_signals(self, df: pd.DataFrame) -> pd.Series:
|
|||
|
|
"""
|
|||
|
|
基于筛选出的指标,用规则组合生成交易信号
|
|||
|
|
:param df: 包含指标列的 DataFrame(原始值,非标准化)
|
|||
|
|
:return: 信号 Series (0=观望, 1=做多, 2=做空)
|
|||
|
|
"""
|
|||
|
|
signals = pd.Series(0, index=df.index)
|
|||
|
|
long_score = pd.Series(0.0, index=df.index)
|
|||
|
|
short_score = pd.Series(0.0, index=df.index)
|
|||
|
|
matched = 0
|
|||
|
|
|
|||
|
|
for feat in self.top_features:
|
|||
|
|
if feat not in df.columns:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
col = df[feat]
|
|||
|
|
base = feat.split('_lag')[0] # 去掉 _lagN 后缀
|
|||
|
|
# 去掉辅助周期后缀 _5m / _60m
|
|||
|
|
for suffix in ['_5m', '_60m', '_3m', '_15m', '_30m', '_1m']:
|
|||
|
|
if base.endswith(suffix):
|
|||
|
|
base = base[:-len(suffix)]
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if 'rsi' in base:
|
|||
|
|
long_score += (col < 35).astype(float)
|
|||
|
|
short_score += (col > 65).astype(float)
|
|||
|
|
matched += 1
|
|||
|
|
elif base == 'macd_hist':
|
|||
|
|
long_score += (col > 0).astype(float)
|
|||
|
|
short_score += (col < 0).astype(float)
|
|||
|
|
matched += 1
|
|||
|
|
elif base == 'macd':
|
|||
|
|
long_score += (col > 0).astype(float)
|
|||
|
|
short_score += (col < 0).astype(float)
|
|||
|
|
matched += 1
|
|||
|
|
elif 'bb_pband' in base:
|
|||
|
|
long_score += (col < 0.2).astype(float)
|
|||
|
|
short_score += (col > 0.8).astype(float)
|
|||
|
|
matched += 1
|
|||
|
|
elif 'adx' in base:
|
|||
|
|
long_score += (col > 25).astype(float)
|
|||
|
|
short_score += (col > 25).astype(float)
|
|||
|
|
matched += 1
|
|||
|
|
elif 'cci' in base:
|
|||
|
|
long_score += (col < -100).astype(float)
|
|||
|
|
short_score += (col > 100).astype(float)
|
|||
|
|
matched += 1
|
|||
|
|
elif 'stoch_k' in base or 'stoch_rsi_k' in base:
|
|||
|
|
long_score += (col < 25).astype(float)
|
|||
|
|
short_score += (col > 75).astype(float)
|
|||
|
|
matched += 1
|
|||
|
|
elif 'williams_r' in base:
|
|||
|
|
long_score += (col < -80).astype(float)
|
|||
|
|
short_score += (col > -20).astype(float)
|
|||
|
|
matched += 1
|
|||
|
|
elif 'ao' in base or 'tsi' in base or 'roc' in base or 'ppo' in base:
|
|||
|
|
long_score += (col > 0).astype(float)
|
|||
|
|
short_score += (col < 0).astype(float)
|
|||
|
|
matched += 1
|
|||
|
|
elif 'atr' in base or 'volatility_ratio' in base:
|
|||
|
|
# 波动率类:高波动时趋势更强,用均值分界
|
|||
|
|
median = col.rolling(200, min_periods=50).median()
|
|||
|
|
long_score += (col > median).astype(float) * 0.5
|
|||
|
|
short_score += (col > median).astype(float) * 0.5
|
|||
|
|
matched += 1
|
|||
|
|
elif 'high_low_range' in base or 'body_ratio' in base:
|
|||
|
|
median = col.rolling(200, min_periods=50).median()
|
|||
|
|
long_score += (col > median).astype(float) * 0.3
|
|||
|
|
short_score += (col > median).astype(float) * 0.3
|
|||
|
|
matched += 1
|
|||
|
|
elif 'bb_width' in base:
|
|||
|
|
# 布林带宽度收窄后扩张 = 突破信号,结合价格方向
|
|||
|
|
median = col.rolling(200, min_periods=50).median()
|
|||
|
|
expanding = col > col.shift(1) # 宽度在扩张
|
|||
|
|
was_narrow = col.shift(1) < median # 之前是收窄的
|
|||
|
|
breakout = expanding & was_narrow
|
|||
|
|
if 'close' in df.columns:
|
|||
|
|
price_up = df['close'] > df['close'].shift(1)
|
|||
|
|
long_score += (breakout & price_up).astype(float)
|
|||
|
|
short_score += (breakout & ~price_up).astype(float)
|
|||
|
|
else:
|
|||
|
|
long_score += breakout.astype(float) * 0.5
|
|||
|
|
short_score += breakout.astype(float) * 0.5
|
|||
|
|
matched += 1
|
|||
|
|
elif 'close_sma20_ratio' in base or 'close_ema12_ratio' in base:
|
|||
|
|
# 价格在均线上方=多头,下方=空头
|
|||
|
|
long_score += (col > 1.0).astype(float)
|
|||
|
|
short_score += (col < 1.0).astype(float)
|
|||
|
|
matched += 1
|
|||
|
|
elif 'ichimoku' in base:
|
|||
|
|
if 'close' in df.columns:
|
|||
|
|
long_score += (df['close'] > col).astype(float)
|
|||
|
|
short_score += (df['close'] < col).astype(float)
|
|||
|
|
matched += 1
|
|||
|
|
elif 'momentum' in base or 'price_change' in base:
|
|||
|
|
long_score += (col > 0).astype(float)
|
|||
|
|
short_score += (col < 0).astype(float)
|
|||
|
|
matched += 1
|
|||
|
|
|
|||
|
|
logger.info(f"规则匹配: {matched}/{len(self.top_features)} 个特征有对应规则")
|
|||
|
|
|
|||
|
|
# 阈值:至少50%的匹配特征同时确认(更严格)
|
|||
|
|
threshold = max(3, matched * 0.5)
|
|||
|
|
logger.info(f"信号阈值: {threshold:.1f} (需要至少这么多指标同时确认)")
|
|||
|
|
signals[long_score >= threshold] = 1
|
|||
|
|
signals[short_score >= threshold] = 2
|
|||
|
|
# 多空同时满足时取更强的
|
|||
|
|
both = (long_score >= threshold) & (short_score >= threshold)
|
|||
|
|
signals[both & (long_score > short_score)] = 1
|
|||
|
|
signals[both & (short_score > long_score)] = 2
|
|||
|
|
signals[both & (long_score == short_score)] = 0
|
|||
|
|
|
|||
|
|
dist = signals.value_counts().to_dict()
|
|||
|
|
logger.info(f"规则信号分布: 观望={dist.get(0, 0)}, 做多={dist.get(1, 0)}, 做空={dist.get(2, 0)}")
|
|||
|
|
return signals
|
|||
|
|
|
|||
|
|
def run(self, period: int = None, start_date: str = None, end_date: str = None) -> dict:
|
|||
|
|
"""
|
|||
|
|
完整运行方案A
|
|||
|
|
:return: 回测结果
|
|||
|
|
"""
|
|||
|
|
if period is None:
|
|||
|
|
period = PRIMARY_PERIOD
|
|||
|
|
|
|||
|
|
logger.info("=" * 60)
|
|||
|
|
logger.info("方案A:统计筛选 + 规则组合策略")
|
|||
|
|
logger.info("=" * 60)
|
|||
|
|
|
|||
|
|
# 1. 准备数据(标准化版本,用于特征筛选)
|
|||
|
|
X, y, feature_names, _ = prepare_dataset(period, start_date, end_date)
|
|||
|
|
|
|||
|
|
# 2. 筛选特征
|
|||
|
|
self.select_features(X, y)
|
|||
|
|
|
|||
|
|
# 3. 构建完整特征矩阵(原始值,非标准化,用于规则判断)
|
|||
|
|
from .data_loader import load_kline, load_multi_period
|
|||
|
|
from .feature_engine import build_features
|
|||
|
|
primary_df = load_kline(period, start_date, end_date)
|
|||
|
|
aux_dfs = {}
|
|||
|
|
for aux_p in AUX_PERIODS:
|
|||
|
|
try:
|
|||
|
|
aux_df = load_kline(aux_p, start_date, end_date)
|
|||
|
|
if not aux_df.empty:
|
|||
|
|
aux_dfs[aux_p] = aux_df
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
df = build_features(primary_df, aux_dfs)
|
|||
|
|
df.dropna(inplace=True)
|
|||
|
|
|
|||
|
|
# 4. 生成信号
|
|||
|
|
signals = self.generate_signals(df)
|
|||
|
|
|
|||
|
|
# 5. 回测
|
|||
|
|
engine = BacktestEngine()
|
|||
|
|
result = engine.run(df['close'], signals)
|
|||
|
|
|
|||
|
|
print_metrics(result['metrics'], "方案A: 统计筛选策略")
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_stat_strategy(period: int = None, start_date: str = None, end_date: str = None) -> dict:
|
|||
|
|
"""方案A快捷入口"""
|
|||
|
|
strategy = StatStrategy()
|
|||
|
|
return strategy.run(period, start_date, end_date)
|