This commit is contained in:
ddrwode
2026-02-20 20:57:25 +08:00
parent de18d79d1a
commit 21f2adc4a4
13 changed files with 2448 additions and 0 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 186 KiB

9
requirements.txt Normal file
View File

@@ -0,0 +1,9 @@
pandas>=2.0
numpy>=1.24
ta>=0.11.0
scikit-learn>=1.3
lightgbm>=4.0
xgboost>=2.0
matplotlib>=3.7
peewee>=3.16
loguru>=0.7

1
strategy/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""52指标AI交易策略系统"""

214
strategy/ai_strategy.py Normal file
View File

@@ -0,0 +1,214 @@
"""
方案BAI模型训练 + 信号生成
使用 LightGBM / XGBoostWalk-Forward 滚动训练
"""
import json
import joblib
import pandas as pd
import numpy as np
from pathlib import Path
from loguru import logger
from .config import MODEL_CONFIG as MC, PRIMARY_PERIOD, PROJECT_ROOT
from .feature_engine import prepare_dataset, get_latest_feature_row
from .backtest import BacktestEngine, print_metrics
SCHEME_B_MODEL_DIR = PROJECT_ROOT / 'models'
SCHEME_B_MODEL_FILE = SCHEME_B_MODEL_DIR / 'scheme_b_last_model.joblib'
SCHEME_B_SCALER_FILE = SCHEME_B_MODEL_DIR / 'scheme_b_scaler.joblib'
SCHEME_B_FEATURES_FILE = SCHEME_B_MODEL_DIR / 'scheme_b_features.json'
class AIStrategy:
"""AI模型策略 — LightGBM / XGBoost Walk-Forward"""
def __init__(self, model_type: str = 'lightgbm'):
"""
:param model_type: 'lightgbm''xgboost'
"""
self.model_type = model_type
self.models = [] # 存储每个窗口训练的模型
self.feature_importance = None
def _create_model(self):
"""创建模型实例"""
if self.model_type == 'lightgbm':
import lightgbm as lgb
params = MC['lightgbm'].copy()
return lgb.LGBMClassifier(**params)
elif self.model_type == 'xgboost':
import xgboost as xgb
params = MC['xgboost'].copy()
return xgb.XGBClassifier(**params)
else:
raise ValueError(f"不支持的模型类型: {self.model_type}")
def walk_forward_train(self, X: pd.DataFrame, y: pd.Series,
confidence_threshold: float = 0.45) -> pd.Series:
"""
Walk-Forward 滚动训练与预测
:param confidence_threshold: 概率阈值低于此值的预测设为0观望
:return: 全部测试窗口拼接的预测信号
"""
train_size = MC['walk_forward_train_size']
test_size = MC['walk_forward_test_size']
step = MC['walk_forward_step']
n = len(X)
all_preds = pd.Series(dtype=float)
window_count = 0
logger.info(f"Walk-Forward: 数据量={n}, 训练窗口={train_size}, "
f"测试窗口={test_size}, 步长={step}, 置信阈值={confidence_threshold}")
start = 0
while start + train_size + test_size <= n:
train_end = start + train_size
test_end = min(train_end + test_size, n)
X_train = X.iloc[start:train_end]
y_train = y.iloc[start:train_end]
X_test = X.iloc[train_end:test_end]
y_test = y.iloc[train_end:test_end]
# 训练
model = self._create_model()
model.fit(X_train, y_train)
self.models.append(model)
# 预测概率 + 置信度过滤
proba = model.predict_proba(X_test)
max_proba = proba.max(axis=1)
raw_preds = model.predict(X_test)
# 置信度不够的设为观望
filtered_preds = raw_preds.copy()
filtered_preds[max_proba < confidence_threshold] = 0
preds = pd.Series(filtered_preds, index=X_test.index)
all_preds = pd.concat([all_preds, preds])
# 准确率(用原始预测算)
acc = (raw_preds == y_test).mean()
n_filtered = (max_proba < confidence_threshold).sum()
window_count += 1
logger.info(f" 窗口 {window_count}: 训练[{start}:{train_end}] "
f"测试[{train_end}:{test_end}] 准确率={acc:.2%} "
f"过滤={n_filtered}/{len(X_test)}")
start += step
# 特征重要性(取最后一个模型)
if self.models:
last_model = self.models[-1]
if hasattr(last_model, 'feature_importances_'):
self.feature_importance = pd.Series(
last_model.feature_importances_, index=X.columns
).sort_values(ascending=False)
logger.info(f"Walk-Forward 完成: {window_count} 个窗口, "
f"{len(all_preds)} 条预测")
return all_preds
def get_top_features(self, n: int = 20) -> pd.Series:
"""获取Top N重要特征"""
if self.feature_importance is not None:
return self.feature_importance.head(n)
return pd.Series(dtype=float)
def run(self, period: int = None, start_date: str = None, end_date: str = None) -> dict:
"""
完整运行方案B
若指定了 start_date/end_date会向前加载 warm_up_months 月数据用于训练,使回测区间首月即有预测。
:return: 回测结果
"""
if period is None:
period = PRIMARY_PERIOD
logger.info("=" * 60)
logger.info(f"方案BAI模型策略 ({self.model_type})")
logger.info("=" * 60)
from .data_loader import load_kline
# 1. 准备数据:若指定了回测区间,则向前加载预热数据,使区间内从首月就有预测
load_start, load_end = start_date, end_date
if start_date and end_date:
warm_months = MC.get('warm_up_months', 12)
load_start_ts = pd.Timestamp(start_date) - pd.DateOffset(months=warm_months)
load_start = load_start_ts.strftime('%Y-%m-%d')
logger.info(f"回测区间 [{start_date} ~ {end_date}],向前加载 {warm_months} 月至 {load_start} 用于训练")
X, y, feature_names, scaler = prepare_dataset(period, load_start, load_end)
# 2. Walk-Forward 训练
predictions = self.walk_forward_train(X, y)
# 3. 回测仅用用户指定区间将预测对齐到该区间的每根K线
df = load_kline(period, start_date, end_date)
if df.empty:
logger.warning("回测区间内无K线数据")
return BacktestEngine()._empty_result()
# 对齐信号:回测区间内有的时间戳用预测,缺失的填 0观望
signals = predictions.reindex(df.index, fill_value=0).astype(int)
prices = df['close']
# 4. 回测
engine = BacktestEngine()
result = engine.run(prices, signals)
print_metrics(result['metrics'], f"方案B: {self.model_type} AI策略")
# 5. 保存最后一窗模型、scaler、特征列供实盘 get_live_signal 使用)
if self.models and scaler is not None:
SCHEME_B_MODEL_DIR.mkdir(parents=True, exist_ok=True)
joblib.dump(self.models[-1], SCHEME_B_MODEL_FILE)
joblib.dump(scaler, SCHEME_B_SCALER_FILE)
with open(SCHEME_B_FEATURES_FILE, 'w', encoding='utf-8') as f:
json.dump(feature_names, f, ensure_ascii=False)
logger.info(f"已保存方案B模型: {SCHEME_B_MODEL_FILE}, scaler, {len(feature_names)} 个特征")
# 6. 输出特征重要性
top_feat = self.get_top_features(15)
if not top_feat.empty:
logger.info("\nTop 15 重要特征:")
for i, (feat, imp) in enumerate(top_feat.items()):
logger.info(f" {i+1}. {feat}: {imp:.4f}")
result['feature_importance'] = self.feature_importance
return result
def run_ai_strategy(model_type: str = 'lightgbm', period: int = None,
start_date: str = None, end_date: str = None) -> dict:
"""方案B快捷入口"""
strategy = AIStrategy(model_type=model_type)
return strategy.run(period, start_date, end_date)
def get_live_signal(period: int = None, model_type: str = 'lightgbm',
start_date: str = None, end_date: str = None) -> int:
"""
使用已保存的方案B模型对当前最新K线生成信号供实盘/模拟盘调用)。
需先运行过 run_ai_strategy 或 AIStrategy().run() 以生成 models/scheme_b_*.joblib 与 features.json。
:param period: K线主周期默认 15
:param model_type: 未使用(模型已固定为磁盘上的 scheme_b_last_model.joblib
:param start_date, end_date: 可选,限制 load_kline 范围
:return: 0=观望, 1=做多, 2=做空
"""
if period is None:
period = PRIMARY_PERIOD
if not SCHEME_B_MODEL_FILE.exists() or not SCHEME_B_SCALER_FILE.exists() or not SCHEME_B_FEATURES_FILE.exists():
logger.warning("方案B模型未找到请先运行 AI 策略训练保存模型")
return 0
model = joblib.load(SCHEME_B_MODEL_FILE)
scaler = joblib.load(SCHEME_B_SCALER_FILE)
with open(SCHEME_B_FEATURES_FILE, 'r', encoding='utf-8') as f:
feature_cols = json.load(f)
X_last = get_latest_feature_row(period, feature_cols, start_date, end_date)
if X_last.empty:
return 0
X_scaled = scaler.transform(X_last)
pred = model.predict(X_scaled)
return int(pred[0])

298
strategy/backtest.py Normal file
View File

@@ -0,0 +1,298 @@
"""
回测引擎 — 多空双向、手续费、滑点、绩效统计
每笔固定名义 100U、100 倍杠杆;同一时间仅一个仓位;最大回撤 300U 硬约束;手续费 90% 返佣
"""
import pandas as pd
import numpy as np
from loguru import logger
from .config import TRADE_CONFIG as TC, SIGNAL_MAP
class BacktestEngine:
"""回测引擎:固定每笔 100U 名义、100x单仓位最大回撤 300U 内实付手续费90% 返佣)"""
def __init__(self, commission: float = None, slippage: float = None,
initial_capital: float = None, position_size: float = None,
position_notional_usd: float = None, max_drawdown_limit: float = None,
commission_rebate: float = None):
raw_commission = commission if commission is not None else TC['commission']
rebate = commission_rebate if commission_rebate is not None else TC.get('commission_rebate', 0)
self.commission = raw_commission * (1 - rebate) # 实付手续费90% 返佣)
self.slippage = slippage or TC['slippage']
self.initial_capital = initial_capital or TC['initial_capital']
self.position_size = position_size or TC['position_size']
self.position_notional_usd = position_notional_usd if position_notional_usd is not None else TC.get('position_notional_usd', self.initial_capital * self.position_size)
self.max_drawdown_limit = max_drawdown_limit if max_drawdown_limit is not None else TC.get('max_drawdown_limit', float('inf'))
def run(self, prices: pd.Series, signals: pd.Series) -> dict:
"""
执行回测
:param prices: 收盘价 Series
:param signals: 信号 Series值: 0=观望, 1=做多, 2=做空
:return: 回测结果字典
"""
df = pd.DataFrame({'price': prices, 'signal': signals}).dropna()
if df.empty:
logger.warning("回测数据为空")
return self._empty_result()
n = len(df)
capital = self.initial_capital
position = 0 # 持仓数量(正=多头,负=空头)
entry_price = 0.0
direction = 0 # 当前方向: 0=空仓, 1=多, 2=空
trades = []
equity_curve = np.zeros(n)
peak_equity = self.initial_capital # 用于 300U 最大回撤约束
for i in range(n):
price = df.iloc[i]['price']
signal = int(df.iloc[i]['signal'])
# 计算当前权益
if position > 0:
equity_curve[i] = capital + position * (price - entry_price)
elif position < 0:
equity_curve[i] = capital + position * (price - entry_price)
else:
equity_curve[i] = capital
current_equity = equity_curve[i]
peak_equity = max(peak_equity, current_equity)
# 最大回撤硬约束:从峰值回落超过 300U 则不再开新仓(只允许平仓)
drawdown_usd = peak_equity - current_equity
can_open = drawdown_usd < self.max_drawdown_limit
# 强制止损持仓浮亏超过初始资金的20%时强制平仓
if position != 0:
unrealized = position * (price - entry_price)
if unrealized < -self.initial_capital * 0.20:
capital, trade = self._close_position(
capital, position, entry_price, price, df.index[i])
trades.append(trade)
position = 0
direction = 0
continue
# 资金不足时不开新仓;同一时间仅一个仓位(已有持仓则只能先平再开)
min_capital = self.initial_capital * 0.05
can_trade = capital > min_capital and can_open
# 跳过:信号与当前持仓方向相同
if signal == direction:
continue
# 每笔固定名义 100Uqty = position_notional_usd / price
notional = self.position_notional_usd
qty = notional / price if price > 0 else 0
# 需要换方向或平仓
if signal == 1 and direction != 1:
# 先平仓
if position != 0:
capital, trade = self._close_position(
capital, position, entry_price, price, df.index[i])
trades.append(trade)
position = 0
if not can_trade or qty <= 0:
direction = 0
continue
# 开多:固定 100U 名义,实付手续费
cost = notional * (self.commission + self.slippage)
capital -= cost
position = qty
entry_price = price
direction = 1
elif signal == 2 and direction != 2:
# 先平仓
if position != 0:
capital, trade = self._close_position(
capital, position, entry_price, price, df.index[i])
trades.append(trade)
position = 0
if not can_trade or qty <= 0:
direction = 0
continue
# 开空:固定 100U 名义
cost = notional * (self.commission + self.slippage)
capital -= cost
position = -qty
entry_price = price
direction = 2
elif signal == 0 and position != 0:
# 平仓
capital, trade = self._close_position(
capital, position, entry_price, price, df.index[i])
trades.append(trade)
position = 0
direction = 0
# 最终平仓
if position != 0:
price = df.iloc[-1]['price']
capital, trade = self._close_position(
capital, position, entry_price, price, df.index[-1])
trades.append(trade)
equity = pd.Series(equity_curve, index=df.index)
trades_df = pd.DataFrame(trades) if trades else pd.DataFrame()
metrics = self._calc_metrics(equity, trades_df)
monthly_pnl = self._monthly_pnl(equity)
return {
'equity_curve': equity,
'trades': trades_df,
'metrics': metrics,
'final_capital': capital,
'monthly_pnl': monthly_pnl,
}
def _close_position(self, capital, position, entry_price, exit_price, time):
"""平仓并返回更新后的capital和交易记录"""
if position > 0:
pnl = position * (exit_price - entry_price)
cost = position * exit_price * (self.commission + self.slippage)
trade_type = '平多'
else:
pnl = -position * (entry_price - exit_price)
cost = abs(position) * exit_price * (self.commission + self.slippage)
trade_type = '平空'
capital += pnl - cost
trade = {
'type': trade_type,
'entry': entry_price,
'exit': exit_price,
'pnl': pnl - cost,
'return_pct': (exit_price / entry_price - 1) * (1 if position > 0 else -1),
'time': time,
}
return capital, trade
def _calc_metrics(self, equity: pd.Series, trades: pd.DataFrame) -> dict:
"""计算绩效指标"""
if equity.empty:
return self._empty_metrics()
total_return = (equity.iloc[-1] / self.initial_capital) - 1
n_bars = len(equity)
if n_bars > 1:
# 按日聚合收益率计算夏普
daily_equity = equity.resample('D').last().dropna()
if len(daily_equity) > 1:
daily_returns = daily_equity.pct_change().dropna()
n_days = len(daily_returns)
annualized_return = (1 + total_return) ** (365 / max(n_days, 1)) - 1 if total_return > -1 else -1.0
sharpe = (daily_returns.mean() / daily_returns.std() * np.sqrt(365)
if daily_returns.std() > 0 else 0)
else:
annualized_return = 0
sharpe = 0
else:
annualized_return = 0
sharpe = 0
# 最大回撤(比例与绝对 USDT
cummax = equity.cummax()
drawdown = (equity - cummax) / cummax.replace(0, np.nan)
max_drawdown = drawdown.min() if not drawdown.empty else 0
max_drawdown_usd = (cummax - equity).max() if not equity.empty else 0
# 按月收益(自然月):用于目标「每月盈利 ≥ 1000U」
monthly_pnl_usd = None
months_above_1000 = 0
if not equity.empty and hasattr(equity.index, 'to_period'):
try:
monthly = equity.resample('ME').last().dropna()
if len(monthly) > 0:
monthly_pnl = monthly.diff()
monthly_pnl.iloc[0] = monthly.iloc[0] - self.initial_capital
monthly_pnl_usd = monthly_pnl
months_above_1000 = (monthly_pnl >= 1000).sum()
except Exception:
pass
# 交易统计
n_trades = len(trades)
if n_trades > 0:
wins = trades[trades['pnl'] > 0]
losses = trades[trades['pnl'] <= 0]
win_rate = len(wins) / n_trades
avg_win = wins['pnl'].mean() if len(wins) > 0 else 0
avg_loss = abs(losses['pnl'].mean()) if len(losses) > 0 else 0
profit_factor = (wins['pnl'].sum() / abs(losses['pnl'].sum())
if len(losses) > 0 and losses['pnl'].sum() != 0 else float('inf'))
else:
win_rate = 0
avg_win = 0
avg_loss = 0
profit_factor = 0
out = {
'总收益率': f'{total_return:.2%}',
'年化收益率': f'{annualized_return:.2%}',
'最大回撤': f'{max_drawdown:.2%}',
'最大回撤(U)': f'{max_drawdown_usd:.2f}',
'夏普比率': f'{sharpe:.2f}',
'总交易次数': n_trades,
'胜率': f'{win_rate:.2%}',
'平均盈利': f'{avg_win:.2f}',
'平均亏损': f'{avg_loss:.2f}',
'盈亏比': f'{profit_factor:.2f}',
'最终资金': f'{equity.iloc[-1]:.2f}',
}
out['月盈利≥1000U的月数'] = months_above_1000 if monthly_pnl_usd is not None else 0
if monthly_pnl_usd is not None and len(monthly_pnl_usd) > 0:
out['月均盈利(U)'] = f'{monthly_pnl_usd.mean():.2f}'
else:
out['月均盈利(U)'] = '0.00'
return out
def _monthly_pnl(self, equity: pd.Series):
"""按自然月汇总收益USDT首月为当月权益 - 初始资金"""
if equity.empty or not hasattr(equity.index, 'to_period'):
return None
try:
monthly = equity.resample('ME').last().dropna()
if len(monthly) == 0:
return None
pnl = monthly.diff()
pnl.iloc[0] = monthly.iloc[0] - self.initial_capital
return pnl
except Exception:
return None
def _empty_result(self):
return {
'equity_curve': pd.Series(dtype=float),
'trades': pd.DataFrame(),
'metrics': self._empty_metrics(),
'final_capital': self.initial_capital,
'monthly_pnl': None,
}
def _empty_metrics(self):
return {
'总收益率': '0.00%', '年化收益率': '0.00%', '最大回撤': '0.00%',
'最大回撤(U)': '0.00', '夏普比率': '0.00', '总交易次数': 0, '胜率': '0.00%',
'平均盈利': '0.00', '平均亏损': '0.00', '盈亏比': '0.00',
'最终资金': f'{self.initial_capital:.2f}', '月盈利≥1000U的月数': 0,
'月均盈利(U)': '0.00',
}
def print_metrics(metrics: dict, title: str = "回测结果"):
"""打印绩效指标"""
logger.info(f"\n{'='*50}")
logger.info(f" {title}")
logger.info(f"{'='*50}")
for k, v in metrics.items():
logger.info(f" {k:>12}: {v}")
logger.info(f"{'='*50}")

222
strategy/compare.py Normal file
View File

@@ -0,0 +1,222 @@
"""
方案对比评估 — 方案A(统计筛选) vs 方案B(AI模型) 并排对比 + 报告输出
"""
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg') # 非交互式后端
import matplotlib.pyplot as plt
from matplotlib import font_manager
from pathlib import Path
from loguru import logger
# 设置中文字体
_zh_font = None
for fname in ['PingFang SC', 'Heiti SC', 'STHeiti', 'SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei']:
try:
_zh_font = font_manager.FontProperties(family=fname)
# 验证字体存在
font_manager.findfont(_zh_font, fallback_to_default=False)
plt.rcParams['font.sans-serif'] = [fname, 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
break
except Exception:
_zh_font = None
continue
from .config import PRIMARY_PERIOD, PROJECT_ROOT
from .stat_strategy import StatStrategy
from .ai_strategy import AIStrategy
from .backtest import print_metrics
REPORT_DIR = PROJECT_ROOT / 'reports'
def compare_strategies(period: int = None, start_date: str = None, end_date: str = None,
save_plot: bool = True) -> dict:
"""
运行两种方案并对比
:return: {'stat': result_a, 'lgb': result_b, 'xgb': result_c, 'comparison': DataFrame}
"""
if period is None:
period = PRIMARY_PERIOD
logger.info("=" * 70)
logger.info(" 开始策略对比评估")
logger.info("=" * 70)
results = {}
# 方案A统计筛选
logger.info("\n>>> 运行方案A: 统计筛选策略")
stat = StatStrategy()
results['stat'] = stat.run(period, start_date, end_date)
# 方案B-1LightGBM
logger.info("\n>>> 运行方案B-1: LightGBM AI策略")
lgb_strategy = AIStrategy(model_type='lightgbm')
results['lgb'] = lgb_strategy.run(period, start_date, end_date)
# 方案B-2XGBoost
logger.info("\n>>> 运行方案B-2: XGBoost AI策略")
xgb_strategy = AIStrategy(model_type='xgboost')
results['xgb'] = xgb_strategy.run(period, start_date, end_date)
# 对比表格
comparison = _build_comparison_table(results)
results['comparison'] = comparison
# 打印对比
logger.info("\n" + "=" * 70)
logger.info(" 策略对比总结")
logger.info("=" * 70)
logger.info(f"\n{comparison.to_string()}")
# 每月盈利U
_log_monthly_pnl(results)
# 保存图表
if save_plot:
_save_equity_plot(results)
return results
def _build_comparison_table(results: dict) -> pd.DataFrame:
"""构建对比表格"""
rows = {}
name_map = {
'stat': '方案A: 统计筛选',
'lgb': '方案B-1: LightGBM',
'xgb': '方案B-2: XGBoost',
}
for key, name in name_map.items():
if key in results and 'metrics' in results[key]:
rows[name] = results[key]['metrics']
if not rows:
return pd.DataFrame()
df = pd.DataFrame(rows).T
return df
def _log_monthly_pnl(results: dict):
"""打印各策略每月盈利USDT"""
name_map = {
'stat': '方案A',
'lgb': 'LightGBM',
'xgb': 'XGBoost',
}
cols = []
for key, name in name_map.items():
if key not in results or results[key].get('monthly_pnl') is None:
continue
s = results[key]['monthly_pnl']
if s is None or s.empty:
continue
s = s.astype(float).round(2)
s.name = name
cols.append(s)
if not cols:
return
monthly_df = pd.concat(cols, axis=1)
monthly_df = monthly_df.fillna(0)
logger.info("\n" + "-" * 70)
logger.info(" 每月盈利 (USDT)")
logger.info("-" * 70)
logger.info(f"\n{monthly_df.to_string()}")
logger.info("-" * 70)
def _save_equity_plot(results: dict):
"""保存权益曲线对比图"""
REPORT_DIR.mkdir(parents=True, exist_ok=True)
fig, axes = plt.subplots(2, 1, figsize=(14, 10))
# 上图:权益曲线
ax1 = axes[0]
name_map = {
'stat': '方案A: 统计筛选',
'lgb': '方案B-1: LightGBM',
'xgb': '方案B-2: XGBoost',
}
colors = {'stat': '#2196F3', 'lgb': '#4CAF50', 'xgb': '#FF9800'}
for key, name in name_map.items():
if key in results and 'equity_curve' in results[key]:
eq = results[key]['equity_curve']
if not eq.empty:
ax1.plot(eq.index, eq.values, label=name, color=colors.get(key, 'gray'), linewidth=1)
ax1.set_title('权益曲线对比', fontsize=14)
ax1.set_ylabel('资金 (USDT)')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 下图:回撤曲线
ax2 = axes[1]
for key, name in name_map.items():
if key in results and 'equity_curve' in results[key]:
eq = results[key]['equity_curve']
if not eq.empty:
cummax = eq.cummax()
drawdown = (eq - cummax) / cummax * 100
ax2.fill_between(drawdown.index, drawdown.values, 0,
alpha=0.3, label=name, color=colors.get(key, 'gray'))
ax2.set_title('回撤对比', fontsize=14)
ax2.set_ylabel('回撤 (%)')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plot_path = REPORT_DIR / 'strategy_comparison.png'
plt.savefig(str(plot_path), dpi=150)
plt.close()
logger.info(f"对比图表已保存: {plot_path}")
def run_full_comparison(period: int = None, start_date: str = None, end_date: str = None, save_plot: bool = True):
"""完整对比入口(可直接调用)"""
results = compare_strategies(period, start_date, end_date, save_plot=save_plot)
# 推荐最优方案
best_key = None
best_return = -float('inf')
for key in ['stat', 'lgb', 'xgb']:
if key in results and 'metrics' in results[key]:
ret_str = results[key]['metrics'].get('总收益率', '0%')
ret_val = float(ret_str.strip('%')) / 100
if ret_val > best_return:
best_return = ret_val
best_key = key
name_map = {'stat': '方案A: 统计筛选', 'lgb': '方案B-1: LightGBM', 'xgb': '方案B-2: XGBoost'}
if best_key:
logger.info(f"\n{'='*50}")
logger.info(f" 推荐方案: {name_map.get(best_key, best_key)}")
logger.info(f" 总收益率: {best_return:.2%}")
logger.info(f"{'='*50}")
return results
if __name__ == '__main__':
import argparse
p = argparse.ArgumentParser(description='运行策略对比方案A(统计) vs 方案B(LightGBM/XGBoost)')
p.add_argument('--period', type=int, default=None, help='K线周期分钟')
p.add_argument('--start', type=str, default=None, help='开始日期,如 2024-01-01')
p.add_argument('--end', type=str, default=None, help='结束日期,如 2024-12-31')
p.add_argument('--no-plot', action='store_true', help='不保存权益/回撤图')
args = p.parse_args()
run_full_comparison(
period=args.period,
start_date=args.start,
end_date=args.end,
save_plot=not args.no_plot,
)

151
strategy/config.py Normal file
View File

@@ -0,0 +1,151 @@
"""
全局配置 — 交易参数、指标参数、模型参数
"""
from pathlib import Path
# ============ 路径 ============
PROJECT_ROOT = Path(__file__).parent.parent
DB_PATH = PROJECT_ROOT / 'models' / 'database.db'
# ============ 交易参数 ============
TRADE_CONFIG = {
'symbol': 'ETHUSDT',
'commission': 0.0006, # 名义手续费 0.06%
'commission_rebate': 0.90, # 90% 返佣(次日 8 点结算),实付 = commission * (1 - rebate)
'slippage': 0.0001, # 滑点 0.01%
'initial_capital': 10000, # 本金 USDT
'leverage': 100, # 杠杆倍数
'position_notional_usd': 500, # 每笔名义 500U开 100 倍),目标月均收益约 500U
'max_drawdown_limit': 300, # 最大回撤硬约束:权益从峰值回落超过 300U 则不再开新仓
# 兼容旧逻辑(若用比例算仓则用此项)
'position_size': 0.95,
}
# ============ K线周期 ============
KLINE_PERIODS = {
1: '1m',
3: '3m',
5: '5m',
15: '15m',
30: '30m',
60: '1h',
}
# 主周期(用于生成信号)
PRIMARY_PERIOD = 15 # 15分钟
# 辅助周期(用于多周期融合特征)
AUX_PERIODS = [5, 60]
# ============ 指标参数 ============
INDICATOR_PARAMS = {
# 趋势类
'sma_windows': [5, 10, 20, 50, 200],
'ema_windows': [12, 26],
'macd_fast': 12,
'macd_slow': 26,
'macd_signal': 9,
'adx_window': 14,
'ichimoku_conversion': 9,
'ichimoku_base': 26,
'ichimoku_span_b': 52,
'trix_window': 15,
'aroon_window': 25,
'cci_window': 20,
'dpo_window': 20,
'kst_roc1': 10,
'kst_roc2': 15,
'kst_roc3': 20,
'kst_roc4': 30,
'vortex_window': 14,
# 动量类
'rsi_window': 14,
'stoch_window': 14,
'stoch_smooth': 3,
'williams_window': 14,
'roc_window': 12,
'mfi_window': 14,
'tsi_slow': 25,
'tsi_fast': 13,
'uo_short': 7,
'uo_medium': 14,
'uo_long': 28,
'ao_short': 5,
'ao_long': 34,
'kama_window': 10,
'ppo_slow': 26,
'ppo_fast': 12,
'stoch_rsi_window': 14,
'stoch_rsi_smooth': 3,
# 波动率类
'bb_window': 20,
'bb_std': 2,
'atr_window': 14,
'kc_window': 20,
'dc_window': 20,
# 成交量类部分指标需要volumeK线数据可能无volume则跳过
'obv_enabled': True,
'cmf_window': 20,
'emv_window': 14,
'fi_window': 13,
}
# ============ 特征工程参数 ============
FEATURE_CONFIG = {
'label_forward_periods': 10, # 未来N根K线用于生成标签
'label_threshold': 0.002, # 涨跌阈值0.2%以内算震荡)
'lookback_lags': [1, 3, 5], # 滞后特征的lag值
'normalize': True, # 是否标准化
}
# ============ 模型参数 ============
MODEL_CONFIG = {
'walk_forward_train_size': 20000, # Walk-Forward 训练窗口大小
'walk_forward_test_size': 2000, # Walk-Forward 测试窗口大小
'walk_forward_step': 2000, # 滚动步长
'warm_up_months': 12, # 指定回测区间时向前加载的月数,使区间首月即有预测
'lightgbm': {
'n_estimators': 300,
'max_depth': 4,
'learning_rate': 0.03,
'num_leaves': 15,
'min_child_samples': 50,
'subsample': 0.7,
'colsample_bytree': 0.6,
'reg_alpha': 1.0,
'reg_lambda': 1.0,
'objective': 'multiclass',
'num_class': 3,
'verbose': -1,
},
'xgboost': {
'n_estimators': 300,
'max_depth': 4,
'learning_rate': 0.03,
'subsample': 0.7,
'colsample_bytree': 0.6,
'reg_alpha': 1.0,
'reg_lambda': 1.0,
'objective': 'multi:softprob',
'num_class': 3,
'verbosity': 0,
},
}
# ============ 统计筛选参数 ============
STAT_CONFIG = {
'top_n_features': 15, # 筛选Top N个指标
'correlation_threshold': 0.9, # 去除高相关特征的阈值
'grid_search_cv': 3, # 网格搜索交叉验证折数
}
# ============ 信号标签映射 ============
SIGNAL_MAP = {
0: '观望',
1: '做多',
2: '做空',
}

119
strategy/data_loader.py Normal file
View File

@@ -0,0 +1,119 @@
"""
数据加载器 — 从SQLite加载K线数据为pandas DataFrame
"""
import sqlite3
import pandas as pd
from loguru import logger
from .config import DB_PATH, KLINE_PERIODS
def load_kline(period: int = 15, start_date: str = None, end_date: str = None) -> pd.DataFrame:
"""
加载指定周期的K线数据
:param period: K线周期分钟如 1, 3, 5, 15, 30, 60
:param start_date: 起始日期 'YYYY-MM-DD'(可选)
:param end_date: 结束日期 'YYYY-MM-DD'(可选)
:return: DataFrame列: timestamp, open, high, low, close
"""
suffix = KLINE_PERIODS.get(period)
if suffix is None:
raise ValueError(f"不支持的周期: {period},可选: {list(KLINE_PERIODS.keys())}")
table_name = f'bitmart_eth_{suffix}'
conn = sqlite3.connect(str(DB_PATH))
query = f"SELECT id as timestamp, open, high, low, close FROM {table_name} ORDER BY id"
df = pd.read_sql_query(query, conn)
conn.close()
if df.empty:
logger.warning(f"[{suffix}] 表中无数据")
return df
# id 是毫秒时间戳,转为 datetime 索引
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
df.set_index('datetime', inplace=True)
# 按日期过滤
if start_date:
df = df[df.index >= start_date]
if end_date:
df = df[df.index <= end_date]
logger.info(f"[{suffix}] 加载 {len(df)} 条K线 | {df.index[0]} ~ {df.index[-1]}")
return df
def load_multi_period(periods: list = None, start_date: str = None, end_date: str = None) -> dict:
"""
加载多个周期的K线数据
:param periods: 周期列表,如 [5, 15, 60],默认全部
:param start_date: 起始日期
:param end_date: 结束日期
:return: {period: DataFrame} 字典
"""
if periods is None:
periods = list(KLINE_PERIODS.keys())
result = {}
for p in periods:
try:
df = load_kline(p, start_date, end_date)
if not df.empty:
result[p] = df
except Exception as e:
logger.error(f"加载 {p}分钟 K线失败: {e}")
return result
def load_trades(start_date: str = None, end_date: str = None, limit: int = None) -> pd.DataFrame:
"""
加载原始成交记录
:return: DataFrame列: id, timestamp, price, volume, side
"""
conn = sqlite3.connect(str(DB_PATH))
query = "SELECT id, timestamp, price, volume, side FROM bitmart_eth_trades ORDER BY timestamp"
df = pd.read_sql_query(query, conn)
conn.close()
if df.empty:
logger.warning("成交记录表中无数据")
return df
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
df.set_index('datetime', inplace=True)
if start_date:
df = df[df.index >= start_date]
if end_date:
df = df[df.index <= end_date]
if limit:
df = df.head(limit)
logger.info(f"加载 {len(df)} 条成交记录")
return df
def get_available_tables() -> list:
"""列出数据库中所有可用的表"""
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
tables = [row[0] for row in cursor.fetchall()]
conn.close()
return tables
def get_table_stats() -> dict:
"""获取各表的数据统计"""
conn = sqlite3.connect(str(DB_PATH))
tables = get_available_tables()
stats = {}
for table in tables:
try:
count = pd.read_sql_query(f"SELECT COUNT(*) as cnt FROM {table}", conn).iloc[0]['cnt']
stats[table] = count
except Exception:
stats[table] = 0
conn.close()
return stats

204
strategy/feature_engine.py Normal file
View File

@@ -0,0 +1,204 @@
"""
特征工程 — 标准化、多周期融合、滞后特征、标签生成
"""
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from loguru import logger
from .config import FEATURE_CONFIG as FC, PRIMARY_PERIOD, AUX_PERIODS
from .data_loader import load_kline
from .indicators import compute_all_indicators, get_indicator_names
def build_features(primary_df: pd.DataFrame,
aux_dfs: dict = None,
has_volume: bool = False) -> pd.DataFrame:
"""
构建完整特征矩阵
:param primary_df: 主周期K线 DataFrame
:param aux_dfs: {period: DataFrame} 辅助周期数据(可选)
:param has_volume: 是否有成交量
:return: 特征矩阵 DataFrame
"""
# 1. 主周期指标
logger.info("计算主周期指标...")
df = compute_all_indicators(primary_df, has_volume=has_volume)
# 2. 滞后特征
logger.info("生成滞后特征...")
indicator_cols = get_indicator_names(has_volume)
existing_cols = [col for col in indicator_cols if col in df.columns]
lag_frames = []
for lag in FC['lookback_lags']:
lagged = df[existing_cols].shift(lag).add_suffix(f'_lag{lag}')
lag_frames.append(lagged)
if lag_frames:
df = pd.concat([df] + lag_frames, axis=1)
# 3. 多周期融合
if aux_dfs:
aux_frames = []
for period, aux_df in aux_dfs.items():
logger.info(f"融合 {period}分钟 辅助周期特征...")
aux_with_ind = compute_all_indicators(aux_df, has_volume=has_volume)
key_indicators = ['rsi', 'macd', 'adx', 'bb_pband', 'atr', 'cci']
for ind in key_indicators:
if ind in aux_with_ind.columns:
aligned = aux_with_ind[ind].reindex(df.index, method='ffill')
aux_frames.append(aligned.rename(f'{ind}_{period}m'))
if aux_frames:
df = pd.concat([df] + aux_frames, axis=1)
# 4. 去除全NaN列
before_cols = len(df.columns)
df.dropna(axis=1, how='all', inplace=True)
after_cols = len(df.columns)
if before_cols != after_cols:
logger.info(f"移除 {before_cols - after_cols} 个全NaN列")
logger.info(f"特征矩阵: {df.shape[0]} 行 x {df.shape[1]}")
return df
def generate_labels(df: pd.DataFrame, forward_periods: int = None,
threshold: float = None) -> pd.Series:
"""
生成交易标签
:param df: 包含 close 列的 DataFrame
:param forward_periods: 未来N根K线
:param threshold: 涨跌阈值
:return: Series值为 0=观望, 1=做多, 2=做空
"""
if forward_periods is None:
forward_periods = FC['label_forward_periods']
if threshold is None:
threshold = FC['label_threshold']
future_return = df['close'].shift(-forward_periods) / df['close'] - 1
labels = pd.Series(0, index=df.index, name='label') # 默认观望
labels[future_return > threshold] = 1 # 做多
labels[future_return < -threshold] = 2 # 做空
# 最后 forward_periods 行无法计算标签设为NaN
labels.iloc[-forward_periods:] = np.nan
dist = labels.value_counts().to_dict()
logger.info(f"标签分布: 观望={dist.get(0, 0)}, 做多={dist.get(1, 0)}, 做空={dist.get(2, 0)}")
return labels
def prepare_dataset(period: int = None, start_date: str = None, end_date: str = None,
has_volume: bool = False) -> tuple:
"""
一键准备训练数据集
:return: (X, y, feature_names) — 已去NaN、已标准化
"""
if period is None:
period = PRIMARY_PERIOD
# 加载主周期
primary_df = load_kline(period, start_date, end_date)
if primary_df.empty:
raise ValueError(f"{period}分钟 K线数据为空")
# 加载辅助周期
aux_dfs = {}
for aux_p in AUX_PERIODS:
try:
aux_df = load_kline(aux_p, start_date, end_date)
if not aux_df.empty:
aux_dfs[aux_p] = aux_df
except Exception as e:
logger.warning(f"加载 {aux_p}分钟 辅助数据失败: {e}")
# 构建特征
df = build_features(primary_df, aux_dfs, has_volume=has_volume)
# 生成标签
labels = generate_labels(df)
df = df.copy()
df['label'] = labels
# 去除NaN行
df.dropna(inplace=True)
logger.info(f"去NaN后剩余 {len(df)}")
# 分离 X, y
exclude_cols = ['open', 'high', 'low', 'close', 'timestamp', 'label']
if 'volume' in df.columns:
exclude_cols.append('volume')
# 排除价格级别特征(会泄露绝对价格信息导致过拟合)
price_level_patterns = [
'sma_', 'ema_', 'bb_upper', 'bb_mid', 'bb_lower',
'kc_upper', 'kc_mid', 'kc_lower', 'dc_upper', 'dc_mid', 'dc_lower',
'ichimoku_conv', 'ichimoku_base', 'ichimoku_a', 'ichimoku_b',
'kama', 'vwap', 'obv', 'adi', 'vpt', 'nvi',
'momentum_3', 'momentum_5',
]
feature_cols = []
for c in df.columns:
if c in exclude_cols:
continue
base_name = c.split('_lag')[0]
# 去掉周期后缀
for suffix in ['_5m', '_60m', '_3m', '_15m', '_30m', '_1m']:
if base_name.endswith(suffix):
base_name = base_name[:-len(suffix)]
break
if any(base_name.startswith(p) or base_name == p.rstrip('_') for p in price_level_patterns):
continue
feature_cols.append(c)
logger.info(f"排除价格级别特征后剩余 {len(feature_cols)} 个特征")
X = df[feature_cols].copy()
y = df['label'].astype(int)
# 标准化
scaler = None
if FC['normalize']:
scaler = StandardScaler()
X = pd.DataFrame(scaler.fit_transform(X), columns=feature_cols, index=X.index)
logger.info("特征已标准化")
logger.info(f"最终数据集: X={X.shape}, y={y.shape}, 特征数={len(feature_cols)}")
return X, y, feature_cols, scaler
def get_latest_feature_row(period: int = None, feature_cols: list = None,
start_date: str = None, end_date: str = None) -> pd.DataFrame:
"""
构建特征并返回最后一行的特征矩阵(用于实盘/模拟盘预测)。
需传入已保存的 feature_cols保证与训练时一致。
:return: 1 行 DataFrame列为 feature_cols若缺列则返回空 DataFrame
"""
if period is None:
period = PRIMARY_PERIOD
if not feature_cols:
return pd.DataFrame()
primary_df = load_kline(period, start_date, end_date)
if primary_df.empty:
return pd.DataFrame()
aux_dfs = {}
for aux_p in AUX_PERIODS:
try:
aux_df = load_kline(aux_p, start_date, end_date)
if not aux_df.empty:
aux_dfs[aux_p] = aux_df
except Exception:
pass
df = build_features(primary_df, aux_dfs, has_volume=False)
labels = generate_labels(df)
df = df.copy()
df['label'] = labels
df.dropna(inplace=True)
missing = [c for c in feature_cols if c not in df.columns]
if missing:
logger.warning(f"get_latest_feature_row: 缺少特征列 {missing[:5]}{'...' if len(missing) > 5 else ''}")
return pd.DataFrame()
return df[feature_cols].iloc[-1:].copy()

278
strategy/indicators.py Normal file
View File

@@ -0,0 +1,278 @@
"""
52个技术指标计算引擎 — 基于 ta 库
覆盖趋势、动量、波动率、成交量、自定义衍生特征五大类
"""
import pandas as pd
import numpy as np
import ta
from .config import INDICATOR_PARAMS as P
def compute_all_indicators(df: pd.DataFrame, has_volume: bool = False) -> pd.DataFrame:
"""
计算全部52个技术指标返回拼接后的DataFrame
:param df: 必须包含 open, high, low, close 列;可选 volume 列
:param has_volume: 是否有成交量数据
:return: 原始列 + 52个指标列
"""
out = df.copy()
o, h, l, c = out['open'], out['high'], out['low'], out['close']
v = out['volume'] if has_volume and 'volume' in out.columns else None
# ========== 趋势类 (14) ==========
out = _add_trend(out, o, h, l, c)
# ========== 动量类 (12) ==========
out = _add_momentum(out, h, l, c, v)
# ========== 波动率类 (8) ==========
out = _add_volatility(out, h, l, c)
# ========== 成交量类 (8) ==========
if has_volume and v is not None:
out = _add_volume(out, h, l, c, v)
# ========== 自定义衍生特征 (10) ==========
out = _add_custom(out, o, h, l, c)
return out
def _add_trend(out, o, h, l, c):
"""趋势类指标 (14个特征)"""
# SMA (5个)
for w in P['sma_windows']:
out[f'sma_{w}'] = ta.trend.sma_indicator(c, window=w)
# EMA (2个)
for w in P['ema_windows']:
out[f'ema_{w}'] = ta.trend.ema_indicator(c, window=w)
# MACD (3个)
macd = ta.trend.MACD(c, window_slow=P['macd_slow'], window_fast=P['macd_fast'],
window_sign=P['macd_signal'])
out['macd'] = macd.macd()
out['macd_signal'] = macd.macd_signal()
out['macd_hist'] = macd.macd_diff()
# ADX + DI (3个)
adx = ta.trend.ADXIndicator(h, l, c, window=P['adx_window'])
out['adx'] = adx.adx()
out['di_plus'] = adx.adx_pos()
out['di_minus'] = adx.adx_neg()
# Ichimoku (4个)
ichi = ta.trend.IchimokuIndicator(h, l,
window1=P['ichimoku_conversion'],
window2=P['ichimoku_base'],
window3=P['ichimoku_span_b'])
out['ichimoku_conv'] = ichi.ichimoku_conversion_line()
out['ichimoku_base'] = ichi.ichimoku_base_line()
out['ichimoku_a'] = ichi.ichimoku_a()
out['ichimoku_b'] = ichi.ichimoku_b()
# TRIX
out['trix'] = ta.trend.trix(c, window=P['trix_window'])
# Aroon (2个)
aroon = ta.trend.AroonIndicator(h, l, window=P['aroon_window'])
out['aroon_up'] = aroon.aroon_up()
out['aroon_down'] = aroon.aroon_down()
# CCI
out['cci'] = ta.trend.cci(h, l, c, window=P['cci_window'])
# DPO
out['dpo'] = ta.trend.dpo(c, window=P['dpo_window'])
# KST
kst = ta.trend.KSTIndicator(c, roc1=P['kst_roc1'], roc2=P['kst_roc2'],
roc3=P['kst_roc3'], roc4=P['kst_roc4'])
out['kst'] = kst.kst()
# Vortex (2个)
vortex = ta.trend.VortexIndicator(h, l, c, window=P['vortex_window'])
out['vortex_pos'] = vortex.vortex_indicator_pos()
out['vortex_neg'] = vortex.vortex_indicator_neg()
return out
def _add_momentum(out, h, l, c, v):
"""动量类指标 (12个特征)"""
# RSI
out['rsi'] = ta.momentum.rsi(c, window=P['rsi_window'])
# Stochastic %K / %D
stoch = ta.momentum.StochasticOscillator(h, l, c,
window=P['stoch_window'],
smooth_window=P['stoch_smooth'])
out['stoch_k'] = stoch.stoch()
out['stoch_d'] = stoch.stoch_signal()
# Williams %R
out['williams_r'] = ta.momentum.williams_r(h, l, c, lbp=P['williams_window'])
# ROC
out['roc'] = ta.momentum.roc(c, window=P['roc_window'])
# MFI需要volume
if v is not None:
out['mfi'] = ta.volume.money_flow_index(h, l, c, v, window=P['mfi_window'])
# TSI
out['tsi'] = ta.momentum.tsi(c, window_slow=P['tsi_slow'], window_fast=P['tsi_fast'])
# Ultimate Oscillator
out['uo'] = ta.momentum.ultimate_oscillator(h, l, c,
window1=P['uo_short'],
window2=P['uo_medium'],
window3=P['uo_long'])
# Awesome Oscillator
out['ao'] = ta.momentum.awesome_oscillator(h, l,
window1=P['ao_short'],
window2=P['ao_long'])
# KAMA
out['kama'] = ta.momentum.kama(c, window=P['kama_window'])
# PPO
out['ppo'] = ta.momentum.ppo(c, window_slow=P['ppo_slow'], window_fast=P['ppo_fast'])
# Stochastic RSI %K / %D
stoch_rsi = ta.momentum.StochRSIIndicator(c,
window=P['stoch_rsi_window'],
smooth1=P['stoch_rsi_smooth'],
smooth2=P['stoch_rsi_smooth'])
out['stoch_rsi_k'] = stoch_rsi.stochrsi_k()
out['stoch_rsi_d'] = stoch_rsi.stochrsi_d()
return out
def _add_volatility(out, h, l, c):
"""波动率类指标 (8个特征 — 含子指标共12列)"""
# Bollinger Bands (5个)
bb = ta.volatility.BollingerBands(c, window=P['bb_window'], window_dev=P['bb_std'])
out['bb_upper'] = bb.bollinger_hband()
out['bb_mid'] = bb.bollinger_mavg()
out['bb_lower'] = bb.bollinger_lband()
out['bb_width'] = bb.bollinger_wband()
out['bb_pband'] = bb.bollinger_pband()
# ATR
out['atr'] = ta.volatility.average_true_range(h, l, c, window=P['atr_window'])
# Keltner Channel (3个)
kc = ta.volatility.KeltnerChannel(h, l, c, window=P['kc_window'])
out['kc_upper'] = kc.keltner_channel_hband()
out['kc_mid'] = kc.keltner_channel_mband()
out['kc_lower'] = kc.keltner_channel_lband()
# Donchian Channel (3个)
dc = ta.volatility.DonchianChannel(h, l, c, window=P['dc_window'])
out['dc_upper'] = dc.donchian_channel_hband()
out['dc_mid'] = dc.donchian_channel_mband()
out['dc_lower'] = dc.donchian_channel_lband()
return out
def _add_volume(out, h, l, c, v):
"""成交量类指标 (8个特征)"""
# OBV
out['obv'] = ta.volume.on_balance_volume(c, v)
# VWAP
out['vwap'] = ta.volume.volume_weighted_average_price(h, l, c, v)
# CMF
out['cmf'] = ta.volume.chaikin_money_flow(h, l, c, v, window=P['cmf_window'])
# ADI (Accumulation/Distribution Index)
out['adi'] = ta.volume.acc_dist_index(h, l, c, v)
# EMV (Ease of Movement)
out['emv'] = ta.volume.ease_of_movement(h, l, v, window=P['emv_window'])
# Force Index
out['fi'] = ta.volume.force_index(c, v, window=P['fi_window'])
# VPT (Volume Price Trend)
out['vpt'] = ta.volume.volume_price_trend(c, v)
# NVI (Negative Volume Index)
out['nvi'] = ta.volume.negative_volume_index(c, v)
return out
def _add_custom(out, o, h, l, c):
"""自定义衍生特征 (10个)"""
# 价格变化率
out['price_change_pct'] = c.pct_change()
# 振幅High-Low范围 / Close
out['high_low_range'] = (h - l) / c
# 实体比率(|Close-Open| / (High-Low)
body = (c - o).abs()
hl_range = (h - l).replace(0, np.nan)
out['body_ratio'] = body / hl_range
# 上影线比率
upper_shadow = h - pd.concat([o, c], axis=1).max(axis=1)
out['upper_shadow'] = upper_shadow / hl_range
# 下影线比率
lower_shadow = pd.concat([o, c], axis=1).min(axis=1) - l
out['lower_shadow'] = lower_shadow / hl_range
# 波动率比率ATR / Close 的滚动比值)
atr = ta.volatility.average_true_range(h, l, c, window=14)
out['volatility_ratio'] = atr / c
# Close / SMA20 比率
sma20 = ta.trend.sma_indicator(c, window=20)
out['close_sma20_ratio'] = c / sma20.replace(0, np.nan)
# Close / EMA12 比率
ema12 = ta.trend.ema_indicator(c, window=12)
out['close_ema12_ratio'] = c / ema12.replace(0, np.nan)
# 动量 3周期
out['momentum_3'] = c - c.shift(3)
# 动量 5周期
out['momentum_5'] = c - c.shift(5)
return out
def get_indicator_names(has_volume: bool = False) -> list:
"""返回所有指标列名"""
names = []
# 趋势
for w in P['sma_windows']:
names.append(f'sma_{w}')
for w in P['ema_windows']:
names.append(f'ema_{w}')
names += ['macd', 'macd_signal', 'macd_hist', 'adx', 'di_plus', 'di_minus']
names += ['ichimoku_conv', 'ichimoku_base', 'ichimoku_a', 'ichimoku_b']
names += ['trix', 'aroon_up', 'aroon_down', 'cci', 'dpo', 'kst', 'vortex_pos', 'vortex_neg']
# 动量
names += ['rsi', 'stoch_k', 'stoch_d', 'williams_r', 'roc', 'tsi', 'uo', 'ao', 'kama', 'ppo',
'stoch_rsi_k', 'stoch_rsi_d']
if has_volume:
names.append('mfi')
# 波动率
names += ['bb_upper', 'bb_mid', 'bb_lower', 'bb_width', 'bb_pband', 'atr',
'kc_upper', 'kc_mid', 'kc_lower', 'dc_upper', 'dc_mid', 'dc_lower']
# 成交量
if has_volume:
names += ['obv', 'vwap', 'cmf', 'adi', 'emv', 'fi', 'vpt', 'nvi']
# 自定义
names += ['price_change_pct', 'high_low_range', 'body_ratio', 'upper_shadow', 'lower_shadow',
'volatility_ratio', 'close_sma20_ratio', 'close_ema12_ratio', 'momentum_3', 'momentum_5']
return names

256
strategy/stat_strategy.py Normal file
View File

@@ -0,0 +1,256 @@
"""
方案A统计筛选 + 规则组合策略
1. 从52个指标中用统计方法筛选最有效的指标
2. 用经典规则组合生成交易信号
3. 网格搜索优化参数
"""
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import mutual_info_classif
from loguru import logger
from .config import STAT_CONFIG as SC, PRIMARY_PERIOD, AUX_PERIODS
from .feature_engine import prepare_dataset
from .backtest import BacktestEngine, print_metrics
class StatStrategy:
"""统计筛选策略"""
def __init__(self):
self.top_features = []
self.feature_scores = {}
self.best_params = {}
def select_features(self, X: pd.DataFrame, y: pd.Series) -> list:
"""
用多种统计方法筛选有效指标
:return: Top N 特征名列表
"""
logger.info("=" * 50)
logger.info("开始特征筛选...")
scores = {}
# 1. 皮尔逊相关系数
logger.info("计算皮尔逊相关系数...")
corr_scores = X.corrwith(y).abs().fillna(0)
for col in X.columns:
scores[col] = scores.get(col, 0) + corr_scores.get(col, 0)
# 2. 互信息
logger.info("计算互信息...")
mi = mutual_info_classif(X.fillna(0), y, random_state=42)
mi_series = pd.Series(mi, index=X.columns)
mi_norm = mi_series / mi_series.max() if mi_series.max() > 0 else mi_series
for col in X.columns:
scores[col] = scores.get(col, 0) + mi_norm.get(col, 0)
# 3. 随机森林特征重要性
logger.info("训练随机森林评估特征重要性...")
rf = RandomForestClassifier(n_estimators=200, max_depth=8, random_state=42, n_jobs=-1)
rf.fit(X.fillna(0), y)
rf_imp = pd.Series(rf.feature_importances_, index=X.columns)
rf_norm = rf_imp / rf_imp.max() if rf_imp.max() > 0 else rf_imp
for col in X.columns:
scores[col] = scores.get(col, 0) + rf_norm.get(col, 0)
# 综合排名
score_series = pd.Series(scores).sort_values(ascending=False)
self.feature_scores = score_series.to_dict()
# 去除高相关特征
top_candidates = score_series.head(SC['top_n_features'] * 2).index.tolist()
selected = self._remove_correlated(X[top_candidates], SC['correlation_threshold'])
self.top_features = selected[:SC['top_n_features']]
logger.info(f"筛选出 Top {len(self.top_features)} 特征:")
for i, feat in enumerate(self.top_features):
logger.info(f" {i+1}. {feat} (综合得分: {score_series[feat]:.4f})")
return self.top_features
def _remove_correlated(self, X: pd.DataFrame, threshold: float) -> list:
"""去除高度相关的冗余特征"""
corr_matrix = X.corr().abs()
selected = list(X.columns)
to_remove = set()
for i in range(len(selected)):
if selected[i] in to_remove:
continue
for j in range(i + 1, len(selected)):
if selected[j] in to_remove:
continue
if corr_matrix.loc[selected[i], selected[j]] > threshold:
to_remove.add(selected[j])
result = [c for c in selected if c not in to_remove]
if to_remove:
logger.info(f"移除 {len(to_remove)} 个高相关冗余特征")
return result
def generate_signals(self, df: pd.DataFrame) -> pd.Series:
"""
基于筛选出的指标,用规则组合生成交易信号
:param df: 包含指标列的 DataFrame原始值非标准化
:return: 信号 Series (0=观望, 1=做多, 2=做空)
"""
signals = pd.Series(0, index=df.index)
long_score = pd.Series(0.0, index=df.index)
short_score = pd.Series(0.0, index=df.index)
matched = 0
for feat in self.top_features:
if feat not in df.columns:
continue
col = df[feat]
base = feat.split('_lag')[0] # 去掉 _lagN 后缀
# 去掉辅助周期后缀 _5m / _60m
for suffix in ['_5m', '_60m', '_3m', '_15m', '_30m', '_1m']:
if base.endswith(suffix):
base = base[:-len(suffix)]
break
if 'rsi' in base:
long_score += (col < 35).astype(float)
short_score += (col > 65).astype(float)
matched += 1
elif base == 'macd_hist':
long_score += (col > 0).astype(float)
short_score += (col < 0).astype(float)
matched += 1
elif base == 'macd':
long_score += (col > 0).astype(float)
short_score += (col < 0).astype(float)
matched += 1
elif 'bb_pband' in base:
long_score += (col < 0.2).astype(float)
short_score += (col > 0.8).astype(float)
matched += 1
elif 'adx' in base:
long_score += (col > 25).astype(float)
short_score += (col > 25).astype(float)
matched += 1
elif 'cci' in base:
long_score += (col < -100).astype(float)
short_score += (col > 100).astype(float)
matched += 1
elif 'stoch_k' in base or 'stoch_rsi_k' in base:
long_score += (col < 25).astype(float)
short_score += (col > 75).astype(float)
matched += 1
elif 'williams_r' in base:
long_score += (col < -80).astype(float)
short_score += (col > -20).astype(float)
matched += 1
elif 'ao' in base or 'tsi' in base or 'roc' in base or 'ppo' in base:
long_score += (col > 0).astype(float)
short_score += (col < 0).astype(float)
matched += 1
elif 'atr' in base or 'volatility_ratio' in base:
# 波动率类:高波动时趋势更强,用均值分界
median = col.rolling(200, min_periods=50).median()
long_score += (col > median).astype(float) * 0.5
short_score += (col > median).astype(float) * 0.5
matched += 1
elif 'high_low_range' in base or 'body_ratio' in base:
median = col.rolling(200, min_periods=50).median()
long_score += (col > median).astype(float) * 0.3
short_score += (col > median).astype(float) * 0.3
matched += 1
elif 'bb_width' in base:
# 布林带宽度收窄后扩张 = 突破信号,结合价格方向
median = col.rolling(200, min_periods=50).median()
expanding = col > col.shift(1) # 宽度在扩张
was_narrow = col.shift(1) < median # 之前是收窄的
breakout = expanding & was_narrow
if 'close' in df.columns:
price_up = df['close'] > df['close'].shift(1)
long_score += (breakout & price_up).astype(float)
short_score += (breakout & ~price_up).astype(float)
else:
long_score += breakout.astype(float) * 0.5
short_score += breakout.astype(float) * 0.5
matched += 1
elif 'close_sma20_ratio' in base or 'close_ema12_ratio' in base:
# 价格在均线上方=多头,下方=空头
long_score += (col > 1.0).astype(float)
short_score += (col < 1.0).astype(float)
matched += 1
elif 'ichimoku' in base:
if 'close' in df.columns:
long_score += (df['close'] > col).astype(float)
short_score += (df['close'] < col).astype(float)
matched += 1
elif 'momentum' in base or 'price_change' in base:
long_score += (col > 0).astype(float)
short_score += (col < 0).astype(float)
matched += 1
logger.info(f"规则匹配: {matched}/{len(self.top_features)} 个特征有对应规则")
# 阈值至少50%的匹配特征同时确认(更严格)
threshold = max(3, matched * 0.5)
logger.info(f"信号阈值: {threshold:.1f} (需要至少这么多指标同时确认)")
signals[long_score >= threshold] = 1
signals[short_score >= threshold] = 2
# 多空同时满足时取更强的
both = (long_score >= threshold) & (short_score >= threshold)
signals[both & (long_score > short_score)] = 1
signals[both & (short_score > long_score)] = 2
signals[both & (long_score == short_score)] = 0
dist = signals.value_counts().to_dict()
logger.info(f"规则信号分布: 观望={dist.get(0, 0)}, 做多={dist.get(1, 0)}, 做空={dist.get(2, 0)}")
return signals
def run(self, period: int = None, start_date: str = None, end_date: str = None) -> dict:
"""
完整运行方案A
:return: 回测结果
"""
if period is None:
period = PRIMARY_PERIOD
logger.info("=" * 60)
logger.info("方案A统计筛选 + 规则组合策略")
logger.info("=" * 60)
# 1. 准备数据(标准化版本,用于特征筛选)
X, y, feature_names, _ = prepare_dataset(period, start_date, end_date)
# 2. 筛选特征
self.select_features(X, y)
# 3. 构建完整特征矩阵(原始值,非标准化,用于规则判断)
from .data_loader import load_kline, load_multi_period
from .feature_engine import build_features
primary_df = load_kline(period, start_date, end_date)
aux_dfs = {}
for aux_p in AUX_PERIODS:
try:
aux_df = load_kline(aux_p, start_date, end_date)
if not aux_df.empty:
aux_dfs[aux_p] = aux_df
except Exception:
pass
df = build_features(primary_df, aux_dfs)
df.dropna(inplace=True)
# 4. 生成信号
signals = self.generate_signals(df)
# 5. 回测
engine = BacktestEngine()
result = engine.run(df['close'], signals)
print_metrics(result['metrics'], "方案A: 统计筛选策略")
return result
def run_stat_strategy(period: int = None, start_date: str = None, end_date: str = None) -> dict:
"""方案A快捷入口"""
strategy = StatStrategy()
return strategy.run(period, start_date, end_date)

2
test.py Normal file
View File

@@ -0,0 +1,2 @@
from strategy.compare import run_full_comparison
results = run_full_comparison(period=15)

View File

@@ -0,0 +1,694 @@
import sys
import time
from pathlib import Path
from tqdm import tqdm
from loguru import logger
from bit_tools import openBrowser
from DrissionPage import ChromiumPage
from DrissionPage import ChromiumOptions
from bitmart.api_contract import APIContract
# 方案B从 strategy 模块获取实盘信号(需先运行 AI 策略训练保存模型,并保持 models/database.db 有最新 15m/5m/1h K 线,例如运行 抓取多周期K线.py
sys.path.insert(0, str(Path(__file__).resolve().parent))
from strategy.ai_strategy import get_live_signal
class BitmartFuturesTransaction:
def __init__(self, bit_id):
self.page: ChromiumPage | None = None
self.api_key = "a0fb7b98464fd9bcce67e7c519d58ec10d0c38a8"
self.secret_key = "4eaeba78e77aeaab1c2027f846a276d164f264a44c2c1bb1c5f3be50c8de1ca5"
self.memo = "合约交易"
self.contract_symbol = "ETHUSDT"
self.contractAPI = APIContract(self.api_key, self.secret_key, self.memo, timeout=(5, 15))
self.start = 0 # 持仓状态: -1 空, 0 无, 1 多
self.pbar = tqdm(total=30, desc="等待K线", ncols=80) # 可选:用于长时间等待时展示进度
self.last_kline_time = None # 上一次出信号的 15 分钟 K 线 id方案B 每根 15m 只出一次信号)
# 反手频率控制
self.reverse_cooldown_seconds = 1.5 * 60 # 反手冷却时间(秒)
self.reverse_min_move_pct = 0.05 # 反手最小价差过滤(百分比)
self.last_reverse_time = None # 上次反手时间
# 开仓频率控制
self.open_cooldown_seconds = 60 # 开仓冷却时间(秒),两次开仓至少间隔此时长
self.last_open_time = None # 上次开仓时间
self.last_open_kline_id = None # 上次开仓所在 K 线 id同一根 K 线只允许开仓一次
self.leverage = "100" # 高杠杆(全仓模式下可开更大仓位)
self.open_type = "cross" # 全仓模式
self.risk_percent = 0 # 未使用;若启用则可为每次开仓占可用余额的百分比
self.take_profit_usd = 5 # 仓位盈利达到此金额(美元)时平仓止盈
self.stop_loss_usd = -3 # 固定止损:亏损达到 3 美元平仓
self.trailing_activation_usd = 2 # 盈利达到此金额后启动移动止损
self.trailing_distance_usd = 1.5 # 从最高盈利回撤此金额则平仓
self.max_unrealized_pnl_seen = None # 持仓期间见过的最大盈利(用于移动止损)
self.open_avg_price = None # 开仓价格
self.current_amount = None # 持仓量
self.bit_id = bit_id
self.default_order_size = 25 # 开仓/反手张数,统一在此修改
# 策略相关变量
self.prev_kline = None # 上一根K线
self.current_kline = None # 当前K线
self.prev_entity = None # 上一根K线实体大小
self.current_open = None # 当前K线开盘价
def get_klines(self):
"""获取最近2根K线当前K线和上一根K线"""
try:
end_time = int(time.time())
# 获取足够多的条目确保有最新的K线
response = self.contractAPI.get_kline(
contract_symbol=self.contract_symbol,
step=5, # 5分钟
start_time=end_time - 3600 * 3, # 取最近3小时
end_time=end_time
)[0]["data"]
# 每根: [timestamp, open, high, low, close, volume]
formatted = []
for k in response:
formatted.append({
'id': int(k["timestamp"]),
'open': float(k["open_price"]),
'high': float(k["high_price"]),
'low': float(k["low_price"]),
'close': float(k["close_price"])
})
formatted.sort(key=lambda x: x['id'])
# 返回最近2根K线倒数第二根上一根和最后一根当前
if len(formatted) >= 2:
return formatted[-2], formatted[-1]
return None, None
except Exception as e:
logger.error(f"获取K线异常: {e}")
self.ding(text="获取K线异常", error=True)
return None, None
def get_current_price(self):
"""获取当前最新价格"""
try:
end_time = int(time.time())
response = self.contractAPI.get_kline(
contract_symbol=self.contract_symbol,
step=1, # 1分钟
start_time=end_time - 3600 * 1, # 取最近1小时
end_time=end_time
)[0]
if response['code'] == 1000:
return float(response['data'][-1]["close_price"])
return None
except Exception as e:
logger.error(f"获取价格异常: {e}")
return None
def get_available_balance(self):
"""获取合约账户可用USDT余额"""
try:
response = self.contractAPI.get_assets_detail()[0]
if response['code'] == 1000:
data = response['data']
if isinstance(data, dict):
return float(data.get('available_balance', 0))
elif isinstance(data, list):
for asset in data:
if asset.get('currency') == 'USDT':
return float(asset.get('available_balance', 0))
return None
except Exception as e:
logger.error(f"余额查询异常: {e}")
return None
def get_position_status(self):
"""获取当前持仓方向"""
try:
response = self.contractAPI.get_position(contract_symbol=self.contract_symbol)[0]
if response['code'] == 1000:
positions = response['data']
if not positions:
self.start = 0
self.open_avg_price = None
self.current_amount = None
self.unrealized_pnl = None
return True
pos = positions[0]
self.start = 1 if pos['position_type'] == 1 else -1
self.open_avg_price = float(pos['open_avg_price'])
self.current_amount = float(pos['current_amount'])
self.position_cross = pos["position_cross"]
# 直接从API获取未实现盈亏Bitmart返回的是 unrealized_value 字段)
self.unrealized_pnl = float(pos.get('unrealized_value', 0))
logger.debug(f"持仓详情: 方向={self.start}, 开仓均价={self.open_avg_price}, "
f"持仓量={self.current_amount}, 未实现盈亏={self.unrealized_pnl:.2f}")
return True
else:
return False
except Exception as e:
logger.error(f"持仓查询异常: {e}")
return False
def get_unrealized_pnl_usd(self):
"""
获取当前持仓未实现盈亏美元直接使用API返回值
"""
if self.start == 0 or self.unrealized_pnl is None:
return None
return self.unrealized_pnl
def set_leverage(self):
"""程序启动时设置全仓 + 高杠杆"""
try:
response = self.contractAPI.post_submit_leverage(
contract_symbol=self.contract_symbol,
leverage=self.leverage,
open_type=self.open_type
)[0]
if response['code'] == 1000:
logger.success(f"全仓模式 + {self.leverage}x 杠杆设置成功")
return True
else:
logger.error(f"杠杆设置失败: {response}")
return False
except Exception as e:
logger.error(f"设置杠杆异常: {e}")
return False
def openBrowser(self):
"""打开 TGE 对应浏览器实例"""
try:
bit_port = openBrowser(id=self.bit_id)
co = ChromiumOptions()
co.set_local_port(port=bit_port)
self.page = ChromiumPage(addr_or_opts=co)
return True
except:
return False
def click_safe(self, xpath, sleep=0.5):
"""安全点击"""
try:
ele = self.page.ele(xpath)
if not ele:
return False
# ele.scroll.to_see(center=True)
# time.sleep(sleep)
ele.click(by_js=True)
return True
except:
return False
def 平仓(self):
"""平仓操作"""
self.click_safe('x://span[normalize-space(text()) ="市价"]')
def 开单(self, marketPriceLongOrder=0, limitPriceShortOrder=0, size=None, price=None):
"""
marketPriceLongOrder 市价做多或者做空1是做多-1是做空
limitPriceShortOrder 限价做多或者做空
"""
if marketPriceLongOrder == -1:
# self.click_safe('x://button[normalize-space(text()) ="市价"]')
# self.page.ele('x://*[@id="size_0"]').input(vals=size, clear=True)
self.click_safe('x://span[normalize-space(text()) ="卖出/做空"]')
elif marketPriceLongOrder == 1:
# self.click_safe('x://button[normalize-space(text()) ="市价"]')
# self.page.ele('x://*[@id="size_0"]').input(vals=size, clear=True)
self.click_safe('x://span[normalize-space(text()) ="买入/做多"]')
if limitPriceShortOrder == -1:
self.click_safe('x://button[normalize-space(text()) ="限价"]')
self.page.ele('x://*[@id="price_0"]').input(vals=price, clear=True)
time.sleep(1)
self.page.ele('x://*[@id="size_0"]').input(1)
self.click_safe('x://span[normalize-space(text()) ="卖出/做空"]')
elif limitPriceShortOrder == 1:
self.click_safe('x://button[normalize-space(text()) ="限价"]')
self.page.ele('x://*[@id="price_0"]').input(vals=price, clear=True)
time.sleep(1)
self.page.ele('x://*[@id="size_0"]').input(1)
self.click_safe('x://span[normalize-space(text()) ="买入/做多"]')
def ding(self, text, error=False):
"""日志通知"""
if error:
logger.error(text)
else:
logger.info(text)
def calculate_entity(self, kline):
"""计算K线实体大小绝对值"""
return abs(kline['close'] - kline['open'])
def calculate_upper_shadow(self, kline):
"""计算上阴线(上影线)涨幅百分比"""
# 上阴线 = (最高价 - max(开盘价, 收盘价)) / max(开盘价, 收盘价)
body_top = max(kline['open'], kline['close'])
if body_top == 0:
return 0
return (kline['high'] - body_top) / body_top * 100
def calculate_lower_shadow(self, kline):
"""计算下阴线(下影线)跌幅百分比"""
# 下阴线 = (min(开盘价, 收盘价) - 最低价) / min(开盘价, 收盘价)
body_bottom = min(kline['open'], kline['close'])
if body_bottom == 0:
return 0
return (body_bottom - kline['low']) / body_bottom * 100
def get_entity_edge(self, kline):
"""获取K线实体边收盘价或开盘价取决于是阳线还是阴线"""
# 阳线(收盘>开盘):实体上边=收盘价,实体下边=开盘价
# 阴线(收盘<开盘):实体上边=开盘价,实体下边=收盘价
return {
'upper': max(kline['open'], kline['close']), # 实体上边
'lower': min(kline['open'], kline['close']) # 实体下边
}
def check_signal(self, current_price, prev_kline, current_kline):
"""
检查交易信号
返回: ('long', trigger_price) / ('short', trigger_price) / None
"""
# 计算上一根K线实体
prev_entity = self.calculate_entity(prev_kline)
# 实体过小不交易(实体 < 0.1
if prev_entity < 0.1:
logger.info(f"上一根K线实体过小: {prev_entity:.4f},跳过信号检测")
return None
# 获取上一根K线的实体上下边
prev_entity_edge = self.get_entity_edge(prev_kline)
prev_entity_upper = prev_entity_edge['upper'] # 实体上边
prev_entity_lower = prev_entity_edge['lower'] # 实体下边
# 优化:以下两种情况以当前这根的开盘价作为计算基准
# 1) 上一根阳线 且 当前开盘价 > 上一根收盘价(跳空高开)
# 2) 上一根阴线 且 当前开盘价 < 上一根收盘价(跳空低开)
prev_is_bullish_for_calc = prev_kline['close'] > prev_kline['open']
prev_is_bearish_for_calc = prev_kline['close'] < prev_kline['open']
current_open_above_prev_close = current_kline['open'] > prev_kline['close']
current_open_below_prev_close = current_kline['open'] < prev_kline['close']
use_current_open_as_base = (prev_is_bullish_for_calc and current_open_above_prev_close) or (prev_is_bearish_for_calc and current_open_below_prev_close)
if use_current_open_as_base:
# 以当前K线开盘价为基准计算跳空时用当前开盘价参与计算
calc_lower = current_kline['open']
calc_upper = current_kline['open'] # 同一基准,上下四分之一对称
long_trigger = calc_lower + prev_entity / 4
short_trigger = calc_upper - prev_entity / 4
long_breakout = calc_upper + prev_entity / 4
short_breakout = calc_lower - prev_entity / 4
else:
# 原有计算方式
long_trigger = prev_entity_lower + prev_entity / 4 # 做多触发价 = 实体下边 + 实体/4下四分之一处
short_trigger = prev_entity_upper - prev_entity / 4 # 做空触发价 = 实体上边 - 实体/4上四分之一处
long_breakout = prev_entity_upper + prev_entity / 4 # 做多突破价 = 实体上边 + 实体/4
short_breakout = prev_entity_lower - prev_entity / 4 # 做空突破价 = 实体下边 - 实体/4
# 上一根阴线 + 当前阳线做多形态不按上一根K线上三分之一做空
prev_is_bearish = prev_kline['close'] < prev_kline['open']
current_is_bullish = current_kline['close'] > current_kline['open']
skip_short_by_upper_third = prev_is_bearish and current_is_bullish
# 上一根阳线 + 当前阴线做空形态不按上一根K线下三分之一做多
prev_is_bullish = prev_kline['close'] > prev_kline['open']
current_is_bearish = current_kline['close'] < current_kline['open']
skip_long_by_lower_third = prev_is_bullish and current_is_bearish
if use_current_open_as_base:
if prev_is_bullish_for_calc and current_open_above_prev_close:
logger.info(f"上一根阳线且当前开盘价({current_kline['open']:.2f})>上一根收盘价({prev_kline['close']:.2f}),以当前开盘价为基准计算")
else:
logger.info(f"上一根阴线且当前开盘价({current_kline['open']:.2f})<上一根收盘价({prev_kline['close']:.2f}),以当前开盘价为基准计算")
logger.info(f"当前价格: {current_price:.2f}, 上一根实体: {prev_entity:.4f}")
logger.info(f"上一根实体上边: {prev_entity_upper:.2f}, 下边: {prev_entity_lower:.2f}")
logger.info(f"做多触发价(下1/4): {long_trigger:.2f}, 做空触发价(上1/4): {short_trigger:.2f}")
logger.info(f"突破做多价(上1/4外): {long_breakout:.2f}, 突破做空价(下1/4外): {short_breakout:.2f}")
if skip_short_by_upper_third:
logger.info("上一根阴线+当前阳线(做多形态),不按上四分之一做空")
if skip_long_by_lower_third:
logger.info("上一根阳线+当前阴线(做空形态),不按下四分之一做多")
# 无持仓时检查开仓信号
if self.start == 0:
if current_price >= long_breakout and not skip_long_by_lower_third:
logger.info(f"触发做多信号!价格 {current_price:.2f} >= 突破价(上1/4外) {long_breakout:.2f}")
return ('long', long_breakout)
elif current_price <= short_breakout and not skip_short_by_upper_third:
logger.info(f"触发做空信号!价格 {current_price:.2f} <= 突破价(下1/4外) {short_breakout:.2f}")
return ('short', short_breakout)
# 持仓时检查反手信号
elif self.start == 1: # 持多仓
# 反手条件1: 价格跌到上一根K线的上三分之一处做空触发价上一根阴线+当前阳线做多时跳过
if current_price <= short_trigger and not skip_short_by_upper_third:
logger.info(f"持多反手做空!价格 {current_price:.2f} <= 触发价(上1/4) {short_trigger:.2f}")
return ('reverse_short', short_trigger)
# 反手条件2: 上一根K线上阴线涨幅>0.01%,当前跌到上一根实体下边
upper_shadow_pct = self.calculate_upper_shadow(prev_kline)
if upper_shadow_pct > 0.01 and current_price <= prev_entity_lower:
logger.info(f"持多反手做空!上阴线涨幅 {upper_shadow_pct:.4f}% > 0.01%"
f"价格 {current_price:.2f} <= 实体下边 {prev_entity_lower:.2f}")
return ('reverse_short', prev_entity_lower)
elif self.start == -1: # 持空仓
# 反手条件1: 价格涨到上一根K线的下三分之一处做多触发价上一根阳线+当前阴线做空时跳过
if current_price >= long_trigger and not skip_long_by_lower_third:
logger.info(f"持空反手做多!价格 {current_price:.2f} >= 触发价(下1/4) {long_trigger:.2f}")
return ('reverse_long', long_trigger)
# 反手条件2: 上一根K线下阴线跌幅>0.01%,当前涨到上一根实体上边
lower_shadow_pct = self.calculate_lower_shadow(prev_kline)
if lower_shadow_pct > 0.01 and current_price >= prev_entity_upper:
logger.info(f"持空反手做多!下阴线跌幅 {lower_shadow_pct:.4f}% > 0.01%"
f"价格 {current_price:.2f} >= 实体上边 {prev_entity_upper:.2f}")
return ('reverse_long', prev_entity_upper)
return None
def can_open(self, current_kline_id):
"""开仓前过滤:同一根 K 线只开一次 + 开仓冷却时间。仅用于 long/short 新开仓。"""
now = time.time()
if self.last_open_kline_id is not None and self.last_open_kline_id == current_kline_id:
logger.info(f"开仓频率控制:本 K 线({current_kline_id})已开过仓,跳过")
return False
if self.last_open_time is not None and now - self.last_open_time < self.open_cooldown_seconds:
remain = self.open_cooldown_seconds - (now - self.last_open_time)
logger.info(f"开仓冷却中,剩余 {remain:.0f}")
return False
return True
def can_reverse(self, current_price, trigger_price):
"""反手前过滤:冷却时间 + 最小价差"""
now = time.time()
if self.last_reverse_time and now - self.last_reverse_time < self.reverse_cooldown_seconds:
remain = self.reverse_cooldown_seconds - (now - self.last_reverse_time)
logger.info(f"反手冷却中,剩余 {remain:.0f}")
return False
if trigger_price and trigger_price > 0:
move_pct = abs(current_price - trigger_price) / trigger_price * 100
if move_pct < self.reverse_min_move_pct:
logger.info(f"反手价差不足: {move_pct:.4f}% < {self.reverse_min_move_pct}%")
return False
return True
def verify_no_position(self, max_retries=5, retry_interval=3):
"""
验证当前无持仓
返回: True 表示无持仓可以开仓False 表示有持仓不能开仓
"""
for i in range(max_retries):
if self.get_position_status():
if self.start == 0:
logger.info(f"确认无持仓,可以开仓")
return True
else:
logger.warning(
f"仍有持仓 (方向: {self.start}),等待 {retry_interval} 秒后重试 ({i + 1}/{max_retries})")
time.sleep(retry_interval)
else:
logger.warning(f"查询持仓状态失败,等待 {retry_interval} 秒后重试 ({i + 1}/{max_retries})")
time.sleep(retry_interval)
logger.error(f"经过 {max_retries} 次重试仍有持仓或查询失败,放弃开仓")
return False
def verify_position_direction(self, expected_direction):
"""
验证当前持仓方向是否与预期一致
expected_direction: 1 多仓, -1 空仓
返回: True 表示持仓方向正确False 表示不正确
"""
if self.get_position_status():
if self.start == expected_direction:
logger.info(f"持仓方向验证成功: {self.start}")
return True
else:
logger.warning(f"持仓方向不符: 期望 {expected_direction}, 实际 {self.start}")
return False
else:
logger.error("查询持仓状态失败")
return False
def execute_trade(self, signal, size=None):
"""执行交易。size 不传或为 None 时使用 default_order_size。"""
signal_type, trigger_price = signal
size = self.default_order_size if size is None else size
if signal_type == 'long':
# 开多前先确认无持仓
logger.info(f"准备开多,触发价: {trigger_price:.2f}")
if not self.get_position_status():
logger.error("开仓前查询持仓状态失败,放弃开仓")
return False
if self.start != 0:
logger.warning(f"开多前发现已有持仓 (方向: {self.start}),放弃开仓避免双向持仓")
return False
logger.info(f"确认无持仓,执行开多")
self.开单(marketPriceLongOrder=1, size=size)
time.sleep(3) # 等待订单执行
# 验证开仓是否成功
if self.verify_position_direction(1):
self.max_unrealized_pnl_seen = None # 新仓位重置移动止损记录
self.last_open_time = time.time()
self.last_open_kline_id = getattr(self, "_current_kline_id_for_open", None)
logger.success("开多成功")
return True
else:
logger.error("开多后持仓验证失败")
return False
elif signal_type == 'short':
# 开空前先确认无持仓
logger.info(f"准备开空,触发价: {trigger_price:.2f}")
if not self.get_position_status():
logger.error("开仓前查询持仓状态失败,放弃开仓")
return False
if self.start != 0:
logger.warning(f"开空前发现已有持仓 (方向: {self.start}),放弃开仓避免双向持仓")
return False
logger.info(f"确认无持仓,执行开空")
self.开单(marketPriceLongOrder=-1, size=size)
time.sleep(3) # 等待订单执行
# 验证开仓是否成功
if self.verify_position_direction(-1):
self.max_unrealized_pnl_seen = None # 新仓位重置移动止损记录
self.last_open_time = time.time()
self.last_open_kline_id = getattr(self, "_current_kline_id_for_open", None)
logger.success("开空成功")
return True
else:
logger.error("开空后持仓验证失败")
return False
elif signal_type == 'reverse_long':
# 平空 + 开多(反手做多):先平仓,确认无仓后再开多,避免双向持仓
logger.info(f"执行反手做多,触发价: {trigger_price:.2f}")
self.平仓()
time.sleep(1) # 给交易所处理平仓的时间
# 轮询确认已无持仓再开多(最多等约 10 秒)
for _ in range(10):
if self.get_position_status() and self.start == 0:
break
time.sleep(1)
if self.start != 0:
logger.warning("反手做多:平仓后仍有持仓,放弃本次开多")
return False
logger.info("已确认无持仓,执行开多")
self.开单(marketPriceLongOrder=1, size=size)
time.sleep(3)
if self.verify_position_direction(1):
self.max_unrealized_pnl_seen = None
logger.success("反手做多成功")
self.last_reverse_time = time.time()
time.sleep(20)
return True
else:
logger.error("反手做多后持仓验证失败")
return False
elif signal_type == 'reverse_short':
# 平多 + 开空(反手做空):先平仓,确认无仓后再开空
logger.info(f"执行反手做空,触发价: {trigger_price:.2f}")
self.平仓()
time.sleep(1)
for _ in range(10):
if self.get_position_status() and self.start == 0:
break
time.sleep(1)
if self.start != 0:
logger.warning("反手做空:平仓后仍有持仓,放弃本次开空")
return False
logger.info("已确认无持仓,执行开空")
self.开单(marketPriceLongOrder=-1, size=size)
time.sleep(3)
if self.verify_position_direction(-1):
self.max_unrealized_pnl_seen = None
logger.success("反手做空成功")
self.last_reverse_time = time.time()
time.sleep(20)
return True
else:
logger.error("反手做空后持仓验证失败")
return False
return False
def action(self):
"""主循环"""
logger.info("开始运行方案BAI 策略)交易...")
# 启动时设置全仓高杠杆
if not self.set_leverage():
logger.error("杠杆设置失败,程序继续运行但可能下单失败")
return
page_start = True
while True:
if page_start:
# 打开浏览器
for i in range(5):
if self.openBrowser():
logger.info("浏览器打开成功")
break
else:
self.ding("打开浏览器失败!", error=True)
return
# 进入交易页面
self.page.get("https://derivatives.bitmart.com/zh-CN/futures/ETHUSDT")
self.click_safe('x://button[normalize-space(text()) ="市价"]')
self.page.ele('x://*[@id="size_0"]').input(vals=25, clear=True)
page_start = False
try:
# 1. 获取当前价格
current_price = self.get_current_price()
if not current_price:
logger.warning("获取价格失败,等待重试...")
time.sleep(2)
continue
# 2. 每次循环都通过SDK获取真实持仓状态避免状态不同步导致双向持仓
if not self.get_position_status():
logger.warning("获取持仓状态失败,等待重试...")
time.sleep(2)
continue
logger.debug(f"当前持仓状态: {self.start} (0=无, 1=多, -1=空)")
# 3. 止损/止盈/移动止损
if self.start != 0:
pnl_usd = self.get_unrealized_pnl_usd()
if pnl_usd is not None:
# 固定止损:亏损达到 3 美元平仓
if pnl_usd <= self.stop_loss_usd:
logger.info(f"仓位亏损 {pnl_usd:.2f} 美元 <= 止损 {self.stop_loss_usd} 美元,执行止损平仓")
self.平仓()
self.max_unrealized_pnl_seen = None
time.sleep(3)
continue
# 更新持仓期间最大盈利(用于移动止损)
if self.max_unrealized_pnl_seen is None:
self.max_unrealized_pnl_seen = pnl_usd
else:
self.max_unrealized_pnl_seen = max(self.max_unrealized_pnl_seen, pnl_usd)
# 移动止损:盈利曾达到 activation 后,从最高盈利回撤 trailing_distance 则平仓
if self.max_unrealized_pnl_seen >= self.trailing_activation_usd:
if pnl_usd < self.max_unrealized_pnl_seen - self.trailing_distance_usd:
logger.info(f"移动止损:当前盈利 {pnl_usd:.2f} 从最高 {self.max_unrealized_pnl_seen:.2f} 回撤 >= {self.trailing_distance_usd} 美元,平仓")
self.平仓()
self.max_unrealized_pnl_seen = None
time.sleep(3)
continue
# 止盈:盈利达到 take_profit_usd 平仓
if pnl_usd >= self.take_profit_usd:
logger.info(f"仓位盈利 {pnl_usd:.2f} 美元 >= {self.take_profit_usd} 美元,执行止盈平仓")
self.平仓()
self.max_unrealized_pnl_seen = None
time.sleep(3)
continue
# 4. 方案B仅在新的 15 分钟 K 线时取一次信号0=观望, 1=做多, 2=做空)
current_15m_id = int(time.time() // 900) * 900 # 15 分钟 bar 起始时间戳
signal = None
if current_15m_id != self.last_kline_time:
self.last_kline_time = current_15m_id
logger.info(f"进入新 15m K 线: {current_15m_id}")
raw = get_live_signal(period=15)
if raw == 1:
if self.start == 0:
signal = ('long', current_price)
elif self.start == -1:
signal = ('reverse_long', current_price)
elif raw == 2:
if self.start == 0:
signal = ('short', current_price)
elif self.start == 1:
signal = ('reverse_short', current_price)
# 5. 反手过滤:冷却时间 + 最小价差
if signal and signal[0].startswith('reverse_'):
if not self.can_reverse(current_price, signal[1]):
signal = None
# 5.5 开仓频率过滤:同一根 15m K 线只开一次 + 开仓冷却
if signal and signal[0] in ('long', 'short'):
if not self.can_open(current_15m_id):
signal = None
else:
self._current_kline_id_for_open = current_15m_id # 供 execute_trade 成功后记录
# 6. 有信号则执行交易
if signal:
trade_success = self.execute_trade(signal)
if trade_success:
logger.success(f"交易执行完成: {signal[0]}, 当前持仓状态: {self.start}")
page_start = True
else:
logger.warning(f"交易执行失败或被阻止: {signal[0]}")
# 短暂等待后继续循环同一根K线遇到信号就操作
time.sleep(0.1)
if page_start:
self.page.close()
time.sleep(5)
except KeyboardInterrupt:
logger.info("用户中断,程序退出")
break
except Exception as e:
logger.error(f"主循环异常: {e}")
time.sleep(5)
if __name__ == '__main__':
BitmartFuturesTransaction(bit_id="f2320f57e24c45529a009e1541e25961").action()