哈哈
This commit is contained in:
BIN
reports/strategy_comparison.png
Normal file
BIN
reports/strategy_comparison.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 186 KiB |
9
requirements.txt
Normal file
9
requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
pandas>=2.0
|
||||
numpy>=1.24
|
||||
ta>=0.11.0
|
||||
scikit-learn>=1.3
|
||||
lightgbm>=4.0
|
||||
xgboost>=2.0
|
||||
matplotlib>=3.7
|
||||
peewee>=3.16
|
||||
loguru>=0.7
|
||||
1
strategy/__init__.py
Normal file
1
strategy/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""52指标AI交易策略系统"""
|
||||
214
strategy/ai_strategy.py
Normal file
214
strategy/ai_strategy.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
方案B:AI模型训练 + 信号生成
|
||||
使用 LightGBM / XGBoost,Walk-Forward 滚动训练
|
||||
"""
|
||||
import json
|
||||
import joblib
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
from .config import MODEL_CONFIG as MC, PRIMARY_PERIOD, PROJECT_ROOT
|
||||
from .feature_engine import prepare_dataset, get_latest_feature_row
|
||||
from .backtest import BacktestEngine, print_metrics
|
||||
|
||||
SCHEME_B_MODEL_DIR = PROJECT_ROOT / 'models'
|
||||
SCHEME_B_MODEL_FILE = SCHEME_B_MODEL_DIR / 'scheme_b_last_model.joblib'
|
||||
SCHEME_B_SCALER_FILE = SCHEME_B_MODEL_DIR / 'scheme_b_scaler.joblib'
|
||||
SCHEME_B_FEATURES_FILE = SCHEME_B_MODEL_DIR / 'scheme_b_features.json'
|
||||
|
||||
|
||||
class AIStrategy:
|
||||
"""AI模型策略 — LightGBM / XGBoost Walk-Forward"""
|
||||
|
||||
def __init__(self, model_type: str = 'lightgbm'):
|
||||
"""
|
||||
:param model_type: 'lightgbm' 或 'xgboost'
|
||||
"""
|
||||
self.model_type = model_type
|
||||
self.models = [] # 存储每个窗口训练的模型
|
||||
self.feature_importance = None
|
||||
|
||||
def _create_model(self):
|
||||
"""创建模型实例"""
|
||||
if self.model_type == 'lightgbm':
|
||||
import lightgbm as lgb
|
||||
params = MC['lightgbm'].copy()
|
||||
return lgb.LGBMClassifier(**params)
|
||||
elif self.model_type == 'xgboost':
|
||||
import xgboost as xgb
|
||||
params = MC['xgboost'].copy()
|
||||
return xgb.XGBClassifier(**params)
|
||||
else:
|
||||
raise ValueError(f"不支持的模型类型: {self.model_type}")
|
||||
|
||||
def walk_forward_train(self, X: pd.DataFrame, y: pd.Series,
|
||||
confidence_threshold: float = 0.45) -> pd.Series:
|
||||
"""
|
||||
Walk-Forward 滚动训练与预测
|
||||
:param confidence_threshold: 概率阈值,低于此值的预测设为0(观望)
|
||||
:return: 全部测试窗口拼接的预测信号
|
||||
"""
|
||||
train_size = MC['walk_forward_train_size']
|
||||
test_size = MC['walk_forward_test_size']
|
||||
step = MC['walk_forward_step']
|
||||
|
||||
n = len(X)
|
||||
all_preds = pd.Series(dtype=float)
|
||||
window_count = 0
|
||||
|
||||
logger.info(f"Walk-Forward: 数据量={n}, 训练窗口={train_size}, "
|
||||
f"测试窗口={test_size}, 步长={step}, 置信阈值={confidence_threshold}")
|
||||
|
||||
start = 0
|
||||
while start + train_size + test_size <= n:
|
||||
train_end = start + train_size
|
||||
test_end = min(train_end + test_size, n)
|
||||
|
||||
X_train = X.iloc[start:train_end]
|
||||
y_train = y.iloc[start:train_end]
|
||||
X_test = X.iloc[train_end:test_end]
|
||||
y_test = y.iloc[train_end:test_end]
|
||||
|
||||
# 训练
|
||||
model = self._create_model()
|
||||
model.fit(X_train, y_train)
|
||||
self.models.append(model)
|
||||
|
||||
# 预测概率 + 置信度过滤
|
||||
proba = model.predict_proba(X_test)
|
||||
max_proba = proba.max(axis=1)
|
||||
raw_preds = model.predict(X_test)
|
||||
|
||||
# 置信度不够的设为观望
|
||||
filtered_preds = raw_preds.copy()
|
||||
filtered_preds[max_proba < confidence_threshold] = 0
|
||||
|
||||
preds = pd.Series(filtered_preds, index=X_test.index)
|
||||
all_preds = pd.concat([all_preds, preds])
|
||||
|
||||
# 准确率(用原始预测算)
|
||||
acc = (raw_preds == y_test).mean()
|
||||
n_filtered = (max_proba < confidence_threshold).sum()
|
||||
window_count += 1
|
||||
logger.info(f" 窗口 {window_count}: 训练[{start}:{train_end}] "
|
||||
f"测试[{train_end}:{test_end}] 准确率={acc:.2%} "
|
||||
f"过滤={n_filtered}/{len(X_test)}")
|
||||
|
||||
start += step
|
||||
|
||||
# 特征重要性(取最后一个模型)
|
||||
if self.models:
|
||||
last_model = self.models[-1]
|
||||
if hasattr(last_model, 'feature_importances_'):
|
||||
self.feature_importance = pd.Series(
|
||||
last_model.feature_importances_, index=X.columns
|
||||
).sort_values(ascending=False)
|
||||
|
||||
logger.info(f"Walk-Forward 完成: {window_count} 个窗口, "
|
||||
f"共 {len(all_preds)} 条预测")
|
||||
return all_preds
|
||||
|
||||
def get_top_features(self, n: int = 20) -> pd.Series:
|
||||
"""获取Top N重要特征"""
|
||||
if self.feature_importance is not None:
|
||||
return self.feature_importance.head(n)
|
||||
return pd.Series(dtype=float)
|
||||
|
||||
def run(self, period: int = None, start_date: str = None, end_date: str = None) -> dict:
|
||||
"""
|
||||
完整运行方案B
|
||||
若指定了 start_date/end_date,会向前加载 warm_up_months 月数据用于训练,使回测区间首月即有预测。
|
||||
:return: 回测结果
|
||||
"""
|
||||
if period is None:
|
||||
period = PRIMARY_PERIOD
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"方案B:AI模型策略 ({self.model_type})")
|
||||
logger.info("=" * 60)
|
||||
|
||||
from .data_loader import load_kline
|
||||
|
||||
# 1. 准备数据:若指定了回测区间,则向前加载预热数据,使区间内从首月就有预测
|
||||
load_start, load_end = start_date, end_date
|
||||
if start_date and end_date:
|
||||
warm_months = MC.get('warm_up_months', 12)
|
||||
load_start_ts = pd.Timestamp(start_date) - pd.DateOffset(months=warm_months)
|
||||
load_start = load_start_ts.strftime('%Y-%m-%d')
|
||||
logger.info(f"回测区间 [{start_date} ~ {end_date}],向前加载 {warm_months} 月至 {load_start} 用于训练")
|
||||
|
||||
X, y, feature_names, scaler = prepare_dataset(period, load_start, load_end)
|
||||
|
||||
# 2. Walk-Forward 训练
|
||||
predictions = self.walk_forward_train(X, y)
|
||||
|
||||
# 3. 回测仅用用户指定区间;将预测对齐到该区间的每根K线
|
||||
df = load_kline(period, start_date, end_date)
|
||||
if df.empty:
|
||||
logger.warning("回测区间内无K线数据")
|
||||
return BacktestEngine()._empty_result()
|
||||
|
||||
# 对齐信号:回测区间内有的时间戳用预测,缺失的填 0(观望)
|
||||
signals = predictions.reindex(df.index, fill_value=0).astype(int)
|
||||
prices = df['close']
|
||||
|
||||
# 4. 回测
|
||||
engine = BacktestEngine()
|
||||
result = engine.run(prices, signals)
|
||||
|
||||
print_metrics(result['metrics'], f"方案B: {self.model_type} AI策略")
|
||||
|
||||
# 5. 保存最后一窗模型、scaler、特征列(供实盘 get_live_signal 使用)
|
||||
if self.models and scaler is not None:
|
||||
SCHEME_B_MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
joblib.dump(self.models[-1], SCHEME_B_MODEL_FILE)
|
||||
joblib.dump(scaler, SCHEME_B_SCALER_FILE)
|
||||
with open(SCHEME_B_FEATURES_FILE, 'w', encoding='utf-8') as f:
|
||||
json.dump(feature_names, f, ensure_ascii=False)
|
||||
logger.info(f"已保存方案B模型: {SCHEME_B_MODEL_FILE}, scaler, {len(feature_names)} 个特征")
|
||||
|
||||
# 6. 输出特征重要性
|
||||
top_feat = self.get_top_features(15)
|
||||
if not top_feat.empty:
|
||||
logger.info("\nTop 15 重要特征:")
|
||||
for i, (feat, imp) in enumerate(top_feat.items()):
|
||||
logger.info(f" {i+1}. {feat}: {imp:.4f}")
|
||||
|
||||
result['feature_importance'] = self.feature_importance
|
||||
return result
|
||||
|
||||
|
||||
def run_ai_strategy(model_type: str = 'lightgbm', period: int = None,
|
||||
start_date: str = None, end_date: str = None) -> dict:
|
||||
"""方案B快捷入口"""
|
||||
strategy = AIStrategy(model_type=model_type)
|
||||
return strategy.run(period, start_date, end_date)
|
||||
|
||||
|
||||
def get_live_signal(period: int = None, model_type: str = 'lightgbm',
|
||||
start_date: str = None, end_date: str = None) -> int:
|
||||
"""
|
||||
使用已保存的方案B模型对当前最新K线生成信号(供实盘/模拟盘调用)。
|
||||
需先运行过 run_ai_strategy 或 AIStrategy().run() 以生成 models/scheme_b_*.joblib 与 features.json。
|
||||
:param period: K线主周期,默认 15
|
||||
:param model_type: 未使用(模型已固定为磁盘上的 scheme_b_last_model.joblib)
|
||||
:param start_date, end_date: 可选,限制 load_kline 范围
|
||||
:return: 0=观望, 1=做多, 2=做空
|
||||
"""
|
||||
if period is None:
|
||||
period = PRIMARY_PERIOD
|
||||
if not SCHEME_B_MODEL_FILE.exists() or not SCHEME_B_SCALER_FILE.exists() or not SCHEME_B_FEATURES_FILE.exists():
|
||||
logger.warning("方案B模型未找到,请先运行 AI 策略训练保存模型")
|
||||
return 0
|
||||
model = joblib.load(SCHEME_B_MODEL_FILE)
|
||||
scaler = joblib.load(SCHEME_B_SCALER_FILE)
|
||||
with open(SCHEME_B_FEATURES_FILE, 'r', encoding='utf-8') as f:
|
||||
feature_cols = json.load(f)
|
||||
X_last = get_latest_feature_row(period, feature_cols, start_date, end_date)
|
||||
if X_last.empty:
|
||||
return 0
|
||||
X_scaled = scaler.transform(X_last)
|
||||
pred = model.predict(X_scaled)
|
||||
return int(pred[0])
|
||||
298
strategy/backtest.py
Normal file
298
strategy/backtest.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""
|
||||
回测引擎 — 多空双向、手续费、滑点、绩效统计
|
||||
每笔固定名义 100U、100 倍杠杆;同一时间仅一个仓位;最大回撤 300U 硬约束;手续费 90% 返佣
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from .config import TRADE_CONFIG as TC, SIGNAL_MAP
|
||||
|
||||
|
||||
class BacktestEngine:
|
||||
"""回测引擎:固定每笔 100U 名义、100x,单仓位,最大回撤 300U 内,实付手续费(90% 返佣)"""
|
||||
|
||||
def __init__(self, commission: float = None, slippage: float = None,
|
||||
initial_capital: float = None, position_size: float = None,
|
||||
position_notional_usd: float = None, max_drawdown_limit: float = None,
|
||||
commission_rebate: float = None):
|
||||
raw_commission = commission if commission is not None else TC['commission']
|
||||
rebate = commission_rebate if commission_rebate is not None else TC.get('commission_rebate', 0)
|
||||
self.commission = raw_commission * (1 - rebate) # 实付手续费(90% 返佣)
|
||||
self.slippage = slippage or TC['slippage']
|
||||
self.initial_capital = initial_capital or TC['initial_capital']
|
||||
self.position_size = position_size or TC['position_size']
|
||||
self.position_notional_usd = position_notional_usd if position_notional_usd is not None else TC.get('position_notional_usd', self.initial_capital * self.position_size)
|
||||
self.max_drawdown_limit = max_drawdown_limit if max_drawdown_limit is not None else TC.get('max_drawdown_limit', float('inf'))
|
||||
|
||||
def run(self, prices: pd.Series, signals: pd.Series) -> dict:
|
||||
"""
|
||||
执行回测
|
||||
:param prices: 收盘价 Series
|
||||
:param signals: 信号 Series,值: 0=观望, 1=做多, 2=做空
|
||||
:return: 回测结果字典
|
||||
"""
|
||||
df = pd.DataFrame({'price': prices, 'signal': signals}).dropna()
|
||||
if df.empty:
|
||||
logger.warning("回测数据为空")
|
||||
return self._empty_result()
|
||||
|
||||
n = len(df)
|
||||
capital = self.initial_capital
|
||||
position = 0 # 持仓数量(正=多头,负=空头)
|
||||
entry_price = 0.0
|
||||
direction = 0 # 当前方向: 0=空仓, 1=多, 2=空
|
||||
trades = []
|
||||
equity_curve = np.zeros(n)
|
||||
peak_equity = self.initial_capital # 用于 300U 最大回撤约束
|
||||
|
||||
for i in range(n):
|
||||
price = df.iloc[i]['price']
|
||||
signal = int(df.iloc[i]['signal'])
|
||||
|
||||
# 计算当前权益
|
||||
if position > 0:
|
||||
equity_curve[i] = capital + position * (price - entry_price)
|
||||
elif position < 0:
|
||||
equity_curve[i] = capital + position * (price - entry_price)
|
||||
else:
|
||||
equity_curve[i] = capital
|
||||
current_equity = equity_curve[i]
|
||||
peak_equity = max(peak_equity, current_equity)
|
||||
|
||||
# 最大回撤硬约束:从峰值回落超过 300U 则不再开新仓(只允许平仓)
|
||||
drawdown_usd = peak_equity - current_equity
|
||||
can_open = drawdown_usd < self.max_drawdown_limit
|
||||
|
||||
# 强制止损:持仓浮亏超过初始资金的20%时强制平仓
|
||||
if position != 0:
|
||||
unrealized = position * (price - entry_price)
|
||||
if unrealized < -self.initial_capital * 0.20:
|
||||
capital, trade = self._close_position(
|
||||
capital, position, entry_price, price, df.index[i])
|
||||
trades.append(trade)
|
||||
position = 0
|
||||
direction = 0
|
||||
continue
|
||||
|
||||
# 资金不足时不开新仓;同一时间仅一个仓位(已有持仓则只能先平再开)
|
||||
min_capital = self.initial_capital * 0.05
|
||||
can_trade = capital > min_capital and can_open
|
||||
|
||||
# 跳过:信号与当前持仓方向相同
|
||||
if signal == direction:
|
||||
continue
|
||||
|
||||
# 每笔固定名义 100U:qty = position_notional_usd / price
|
||||
notional = self.position_notional_usd
|
||||
qty = notional / price if price > 0 else 0
|
||||
|
||||
# 需要换方向或平仓
|
||||
if signal == 1 and direction != 1:
|
||||
# 先平仓
|
||||
if position != 0:
|
||||
capital, trade = self._close_position(
|
||||
capital, position, entry_price, price, df.index[i])
|
||||
trades.append(trade)
|
||||
position = 0
|
||||
|
||||
if not can_trade or qty <= 0:
|
||||
direction = 0
|
||||
continue
|
||||
|
||||
# 开多:固定 100U 名义,实付手续费
|
||||
cost = notional * (self.commission + self.slippage)
|
||||
capital -= cost
|
||||
position = qty
|
||||
entry_price = price
|
||||
direction = 1
|
||||
|
||||
elif signal == 2 and direction != 2:
|
||||
# 先平仓
|
||||
if position != 0:
|
||||
capital, trade = self._close_position(
|
||||
capital, position, entry_price, price, df.index[i])
|
||||
trades.append(trade)
|
||||
position = 0
|
||||
|
||||
if not can_trade or qty <= 0:
|
||||
direction = 0
|
||||
continue
|
||||
|
||||
# 开空:固定 100U 名义
|
||||
cost = notional * (self.commission + self.slippage)
|
||||
capital -= cost
|
||||
position = -qty
|
||||
entry_price = price
|
||||
direction = 2
|
||||
|
||||
elif signal == 0 and position != 0:
|
||||
# 平仓
|
||||
capital, trade = self._close_position(
|
||||
capital, position, entry_price, price, df.index[i])
|
||||
trades.append(trade)
|
||||
position = 0
|
||||
direction = 0
|
||||
|
||||
# 最终平仓
|
||||
if position != 0:
|
||||
price = df.iloc[-1]['price']
|
||||
capital, trade = self._close_position(
|
||||
capital, position, entry_price, price, df.index[-1])
|
||||
trades.append(trade)
|
||||
|
||||
equity = pd.Series(equity_curve, index=df.index)
|
||||
trades_df = pd.DataFrame(trades) if trades else pd.DataFrame()
|
||||
metrics = self._calc_metrics(equity, trades_df)
|
||||
monthly_pnl = self._monthly_pnl(equity)
|
||||
|
||||
return {
|
||||
'equity_curve': equity,
|
||||
'trades': trades_df,
|
||||
'metrics': metrics,
|
||||
'final_capital': capital,
|
||||
'monthly_pnl': monthly_pnl,
|
||||
}
|
||||
|
||||
def _close_position(self, capital, position, entry_price, exit_price, time):
|
||||
"""平仓并返回更新后的capital和交易记录"""
|
||||
if position > 0:
|
||||
pnl = position * (exit_price - entry_price)
|
||||
cost = position * exit_price * (self.commission + self.slippage)
|
||||
trade_type = '平多'
|
||||
else:
|
||||
pnl = -position * (entry_price - exit_price)
|
||||
cost = abs(position) * exit_price * (self.commission + self.slippage)
|
||||
trade_type = '平空'
|
||||
|
||||
capital += pnl - cost
|
||||
trade = {
|
||||
'type': trade_type,
|
||||
'entry': entry_price,
|
||||
'exit': exit_price,
|
||||
'pnl': pnl - cost,
|
||||
'return_pct': (exit_price / entry_price - 1) * (1 if position > 0 else -1),
|
||||
'time': time,
|
||||
}
|
||||
return capital, trade
|
||||
|
||||
def _calc_metrics(self, equity: pd.Series, trades: pd.DataFrame) -> dict:
|
||||
"""计算绩效指标"""
|
||||
if equity.empty:
|
||||
return self._empty_metrics()
|
||||
|
||||
total_return = (equity.iloc[-1] / self.initial_capital) - 1
|
||||
|
||||
n_bars = len(equity)
|
||||
if n_bars > 1:
|
||||
# 按日聚合收益率计算夏普
|
||||
daily_equity = equity.resample('D').last().dropna()
|
||||
if len(daily_equity) > 1:
|
||||
daily_returns = daily_equity.pct_change().dropna()
|
||||
n_days = len(daily_returns)
|
||||
annualized_return = (1 + total_return) ** (365 / max(n_days, 1)) - 1 if total_return > -1 else -1.0
|
||||
sharpe = (daily_returns.mean() / daily_returns.std() * np.sqrt(365)
|
||||
if daily_returns.std() > 0 else 0)
|
||||
else:
|
||||
annualized_return = 0
|
||||
sharpe = 0
|
||||
else:
|
||||
annualized_return = 0
|
||||
sharpe = 0
|
||||
|
||||
# 最大回撤(比例与绝对 USDT)
|
||||
cummax = equity.cummax()
|
||||
drawdown = (equity - cummax) / cummax.replace(0, np.nan)
|
||||
max_drawdown = drawdown.min() if not drawdown.empty else 0
|
||||
max_drawdown_usd = (cummax - equity).max() if not equity.empty else 0
|
||||
|
||||
# 按月收益(自然月):用于目标「每月盈利 ≥ 1000U」
|
||||
monthly_pnl_usd = None
|
||||
months_above_1000 = 0
|
||||
if not equity.empty and hasattr(equity.index, 'to_period'):
|
||||
try:
|
||||
monthly = equity.resample('ME').last().dropna()
|
||||
if len(monthly) > 0:
|
||||
monthly_pnl = monthly.diff()
|
||||
monthly_pnl.iloc[0] = monthly.iloc[0] - self.initial_capital
|
||||
monthly_pnl_usd = monthly_pnl
|
||||
months_above_1000 = (monthly_pnl >= 1000).sum()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 交易统计
|
||||
n_trades = len(trades)
|
||||
if n_trades > 0:
|
||||
wins = trades[trades['pnl'] > 0]
|
||||
losses = trades[trades['pnl'] <= 0]
|
||||
win_rate = len(wins) / n_trades
|
||||
avg_win = wins['pnl'].mean() if len(wins) > 0 else 0
|
||||
avg_loss = abs(losses['pnl'].mean()) if len(losses) > 0 else 0
|
||||
profit_factor = (wins['pnl'].sum() / abs(losses['pnl'].sum())
|
||||
if len(losses) > 0 and losses['pnl'].sum() != 0 else float('inf'))
|
||||
else:
|
||||
win_rate = 0
|
||||
avg_win = 0
|
||||
avg_loss = 0
|
||||
profit_factor = 0
|
||||
|
||||
out = {
|
||||
'总收益率': f'{total_return:.2%}',
|
||||
'年化收益率': f'{annualized_return:.2%}',
|
||||
'最大回撤': f'{max_drawdown:.2%}',
|
||||
'最大回撤(U)': f'{max_drawdown_usd:.2f}',
|
||||
'夏普比率': f'{sharpe:.2f}',
|
||||
'总交易次数': n_trades,
|
||||
'胜率': f'{win_rate:.2%}',
|
||||
'平均盈利': f'{avg_win:.2f}',
|
||||
'平均亏损': f'{avg_loss:.2f}',
|
||||
'盈亏比': f'{profit_factor:.2f}',
|
||||
'最终资金': f'{equity.iloc[-1]:.2f}',
|
||||
}
|
||||
out['月盈利≥1000U的月数'] = months_above_1000 if monthly_pnl_usd is not None else 0
|
||||
if monthly_pnl_usd is not None and len(monthly_pnl_usd) > 0:
|
||||
out['月均盈利(U)'] = f'{monthly_pnl_usd.mean():.2f}'
|
||||
else:
|
||||
out['月均盈利(U)'] = '0.00'
|
||||
return out
|
||||
|
||||
def _monthly_pnl(self, equity: pd.Series):
|
||||
"""按自然月汇总收益(USDT),首月为当月权益 - 初始资金"""
|
||||
if equity.empty or not hasattr(equity.index, 'to_period'):
|
||||
return None
|
||||
try:
|
||||
monthly = equity.resample('ME').last().dropna()
|
||||
if len(monthly) == 0:
|
||||
return None
|
||||
pnl = monthly.diff()
|
||||
pnl.iloc[0] = monthly.iloc[0] - self.initial_capital
|
||||
return pnl
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _empty_result(self):
|
||||
return {
|
||||
'equity_curve': pd.Series(dtype=float),
|
||||
'trades': pd.DataFrame(),
|
||||
'metrics': self._empty_metrics(),
|
||||
'final_capital': self.initial_capital,
|
||||
'monthly_pnl': None,
|
||||
}
|
||||
|
||||
def _empty_metrics(self):
|
||||
return {
|
||||
'总收益率': '0.00%', '年化收益率': '0.00%', '最大回撤': '0.00%',
|
||||
'最大回撤(U)': '0.00', '夏普比率': '0.00', '总交易次数': 0, '胜率': '0.00%',
|
||||
'平均盈利': '0.00', '平均亏损': '0.00', '盈亏比': '0.00',
|
||||
'最终资金': f'{self.initial_capital:.2f}', '月盈利≥1000U的月数': 0,
|
||||
'月均盈利(U)': '0.00',
|
||||
}
|
||||
|
||||
|
||||
def print_metrics(metrics: dict, title: str = "回测结果"):
|
||||
"""打印绩效指标"""
|
||||
logger.info(f"\n{'='*50}")
|
||||
logger.info(f" {title}")
|
||||
logger.info(f"{'='*50}")
|
||||
for k, v in metrics.items():
|
||||
logger.info(f" {k:>12}: {v}")
|
||||
logger.info(f"{'='*50}")
|
||||
222
strategy/compare.py
Normal file
222
strategy/compare.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
方案对比评估 — 方案A(统计筛选) vs 方案B(AI模型) 并排对比 + 报告输出
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # 非交互式后端
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib import font_manager
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
# 设置中文字体
|
||||
_zh_font = None
|
||||
for fname in ['PingFang SC', 'Heiti SC', 'STHeiti', 'SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei']:
|
||||
try:
|
||||
_zh_font = font_manager.FontProperties(family=fname)
|
||||
# 验证字体存在
|
||||
font_manager.findfont(_zh_font, fallback_to_default=False)
|
||||
plt.rcParams['font.sans-serif'] = [fname, 'DejaVu Sans']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
break
|
||||
except Exception:
|
||||
_zh_font = None
|
||||
continue
|
||||
|
||||
from .config import PRIMARY_PERIOD, PROJECT_ROOT
|
||||
from .stat_strategy import StatStrategy
|
||||
from .ai_strategy import AIStrategy
|
||||
from .backtest import print_metrics
|
||||
|
||||
|
||||
REPORT_DIR = PROJECT_ROOT / 'reports'
|
||||
|
||||
|
||||
def compare_strategies(period: int = None, start_date: str = None, end_date: str = None,
|
||||
save_plot: bool = True) -> dict:
|
||||
"""
|
||||
运行两种方案并对比
|
||||
:return: {'stat': result_a, 'lgb': result_b, 'xgb': result_c, 'comparison': DataFrame}
|
||||
"""
|
||||
if period is None:
|
||||
period = PRIMARY_PERIOD
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info(" 开始策略对比评估")
|
||||
logger.info("=" * 70)
|
||||
|
||||
results = {}
|
||||
|
||||
# 方案A:统计筛选
|
||||
logger.info("\n>>> 运行方案A: 统计筛选策略")
|
||||
stat = StatStrategy()
|
||||
results['stat'] = stat.run(period, start_date, end_date)
|
||||
|
||||
# 方案B-1:LightGBM
|
||||
logger.info("\n>>> 运行方案B-1: LightGBM AI策略")
|
||||
lgb_strategy = AIStrategy(model_type='lightgbm')
|
||||
results['lgb'] = lgb_strategy.run(period, start_date, end_date)
|
||||
|
||||
# 方案B-2:XGBoost
|
||||
logger.info("\n>>> 运行方案B-2: XGBoost AI策略")
|
||||
xgb_strategy = AIStrategy(model_type='xgboost')
|
||||
results['xgb'] = xgb_strategy.run(period, start_date, end_date)
|
||||
|
||||
# 对比表格
|
||||
comparison = _build_comparison_table(results)
|
||||
results['comparison'] = comparison
|
||||
|
||||
# 打印对比
|
||||
logger.info("\n" + "=" * 70)
|
||||
logger.info(" 策略对比总结")
|
||||
logger.info("=" * 70)
|
||||
logger.info(f"\n{comparison.to_string()}")
|
||||
|
||||
# 每月盈利(U)
|
||||
_log_monthly_pnl(results)
|
||||
|
||||
# 保存图表
|
||||
if save_plot:
|
||||
_save_equity_plot(results)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _build_comparison_table(results: dict) -> pd.DataFrame:
|
||||
"""构建对比表格"""
|
||||
rows = {}
|
||||
name_map = {
|
||||
'stat': '方案A: 统计筛选',
|
||||
'lgb': '方案B-1: LightGBM',
|
||||
'xgb': '方案B-2: XGBoost',
|
||||
}
|
||||
|
||||
for key, name in name_map.items():
|
||||
if key in results and 'metrics' in results[key]:
|
||||
rows[name] = results[key]['metrics']
|
||||
|
||||
if not rows:
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame(rows).T
|
||||
return df
|
||||
|
||||
|
||||
def _log_monthly_pnl(results: dict):
|
||||
"""打印各策略每月盈利(USDT)"""
|
||||
name_map = {
|
||||
'stat': '方案A',
|
||||
'lgb': 'LightGBM',
|
||||
'xgb': 'XGBoost',
|
||||
}
|
||||
cols = []
|
||||
for key, name in name_map.items():
|
||||
if key not in results or results[key].get('monthly_pnl') is None:
|
||||
continue
|
||||
s = results[key]['monthly_pnl']
|
||||
if s is None or s.empty:
|
||||
continue
|
||||
s = s.astype(float).round(2)
|
||||
s.name = name
|
||||
cols.append(s)
|
||||
if not cols:
|
||||
return
|
||||
monthly_df = pd.concat(cols, axis=1)
|
||||
monthly_df = monthly_df.fillna(0)
|
||||
logger.info("\n" + "-" * 70)
|
||||
logger.info(" 每月盈利 (USDT)")
|
||||
logger.info("-" * 70)
|
||||
logger.info(f"\n{monthly_df.to_string()}")
|
||||
logger.info("-" * 70)
|
||||
|
||||
|
||||
def _save_equity_plot(results: dict):
|
||||
"""保存权益曲线对比图"""
|
||||
REPORT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
fig, axes = plt.subplots(2, 1, figsize=(14, 10))
|
||||
|
||||
# 上图:权益曲线
|
||||
ax1 = axes[0]
|
||||
name_map = {
|
||||
'stat': '方案A: 统计筛选',
|
||||
'lgb': '方案B-1: LightGBM',
|
||||
'xgb': '方案B-2: XGBoost',
|
||||
}
|
||||
colors = {'stat': '#2196F3', 'lgb': '#4CAF50', 'xgb': '#FF9800'}
|
||||
|
||||
for key, name in name_map.items():
|
||||
if key in results and 'equity_curve' in results[key]:
|
||||
eq = results[key]['equity_curve']
|
||||
if not eq.empty:
|
||||
ax1.plot(eq.index, eq.values, label=name, color=colors.get(key, 'gray'), linewidth=1)
|
||||
|
||||
ax1.set_title('权益曲线对比', fontsize=14)
|
||||
ax1.set_ylabel('资金 (USDT)')
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
|
||||
# 下图:回撤曲线
|
||||
ax2 = axes[1]
|
||||
for key, name in name_map.items():
|
||||
if key in results and 'equity_curve' in results[key]:
|
||||
eq = results[key]['equity_curve']
|
||||
if not eq.empty:
|
||||
cummax = eq.cummax()
|
||||
drawdown = (eq - cummax) / cummax * 100
|
||||
ax2.fill_between(drawdown.index, drawdown.values, 0,
|
||||
alpha=0.3, label=name, color=colors.get(key, 'gray'))
|
||||
|
||||
ax2.set_title('回撤对比', fontsize=14)
|
||||
ax2.set_ylabel('回撤 (%)')
|
||||
ax2.legend()
|
||||
ax2.grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plot_path = REPORT_DIR / 'strategy_comparison.png'
|
||||
plt.savefig(str(plot_path), dpi=150)
|
||||
plt.close()
|
||||
logger.info(f"对比图表已保存: {plot_path}")
|
||||
|
||||
|
||||
def run_full_comparison(period: int = None, start_date: str = None, end_date: str = None, save_plot: bool = True):
|
||||
"""完整对比入口(可直接调用)"""
|
||||
results = compare_strategies(period, start_date, end_date, save_plot=save_plot)
|
||||
|
||||
# 推荐最优方案
|
||||
best_key = None
|
||||
best_return = -float('inf')
|
||||
|
||||
for key in ['stat', 'lgb', 'xgb']:
|
||||
if key in results and 'metrics' in results[key]:
|
||||
ret_str = results[key]['metrics'].get('总收益率', '0%')
|
||||
ret_val = float(ret_str.strip('%')) / 100
|
||||
if ret_val > best_return:
|
||||
best_return = ret_val
|
||||
best_key = key
|
||||
|
||||
name_map = {'stat': '方案A: 统计筛选', 'lgb': '方案B-1: LightGBM', 'xgb': '方案B-2: XGBoost'}
|
||||
if best_key:
|
||||
logger.info(f"\n{'='*50}")
|
||||
logger.info(f" 推荐方案: {name_map.get(best_key, best_key)}")
|
||||
logger.info(f" 总收益率: {best_return:.2%}")
|
||||
logger.info(f"{'='*50}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
p = argparse.ArgumentParser(description='运行策略对比:方案A(统计) vs 方案B(LightGBM/XGBoost)')
|
||||
p.add_argument('--period', type=int, default=None, help='K线周期分钟')
|
||||
p.add_argument('--start', type=str, default=None, help='开始日期,如 2024-01-01')
|
||||
p.add_argument('--end', type=str, default=None, help='结束日期,如 2024-12-31')
|
||||
p.add_argument('--no-plot', action='store_true', help='不保存权益/回撤图')
|
||||
args = p.parse_args()
|
||||
run_full_comparison(
|
||||
period=args.period,
|
||||
start_date=args.start,
|
||||
end_date=args.end,
|
||||
save_plot=not args.no_plot,
|
||||
)
|
||||
151
strategy/config.py
Normal file
151
strategy/config.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
全局配置 — 交易参数、指标参数、模型参数
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
# ============ 路径 ============
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
DB_PATH = PROJECT_ROOT / 'models' / 'database.db'
|
||||
|
||||
# ============ 交易参数 ============
|
||||
TRADE_CONFIG = {
|
||||
'symbol': 'ETHUSDT',
|
||||
'commission': 0.0006, # 名义手续费 0.06%
|
||||
'commission_rebate': 0.90, # 90% 返佣(次日 8 点结算),实付 = commission * (1 - rebate)
|
||||
'slippage': 0.0001, # 滑点 0.01%
|
||||
'initial_capital': 10000, # 本金 USDT
|
||||
'leverage': 100, # 杠杆倍数
|
||||
'position_notional_usd': 500, # 每笔名义 500U(开 100 倍),目标月均收益约 500U
|
||||
'max_drawdown_limit': 300, # 最大回撤硬约束:权益从峰值回落超过 300U 则不再开新仓
|
||||
# 兼容旧逻辑(若用比例算仓则用此项)
|
||||
'position_size': 0.95,
|
||||
}
|
||||
|
||||
# ============ K线周期 ============
|
||||
KLINE_PERIODS = {
|
||||
1: '1m',
|
||||
3: '3m',
|
||||
5: '5m',
|
||||
15: '15m',
|
||||
30: '30m',
|
||||
60: '1h',
|
||||
}
|
||||
|
||||
# 主周期(用于生成信号)
|
||||
PRIMARY_PERIOD = 15 # 15分钟
|
||||
# 辅助周期(用于多周期融合特征)
|
||||
AUX_PERIODS = [5, 60]
|
||||
|
||||
# ============ 指标参数 ============
|
||||
INDICATOR_PARAMS = {
|
||||
# 趋势类
|
||||
'sma_windows': [5, 10, 20, 50, 200],
|
||||
'ema_windows': [12, 26],
|
||||
'macd_fast': 12,
|
||||
'macd_slow': 26,
|
||||
'macd_signal': 9,
|
||||
'adx_window': 14,
|
||||
'ichimoku_conversion': 9,
|
||||
'ichimoku_base': 26,
|
||||
'ichimoku_span_b': 52,
|
||||
'trix_window': 15,
|
||||
'aroon_window': 25,
|
||||
'cci_window': 20,
|
||||
'dpo_window': 20,
|
||||
'kst_roc1': 10,
|
||||
'kst_roc2': 15,
|
||||
'kst_roc3': 20,
|
||||
'kst_roc4': 30,
|
||||
'vortex_window': 14,
|
||||
|
||||
# 动量类
|
||||
'rsi_window': 14,
|
||||
'stoch_window': 14,
|
||||
'stoch_smooth': 3,
|
||||
'williams_window': 14,
|
||||
'roc_window': 12,
|
||||
'mfi_window': 14,
|
||||
'tsi_slow': 25,
|
||||
'tsi_fast': 13,
|
||||
'uo_short': 7,
|
||||
'uo_medium': 14,
|
||||
'uo_long': 28,
|
||||
'ao_short': 5,
|
||||
'ao_long': 34,
|
||||
'kama_window': 10,
|
||||
'ppo_slow': 26,
|
||||
'ppo_fast': 12,
|
||||
'stoch_rsi_window': 14,
|
||||
'stoch_rsi_smooth': 3,
|
||||
|
||||
# 波动率类
|
||||
'bb_window': 20,
|
||||
'bb_std': 2,
|
||||
'atr_window': 14,
|
||||
'kc_window': 20,
|
||||
'dc_window': 20,
|
||||
|
||||
# 成交量类(部分指标需要volume,K线数据可能无volume则跳过)
|
||||
'obv_enabled': True,
|
||||
'cmf_window': 20,
|
||||
'emv_window': 14,
|
||||
'fi_window': 13,
|
||||
}
|
||||
|
||||
# ============ 特征工程参数 ============
|
||||
FEATURE_CONFIG = {
|
||||
'label_forward_periods': 10, # 未来N根K线用于生成标签
|
||||
'label_threshold': 0.002, # 涨跌阈值(0.2%以内算震荡)
|
||||
'lookback_lags': [1, 3, 5], # 滞后特征的lag值
|
||||
'normalize': True, # 是否标准化
|
||||
}
|
||||
|
||||
# ============ 模型参数 ============
|
||||
MODEL_CONFIG = {
|
||||
'walk_forward_train_size': 20000, # Walk-Forward 训练窗口大小
|
||||
'walk_forward_test_size': 2000, # Walk-Forward 测试窗口大小
|
||||
'walk_forward_step': 2000, # 滚动步长
|
||||
'warm_up_months': 12, # 指定回测区间时向前加载的月数,使区间首月即有预测
|
||||
|
||||
'lightgbm': {
|
||||
'n_estimators': 300,
|
||||
'max_depth': 4,
|
||||
'learning_rate': 0.03,
|
||||
'num_leaves': 15,
|
||||
'min_child_samples': 50,
|
||||
'subsample': 0.7,
|
||||
'colsample_bytree': 0.6,
|
||||
'reg_alpha': 1.0,
|
||||
'reg_lambda': 1.0,
|
||||
'objective': 'multiclass',
|
||||
'num_class': 3,
|
||||
'verbose': -1,
|
||||
},
|
||||
|
||||
'xgboost': {
|
||||
'n_estimators': 300,
|
||||
'max_depth': 4,
|
||||
'learning_rate': 0.03,
|
||||
'subsample': 0.7,
|
||||
'colsample_bytree': 0.6,
|
||||
'reg_alpha': 1.0,
|
||||
'reg_lambda': 1.0,
|
||||
'objective': 'multi:softprob',
|
||||
'num_class': 3,
|
||||
'verbosity': 0,
|
||||
},
|
||||
}
|
||||
|
||||
# ============ 统计筛选参数 ============
|
||||
STAT_CONFIG = {
|
||||
'top_n_features': 15, # 筛选Top N个指标
|
||||
'correlation_threshold': 0.9, # 去除高相关特征的阈值
|
||||
'grid_search_cv': 3, # 网格搜索交叉验证折数
|
||||
}
|
||||
|
||||
# ============ 信号标签映射 ============
|
||||
SIGNAL_MAP = {
|
||||
0: '观望',
|
||||
1: '做多',
|
||||
2: '做空',
|
||||
}
|
||||
119
strategy/data_loader.py
Normal file
119
strategy/data_loader.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
数据加载器 — 从SQLite加载K线数据为pandas DataFrame
|
||||
"""
|
||||
import sqlite3
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
from .config import DB_PATH, KLINE_PERIODS
|
||||
|
||||
|
||||
def load_kline(period: int = 15, start_date: str = None, end_date: str = None) -> pd.DataFrame:
|
||||
"""
|
||||
加载指定周期的K线数据
|
||||
:param period: K线周期(分钟),如 1, 3, 5, 15, 30, 60
|
||||
:param start_date: 起始日期 'YYYY-MM-DD'(可选)
|
||||
:param end_date: 结束日期 'YYYY-MM-DD'(可选)
|
||||
:return: DataFrame,列: timestamp, open, high, low, close
|
||||
"""
|
||||
suffix = KLINE_PERIODS.get(period)
|
||||
if suffix is None:
|
||||
raise ValueError(f"不支持的周期: {period},可选: {list(KLINE_PERIODS.keys())}")
|
||||
|
||||
table_name = f'bitmart_eth_{suffix}'
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
|
||||
query = f"SELECT id as timestamp, open, high, low, close FROM {table_name} ORDER BY id"
|
||||
df = pd.read_sql_query(query, conn)
|
||||
conn.close()
|
||||
|
||||
if df.empty:
|
||||
logger.warning(f"[{suffix}] 表中无数据")
|
||||
return df
|
||||
|
||||
# id 是毫秒时间戳,转为 datetime 索引
|
||||
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||||
df.set_index('datetime', inplace=True)
|
||||
|
||||
# 按日期过滤
|
||||
if start_date:
|
||||
df = df[df.index >= start_date]
|
||||
if end_date:
|
||||
df = df[df.index <= end_date]
|
||||
|
||||
logger.info(f"[{suffix}] 加载 {len(df)} 条K线 | {df.index[0]} ~ {df.index[-1]}")
|
||||
return df
|
||||
|
||||
|
||||
def load_multi_period(periods: list = None, start_date: str = None, end_date: str = None) -> dict:
|
||||
"""
|
||||
加载多个周期的K线数据
|
||||
:param periods: 周期列表,如 [5, 15, 60],默认全部
|
||||
:param start_date: 起始日期
|
||||
:param end_date: 结束日期
|
||||
:return: {period: DataFrame} 字典
|
||||
"""
|
||||
if periods is None:
|
||||
periods = list(KLINE_PERIODS.keys())
|
||||
|
||||
result = {}
|
||||
for p in periods:
|
||||
try:
|
||||
df = load_kline(p, start_date, end_date)
|
||||
if not df.empty:
|
||||
result[p] = df
|
||||
except Exception as e:
|
||||
logger.error(f"加载 {p}分钟 K线失败: {e}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def load_trades(start_date: str = None, end_date: str = None, limit: int = None) -> pd.DataFrame:
|
||||
"""
|
||||
加载原始成交记录
|
||||
:return: DataFrame,列: id, timestamp, price, volume, side
|
||||
"""
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
query = "SELECT id, timestamp, price, volume, side FROM bitmart_eth_trades ORDER BY timestamp"
|
||||
df = pd.read_sql_query(query, conn)
|
||||
conn.close()
|
||||
|
||||
if df.empty:
|
||||
logger.warning("成交记录表中无数据")
|
||||
return df
|
||||
|
||||
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||||
df.set_index('datetime', inplace=True)
|
||||
|
||||
if start_date:
|
||||
df = df[df.index >= start_date]
|
||||
if end_date:
|
||||
df = df[df.index <= end_date]
|
||||
if limit:
|
||||
df = df.head(limit)
|
||||
|
||||
logger.info(f"加载 {len(df)} 条成交记录")
|
||||
return df
|
||||
|
||||
|
||||
def get_available_tables() -> list:
|
||||
"""列出数据库中所有可用的表"""
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
conn.close()
|
||||
return tables
|
||||
|
||||
|
||||
def get_table_stats() -> dict:
|
||||
"""获取各表的数据统计"""
|
||||
conn = sqlite3.connect(str(DB_PATH))
|
||||
tables = get_available_tables()
|
||||
stats = {}
|
||||
for table in tables:
|
||||
try:
|
||||
count = pd.read_sql_query(f"SELECT COUNT(*) as cnt FROM {table}", conn).iloc[0]['cnt']
|
||||
stats[table] = count
|
||||
except Exception:
|
||||
stats[table] = 0
|
||||
conn.close()
|
||||
return stats
|
||||
204
strategy/feature_engine.py
Normal file
204
strategy/feature_engine.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
特征工程 — 标准化、多周期融合、滞后特征、标签生成
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
from loguru import logger
|
||||
|
||||
from .config import FEATURE_CONFIG as FC, PRIMARY_PERIOD, AUX_PERIODS
|
||||
from .data_loader import load_kline
|
||||
from .indicators import compute_all_indicators, get_indicator_names
|
||||
|
||||
|
||||
def build_features(primary_df: pd.DataFrame,
|
||||
aux_dfs: dict = None,
|
||||
has_volume: bool = False) -> pd.DataFrame:
|
||||
"""
|
||||
构建完整特征矩阵
|
||||
:param primary_df: 主周期K线 DataFrame
|
||||
:param aux_dfs: {period: DataFrame} 辅助周期数据(可选)
|
||||
:param has_volume: 是否有成交量
|
||||
:return: 特征矩阵 DataFrame
|
||||
"""
|
||||
# 1. 主周期指标
|
||||
logger.info("计算主周期指标...")
|
||||
df = compute_all_indicators(primary_df, has_volume=has_volume)
|
||||
|
||||
# 2. 滞后特征
|
||||
logger.info("生成滞后特征...")
|
||||
indicator_cols = get_indicator_names(has_volume)
|
||||
existing_cols = [col for col in indicator_cols if col in df.columns]
|
||||
|
||||
lag_frames = []
|
||||
for lag in FC['lookback_lags']:
|
||||
lagged = df[existing_cols].shift(lag).add_suffix(f'_lag{lag}')
|
||||
lag_frames.append(lagged)
|
||||
if lag_frames:
|
||||
df = pd.concat([df] + lag_frames, axis=1)
|
||||
|
||||
# 3. 多周期融合
|
||||
if aux_dfs:
|
||||
aux_frames = []
|
||||
for period, aux_df in aux_dfs.items():
|
||||
logger.info(f"融合 {period}分钟 辅助周期特征...")
|
||||
aux_with_ind = compute_all_indicators(aux_df, has_volume=has_volume)
|
||||
key_indicators = ['rsi', 'macd', 'adx', 'bb_pband', 'atr', 'cci']
|
||||
for ind in key_indicators:
|
||||
if ind in aux_with_ind.columns:
|
||||
aligned = aux_with_ind[ind].reindex(df.index, method='ffill')
|
||||
aux_frames.append(aligned.rename(f'{ind}_{period}m'))
|
||||
if aux_frames:
|
||||
df = pd.concat([df] + aux_frames, axis=1)
|
||||
|
||||
# 4. 去除全NaN列
|
||||
before_cols = len(df.columns)
|
||||
df.dropna(axis=1, how='all', inplace=True)
|
||||
after_cols = len(df.columns)
|
||||
if before_cols != after_cols:
|
||||
logger.info(f"移除 {before_cols - after_cols} 个全NaN列")
|
||||
|
||||
logger.info(f"特征矩阵: {df.shape[0]} 行 x {df.shape[1]} 列")
|
||||
return df
|
||||
|
||||
|
||||
def generate_labels(df: pd.DataFrame, forward_periods: int = None,
|
||||
threshold: float = None) -> pd.Series:
|
||||
"""
|
||||
生成交易标签
|
||||
:param df: 包含 close 列的 DataFrame
|
||||
:param forward_periods: 未来N根K线
|
||||
:param threshold: 涨跌阈值
|
||||
:return: Series,值为 0=观望, 1=做多, 2=做空
|
||||
"""
|
||||
if forward_periods is None:
|
||||
forward_periods = FC['label_forward_periods']
|
||||
if threshold is None:
|
||||
threshold = FC['label_threshold']
|
||||
|
||||
future_return = df['close'].shift(-forward_periods) / df['close'] - 1
|
||||
|
||||
labels = pd.Series(0, index=df.index, name='label') # 默认观望
|
||||
labels[future_return > threshold] = 1 # 做多
|
||||
labels[future_return < -threshold] = 2 # 做空
|
||||
|
||||
# 最后 forward_periods 行无法计算标签,设为NaN
|
||||
labels.iloc[-forward_periods:] = np.nan
|
||||
|
||||
dist = labels.value_counts().to_dict()
|
||||
logger.info(f"标签分布: 观望={dist.get(0, 0)}, 做多={dist.get(1, 0)}, 做空={dist.get(2, 0)}")
|
||||
return labels
|
||||
|
||||
|
||||
def prepare_dataset(period: int = None, start_date: str = None, end_date: str = None,
|
||||
has_volume: bool = False) -> tuple:
|
||||
"""
|
||||
一键准备训练数据集
|
||||
:return: (X, y, feature_names) — 已去NaN、已标准化
|
||||
"""
|
||||
if period is None:
|
||||
period = PRIMARY_PERIOD
|
||||
|
||||
# 加载主周期
|
||||
primary_df = load_kline(period, start_date, end_date)
|
||||
if primary_df.empty:
|
||||
raise ValueError(f"{period}分钟 K线数据为空")
|
||||
|
||||
# 加载辅助周期
|
||||
aux_dfs = {}
|
||||
for aux_p in AUX_PERIODS:
|
||||
try:
|
||||
aux_df = load_kline(aux_p, start_date, end_date)
|
||||
if not aux_df.empty:
|
||||
aux_dfs[aux_p] = aux_df
|
||||
except Exception as e:
|
||||
logger.warning(f"加载 {aux_p}分钟 辅助数据失败: {e}")
|
||||
|
||||
# 构建特征
|
||||
df = build_features(primary_df, aux_dfs, has_volume=has_volume)
|
||||
|
||||
# 生成标签
|
||||
labels = generate_labels(df)
|
||||
df = df.copy()
|
||||
df['label'] = labels
|
||||
|
||||
# 去除NaN行
|
||||
df.dropna(inplace=True)
|
||||
logger.info(f"去NaN后剩余 {len(df)} 行")
|
||||
|
||||
# 分离 X, y
|
||||
exclude_cols = ['open', 'high', 'low', 'close', 'timestamp', 'label']
|
||||
if 'volume' in df.columns:
|
||||
exclude_cols.append('volume')
|
||||
|
||||
# 排除价格级别特征(会泄露绝对价格信息导致过拟合)
|
||||
price_level_patterns = [
|
||||
'sma_', 'ema_', 'bb_upper', 'bb_mid', 'bb_lower',
|
||||
'kc_upper', 'kc_mid', 'kc_lower', 'dc_upper', 'dc_mid', 'dc_lower',
|
||||
'ichimoku_conv', 'ichimoku_base', 'ichimoku_a', 'ichimoku_b',
|
||||
'kama', 'vwap', 'obv', 'adi', 'vpt', 'nvi',
|
||||
'momentum_3', 'momentum_5',
|
||||
]
|
||||
feature_cols = []
|
||||
for c in df.columns:
|
||||
if c in exclude_cols:
|
||||
continue
|
||||
base_name = c.split('_lag')[0]
|
||||
# 去掉周期后缀
|
||||
for suffix in ['_5m', '_60m', '_3m', '_15m', '_30m', '_1m']:
|
||||
if base_name.endswith(suffix):
|
||||
base_name = base_name[:-len(suffix)]
|
||||
break
|
||||
if any(base_name.startswith(p) or base_name == p.rstrip('_') for p in price_level_patterns):
|
||||
continue
|
||||
feature_cols.append(c)
|
||||
|
||||
logger.info(f"排除价格级别特征后剩余 {len(feature_cols)} 个特征")
|
||||
|
||||
X = df[feature_cols].copy()
|
||||
y = df['label'].astype(int)
|
||||
|
||||
# 标准化
|
||||
scaler = None
|
||||
if FC['normalize']:
|
||||
scaler = StandardScaler()
|
||||
X = pd.DataFrame(scaler.fit_transform(X), columns=feature_cols, index=X.index)
|
||||
logger.info("特征已标准化")
|
||||
|
||||
logger.info(f"最终数据集: X={X.shape}, y={y.shape}, 特征数={len(feature_cols)}")
|
||||
return X, y, feature_cols, scaler
|
||||
|
||||
|
||||
def get_latest_feature_row(period: int = None, feature_cols: list = None,
|
||||
start_date: str = None, end_date: str = None) -> pd.DataFrame:
|
||||
"""
|
||||
构建特征并返回最后一行的特征矩阵(用于实盘/模拟盘预测)。
|
||||
需传入已保存的 feature_cols,保证与训练时一致。
|
||||
:return: 1 行 DataFrame,列为 feature_cols;若缺列则返回空 DataFrame
|
||||
"""
|
||||
if period is None:
|
||||
period = PRIMARY_PERIOD
|
||||
if not feature_cols:
|
||||
return pd.DataFrame()
|
||||
|
||||
primary_df = load_kline(period, start_date, end_date)
|
||||
if primary_df.empty:
|
||||
return pd.DataFrame()
|
||||
aux_dfs = {}
|
||||
for aux_p in AUX_PERIODS:
|
||||
try:
|
||||
aux_df = load_kline(aux_p, start_date, end_date)
|
||||
if not aux_df.empty:
|
||||
aux_dfs[aux_p] = aux_df
|
||||
except Exception:
|
||||
pass
|
||||
df = build_features(primary_df, aux_dfs, has_volume=False)
|
||||
labels = generate_labels(df)
|
||||
df = df.copy()
|
||||
df['label'] = labels
|
||||
df.dropna(inplace=True)
|
||||
missing = [c for c in feature_cols if c not in df.columns]
|
||||
if missing:
|
||||
logger.warning(f"get_latest_feature_row: 缺少特征列 {missing[:5]}{'...' if len(missing) > 5 else ''}")
|
||||
return pd.DataFrame()
|
||||
return df[feature_cols].iloc[-1:].copy()
|
||||
278
strategy/indicators.py
Normal file
278
strategy/indicators.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""
|
||||
52个技术指标计算引擎 — 基于 ta 库
|
||||
覆盖趋势、动量、波动率、成交量、自定义衍生特征五大类
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import ta
|
||||
from .config import INDICATOR_PARAMS as P
|
||||
|
||||
|
||||
def compute_all_indicators(df: pd.DataFrame, has_volume: bool = False) -> pd.DataFrame:
|
||||
"""
|
||||
计算全部52个技术指标,返回拼接后的DataFrame
|
||||
:param df: 必须包含 open, high, low, close 列;可选 volume 列
|
||||
:param has_volume: 是否有成交量数据
|
||||
:return: 原始列 + 52个指标列
|
||||
"""
|
||||
out = df.copy()
|
||||
o, h, l, c = out['open'], out['high'], out['low'], out['close']
|
||||
v = out['volume'] if has_volume and 'volume' in out.columns else None
|
||||
|
||||
# ========== 趋势类 (14) ==========
|
||||
out = _add_trend(out, o, h, l, c)
|
||||
|
||||
# ========== 动量类 (12) ==========
|
||||
out = _add_momentum(out, h, l, c, v)
|
||||
|
||||
# ========== 波动率类 (8) ==========
|
||||
out = _add_volatility(out, h, l, c)
|
||||
|
||||
# ========== 成交量类 (8) ==========
|
||||
if has_volume and v is not None:
|
||||
out = _add_volume(out, h, l, c, v)
|
||||
|
||||
# ========== 自定义衍生特征 (10) ==========
|
||||
out = _add_custom(out, o, h, l, c)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _add_trend(out, o, h, l, c):
|
||||
"""趋势类指标 (14个特征)"""
|
||||
# SMA (5个)
|
||||
for w in P['sma_windows']:
|
||||
out[f'sma_{w}'] = ta.trend.sma_indicator(c, window=w)
|
||||
|
||||
# EMA (2个)
|
||||
for w in P['ema_windows']:
|
||||
out[f'ema_{w}'] = ta.trend.ema_indicator(c, window=w)
|
||||
|
||||
# MACD (3个)
|
||||
macd = ta.trend.MACD(c, window_slow=P['macd_slow'], window_fast=P['macd_fast'],
|
||||
window_sign=P['macd_signal'])
|
||||
out['macd'] = macd.macd()
|
||||
out['macd_signal'] = macd.macd_signal()
|
||||
out['macd_hist'] = macd.macd_diff()
|
||||
|
||||
# ADX + DI (3个)
|
||||
adx = ta.trend.ADXIndicator(h, l, c, window=P['adx_window'])
|
||||
out['adx'] = adx.adx()
|
||||
out['di_plus'] = adx.adx_pos()
|
||||
out['di_minus'] = adx.adx_neg()
|
||||
|
||||
# Ichimoku (4个)
|
||||
ichi = ta.trend.IchimokuIndicator(h, l,
|
||||
window1=P['ichimoku_conversion'],
|
||||
window2=P['ichimoku_base'],
|
||||
window3=P['ichimoku_span_b'])
|
||||
out['ichimoku_conv'] = ichi.ichimoku_conversion_line()
|
||||
out['ichimoku_base'] = ichi.ichimoku_base_line()
|
||||
out['ichimoku_a'] = ichi.ichimoku_a()
|
||||
out['ichimoku_b'] = ichi.ichimoku_b()
|
||||
|
||||
# TRIX
|
||||
out['trix'] = ta.trend.trix(c, window=P['trix_window'])
|
||||
|
||||
# Aroon (2个)
|
||||
aroon = ta.trend.AroonIndicator(h, l, window=P['aroon_window'])
|
||||
out['aroon_up'] = aroon.aroon_up()
|
||||
out['aroon_down'] = aroon.aroon_down()
|
||||
|
||||
# CCI
|
||||
out['cci'] = ta.trend.cci(h, l, c, window=P['cci_window'])
|
||||
|
||||
# DPO
|
||||
out['dpo'] = ta.trend.dpo(c, window=P['dpo_window'])
|
||||
|
||||
# KST
|
||||
kst = ta.trend.KSTIndicator(c, roc1=P['kst_roc1'], roc2=P['kst_roc2'],
|
||||
roc3=P['kst_roc3'], roc4=P['kst_roc4'])
|
||||
out['kst'] = kst.kst()
|
||||
|
||||
# Vortex (2个)
|
||||
vortex = ta.trend.VortexIndicator(h, l, c, window=P['vortex_window'])
|
||||
out['vortex_pos'] = vortex.vortex_indicator_pos()
|
||||
out['vortex_neg'] = vortex.vortex_indicator_neg()
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _add_momentum(out, h, l, c, v):
|
||||
"""动量类指标 (12个特征)"""
|
||||
# RSI
|
||||
out['rsi'] = ta.momentum.rsi(c, window=P['rsi_window'])
|
||||
|
||||
# Stochastic %K / %D
|
||||
stoch = ta.momentum.StochasticOscillator(h, l, c,
|
||||
window=P['stoch_window'],
|
||||
smooth_window=P['stoch_smooth'])
|
||||
out['stoch_k'] = stoch.stoch()
|
||||
out['stoch_d'] = stoch.stoch_signal()
|
||||
|
||||
# Williams %R
|
||||
out['williams_r'] = ta.momentum.williams_r(h, l, c, lbp=P['williams_window'])
|
||||
|
||||
# ROC
|
||||
out['roc'] = ta.momentum.roc(c, window=P['roc_window'])
|
||||
|
||||
# MFI(需要volume)
|
||||
if v is not None:
|
||||
out['mfi'] = ta.volume.money_flow_index(h, l, c, v, window=P['mfi_window'])
|
||||
|
||||
# TSI
|
||||
out['tsi'] = ta.momentum.tsi(c, window_slow=P['tsi_slow'], window_fast=P['tsi_fast'])
|
||||
|
||||
# Ultimate Oscillator
|
||||
out['uo'] = ta.momentum.ultimate_oscillator(h, l, c,
|
||||
window1=P['uo_short'],
|
||||
window2=P['uo_medium'],
|
||||
window3=P['uo_long'])
|
||||
|
||||
# Awesome Oscillator
|
||||
out['ao'] = ta.momentum.awesome_oscillator(h, l,
|
||||
window1=P['ao_short'],
|
||||
window2=P['ao_long'])
|
||||
|
||||
# KAMA
|
||||
out['kama'] = ta.momentum.kama(c, window=P['kama_window'])
|
||||
|
||||
# PPO
|
||||
out['ppo'] = ta.momentum.ppo(c, window_slow=P['ppo_slow'], window_fast=P['ppo_fast'])
|
||||
|
||||
# Stochastic RSI %K / %D
|
||||
stoch_rsi = ta.momentum.StochRSIIndicator(c,
|
||||
window=P['stoch_rsi_window'],
|
||||
smooth1=P['stoch_rsi_smooth'],
|
||||
smooth2=P['stoch_rsi_smooth'])
|
||||
out['stoch_rsi_k'] = stoch_rsi.stochrsi_k()
|
||||
out['stoch_rsi_d'] = stoch_rsi.stochrsi_d()
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _add_volatility(out, h, l, c):
|
||||
"""波动率类指标 (8个特征 — 含子指标共12列)"""
|
||||
# Bollinger Bands (5个)
|
||||
bb = ta.volatility.BollingerBands(c, window=P['bb_window'], window_dev=P['bb_std'])
|
||||
out['bb_upper'] = bb.bollinger_hband()
|
||||
out['bb_mid'] = bb.bollinger_mavg()
|
||||
out['bb_lower'] = bb.bollinger_lband()
|
||||
out['bb_width'] = bb.bollinger_wband()
|
||||
out['bb_pband'] = bb.bollinger_pband()
|
||||
|
||||
# ATR
|
||||
out['atr'] = ta.volatility.average_true_range(h, l, c, window=P['atr_window'])
|
||||
|
||||
# Keltner Channel (3个)
|
||||
kc = ta.volatility.KeltnerChannel(h, l, c, window=P['kc_window'])
|
||||
out['kc_upper'] = kc.keltner_channel_hband()
|
||||
out['kc_mid'] = kc.keltner_channel_mband()
|
||||
out['kc_lower'] = kc.keltner_channel_lband()
|
||||
|
||||
# Donchian Channel (3个)
|
||||
dc = ta.volatility.DonchianChannel(h, l, c, window=P['dc_window'])
|
||||
out['dc_upper'] = dc.donchian_channel_hband()
|
||||
out['dc_mid'] = dc.donchian_channel_mband()
|
||||
out['dc_lower'] = dc.donchian_channel_lband()
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _add_volume(out, h, l, c, v):
|
||||
"""成交量类指标 (8个特征)"""
|
||||
# OBV
|
||||
out['obv'] = ta.volume.on_balance_volume(c, v)
|
||||
|
||||
# VWAP
|
||||
out['vwap'] = ta.volume.volume_weighted_average_price(h, l, c, v)
|
||||
|
||||
# CMF
|
||||
out['cmf'] = ta.volume.chaikin_money_flow(h, l, c, v, window=P['cmf_window'])
|
||||
|
||||
# ADI (Accumulation/Distribution Index)
|
||||
out['adi'] = ta.volume.acc_dist_index(h, l, c, v)
|
||||
|
||||
# EMV (Ease of Movement)
|
||||
out['emv'] = ta.volume.ease_of_movement(h, l, v, window=P['emv_window'])
|
||||
|
||||
# Force Index
|
||||
out['fi'] = ta.volume.force_index(c, v, window=P['fi_window'])
|
||||
|
||||
# VPT (Volume Price Trend)
|
||||
out['vpt'] = ta.volume.volume_price_trend(c, v)
|
||||
|
||||
# NVI (Negative Volume Index)
|
||||
out['nvi'] = ta.volume.negative_volume_index(c, v)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _add_custom(out, o, h, l, c):
|
||||
"""自定义衍生特征 (10个)"""
|
||||
# 价格变化率
|
||||
out['price_change_pct'] = c.pct_change()
|
||||
|
||||
# 振幅(High-Low范围 / Close)
|
||||
out['high_low_range'] = (h - l) / c
|
||||
|
||||
# 实体比率(|Close-Open| / (High-Low))
|
||||
body = (c - o).abs()
|
||||
hl_range = (h - l).replace(0, np.nan)
|
||||
out['body_ratio'] = body / hl_range
|
||||
|
||||
# 上影线比率
|
||||
upper_shadow = h - pd.concat([o, c], axis=1).max(axis=1)
|
||||
out['upper_shadow'] = upper_shadow / hl_range
|
||||
|
||||
# 下影线比率
|
||||
lower_shadow = pd.concat([o, c], axis=1).min(axis=1) - l
|
||||
out['lower_shadow'] = lower_shadow / hl_range
|
||||
|
||||
# 波动率比率(ATR / Close 的滚动比值)
|
||||
atr = ta.volatility.average_true_range(h, l, c, window=14)
|
||||
out['volatility_ratio'] = atr / c
|
||||
|
||||
# Close / SMA20 比率
|
||||
sma20 = ta.trend.sma_indicator(c, window=20)
|
||||
out['close_sma20_ratio'] = c / sma20.replace(0, np.nan)
|
||||
|
||||
# Close / EMA12 比率
|
||||
ema12 = ta.trend.ema_indicator(c, window=12)
|
||||
out['close_ema12_ratio'] = c / ema12.replace(0, np.nan)
|
||||
|
||||
# 动量 3周期
|
||||
out['momentum_3'] = c - c.shift(3)
|
||||
|
||||
# 动量 5周期
|
||||
out['momentum_5'] = c - c.shift(5)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def get_indicator_names(has_volume: bool = False) -> list:
|
||||
"""返回所有指标列名"""
|
||||
names = []
|
||||
# 趋势
|
||||
for w in P['sma_windows']:
|
||||
names.append(f'sma_{w}')
|
||||
for w in P['ema_windows']:
|
||||
names.append(f'ema_{w}')
|
||||
names += ['macd', 'macd_signal', 'macd_hist', 'adx', 'di_plus', 'di_minus']
|
||||
names += ['ichimoku_conv', 'ichimoku_base', 'ichimoku_a', 'ichimoku_b']
|
||||
names += ['trix', 'aroon_up', 'aroon_down', 'cci', 'dpo', 'kst', 'vortex_pos', 'vortex_neg']
|
||||
# 动量
|
||||
names += ['rsi', 'stoch_k', 'stoch_d', 'williams_r', 'roc', 'tsi', 'uo', 'ao', 'kama', 'ppo',
|
||||
'stoch_rsi_k', 'stoch_rsi_d']
|
||||
if has_volume:
|
||||
names.append('mfi')
|
||||
# 波动率
|
||||
names += ['bb_upper', 'bb_mid', 'bb_lower', 'bb_width', 'bb_pband', 'atr',
|
||||
'kc_upper', 'kc_mid', 'kc_lower', 'dc_upper', 'dc_mid', 'dc_lower']
|
||||
# 成交量
|
||||
if has_volume:
|
||||
names += ['obv', 'vwap', 'cmf', 'adi', 'emv', 'fi', 'vpt', 'nvi']
|
||||
# 自定义
|
||||
names += ['price_change_pct', 'high_low_range', 'body_ratio', 'upper_shadow', 'lower_shadow',
|
||||
'volatility_ratio', 'close_sma20_ratio', 'close_ema12_ratio', 'momentum_3', 'momentum_5']
|
||||
return names
|
||||
256
strategy/stat_strategy.py
Normal file
256
strategy/stat_strategy.py
Normal file
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
方案A:统计筛选 + 规则组合策略
|
||||
1. 从52个指标中用统计方法筛选最有效的指标
|
||||
2. 用经典规则组合生成交易信号
|
||||
3. 网格搜索优化参数
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.feature_selection import mutual_info_classif
|
||||
from loguru import logger
|
||||
|
||||
from .config import STAT_CONFIG as SC, PRIMARY_PERIOD, AUX_PERIODS
|
||||
from .feature_engine import prepare_dataset
|
||||
from .backtest import BacktestEngine, print_metrics
|
||||
|
||||
|
||||
class StatStrategy:
|
||||
"""统计筛选策略"""
|
||||
|
||||
def __init__(self):
|
||||
self.top_features = []
|
||||
self.feature_scores = {}
|
||||
self.best_params = {}
|
||||
|
||||
def select_features(self, X: pd.DataFrame, y: pd.Series) -> list:
|
||||
"""
|
||||
用多种统计方法筛选有效指标
|
||||
:return: Top N 特征名列表
|
||||
"""
|
||||
logger.info("=" * 50)
|
||||
logger.info("开始特征筛选...")
|
||||
scores = {}
|
||||
|
||||
# 1. 皮尔逊相关系数
|
||||
logger.info("计算皮尔逊相关系数...")
|
||||
corr_scores = X.corrwith(y).abs().fillna(0)
|
||||
for col in X.columns:
|
||||
scores[col] = scores.get(col, 0) + corr_scores.get(col, 0)
|
||||
|
||||
# 2. 互信息
|
||||
logger.info("计算互信息...")
|
||||
mi = mutual_info_classif(X.fillna(0), y, random_state=42)
|
||||
mi_series = pd.Series(mi, index=X.columns)
|
||||
mi_norm = mi_series / mi_series.max() if mi_series.max() > 0 else mi_series
|
||||
for col in X.columns:
|
||||
scores[col] = scores.get(col, 0) + mi_norm.get(col, 0)
|
||||
|
||||
# 3. 随机森林特征重要性
|
||||
logger.info("训练随机森林评估特征重要性...")
|
||||
rf = RandomForestClassifier(n_estimators=200, max_depth=8, random_state=42, n_jobs=-1)
|
||||
rf.fit(X.fillna(0), y)
|
||||
rf_imp = pd.Series(rf.feature_importances_, index=X.columns)
|
||||
rf_norm = rf_imp / rf_imp.max() if rf_imp.max() > 0 else rf_imp
|
||||
for col in X.columns:
|
||||
scores[col] = scores.get(col, 0) + rf_norm.get(col, 0)
|
||||
|
||||
# 综合排名
|
||||
score_series = pd.Series(scores).sort_values(ascending=False)
|
||||
self.feature_scores = score_series.to_dict()
|
||||
|
||||
# 去除高相关特征
|
||||
top_candidates = score_series.head(SC['top_n_features'] * 2).index.tolist()
|
||||
selected = self._remove_correlated(X[top_candidates], SC['correlation_threshold'])
|
||||
self.top_features = selected[:SC['top_n_features']]
|
||||
|
||||
logger.info(f"筛选出 Top {len(self.top_features)} 特征:")
|
||||
for i, feat in enumerate(self.top_features):
|
||||
logger.info(f" {i+1}. {feat} (综合得分: {score_series[feat]:.4f})")
|
||||
|
||||
return self.top_features
|
||||
|
||||
def _remove_correlated(self, X: pd.DataFrame, threshold: float) -> list:
|
||||
"""去除高度相关的冗余特征"""
|
||||
corr_matrix = X.corr().abs()
|
||||
selected = list(X.columns)
|
||||
to_remove = set()
|
||||
|
||||
for i in range(len(selected)):
|
||||
if selected[i] in to_remove:
|
||||
continue
|
||||
for j in range(i + 1, len(selected)):
|
||||
if selected[j] in to_remove:
|
||||
continue
|
||||
if corr_matrix.loc[selected[i], selected[j]] > threshold:
|
||||
to_remove.add(selected[j])
|
||||
|
||||
result = [c for c in selected if c not in to_remove]
|
||||
if to_remove:
|
||||
logger.info(f"移除 {len(to_remove)} 个高相关冗余特征")
|
||||
return result
|
||||
|
||||
def generate_signals(self, df: pd.DataFrame) -> pd.Series:
|
||||
"""
|
||||
基于筛选出的指标,用规则组合生成交易信号
|
||||
:param df: 包含指标列的 DataFrame(原始值,非标准化)
|
||||
:return: 信号 Series (0=观望, 1=做多, 2=做空)
|
||||
"""
|
||||
signals = pd.Series(0, index=df.index)
|
||||
long_score = pd.Series(0.0, index=df.index)
|
||||
short_score = pd.Series(0.0, index=df.index)
|
||||
matched = 0
|
||||
|
||||
for feat in self.top_features:
|
||||
if feat not in df.columns:
|
||||
continue
|
||||
|
||||
col = df[feat]
|
||||
base = feat.split('_lag')[0] # 去掉 _lagN 后缀
|
||||
# 去掉辅助周期后缀 _5m / _60m
|
||||
for suffix in ['_5m', '_60m', '_3m', '_15m', '_30m', '_1m']:
|
||||
if base.endswith(suffix):
|
||||
base = base[:-len(suffix)]
|
||||
break
|
||||
|
||||
if 'rsi' in base:
|
||||
long_score += (col < 35).astype(float)
|
||||
short_score += (col > 65).astype(float)
|
||||
matched += 1
|
||||
elif base == 'macd_hist':
|
||||
long_score += (col > 0).astype(float)
|
||||
short_score += (col < 0).astype(float)
|
||||
matched += 1
|
||||
elif base == 'macd':
|
||||
long_score += (col > 0).astype(float)
|
||||
short_score += (col < 0).astype(float)
|
||||
matched += 1
|
||||
elif 'bb_pband' in base:
|
||||
long_score += (col < 0.2).astype(float)
|
||||
short_score += (col > 0.8).astype(float)
|
||||
matched += 1
|
||||
elif 'adx' in base:
|
||||
long_score += (col > 25).astype(float)
|
||||
short_score += (col > 25).astype(float)
|
||||
matched += 1
|
||||
elif 'cci' in base:
|
||||
long_score += (col < -100).astype(float)
|
||||
short_score += (col > 100).astype(float)
|
||||
matched += 1
|
||||
elif 'stoch_k' in base or 'stoch_rsi_k' in base:
|
||||
long_score += (col < 25).astype(float)
|
||||
short_score += (col > 75).astype(float)
|
||||
matched += 1
|
||||
elif 'williams_r' in base:
|
||||
long_score += (col < -80).astype(float)
|
||||
short_score += (col > -20).astype(float)
|
||||
matched += 1
|
||||
elif 'ao' in base or 'tsi' in base or 'roc' in base or 'ppo' in base:
|
||||
long_score += (col > 0).astype(float)
|
||||
short_score += (col < 0).astype(float)
|
||||
matched += 1
|
||||
elif 'atr' in base or 'volatility_ratio' in base:
|
||||
# 波动率类:高波动时趋势更强,用均值分界
|
||||
median = col.rolling(200, min_periods=50).median()
|
||||
long_score += (col > median).astype(float) * 0.5
|
||||
short_score += (col > median).astype(float) * 0.5
|
||||
matched += 1
|
||||
elif 'high_low_range' in base or 'body_ratio' in base:
|
||||
median = col.rolling(200, min_periods=50).median()
|
||||
long_score += (col > median).astype(float) * 0.3
|
||||
short_score += (col > median).astype(float) * 0.3
|
||||
matched += 1
|
||||
elif 'bb_width' in base:
|
||||
# 布林带宽度收窄后扩张 = 突破信号,结合价格方向
|
||||
median = col.rolling(200, min_periods=50).median()
|
||||
expanding = col > col.shift(1) # 宽度在扩张
|
||||
was_narrow = col.shift(1) < median # 之前是收窄的
|
||||
breakout = expanding & was_narrow
|
||||
if 'close' in df.columns:
|
||||
price_up = df['close'] > df['close'].shift(1)
|
||||
long_score += (breakout & price_up).astype(float)
|
||||
short_score += (breakout & ~price_up).astype(float)
|
||||
else:
|
||||
long_score += breakout.astype(float) * 0.5
|
||||
short_score += breakout.astype(float) * 0.5
|
||||
matched += 1
|
||||
elif 'close_sma20_ratio' in base or 'close_ema12_ratio' in base:
|
||||
# 价格在均线上方=多头,下方=空头
|
||||
long_score += (col > 1.0).astype(float)
|
||||
short_score += (col < 1.0).astype(float)
|
||||
matched += 1
|
||||
elif 'ichimoku' in base:
|
||||
if 'close' in df.columns:
|
||||
long_score += (df['close'] > col).astype(float)
|
||||
short_score += (df['close'] < col).astype(float)
|
||||
matched += 1
|
||||
elif 'momentum' in base or 'price_change' in base:
|
||||
long_score += (col > 0).astype(float)
|
||||
short_score += (col < 0).astype(float)
|
||||
matched += 1
|
||||
|
||||
logger.info(f"规则匹配: {matched}/{len(self.top_features)} 个特征有对应规则")
|
||||
|
||||
# 阈值:至少50%的匹配特征同时确认(更严格)
|
||||
threshold = max(3, matched * 0.5)
|
||||
logger.info(f"信号阈值: {threshold:.1f} (需要至少这么多指标同时确认)")
|
||||
signals[long_score >= threshold] = 1
|
||||
signals[short_score >= threshold] = 2
|
||||
# 多空同时满足时取更强的
|
||||
both = (long_score >= threshold) & (short_score >= threshold)
|
||||
signals[both & (long_score > short_score)] = 1
|
||||
signals[both & (short_score > long_score)] = 2
|
||||
signals[both & (long_score == short_score)] = 0
|
||||
|
||||
dist = signals.value_counts().to_dict()
|
||||
logger.info(f"规则信号分布: 观望={dist.get(0, 0)}, 做多={dist.get(1, 0)}, 做空={dist.get(2, 0)}")
|
||||
return signals
|
||||
|
||||
def run(self, period: int = None, start_date: str = None, end_date: str = None) -> dict:
|
||||
"""
|
||||
完整运行方案A
|
||||
:return: 回测结果
|
||||
"""
|
||||
if period is None:
|
||||
period = PRIMARY_PERIOD
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("方案A:统计筛选 + 规则组合策略")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# 1. 准备数据(标准化版本,用于特征筛选)
|
||||
X, y, feature_names, _ = prepare_dataset(period, start_date, end_date)
|
||||
|
||||
# 2. 筛选特征
|
||||
self.select_features(X, y)
|
||||
|
||||
# 3. 构建完整特征矩阵(原始值,非标准化,用于规则判断)
|
||||
from .data_loader import load_kline, load_multi_period
|
||||
from .feature_engine import build_features
|
||||
primary_df = load_kline(period, start_date, end_date)
|
||||
aux_dfs = {}
|
||||
for aux_p in AUX_PERIODS:
|
||||
try:
|
||||
aux_df = load_kline(aux_p, start_date, end_date)
|
||||
if not aux_df.empty:
|
||||
aux_dfs[aux_p] = aux_df
|
||||
except Exception:
|
||||
pass
|
||||
df = build_features(primary_df, aux_dfs)
|
||||
df.dropna(inplace=True)
|
||||
|
||||
# 4. 生成信号
|
||||
signals = self.generate_signals(df)
|
||||
|
||||
# 5. 回测
|
||||
engine = BacktestEngine()
|
||||
result = engine.run(df['close'], signals)
|
||||
|
||||
print_metrics(result['metrics'], "方案A: 统计筛选策略")
|
||||
return result
|
||||
|
||||
|
||||
def run_stat_strategy(period: int = None, start_date: str = None, end_date: str = None) -> dict:
|
||||
"""方案A快捷入口"""
|
||||
strategy = StatStrategy()
|
||||
return strategy.run(period, start_date, end_date)
|
||||
2
test.py
Normal file
2
test.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from strategy.compare import run_full_comparison
|
||||
results = run_full_comparison(period=15)
|
||||
694
四分之一,五分钟,反手条件充足.py
Normal file
694
四分之一,五分钟,反手条件充足.py
Normal file
@@ -0,0 +1,694 @@
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from bit_tools import openBrowser
|
||||
from DrissionPage import ChromiumPage
|
||||
from DrissionPage import ChromiumOptions
|
||||
|
||||
from bitmart.api_contract import APIContract
|
||||
|
||||
# 方案B:从 strategy 模块获取实盘信号(需先运行 AI 策略训练保存模型,并保持 models/database.db 有最新 15m/5m/1h K 线,例如运行 抓取多周期K线.py)
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent))
|
||||
from strategy.ai_strategy import get_live_signal
|
||||
|
||||
|
||||
class BitmartFuturesTransaction:
|
||||
def __init__(self, bit_id):
|
||||
|
||||
self.page: ChromiumPage | None = None
|
||||
|
||||
self.api_key = "a0fb7b98464fd9bcce67e7c519d58ec10d0c38a8"
|
||||
self.secret_key = "4eaeba78e77aeaab1c2027f846a276d164f264a44c2c1bb1c5f3be50c8de1ca5"
|
||||
self.memo = "合约交易"
|
||||
|
||||
self.contract_symbol = "ETHUSDT"
|
||||
|
||||
self.contractAPI = APIContract(self.api_key, self.secret_key, self.memo, timeout=(5, 15))
|
||||
|
||||
self.start = 0 # 持仓状态: -1 空, 0 无, 1 多
|
||||
|
||||
self.pbar = tqdm(total=30, desc="等待K线", ncols=80) # 可选:用于长时间等待时展示进度
|
||||
|
||||
self.last_kline_time = None # 上一次出信号的 15 分钟 K 线 id(方案B 每根 15m 只出一次信号)
|
||||
|
||||
# 反手频率控制
|
||||
self.reverse_cooldown_seconds = 1.5 * 60 # 反手冷却时间(秒)
|
||||
self.reverse_min_move_pct = 0.05 # 反手最小价差过滤(百分比)
|
||||
self.last_reverse_time = None # 上次反手时间
|
||||
|
||||
# 开仓频率控制
|
||||
self.open_cooldown_seconds = 60 # 开仓冷却时间(秒),两次开仓至少间隔此时长
|
||||
self.last_open_time = None # 上次开仓时间
|
||||
self.last_open_kline_id = None # 上次开仓所在 K 线 id,同一根 K 线只允许开仓一次
|
||||
|
||||
self.leverage = "100" # 高杠杆(全仓模式下可开更大仓位)
|
||||
self.open_type = "cross" # 全仓模式
|
||||
self.risk_percent = 0 # 未使用;若启用则可为每次开仓占可用余额的百分比
|
||||
self.take_profit_usd = 5 # 仓位盈利达到此金额(美元)时平仓止盈
|
||||
self.stop_loss_usd = -3 # 固定止损:亏损达到 3 美元平仓
|
||||
self.trailing_activation_usd = 2 # 盈利达到此金额后启动移动止损
|
||||
self.trailing_distance_usd = 1.5 # 从最高盈利回撤此金额则平仓
|
||||
self.max_unrealized_pnl_seen = None # 持仓期间见过的最大盈利(用于移动止损)
|
||||
|
||||
self.open_avg_price = None # 开仓价格
|
||||
self.current_amount = None # 持仓量
|
||||
|
||||
self.bit_id = bit_id
|
||||
self.default_order_size = 25 # 开仓/反手张数,统一在此修改
|
||||
|
||||
# 策略相关变量
|
||||
self.prev_kline = None # 上一根K线
|
||||
self.current_kline = None # 当前K线
|
||||
self.prev_entity = None # 上一根K线实体大小
|
||||
self.current_open = None # 当前K线开盘价
|
||||
|
||||
def get_klines(self):
|
||||
"""获取最近2根K线(当前K线和上一根K线)"""
|
||||
try:
|
||||
end_time = int(time.time())
|
||||
# 获取足够多的条目确保有最新的K线
|
||||
response = self.contractAPI.get_kline(
|
||||
contract_symbol=self.contract_symbol,
|
||||
step=5, # 5分钟
|
||||
start_time=end_time - 3600 * 3, # 取最近3小时
|
||||
end_time=end_time
|
||||
)[0]["data"]
|
||||
|
||||
# 每根: [timestamp, open, high, low, close, volume]
|
||||
formatted = []
|
||||
for k in response:
|
||||
formatted.append({
|
||||
'id': int(k["timestamp"]),
|
||||
'open': float(k["open_price"]),
|
||||
'high': float(k["high_price"]),
|
||||
'low': float(k["low_price"]),
|
||||
'close': float(k["close_price"])
|
||||
})
|
||||
formatted.sort(key=lambda x: x['id'])
|
||||
|
||||
# 返回最近2根K线:倒数第二根(上一根)和最后一根(当前)
|
||||
if len(formatted) >= 2:
|
||||
return formatted[-2], formatted[-1]
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error(f"获取K线异常: {e}")
|
||||
self.ding(text="获取K线异常", error=True)
|
||||
return None, None
|
||||
|
||||
def get_current_price(self):
|
||||
"""获取当前最新价格"""
|
||||
try:
|
||||
end_time = int(time.time())
|
||||
response = self.contractAPI.get_kline(
|
||||
contract_symbol=self.contract_symbol,
|
||||
step=1, # 1分钟
|
||||
start_time=end_time - 3600 * 1, # 取最近1小时
|
||||
end_time=end_time
|
||||
)[0]
|
||||
if response['code'] == 1000:
|
||||
return float(response['data'][-1]["close_price"])
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取价格异常: {e}")
|
||||
return None
|
||||
|
||||
def get_available_balance(self):
|
||||
"""获取合约账户可用USDT余额"""
|
||||
try:
|
||||
response = self.contractAPI.get_assets_detail()[0]
|
||||
if response['code'] == 1000:
|
||||
data = response['data']
|
||||
if isinstance(data, dict):
|
||||
return float(data.get('available_balance', 0))
|
||||
elif isinstance(data, list):
|
||||
for asset in data:
|
||||
if asset.get('currency') == 'USDT':
|
||||
return float(asset.get('available_balance', 0))
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"余额查询异常: {e}")
|
||||
return None
|
||||
|
||||
def get_position_status(self):
|
||||
"""获取当前持仓方向"""
|
||||
try:
|
||||
response = self.contractAPI.get_position(contract_symbol=self.contract_symbol)[0]
|
||||
if response['code'] == 1000:
|
||||
positions = response['data']
|
||||
if not positions:
|
||||
self.start = 0
|
||||
self.open_avg_price = None
|
||||
self.current_amount = None
|
||||
self.unrealized_pnl = None
|
||||
return True
|
||||
pos = positions[0]
|
||||
self.start = 1 if pos['position_type'] == 1 else -1
|
||||
self.open_avg_price = float(pos['open_avg_price'])
|
||||
self.current_amount = float(pos['current_amount'])
|
||||
self.position_cross = pos["position_cross"]
|
||||
# 直接从API获取未实现盈亏(Bitmart返回的是 unrealized_value 字段)
|
||||
self.unrealized_pnl = float(pos.get('unrealized_value', 0))
|
||||
logger.debug(f"持仓详情: 方向={self.start}, 开仓均价={self.open_avg_price}, "
|
||||
f"持仓量={self.current_amount}, 未实现盈亏={self.unrealized_pnl:.2f}")
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"持仓查询异常: {e}")
|
||||
return False
|
||||
|
||||
def get_unrealized_pnl_usd(self):
|
||||
"""
|
||||
获取当前持仓未实现盈亏(美元),直接使用API返回值
|
||||
"""
|
||||
if self.start == 0 or self.unrealized_pnl is None:
|
||||
return None
|
||||
return self.unrealized_pnl
|
||||
|
||||
def set_leverage(self):
|
||||
"""程序启动时设置全仓 + 高杠杆"""
|
||||
try:
|
||||
response = self.contractAPI.post_submit_leverage(
|
||||
contract_symbol=self.contract_symbol,
|
||||
leverage=self.leverage,
|
||||
open_type=self.open_type
|
||||
)[0]
|
||||
if response['code'] == 1000:
|
||||
logger.success(f"全仓模式 + {self.leverage}x 杠杆设置成功")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"杠杆设置失败: {response}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"设置杠杆异常: {e}")
|
||||
return False
|
||||
|
||||
def openBrowser(self):
|
||||
"""打开 TGE 对应浏览器实例"""
|
||||
try:
|
||||
bit_port = openBrowser(id=self.bit_id)
|
||||
co = ChromiumOptions()
|
||||
co.set_local_port(port=bit_port)
|
||||
self.page = ChromiumPage(addr_or_opts=co)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def click_safe(self, xpath, sleep=0.5):
|
||||
"""安全点击"""
|
||||
try:
|
||||
ele = self.page.ele(xpath)
|
||||
if not ele:
|
||||
return False
|
||||
# ele.scroll.to_see(center=True)
|
||||
# time.sleep(sleep)
|
||||
ele.click(by_js=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def 平仓(self):
|
||||
"""平仓操作"""
|
||||
self.click_safe('x://span[normalize-space(text()) ="市价"]')
|
||||
|
||||
def 开单(self, marketPriceLongOrder=0, limitPriceShortOrder=0, size=None, price=None):
|
||||
"""
|
||||
marketPriceLongOrder 市价做多或者做空,1是做多,-1是做空
|
||||
limitPriceShortOrder 限价做多或者做空
|
||||
"""
|
||||
if marketPriceLongOrder == -1:
|
||||
# self.click_safe('x://button[normalize-space(text()) ="市价"]')
|
||||
# self.page.ele('x://*[@id="size_0"]').input(vals=size, clear=True)
|
||||
self.click_safe('x://span[normalize-space(text()) ="卖出/做空"]')
|
||||
elif marketPriceLongOrder == 1:
|
||||
# self.click_safe('x://button[normalize-space(text()) ="市价"]')
|
||||
# self.page.ele('x://*[@id="size_0"]').input(vals=size, clear=True)
|
||||
self.click_safe('x://span[normalize-space(text()) ="买入/做多"]')
|
||||
|
||||
if limitPriceShortOrder == -1:
|
||||
self.click_safe('x://button[normalize-space(text()) ="限价"]')
|
||||
self.page.ele('x://*[@id="price_0"]').input(vals=price, clear=True)
|
||||
time.sleep(1)
|
||||
self.page.ele('x://*[@id="size_0"]').input(1)
|
||||
self.click_safe('x://span[normalize-space(text()) ="卖出/做空"]')
|
||||
elif limitPriceShortOrder == 1:
|
||||
self.click_safe('x://button[normalize-space(text()) ="限价"]')
|
||||
self.page.ele('x://*[@id="price_0"]').input(vals=price, clear=True)
|
||||
time.sleep(1)
|
||||
self.page.ele('x://*[@id="size_0"]').input(1)
|
||||
self.click_safe('x://span[normalize-space(text()) ="买入/做多"]')
|
||||
|
||||
def ding(self, text, error=False):
|
||||
"""日志通知"""
|
||||
if error:
|
||||
logger.error(text)
|
||||
else:
|
||||
logger.info(text)
|
||||
|
||||
def calculate_entity(self, kline):
|
||||
"""计算K线实体大小(绝对值)"""
|
||||
return abs(kline['close'] - kline['open'])
|
||||
|
||||
def calculate_upper_shadow(self, kline):
|
||||
"""计算上阴线(上影线)涨幅百分比"""
|
||||
# 上阴线 = (最高价 - max(开盘价, 收盘价)) / max(开盘价, 收盘价)
|
||||
body_top = max(kline['open'], kline['close'])
|
||||
if body_top == 0:
|
||||
return 0
|
||||
return (kline['high'] - body_top) / body_top * 100
|
||||
|
||||
def calculate_lower_shadow(self, kline):
|
||||
"""计算下阴线(下影线)跌幅百分比"""
|
||||
# 下阴线 = (min(开盘价, 收盘价) - 最低价) / min(开盘价, 收盘价)
|
||||
body_bottom = min(kline['open'], kline['close'])
|
||||
if body_bottom == 0:
|
||||
return 0
|
||||
return (body_bottom - kline['low']) / body_bottom * 100
|
||||
|
||||
def get_entity_edge(self, kline):
|
||||
"""获取K线实体边(收盘价或开盘价,取决于是阳线还是阴线)"""
|
||||
# 阳线(收盘>开盘):实体上边=收盘价,实体下边=开盘价
|
||||
# 阴线(收盘<开盘):实体上边=开盘价,实体下边=收盘价
|
||||
return {
|
||||
'upper': max(kline['open'], kline['close']), # 实体上边
|
||||
'lower': min(kline['open'], kline['close']) # 实体下边
|
||||
}
|
||||
|
||||
def check_signal(self, current_price, prev_kline, current_kline):
|
||||
"""
|
||||
检查交易信号
|
||||
返回: ('long', trigger_price) / ('short', trigger_price) / None
|
||||
"""
|
||||
# 计算上一根K线实体
|
||||
prev_entity = self.calculate_entity(prev_kline)
|
||||
|
||||
# 实体过小不交易(实体 < 0.1)
|
||||
if prev_entity < 0.1:
|
||||
logger.info(f"上一根K线实体过小: {prev_entity:.4f},跳过信号检测")
|
||||
return None
|
||||
|
||||
# 获取上一根K线的实体上下边
|
||||
prev_entity_edge = self.get_entity_edge(prev_kline)
|
||||
prev_entity_upper = prev_entity_edge['upper'] # 实体上边
|
||||
prev_entity_lower = prev_entity_edge['lower'] # 实体下边
|
||||
|
||||
# 优化:以下两种情况以当前这根的开盘价作为计算基准
|
||||
# 1) 上一根阳线 且 当前开盘价 > 上一根收盘价(跳空高开)
|
||||
# 2) 上一根阴线 且 当前开盘价 < 上一根收盘价(跳空低开)
|
||||
prev_is_bullish_for_calc = prev_kline['close'] > prev_kline['open']
|
||||
prev_is_bearish_for_calc = prev_kline['close'] < prev_kline['open']
|
||||
current_open_above_prev_close = current_kline['open'] > prev_kline['close']
|
||||
current_open_below_prev_close = current_kline['open'] < prev_kline['close']
|
||||
use_current_open_as_base = (prev_is_bullish_for_calc and current_open_above_prev_close) or (prev_is_bearish_for_calc and current_open_below_prev_close)
|
||||
|
||||
if use_current_open_as_base:
|
||||
# 以当前K线开盘价为基准计算(跳空时用当前开盘价参与计算)
|
||||
calc_lower = current_kline['open']
|
||||
calc_upper = current_kline['open'] # 同一基准,上下四分之一对称
|
||||
long_trigger = calc_lower + prev_entity / 4
|
||||
short_trigger = calc_upper - prev_entity / 4
|
||||
long_breakout = calc_upper + prev_entity / 4
|
||||
short_breakout = calc_lower - prev_entity / 4
|
||||
else:
|
||||
# 原有计算方式
|
||||
long_trigger = prev_entity_lower + prev_entity / 4 # 做多触发价 = 实体下边 + 实体/4(下四分之一处)
|
||||
short_trigger = prev_entity_upper - prev_entity / 4 # 做空触发价 = 实体上边 - 实体/4(上四分之一处)
|
||||
long_breakout = prev_entity_upper + prev_entity / 4 # 做多突破价 = 实体上边 + 实体/4
|
||||
short_breakout = prev_entity_lower - prev_entity / 4 # 做空突破价 = 实体下边 - 实体/4
|
||||
|
||||
# 上一根阴线 + 当前阳线:做多形态,不按上一根K线上三分之一做空
|
||||
prev_is_bearish = prev_kline['close'] < prev_kline['open']
|
||||
current_is_bullish = current_kline['close'] > current_kline['open']
|
||||
skip_short_by_upper_third = prev_is_bearish and current_is_bullish
|
||||
# 上一根阳线 + 当前阴线:做空形态,不按上一根K线下三分之一做多
|
||||
prev_is_bullish = prev_kline['close'] > prev_kline['open']
|
||||
current_is_bearish = current_kline['close'] < current_kline['open']
|
||||
skip_long_by_lower_third = prev_is_bullish and current_is_bearish
|
||||
|
||||
if use_current_open_as_base:
|
||||
if prev_is_bullish_for_calc and current_open_above_prev_close:
|
||||
logger.info(f"上一根阳线且当前开盘价({current_kline['open']:.2f})>上一根收盘价({prev_kline['close']:.2f}),以当前开盘价为基准计算")
|
||||
else:
|
||||
logger.info(f"上一根阴线且当前开盘价({current_kline['open']:.2f})<上一根收盘价({prev_kline['close']:.2f}),以当前开盘价为基准计算")
|
||||
logger.info(f"当前价格: {current_price:.2f}, 上一根实体: {prev_entity:.4f}")
|
||||
logger.info(f"上一根实体上边: {prev_entity_upper:.2f}, 下边: {prev_entity_lower:.2f}")
|
||||
logger.info(f"做多触发价(下1/4): {long_trigger:.2f}, 做空触发价(上1/4): {short_trigger:.2f}")
|
||||
logger.info(f"突破做多价(上1/4外): {long_breakout:.2f}, 突破做空价(下1/4外): {short_breakout:.2f}")
|
||||
if skip_short_by_upper_third:
|
||||
logger.info("上一根阴线+当前阳线(做多形态),不按上四分之一做空")
|
||||
if skip_long_by_lower_third:
|
||||
logger.info("上一根阳线+当前阴线(做空形态),不按下四分之一做多")
|
||||
|
||||
# 无持仓时检查开仓信号
|
||||
if self.start == 0:
|
||||
if current_price >= long_breakout and not skip_long_by_lower_third:
|
||||
logger.info(f"触发做多信号!价格 {current_price:.2f} >= 突破价(上1/4外) {long_breakout:.2f}")
|
||||
return ('long', long_breakout)
|
||||
elif current_price <= short_breakout and not skip_short_by_upper_third:
|
||||
logger.info(f"触发做空信号!价格 {current_price:.2f} <= 突破价(下1/4外) {short_breakout:.2f}")
|
||||
return ('short', short_breakout)
|
||||
|
||||
# 持仓时检查反手信号
|
||||
elif self.start == 1: # 持多仓
|
||||
# 反手条件1: 价格跌到上一根K线的上三分之一处(做空触发价);上一根阴线+当前阳线做多时跳过
|
||||
if current_price <= short_trigger and not skip_short_by_upper_third:
|
||||
logger.info(f"持多反手做空!价格 {current_price:.2f} <= 触发价(上1/4) {short_trigger:.2f}")
|
||||
return ('reverse_short', short_trigger)
|
||||
|
||||
# 反手条件2: 上一根K线上阴线涨幅>0.01%,当前跌到上一根实体下边
|
||||
upper_shadow_pct = self.calculate_upper_shadow(prev_kline)
|
||||
if upper_shadow_pct > 0.01 and current_price <= prev_entity_lower:
|
||||
logger.info(f"持多反手做空!上阴线涨幅 {upper_shadow_pct:.4f}% > 0.01%,"
|
||||
f"价格 {current_price:.2f} <= 实体下边 {prev_entity_lower:.2f}")
|
||||
return ('reverse_short', prev_entity_lower)
|
||||
|
||||
elif self.start == -1: # 持空仓
|
||||
# 反手条件1: 价格涨到上一根K线的下三分之一处(做多触发价);上一根阳线+当前阴线做空时跳过
|
||||
if current_price >= long_trigger and not skip_long_by_lower_third:
|
||||
logger.info(f"持空反手做多!价格 {current_price:.2f} >= 触发价(下1/4) {long_trigger:.2f}")
|
||||
return ('reverse_long', long_trigger)
|
||||
|
||||
# 反手条件2: 上一根K线下阴线跌幅>0.01%,当前涨到上一根实体上边
|
||||
lower_shadow_pct = self.calculate_lower_shadow(prev_kline)
|
||||
if lower_shadow_pct > 0.01 and current_price >= prev_entity_upper:
|
||||
logger.info(f"持空反手做多!下阴线跌幅 {lower_shadow_pct:.4f}% > 0.01%,"
|
||||
f"价格 {current_price:.2f} >= 实体上边 {prev_entity_upper:.2f}")
|
||||
return ('reverse_long', prev_entity_upper)
|
||||
|
||||
return None
|
||||
|
||||
def can_open(self, current_kline_id):
|
||||
"""开仓前过滤:同一根 K 线只开一次 + 开仓冷却时间。仅用于 long/short 新开仓。"""
|
||||
now = time.time()
|
||||
if self.last_open_kline_id is not None and self.last_open_kline_id == current_kline_id:
|
||||
logger.info(f"开仓频率控制:本 K 线({current_kline_id})已开过仓,跳过")
|
||||
return False
|
||||
if self.last_open_time is not None and now - self.last_open_time < self.open_cooldown_seconds:
|
||||
remain = self.open_cooldown_seconds - (now - self.last_open_time)
|
||||
logger.info(f"开仓冷却中,剩余 {remain:.0f} 秒")
|
||||
return False
|
||||
return True
|
||||
|
||||
def can_reverse(self, current_price, trigger_price):
|
||||
"""反手前过滤:冷却时间 + 最小价差"""
|
||||
now = time.time()
|
||||
if self.last_reverse_time and now - self.last_reverse_time < self.reverse_cooldown_seconds:
|
||||
remain = self.reverse_cooldown_seconds - (now - self.last_reverse_time)
|
||||
logger.info(f"反手冷却中,剩余 {remain:.0f} 秒")
|
||||
return False
|
||||
|
||||
if trigger_price and trigger_price > 0:
|
||||
move_pct = abs(current_price - trigger_price) / trigger_price * 100
|
||||
if move_pct < self.reverse_min_move_pct:
|
||||
logger.info(f"反手价差不足: {move_pct:.4f}% < {self.reverse_min_move_pct}%")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def verify_no_position(self, max_retries=5, retry_interval=3):
|
||||
"""
|
||||
验证当前无持仓
|
||||
返回: True 表示无持仓可以开仓,False 表示有持仓不能开仓
|
||||
"""
|
||||
for i in range(max_retries):
|
||||
if self.get_position_status():
|
||||
if self.start == 0:
|
||||
logger.info(f"确认无持仓,可以开仓")
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
f"仍有持仓 (方向: {self.start}),等待 {retry_interval} 秒后重试 ({i + 1}/{max_retries})")
|
||||
time.sleep(retry_interval)
|
||||
else:
|
||||
logger.warning(f"查询持仓状态失败,等待 {retry_interval} 秒后重试 ({i + 1}/{max_retries})")
|
||||
time.sleep(retry_interval)
|
||||
|
||||
logger.error(f"经过 {max_retries} 次重试仍有持仓或查询失败,放弃开仓")
|
||||
return False
|
||||
|
||||
def verify_position_direction(self, expected_direction):
|
||||
"""
|
||||
验证当前持仓方向是否与预期一致
|
||||
expected_direction: 1 多仓, -1 空仓
|
||||
返回: True 表示持仓方向正确,False 表示不正确
|
||||
"""
|
||||
if self.get_position_status():
|
||||
if self.start == expected_direction:
|
||||
logger.info(f"持仓方向验证成功: {self.start}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"持仓方向不符: 期望 {expected_direction}, 实际 {self.start}")
|
||||
return False
|
||||
else:
|
||||
logger.error("查询持仓状态失败")
|
||||
return False
|
||||
|
||||
def execute_trade(self, signal, size=None):
|
||||
"""执行交易。size 不传或为 None 时使用 default_order_size。"""
|
||||
signal_type, trigger_price = signal
|
||||
size = self.default_order_size if size is None else size
|
||||
|
||||
if signal_type == 'long':
|
||||
# 开多前先确认无持仓
|
||||
logger.info(f"准备开多,触发价: {trigger_price:.2f}")
|
||||
if not self.get_position_status():
|
||||
logger.error("开仓前查询持仓状态失败,放弃开仓")
|
||||
return False
|
||||
if self.start != 0:
|
||||
logger.warning(f"开多前发现已有持仓 (方向: {self.start}),放弃开仓避免双向持仓")
|
||||
return False
|
||||
|
||||
logger.info(f"确认无持仓,执行开多")
|
||||
self.开单(marketPriceLongOrder=1, size=size)
|
||||
time.sleep(3) # 等待订单执行
|
||||
|
||||
# 验证开仓是否成功
|
||||
if self.verify_position_direction(1):
|
||||
self.max_unrealized_pnl_seen = None # 新仓位重置移动止损记录
|
||||
self.last_open_time = time.time()
|
||||
self.last_open_kline_id = getattr(self, "_current_kline_id_for_open", None)
|
||||
logger.success("开多成功")
|
||||
return True
|
||||
else:
|
||||
logger.error("开多后持仓验证失败")
|
||||
return False
|
||||
|
||||
elif signal_type == 'short':
|
||||
# 开空前先确认无持仓
|
||||
logger.info(f"准备开空,触发价: {trigger_price:.2f}")
|
||||
if not self.get_position_status():
|
||||
logger.error("开仓前查询持仓状态失败,放弃开仓")
|
||||
return False
|
||||
if self.start != 0:
|
||||
logger.warning(f"开空前发现已有持仓 (方向: {self.start}),放弃开仓避免双向持仓")
|
||||
return False
|
||||
|
||||
logger.info(f"确认无持仓,执行开空")
|
||||
self.开单(marketPriceLongOrder=-1, size=size)
|
||||
time.sleep(3) # 等待订单执行
|
||||
|
||||
# 验证开仓是否成功
|
||||
if self.verify_position_direction(-1):
|
||||
self.max_unrealized_pnl_seen = None # 新仓位重置移动止损记录
|
||||
self.last_open_time = time.time()
|
||||
self.last_open_kline_id = getattr(self, "_current_kline_id_for_open", None)
|
||||
logger.success("开空成功")
|
||||
return True
|
||||
else:
|
||||
logger.error("开空后持仓验证失败")
|
||||
return False
|
||||
|
||||
elif signal_type == 'reverse_long':
|
||||
# 平空 + 开多(反手做多):先平仓,确认无仓后再开多,避免双向持仓
|
||||
logger.info(f"执行反手做多,触发价: {trigger_price:.2f}")
|
||||
self.平仓()
|
||||
time.sleep(1) # 给交易所处理平仓的时间
|
||||
# 轮询确认已无持仓再开多(最多等约 10 秒)
|
||||
for _ in range(10):
|
||||
if self.get_position_status() and self.start == 0:
|
||||
break
|
||||
time.sleep(1)
|
||||
if self.start != 0:
|
||||
logger.warning("反手做多:平仓后仍有持仓,放弃本次开多")
|
||||
return False
|
||||
logger.info("已确认无持仓,执行开多")
|
||||
self.开单(marketPriceLongOrder=1, size=size)
|
||||
time.sleep(3)
|
||||
|
||||
if self.verify_position_direction(1):
|
||||
self.max_unrealized_pnl_seen = None
|
||||
logger.success("反手做多成功")
|
||||
self.last_reverse_time = time.time()
|
||||
time.sleep(20)
|
||||
return True
|
||||
else:
|
||||
logger.error("反手做多后持仓验证失败")
|
||||
return False
|
||||
|
||||
elif signal_type == 'reverse_short':
|
||||
# 平多 + 开空(反手做空):先平仓,确认无仓后再开空
|
||||
logger.info(f"执行反手做空,触发价: {trigger_price:.2f}")
|
||||
self.平仓()
|
||||
time.sleep(1)
|
||||
for _ in range(10):
|
||||
if self.get_position_status() and self.start == 0:
|
||||
break
|
||||
time.sleep(1)
|
||||
if self.start != 0:
|
||||
logger.warning("反手做空:平仓后仍有持仓,放弃本次开空")
|
||||
return False
|
||||
logger.info("已确认无持仓,执行开空")
|
||||
self.开单(marketPriceLongOrder=-1, size=size)
|
||||
time.sleep(3)
|
||||
|
||||
if self.verify_position_direction(-1):
|
||||
self.max_unrealized_pnl_seen = None
|
||||
logger.success("反手做空成功")
|
||||
self.last_reverse_time = time.time()
|
||||
time.sleep(20)
|
||||
return True
|
||||
else:
|
||||
logger.error("反手做空后持仓验证失败")
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def action(self):
|
||||
"""主循环"""
|
||||
|
||||
logger.info("开始运行方案B(AI 策略)交易...")
|
||||
|
||||
# 启动时设置全仓高杠杆
|
||||
if not self.set_leverage():
|
||||
logger.error("杠杆设置失败,程序继续运行但可能下单失败")
|
||||
return
|
||||
|
||||
page_start = True
|
||||
|
||||
while True:
|
||||
|
||||
if page_start:
|
||||
# 打开浏览器
|
||||
for i in range(5):
|
||||
if self.openBrowser():
|
||||
logger.info("浏览器打开成功")
|
||||
break
|
||||
else:
|
||||
self.ding("打开浏览器失败!", error=True)
|
||||
return
|
||||
|
||||
# 进入交易页面
|
||||
self.page.get("https://derivatives.bitmart.com/zh-CN/futures/ETHUSDT")
|
||||
self.click_safe('x://button[normalize-space(text()) ="市价"]')
|
||||
|
||||
self.page.ele('x://*[@id="size_0"]').input(vals=25, clear=True)
|
||||
|
||||
page_start = False
|
||||
|
||||
try:
|
||||
# 1. 获取当前价格
|
||||
current_price = self.get_current_price()
|
||||
if not current_price:
|
||||
logger.warning("获取价格失败,等待重试...")
|
||||
time.sleep(2)
|
||||
continue
|
||||
|
||||
# 2. 每次循环都通过SDK获取真实持仓状态(避免状态不同步导致双向持仓)
|
||||
if not self.get_position_status():
|
||||
logger.warning("获取持仓状态失败,等待重试...")
|
||||
time.sleep(2)
|
||||
continue
|
||||
|
||||
logger.debug(f"当前持仓状态: {self.start} (0=无, 1=多, -1=空)")
|
||||
|
||||
# 3. 止损/止盈/移动止损
|
||||
if self.start != 0:
|
||||
pnl_usd = self.get_unrealized_pnl_usd()
|
||||
if pnl_usd is not None:
|
||||
# 固定止损:亏损达到 3 美元平仓
|
||||
if pnl_usd <= self.stop_loss_usd:
|
||||
logger.info(f"仓位亏损 {pnl_usd:.2f} 美元 <= 止损 {self.stop_loss_usd} 美元,执行止损平仓")
|
||||
self.平仓()
|
||||
self.max_unrealized_pnl_seen = None
|
||||
time.sleep(3)
|
||||
continue
|
||||
# 更新持仓期间最大盈利(用于移动止损)
|
||||
if self.max_unrealized_pnl_seen is None:
|
||||
self.max_unrealized_pnl_seen = pnl_usd
|
||||
else:
|
||||
self.max_unrealized_pnl_seen = max(self.max_unrealized_pnl_seen, pnl_usd)
|
||||
# 移动止损:盈利曾达到 activation 后,从最高盈利回撤 trailing_distance 则平仓
|
||||
if self.max_unrealized_pnl_seen >= self.trailing_activation_usd:
|
||||
if pnl_usd < self.max_unrealized_pnl_seen - self.trailing_distance_usd:
|
||||
logger.info(f"移动止损:当前盈利 {pnl_usd:.2f} 从最高 {self.max_unrealized_pnl_seen:.2f} 回撤 >= {self.trailing_distance_usd} 美元,平仓")
|
||||
self.平仓()
|
||||
self.max_unrealized_pnl_seen = None
|
||||
time.sleep(3)
|
||||
continue
|
||||
# 止盈:盈利达到 take_profit_usd 平仓
|
||||
if pnl_usd >= self.take_profit_usd:
|
||||
logger.info(f"仓位盈利 {pnl_usd:.2f} 美元 >= {self.take_profit_usd} 美元,执行止盈平仓")
|
||||
self.平仓()
|
||||
self.max_unrealized_pnl_seen = None
|
||||
time.sleep(3)
|
||||
continue
|
||||
|
||||
# 4. 方案B:仅在新的 15 分钟 K 线时取一次信号(0=观望, 1=做多, 2=做空)
|
||||
current_15m_id = int(time.time() // 900) * 900 # 15 分钟 bar 起始时间戳
|
||||
signal = None
|
||||
if current_15m_id != self.last_kline_time:
|
||||
self.last_kline_time = current_15m_id
|
||||
logger.info(f"进入新 15m K 线: {current_15m_id}")
|
||||
raw = get_live_signal(period=15)
|
||||
if raw == 1:
|
||||
if self.start == 0:
|
||||
signal = ('long', current_price)
|
||||
elif self.start == -1:
|
||||
signal = ('reverse_long', current_price)
|
||||
elif raw == 2:
|
||||
if self.start == 0:
|
||||
signal = ('short', current_price)
|
||||
elif self.start == 1:
|
||||
signal = ('reverse_short', current_price)
|
||||
|
||||
# 5. 反手过滤:冷却时间 + 最小价差
|
||||
if signal and signal[0].startswith('reverse_'):
|
||||
if not self.can_reverse(current_price, signal[1]):
|
||||
signal = None
|
||||
|
||||
# 5.5 开仓频率过滤:同一根 15m K 线只开一次 + 开仓冷却
|
||||
if signal and signal[0] in ('long', 'short'):
|
||||
if not self.can_open(current_15m_id):
|
||||
signal = None
|
||||
else:
|
||||
self._current_kline_id_for_open = current_15m_id # 供 execute_trade 成功后记录
|
||||
|
||||
# 6. 有信号则执行交易
|
||||
if signal:
|
||||
trade_success = self.execute_trade(signal)
|
||||
if trade_success:
|
||||
logger.success(f"交易执行完成: {signal[0]}, 当前持仓状态: {self.start}")
|
||||
page_start = True
|
||||
else:
|
||||
logger.warning(f"交易执行失败或被阻止: {signal[0]}")
|
||||
|
||||
# 短暂等待后继续循环(同一根K线遇到信号就操作)
|
||||
time.sleep(0.1)
|
||||
|
||||
if page_start:
|
||||
self.page.close()
|
||||
time.sleep(5)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("用户中断,程序退出")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"主循环异常: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
BitmartFuturesTransaction(bit_id="f2320f57e24c45529a009e1541e25961").action()
|
||||
Reference in New Issue
Block a user