257 lines
10 KiB
Python
257 lines
10 KiB
Python
"""
|
||
方案A:统计筛选 + 规则组合策略
|
||
1. 从52个指标中用统计方法筛选最有效的指标
|
||
2. 用经典规则组合生成交易信号
|
||
3. 网格搜索优化参数
|
||
"""
|
||
import pandas as pd
|
||
import numpy as np
|
||
from sklearn.ensemble import RandomForestClassifier
|
||
from sklearn.feature_selection import mutual_info_classif
|
||
from loguru import logger
|
||
|
||
from .config import STAT_CONFIG as SC, PRIMARY_PERIOD, AUX_PERIODS
|
||
from .feature_engine import prepare_dataset
|
||
from .backtest import BacktestEngine, print_metrics
|
||
|
||
|
||
class StatStrategy:
|
||
"""统计筛选策略"""
|
||
|
||
def __init__(self):
|
||
self.top_features = []
|
||
self.feature_scores = {}
|
||
self.best_params = {}
|
||
|
||
def select_features(self, X: pd.DataFrame, y: pd.Series) -> list:
|
||
"""
|
||
用多种统计方法筛选有效指标
|
||
:return: Top N 特征名列表
|
||
"""
|
||
logger.info("=" * 50)
|
||
logger.info("开始特征筛选...")
|
||
scores = {}
|
||
|
||
# 1. 皮尔逊相关系数
|
||
logger.info("计算皮尔逊相关系数...")
|
||
corr_scores = X.corrwith(y).abs().fillna(0)
|
||
for col in X.columns:
|
||
scores[col] = scores.get(col, 0) + corr_scores.get(col, 0)
|
||
|
||
# 2. 互信息
|
||
logger.info("计算互信息...")
|
||
mi = mutual_info_classif(X.fillna(0), y, random_state=42)
|
||
mi_series = pd.Series(mi, index=X.columns)
|
||
mi_norm = mi_series / mi_series.max() if mi_series.max() > 0 else mi_series
|
||
for col in X.columns:
|
||
scores[col] = scores.get(col, 0) + mi_norm.get(col, 0)
|
||
|
||
# 3. 随机森林特征重要性
|
||
logger.info("训练随机森林评估特征重要性...")
|
||
rf = RandomForestClassifier(n_estimators=200, max_depth=8, random_state=42, n_jobs=-1)
|
||
rf.fit(X.fillna(0), y)
|
||
rf_imp = pd.Series(rf.feature_importances_, index=X.columns)
|
||
rf_norm = rf_imp / rf_imp.max() if rf_imp.max() > 0 else rf_imp
|
||
for col in X.columns:
|
||
scores[col] = scores.get(col, 0) + rf_norm.get(col, 0)
|
||
|
||
# 综合排名
|
||
score_series = pd.Series(scores).sort_values(ascending=False)
|
||
self.feature_scores = score_series.to_dict()
|
||
|
||
# 去除高相关特征
|
||
top_candidates = score_series.head(SC['top_n_features'] * 2).index.tolist()
|
||
selected = self._remove_correlated(X[top_candidates], SC['correlation_threshold'])
|
||
self.top_features = selected[:SC['top_n_features']]
|
||
|
||
logger.info(f"筛选出 Top {len(self.top_features)} 特征:")
|
||
for i, feat in enumerate(self.top_features):
|
||
logger.info(f" {i+1}. {feat} (综合得分: {score_series[feat]:.4f})")
|
||
|
||
return self.top_features
|
||
|
||
def _remove_correlated(self, X: pd.DataFrame, threshold: float) -> list:
|
||
"""去除高度相关的冗余特征"""
|
||
corr_matrix = X.corr().abs()
|
||
selected = list(X.columns)
|
||
to_remove = set()
|
||
|
||
for i in range(len(selected)):
|
||
if selected[i] in to_remove:
|
||
continue
|
||
for j in range(i + 1, len(selected)):
|
||
if selected[j] in to_remove:
|
||
continue
|
||
if corr_matrix.loc[selected[i], selected[j]] > threshold:
|
||
to_remove.add(selected[j])
|
||
|
||
result = [c for c in selected if c not in to_remove]
|
||
if to_remove:
|
||
logger.info(f"移除 {len(to_remove)} 个高相关冗余特征")
|
||
return result
|
||
|
||
def generate_signals(self, df: pd.DataFrame) -> pd.Series:
|
||
"""
|
||
基于筛选出的指标,用规则组合生成交易信号
|
||
:param df: 包含指标列的 DataFrame(原始值,非标准化)
|
||
:return: 信号 Series (0=观望, 1=做多, 2=做空)
|
||
"""
|
||
signals = pd.Series(0, index=df.index)
|
||
long_score = pd.Series(0.0, index=df.index)
|
||
short_score = pd.Series(0.0, index=df.index)
|
||
matched = 0
|
||
|
||
for feat in self.top_features:
|
||
if feat not in df.columns:
|
||
continue
|
||
|
||
col = df[feat]
|
||
base = feat.split('_lag')[0] # 去掉 _lagN 后缀
|
||
# 去掉辅助周期后缀 _5m / _60m
|
||
for suffix in ['_5m', '_60m', '_3m', '_15m', '_30m', '_1m']:
|
||
if base.endswith(suffix):
|
||
base = base[:-len(suffix)]
|
||
break
|
||
|
||
if 'rsi' in base:
|
||
long_score += (col < 35).astype(float)
|
||
short_score += (col > 65).astype(float)
|
||
matched += 1
|
||
elif base == 'macd_hist':
|
||
long_score += (col > 0).astype(float)
|
||
short_score += (col < 0).astype(float)
|
||
matched += 1
|
||
elif base == 'macd':
|
||
long_score += (col > 0).astype(float)
|
||
short_score += (col < 0).astype(float)
|
||
matched += 1
|
||
elif 'bb_pband' in base:
|
||
long_score += (col < 0.2).astype(float)
|
||
short_score += (col > 0.8).astype(float)
|
||
matched += 1
|
||
elif 'adx' in base:
|
||
long_score += (col > 25).astype(float)
|
||
short_score += (col > 25).astype(float)
|
||
matched += 1
|
||
elif 'cci' in base:
|
||
long_score += (col < -100).astype(float)
|
||
short_score += (col > 100).astype(float)
|
||
matched += 1
|
||
elif 'stoch_k' in base or 'stoch_rsi_k' in base:
|
||
long_score += (col < 25).astype(float)
|
||
short_score += (col > 75).astype(float)
|
||
matched += 1
|
||
elif 'williams_r' in base:
|
||
long_score += (col < -80).astype(float)
|
||
short_score += (col > -20).astype(float)
|
||
matched += 1
|
||
elif 'ao' in base or 'tsi' in base or 'roc' in base or 'ppo' in base:
|
||
long_score += (col > 0).astype(float)
|
||
short_score += (col < 0).astype(float)
|
||
matched += 1
|
||
elif 'atr' in base or 'volatility_ratio' in base:
|
||
# 波动率类:高波动时趋势更强,用均值分界
|
||
median = col.rolling(200, min_periods=50).median()
|
||
long_score += (col > median).astype(float) * 0.5
|
||
short_score += (col > median).astype(float) * 0.5
|
||
matched += 1
|
||
elif 'high_low_range' in base or 'body_ratio' in base:
|
||
median = col.rolling(200, min_periods=50).median()
|
||
long_score += (col > median).astype(float) * 0.3
|
||
short_score += (col > median).astype(float) * 0.3
|
||
matched += 1
|
||
elif 'bb_width' in base:
|
||
# 布林带宽度收窄后扩张 = 突破信号,结合价格方向
|
||
median = col.rolling(200, min_periods=50).median()
|
||
expanding = col > col.shift(1) # 宽度在扩张
|
||
was_narrow = col.shift(1) < median # 之前是收窄的
|
||
breakout = expanding & was_narrow
|
||
if 'close' in df.columns:
|
||
price_up = df['close'] > df['close'].shift(1)
|
||
long_score += (breakout & price_up).astype(float)
|
||
short_score += (breakout & ~price_up).astype(float)
|
||
else:
|
||
long_score += breakout.astype(float) * 0.5
|
||
short_score += breakout.astype(float) * 0.5
|
||
matched += 1
|
||
elif 'close_sma20_ratio' in base or 'close_ema12_ratio' in base:
|
||
# 价格在均线上方=多头,下方=空头
|
||
long_score += (col > 1.0).astype(float)
|
||
short_score += (col < 1.0).astype(float)
|
||
matched += 1
|
||
elif 'ichimoku' in base:
|
||
if 'close' in df.columns:
|
||
long_score += (df['close'] > col).astype(float)
|
||
short_score += (df['close'] < col).astype(float)
|
||
matched += 1
|
||
elif 'momentum' in base or 'price_change' in base:
|
||
long_score += (col > 0).astype(float)
|
||
short_score += (col < 0).astype(float)
|
||
matched += 1
|
||
|
||
logger.info(f"规则匹配: {matched}/{len(self.top_features)} 个特征有对应规则")
|
||
|
||
# 阈值:至少50%的匹配特征同时确认(更严格)
|
||
threshold = max(3, matched * 0.5)
|
||
logger.info(f"信号阈值: {threshold:.1f} (需要至少这么多指标同时确认)")
|
||
signals[long_score >= threshold] = 1
|
||
signals[short_score >= threshold] = 2
|
||
# 多空同时满足时取更强的
|
||
both = (long_score >= threshold) & (short_score >= threshold)
|
||
signals[both & (long_score > short_score)] = 1
|
||
signals[both & (short_score > long_score)] = 2
|
||
signals[both & (long_score == short_score)] = 0
|
||
|
||
dist = signals.value_counts().to_dict()
|
||
logger.info(f"规则信号分布: 观望={dist.get(0, 0)}, 做多={dist.get(1, 0)}, 做空={dist.get(2, 0)}")
|
||
return signals
|
||
|
||
def run(self, period: int = None, start_date: str = None, end_date: str = None) -> dict:
|
||
"""
|
||
完整运行方案A
|
||
:return: 回测结果
|
||
"""
|
||
if period is None:
|
||
period = PRIMARY_PERIOD
|
||
|
||
logger.info("=" * 60)
|
||
logger.info("方案A:统计筛选 + 规则组合策略")
|
||
logger.info("=" * 60)
|
||
|
||
# 1. 准备数据(标准化版本,用于特征筛选)
|
||
X, y, feature_names, _ = prepare_dataset(period, start_date, end_date)
|
||
|
||
# 2. 筛选特征
|
||
self.select_features(X, y)
|
||
|
||
# 3. 构建完整特征矩阵(原始值,非标准化,用于规则判断)
|
||
from .data_loader import load_kline, load_multi_period
|
||
from .feature_engine import build_features
|
||
primary_df = load_kline(period, start_date, end_date)
|
||
aux_dfs = {}
|
||
for aux_p in AUX_PERIODS:
|
||
try:
|
||
aux_df = load_kline(aux_p, start_date, end_date)
|
||
if not aux_df.empty:
|
||
aux_dfs[aux_p] = aux_df
|
||
except Exception:
|
||
pass
|
||
df = build_features(primary_df, aux_dfs)
|
||
df.dropna(inplace=True)
|
||
|
||
# 4. 生成信号
|
||
signals = self.generate_signals(df)
|
||
|
||
# 5. 回测
|
||
engine = BacktestEngine()
|
||
result = engine.run(df['close'], signals)
|
||
|
||
print_metrics(result['metrics'], "方案A: 统计筛选策略")
|
||
return result
|
||
|
||
|
||
def run_stat_strategy(period: int = None, start_date: str = None, end_date: str = None) -> dict:
|
||
"""方案A快捷入口"""
|
||
strategy = StatStrategy()
|
||
return strategy.run(period, start_date, end_date)
|