Files
jyx_code4/strategy/stat_strategy.py
ddrwode 21f2adc4a4 哈哈
2026-02-20 20:57:25 +08:00

257 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
方案A统计筛选 + 规则组合策略
1. 从52个指标中用统计方法筛选最有效的指标
2. 用经典规则组合生成交易信号
3. 网格搜索优化参数
"""
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import mutual_info_classif
from loguru import logger
from .config import STAT_CONFIG as SC, PRIMARY_PERIOD, AUX_PERIODS
from .feature_engine import prepare_dataset
from .backtest import BacktestEngine, print_metrics
class StatStrategy:
"""统计筛选策略"""
def __init__(self):
self.top_features = []
self.feature_scores = {}
self.best_params = {}
def select_features(self, X: pd.DataFrame, y: pd.Series) -> list:
"""
用多种统计方法筛选有效指标
:return: Top N 特征名列表
"""
logger.info("=" * 50)
logger.info("开始特征筛选...")
scores = {}
# 1. 皮尔逊相关系数
logger.info("计算皮尔逊相关系数...")
corr_scores = X.corrwith(y).abs().fillna(0)
for col in X.columns:
scores[col] = scores.get(col, 0) + corr_scores.get(col, 0)
# 2. 互信息
logger.info("计算互信息...")
mi = mutual_info_classif(X.fillna(0), y, random_state=42)
mi_series = pd.Series(mi, index=X.columns)
mi_norm = mi_series / mi_series.max() if mi_series.max() > 0 else mi_series
for col in X.columns:
scores[col] = scores.get(col, 0) + mi_norm.get(col, 0)
# 3. 随机森林特征重要性
logger.info("训练随机森林评估特征重要性...")
rf = RandomForestClassifier(n_estimators=200, max_depth=8, random_state=42, n_jobs=-1)
rf.fit(X.fillna(0), y)
rf_imp = pd.Series(rf.feature_importances_, index=X.columns)
rf_norm = rf_imp / rf_imp.max() if rf_imp.max() > 0 else rf_imp
for col in X.columns:
scores[col] = scores.get(col, 0) + rf_norm.get(col, 0)
# 综合排名
score_series = pd.Series(scores).sort_values(ascending=False)
self.feature_scores = score_series.to_dict()
# 去除高相关特征
top_candidates = score_series.head(SC['top_n_features'] * 2).index.tolist()
selected = self._remove_correlated(X[top_candidates], SC['correlation_threshold'])
self.top_features = selected[:SC['top_n_features']]
logger.info(f"筛选出 Top {len(self.top_features)} 特征:")
for i, feat in enumerate(self.top_features):
logger.info(f" {i+1}. {feat} (综合得分: {score_series[feat]:.4f})")
return self.top_features
def _remove_correlated(self, X: pd.DataFrame, threshold: float) -> list:
"""去除高度相关的冗余特征"""
corr_matrix = X.corr().abs()
selected = list(X.columns)
to_remove = set()
for i in range(len(selected)):
if selected[i] in to_remove:
continue
for j in range(i + 1, len(selected)):
if selected[j] in to_remove:
continue
if corr_matrix.loc[selected[i], selected[j]] > threshold:
to_remove.add(selected[j])
result = [c for c in selected if c not in to_remove]
if to_remove:
logger.info(f"移除 {len(to_remove)} 个高相关冗余特征")
return result
def generate_signals(self, df: pd.DataFrame) -> pd.Series:
"""
基于筛选出的指标,用规则组合生成交易信号
:param df: 包含指标列的 DataFrame原始值非标准化
:return: 信号 Series (0=观望, 1=做多, 2=做空)
"""
signals = pd.Series(0, index=df.index)
long_score = pd.Series(0.0, index=df.index)
short_score = pd.Series(0.0, index=df.index)
matched = 0
for feat in self.top_features:
if feat not in df.columns:
continue
col = df[feat]
base = feat.split('_lag')[0] # 去掉 _lagN 后缀
# 去掉辅助周期后缀 _5m / _60m
for suffix in ['_5m', '_60m', '_3m', '_15m', '_30m', '_1m']:
if base.endswith(suffix):
base = base[:-len(suffix)]
break
if 'rsi' in base:
long_score += (col < 35).astype(float)
short_score += (col > 65).astype(float)
matched += 1
elif base == 'macd_hist':
long_score += (col > 0).astype(float)
short_score += (col < 0).astype(float)
matched += 1
elif base == 'macd':
long_score += (col > 0).astype(float)
short_score += (col < 0).astype(float)
matched += 1
elif 'bb_pband' in base:
long_score += (col < 0.2).astype(float)
short_score += (col > 0.8).astype(float)
matched += 1
elif 'adx' in base:
long_score += (col > 25).astype(float)
short_score += (col > 25).astype(float)
matched += 1
elif 'cci' in base:
long_score += (col < -100).astype(float)
short_score += (col > 100).astype(float)
matched += 1
elif 'stoch_k' in base or 'stoch_rsi_k' in base:
long_score += (col < 25).astype(float)
short_score += (col > 75).astype(float)
matched += 1
elif 'williams_r' in base:
long_score += (col < -80).astype(float)
short_score += (col > -20).astype(float)
matched += 1
elif 'ao' in base or 'tsi' in base or 'roc' in base or 'ppo' in base:
long_score += (col > 0).astype(float)
short_score += (col < 0).astype(float)
matched += 1
elif 'atr' in base or 'volatility_ratio' in base:
# 波动率类:高波动时趋势更强,用均值分界
median = col.rolling(200, min_periods=50).median()
long_score += (col > median).astype(float) * 0.5
short_score += (col > median).astype(float) * 0.5
matched += 1
elif 'high_low_range' in base or 'body_ratio' in base:
median = col.rolling(200, min_periods=50).median()
long_score += (col > median).astype(float) * 0.3
short_score += (col > median).astype(float) * 0.3
matched += 1
elif 'bb_width' in base:
# 布林带宽度收窄后扩张 = 突破信号,结合价格方向
median = col.rolling(200, min_periods=50).median()
expanding = col > col.shift(1) # 宽度在扩张
was_narrow = col.shift(1) < median # 之前是收窄的
breakout = expanding & was_narrow
if 'close' in df.columns:
price_up = df['close'] > df['close'].shift(1)
long_score += (breakout & price_up).astype(float)
short_score += (breakout & ~price_up).astype(float)
else:
long_score += breakout.astype(float) * 0.5
short_score += breakout.astype(float) * 0.5
matched += 1
elif 'close_sma20_ratio' in base or 'close_ema12_ratio' in base:
# 价格在均线上方=多头,下方=空头
long_score += (col > 1.0).astype(float)
short_score += (col < 1.0).astype(float)
matched += 1
elif 'ichimoku' in base:
if 'close' in df.columns:
long_score += (df['close'] > col).astype(float)
short_score += (df['close'] < col).astype(float)
matched += 1
elif 'momentum' in base or 'price_change' in base:
long_score += (col > 0).astype(float)
short_score += (col < 0).astype(float)
matched += 1
logger.info(f"规则匹配: {matched}/{len(self.top_features)} 个特征有对应规则")
# 阈值至少50%的匹配特征同时确认(更严格)
threshold = max(3, matched * 0.5)
logger.info(f"信号阈值: {threshold:.1f} (需要至少这么多指标同时确认)")
signals[long_score >= threshold] = 1
signals[short_score >= threshold] = 2
# 多空同时满足时取更强的
both = (long_score >= threshold) & (short_score >= threshold)
signals[both & (long_score > short_score)] = 1
signals[both & (short_score > long_score)] = 2
signals[both & (long_score == short_score)] = 0
dist = signals.value_counts().to_dict()
logger.info(f"规则信号分布: 观望={dist.get(0, 0)}, 做多={dist.get(1, 0)}, 做空={dist.get(2, 0)}")
return signals
def run(self, period: int = None, start_date: str = None, end_date: str = None) -> dict:
"""
完整运行方案A
:return: 回测结果
"""
if period is None:
period = PRIMARY_PERIOD
logger.info("=" * 60)
logger.info("方案A统计筛选 + 规则组合策略")
logger.info("=" * 60)
# 1. 准备数据(标准化版本,用于特征筛选)
X, y, feature_names, _ = prepare_dataset(period, start_date, end_date)
# 2. 筛选特征
self.select_features(X, y)
# 3. 构建完整特征矩阵(原始值,非标准化,用于规则判断)
from .data_loader import load_kline, load_multi_period
from .feature_engine import build_features
primary_df = load_kline(period, start_date, end_date)
aux_dfs = {}
for aux_p in AUX_PERIODS:
try:
aux_df = load_kline(aux_p, start_date, end_date)
if not aux_df.empty:
aux_dfs[aux_p] = aux_df
except Exception:
pass
df = build_features(primary_df, aux_dfs)
df.dropna(inplace=True)
# 4. 生成信号
signals = self.generate_signals(df)
# 5. 回测
engine = BacktestEngine()
result = engine.run(df['close'], signals)
print_metrics(result['metrics'], "方案A: 统计筛选策略")
return result
def run_stat_strategy(period: int = None, start_date: str = None, end_date: str = None) -> dict:
"""方案A快捷入口"""
strategy = StatStrategy()
return strategy.run(period, start_date, end_date)