diff --git a/reports/strategy_comparison.png b/reports/strategy_comparison.png new file mode 100644 index 0000000..27a0377 Binary files /dev/null and b/reports/strategy_comparison.png differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..6d71bb4 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +pandas>=2.0 +numpy>=1.24 +ta>=0.11.0 +scikit-learn>=1.3 +lightgbm>=4.0 +xgboost>=2.0 +matplotlib>=3.7 +peewee>=3.16 +loguru>=0.7 diff --git a/strategy/__init__.py b/strategy/__init__.py new file mode 100644 index 0000000..6771cf2 --- /dev/null +++ b/strategy/__init__.py @@ -0,0 +1 @@ +"""52指标AI交易策略系统""" diff --git a/strategy/ai_strategy.py b/strategy/ai_strategy.py new file mode 100644 index 0000000..20f7a0d --- /dev/null +++ b/strategy/ai_strategy.py @@ -0,0 +1,214 @@ +""" +方案B:AI模型训练 + 信号生成 +使用 LightGBM / XGBoost,Walk-Forward 滚动训练 +""" +import json +import joblib +import pandas as pd +import numpy as np +from pathlib import Path +from loguru import logger + +from .config import MODEL_CONFIG as MC, PRIMARY_PERIOD, PROJECT_ROOT +from .feature_engine import prepare_dataset, get_latest_feature_row +from .backtest import BacktestEngine, print_metrics + +SCHEME_B_MODEL_DIR = PROJECT_ROOT / 'models' +SCHEME_B_MODEL_FILE = SCHEME_B_MODEL_DIR / 'scheme_b_last_model.joblib' +SCHEME_B_SCALER_FILE = SCHEME_B_MODEL_DIR / 'scheme_b_scaler.joblib' +SCHEME_B_FEATURES_FILE = SCHEME_B_MODEL_DIR / 'scheme_b_features.json' + + +class AIStrategy: + """AI模型策略 — LightGBM / XGBoost Walk-Forward""" + + def __init__(self, model_type: str = 'lightgbm'): + """ + :param model_type: 'lightgbm' 或 'xgboost' + """ + self.model_type = model_type + self.models = [] # 存储每个窗口训练的模型 + self.feature_importance = None + + def _create_model(self): + """创建模型实例""" + if self.model_type == 'lightgbm': + import lightgbm as lgb + params = MC['lightgbm'].copy() + return lgb.LGBMClassifier(**params) + elif self.model_type == 'xgboost': + import xgboost as xgb + params = MC['xgboost'].copy() + return xgb.XGBClassifier(**params) + else: + raise ValueError(f"不支持的模型类型: {self.model_type}") + + def walk_forward_train(self, X: pd.DataFrame, y: pd.Series, + confidence_threshold: float = 0.45) -> pd.Series: + """ + Walk-Forward 滚动训练与预测 + :param confidence_threshold: 概率阈值,低于此值的预测设为0(观望) + :return: 全部测试窗口拼接的预测信号 + """ + train_size = MC['walk_forward_train_size'] + test_size = MC['walk_forward_test_size'] + step = MC['walk_forward_step'] + + n = len(X) + all_preds = pd.Series(dtype=float) + window_count = 0 + + logger.info(f"Walk-Forward: 数据量={n}, 训练窗口={train_size}, " + f"测试窗口={test_size}, 步长={step}, 置信阈值={confidence_threshold}") + + start = 0 + while start + train_size + test_size <= n: + train_end = start + train_size + test_end = min(train_end + test_size, n) + + X_train = X.iloc[start:train_end] + y_train = y.iloc[start:train_end] + X_test = X.iloc[train_end:test_end] + y_test = y.iloc[train_end:test_end] + + # 训练 + model = self._create_model() + model.fit(X_train, y_train) + self.models.append(model) + + # 预测概率 + 置信度过滤 + proba = model.predict_proba(X_test) + max_proba = proba.max(axis=1) + raw_preds = model.predict(X_test) + + # 置信度不够的设为观望 + filtered_preds = raw_preds.copy() + filtered_preds[max_proba < confidence_threshold] = 0 + + preds = pd.Series(filtered_preds, index=X_test.index) + all_preds = pd.concat([all_preds, preds]) + + # 准确率(用原始预测算) + acc = (raw_preds == y_test).mean() + n_filtered = (max_proba < confidence_threshold).sum() + window_count += 1 + logger.info(f" 窗口 {window_count}: 训练[{start}:{train_end}] " + f"测试[{train_end}:{test_end}] 准确率={acc:.2%} " + f"过滤={n_filtered}/{len(X_test)}") + + start += step + + # 特征重要性(取最后一个模型) + if self.models: + last_model = self.models[-1] + if hasattr(last_model, 'feature_importances_'): + self.feature_importance = pd.Series( + last_model.feature_importances_, index=X.columns + ).sort_values(ascending=False) + + logger.info(f"Walk-Forward 完成: {window_count} 个窗口, " + f"共 {len(all_preds)} 条预测") + return all_preds + + def get_top_features(self, n: int = 20) -> pd.Series: + """获取Top N重要特征""" + if self.feature_importance is not None: + return self.feature_importance.head(n) + return pd.Series(dtype=float) + + def run(self, period: int = None, start_date: str = None, end_date: str = None) -> dict: + """ + 完整运行方案B + 若指定了 start_date/end_date,会向前加载 warm_up_months 月数据用于训练,使回测区间首月即有预测。 + :return: 回测结果 + """ + if period is None: + period = PRIMARY_PERIOD + + logger.info("=" * 60) + logger.info(f"方案B:AI模型策略 ({self.model_type})") + logger.info("=" * 60) + + from .data_loader import load_kline + + # 1. 准备数据:若指定了回测区间,则向前加载预热数据,使区间内从首月就有预测 + load_start, load_end = start_date, end_date + if start_date and end_date: + warm_months = MC.get('warm_up_months', 12) + load_start_ts = pd.Timestamp(start_date) - pd.DateOffset(months=warm_months) + load_start = load_start_ts.strftime('%Y-%m-%d') + logger.info(f"回测区间 [{start_date} ~ {end_date}],向前加载 {warm_months} 月至 {load_start} 用于训练") + + X, y, feature_names, scaler = prepare_dataset(period, load_start, load_end) + + # 2. Walk-Forward 训练 + predictions = self.walk_forward_train(X, y) + + # 3. 回测仅用用户指定区间;将预测对齐到该区间的每根K线 + df = load_kline(period, start_date, end_date) + if df.empty: + logger.warning("回测区间内无K线数据") + return BacktestEngine()._empty_result() + + # 对齐信号:回测区间内有的时间戳用预测,缺失的填 0(观望) + signals = predictions.reindex(df.index, fill_value=0).astype(int) + prices = df['close'] + + # 4. 回测 + engine = BacktestEngine() + result = engine.run(prices, signals) + + print_metrics(result['metrics'], f"方案B: {self.model_type} AI策略") + + # 5. 保存最后一窗模型、scaler、特征列(供实盘 get_live_signal 使用) + if self.models and scaler is not None: + SCHEME_B_MODEL_DIR.mkdir(parents=True, exist_ok=True) + joblib.dump(self.models[-1], SCHEME_B_MODEL_FILE) + joblib.dump(scaler, SCHEME_B_SCALER_FILE) + with open(SCHEME_B_FEATURES_FILE, 'w', encoding='utf-8') as f: + json.dump(feature_names, f, ensure_ascii=False) + logger.info(f"已保存方案B模型: {SCHEME_B_MODEL_FILE}, scaler, {len(feature_names)} 个特征") + + # 6. 输出特征重要性 + top_feat = self.get_top_features(15) + if not top_feat.empty: + logger.info("\nTop 15 重要特征:") + for i, (feat, imp) in enumerate(top_feat.items()): + logger.info(f" {i+1}. {feat}: {imp:.4f}") + + result['feature_importance'] = self.feature_importance + return result + + +def run_ai_strategy(model_type: str = 'lightgbm', period: int = None, + start_date: str = None, end_date: str = None) -> dict: + """方案B快捷入口""" + strategy = AIStrategy(model_type=model_type) + return strategy.run(period, start_date, end_date) + + +def get_live_signal(period: int = None, model_type: str = 'lightgbm', + start_date: str = None, end_date: str = None) -> int: + """ + 使用已保存的方案B模型对当前最新K线生成信号(供实盘/模拟盘调用)。 + 需先运行过 run_ai_strategy 或 AIStrategy().run() 以生成 models/scheme_b_*.joblib 与 features.json。 + :param period: K线主周期,默认 15 + :param model_type: 未使用(模型已固定为磁盘上的 scheme_b_last_model.joblib) + :param start_date, end_date: 可选,限制 load_kline 范围 + :return: 0=观望, 1=做多, 2=做空 + """ + if period is None: + period = PRIMARY_PERIOD + if not SCHEME_B_MODEL_FILE.exists() or not SCHEME_B_SCALER_FILE.exists() or not SCHEME_B_FEATURES_FILE.exists(): + logger.warning("方案B模型未找到,请先运行 AI 策略训练保存模型") + return 0 + model = joblib.load(SCHEME_B_MODEL_FILE) + scaler = joblib.load(SCHEME_B_SCALER_FILE) + with open(SCHEME_B_FEATURES_FILE, 'r', encoding='utf-8') as f: + feature_cols = json.load(f) + X_last = get_latest_feature_row(period, feature_cols, start_date, end_date) + if X_last.empty: + return 0 + X_scaled = scaler.transform(X_last) + pred = model.predict(X_scaled) + return int(pred[0]) diff --git a/strategy/backtest.py b/strategy/backtest.py new file mode 100644 index 0000000..64f82c7 --- /dev/null +++ b/strategy/backtest.py @@ -0,0 +1,298 @@ +""" +回测引擎 — 多空双向、手续费、滑点、绩效统计 +每笔固定名义 100U、100 倍杠杆;同一时间仅一个仓位;最大回撤 300U 硬约束;手续费 90% 返佣 +""" +import pandas as pd +import numpy as np +from loguru import logger +from .config import TRADE_CONFIG as TC, SIGNAL_MAP + + +class BacktestEngine: + """回测引擎:固定每笔 100U 名义、100x,单仓位,最大回撤 300U 内,实付手续费(90% 返佣)""" + + def __init__(self, commission: float = None, slippage: float = None, + initial_capital: float = None, position_size: float = None, + position_notional_usd: float = None, max_drawdown_limit: float = None, + commission_rebate: float = None): + raw_commission = commission if commission is not None else TC['commission'] + rebate = commission_rebate if commission_rebate is not None else TC.get('commission_rebate', 0) + self.commission = raw_commission * (1 - rebate) # 实付手续费(90% 返佣) + self.slippage = slippage or TC['slippage'] + self.initial_capital = initial_capital or TC['initial_capital'] + self.position_size = position_size or TC['position_size'] + self.position_notional_usd = position_notional_usd if position_notional_usd is not None else TC.get('position_notional_usd', self.initial_capital * self.position_size) + self.max_drawdown_limit = max_drawdown_limit if max_drawdown_limit is not None else TC.get('max_drawdown_limit', float('inf')) + + def run(self, prices: pd.Series, signals: pd.Series) -> dict: + """ + 执行回测 + :param prices: 收盘价 Series + :param signals: 信号 Series,值: 0=观望, 1=做多, 2=做空 + :return: 回测结果字典 + """ + df = pd.DataFrame({'price': prices, 'signal': signals}).dropna() + if df.empty: + logger.warning("回测数据为空") + return self._empty_result() + + n = len(df) + capital = self.initial_capital + position = 0 # 持仓数量(正=多头,负=空头) + entry_price = 0.0 + direction = 0 # 当前方向: 0=空仓, 1=多, 2=空 + trades = [] + equity_curve = np.zeros(n) + peak_equity = self.initial_capital # 用于 300U 最大回撤约束 + + for i in range(n): + price = df.iloc[i]['price'] + signal = int(df.iloc[i]['signal']) + + # 计算当前权益 + if position > 0: + equity_curve[i] = capital + position * (price - entry_price) + elif position < 0: + equity_curve[i] = capital + position * (price - entry_price) + else: + equity_curve[i] = capital + current_equity = equity_curve[i] + peak_equity = max(peak_equity, current_equity) + + # 最大回撤硬约束:从峰值回落超过 300U 则不再开新仓(只允许平仓) + drawdown_usd = peak_equity - current_equity + can_open = drawdown_usd < self.max_drawdown_limit + + # 强制止损:持仓浮亏超过初始资金的20%时强制平仓 + if position != 0: + unrealized = position * (price - entry_price) + if unrealized < -self.initial_capital * 0.20: + capital, trade = self._close_position( + capital, position, entry_price, price, df.index[i]) + trades.append(trade) + position = 0 + direction = 0 + continue + + # 资金不足时不开新仓;同一时间仅一个仓位(已有持仓则只能先平再开) + min_capital = self.initial_capital * 0.05 + can_trade = capital > min_capital and can_open + + # 跳过:信号与当前持仓方向相同 + if signal == direction: + continue + + # 每笔固定名义 100U:qty = position_notional_usd / price + notional = self.position_notional_usd + qty = notional / price if price > 0 else 0 + + # 需要换方向或平仓 + if signal == 1 and direction != 1: + # 先平仓 + if position != 0: + capital, trade = self._close_position( + capital, position, entry_price, price, df.index[i]) + trades.append(trade) + position = 0 + + if not can_trade or qty <= 0: + direction = 0 + continue + + # 开多:固定 100U 名义,实付手续费 + cost = notional * (self.commission + self.slippage) + capital -= cost + position = qty + entry_price = price + direction = 1 + + elif signal == 2 and direction != 2: + # 先平仓 + if position != 0: + capital, trade = self._close_position( + capital, position, entry_price, price, df.index[i]) + trades.append(trade) + position = 0 + + if not can_trade or qty <= 0: + direction = 0 + continue + + # 开空:固定 100U 名义 + cost = notional * (self.commission + self.slippage) + capital -= cost + position = -qty + entry_price = price + direction = 2 + + elif signal == 0 and position != 0: + # 平仓 + capital, trade = self._close_position( + capital, position, entry_price, price, df.index[i]) + trades.append(trade) + position = 0 + direction = 0 + + # 最终平仓 + if position != 0: + price = df.iloc[-1]['price'] + capital, trade = self._close_position( + capital, position, entry_price, price, df.index[-1]) + trades.append(trade) + + equity = pd.Series(equity_curve, index=df.index) + trades_df = pd.DataFrame(trades) if trades else pd.DataFrame() + metrics = self._calc_metrics(equity, trades_df) + monthly_pnl = self._monthly_pnl(equity) + + return { + 'equity_curve': equity, + 'trades': trades_df, + 'metrics': metrics, + 'final_capital': capital, + 'monthly_pnl': monthly_pnl, + } + + def _close_position(self, capital, position, entry_price, exit_price, time): + """平仓并返回更新后的capital和交易记录""" + if position > 0: + pnl = position * (exit_price - entry_price) + cost = position * exit_price * (self.commission + self.slippage) + trade_type = '平多' + else: + pnl = -position * (entry_price - exit_price) + cost = abs(position) * exit_price * (self.commission + self.slippage) + trade_type = '平空' + + capital += pnl - cost + trade = { + 'type': trade_type, + 'entry': entry_price, + 'exit': exit_price, + 'pnl': pnl - cost, + 'return_pct': (exit_price / entry_price - 1) * (1 if position > 0 else -1), + 'time': time, + } + return capital, trade + + def _calc_metrics(self, equity: pd.Series, trades: pd.DataFrame) -> dict: + """计算绩效指标""" + if equity.empty: + return self._empty_metrics() + + total_return = (equity.iloc[-1] / self.initial_capital) - 1 + + n_bars = len(equity) + if n_bars > 1: + # 按日聚合收益率计算夏普 + daily_equity = equity.resample('D').last().dropna() + if len(daily_equity) > 1: + daily_returns = daily_equity.pct_change().dropna() + n_days = len(daily_returns) + annualized_return = (1 + total_return) ** (365 / max(n_days, 1)) - 1 if total_return > -1 else -1.0 + sharpe = (daily_returns.mean() / daily_returns.std() * np.sqrt(365) + if daily_returns.std() > 0 else 0) + else: + annualized_return = 0 + sharpe = 0 + else: + annualized_return = 0 + sharpe = 0 + + # 最大回撤(比例与绝对 USDT) + cummax = equity.cummax() + drawdown = (equity - cummax) / cummax.replace(0, np.nan) + max_drawdown = drawdown.min() if not drawdown.empty else 0 + max_drawdown_usd = (cummax - equity).max() if not equity.empty else 0 + + # 按月收益(自然月):用于目标「每月盈利 ≥ 1000U」 + monthly_pnl_usd = None + months_above_1000 = 0 + if not equity.empty and hasattr(equity.index, 'to_period'): + try: + monthly = equity.resample('ME').last().dropna() + if len(monthly) > 0: + monthly_pnl = monthly.diff() + monthly_pnl.iloc[0] = monthly.iloc[0] - self.initial_capital + monthly_pnl_usd = monthly_pnl + months_above_1000 = (monthly_pnl >= 1000).sum() + except Exception: + pass + + # 交易统计 + n_trades = len(trades) + if n_trades > 0: + wins = trades[trades['pnl'] > 0] + losses = trades[trades['pnl'] <= 0] + win_rate = len(wins) / n_trades + avg_win = wins['pnl'].mean() if len(wins) > 0 else 0 + avg_loss = abs(losses['pnl'].mean()) if len(losses) > 0 else 0 + profit_factor = (wins['pnl'].sum() / abs(losses['pnl'].sum()) + if len(losses) > 0 and losses['pnl'].sum() != 0 else float('inf')) + else: + win_rate = 0 + avg_win = 0 + avg_loss = 0 + profit_factor = 0 + + out = { + '总收益率': f'{total_return:.2%}', + '年化收益率': f'{annualized_return:.2%}', + '最大回撤': f'{max_drawdown:.2%}', + '最大回撤(U)': f'{max_drawdown_usd:.2f}', + '夏普比率': f'{sharpe:.2f}', + '总交易次数': n_trades, + '胜率': f'{win_rate:.2%}', + '平均盈利': f'{avg_win:.2f}', + '平均亏损': f'{avg_loss:.2f}', + '盈亏比': f'{profit_factor:.2f}', + '最终资金': f'{equity.iloc[-1]:.2f}', + } + out['月盈利≥1000U的月数'] = months_above_1000 if monthly_pnl_usd is not None else 0 + if monthly_pnl_usd is not None and len(monthly_pnl_usd) > 0: + out['月均盈利(U)'] = f'{monthly_pnl_usd.mean():.2f}' + else: + out['月均盈利(U)'] = '0.00' + return out + + def _monthly_pnl(self, equity: pd.Series): + """按自然月汇总收益(USDT),首月为当月权益 - 初始资金""" + if equity.empty or not hasattr(equity.index, 'to_period'): + return None + try: + monthly = equity.resample('ME').last().dropna() + if len(monthly) == 0: + return None + pnl = monthly.diff() + pnl.iloc[0] = monthly.iloc[0] - self.initial_capital + return pnl + except Exception: + return None + + def _empty_result(self): + return { + 'equity_curve': pd.Series(dtype=float), + 'trades': pd.DataFrame(), + 'metrics': self._empty_metrics(), + 'final_capital': self.initial_capital, + 'monthly_pnl': None, + } + + def _empty_metrics(self): + return { + '总收益率': '0.00%', '年化收益率': '0.00%', '最大回撤': '0.00%', + '最大回撤(U)': '0.00', '夏普比率': '0.00', '总交易次数': 0, '胜率': '0.00%', + '平均盈利': '0.00', '平均亏损': '0.00', '盈亏比': '0.00', + '最终资金': f'{self.initial_capital:.2f}', '月盈利≥1000U的月数': 0, + '月均盈利(U)': '0.00', + } + + +def print_metrics(metrics: dict, title: str = "回测结果"): + """打印绩效指标""" + logger.info(f"\n{'='*50}") + logger.info(f" {title}") + logger.info(f"{'='*50}") + for k, v in metrics.items(): + logger.info(f" {k:>12}: {v}") + logger.info(f"{'='*50}") diff --git a/strategy/compare.py b/strategy/compare.py new file mode 100644 index 0000000..bdc5795 --- /dev/null +++ b/strategy/compare.py @@ -0,0 +1,222 @@ +""" +方案对比评估 — 方案A(统计筛选) vs 方案B(AI模型) 并排对比 + 报告输出 +""" +import pandas as pd +import numpy as np +import matplotlib +matplotlib.use('Agg') # 非交互式后端 +import matplotlib.pyplot as plt +from matplotlib import font_manager +from pathlib import Path +from loguru import logger + +# 设置中文字体 +_zh_font = None +for fname in ['PingFang SC', 'Heiti SC', 'STHeiti', 'SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei']: + try: + _zh_font = font_manager.FontProperties(family=fname) + # 验证字体存在 + font_manager.findfont(_zh_font, fallback_to_default=False) + plt.rcParams['font.sans-serif'] = [fname, 'DejaVu Sans'] + plt.rcParams['axes.unicode_minus'] = False + break + except Exception: + _zh_font = None + continue + +from .config import PRIMARY_PERIOD, PROJECT_ROOT +from .stat_strategy import StatStrategy +from .ai_strategy import AIStrategy +from .backtest import print_metrics + + +REPORT_DIR = PROJECT_ROOT / 'reports' + + +def compare_strategies(period: int = None, start_date: str = None, end_date: str = None, + save_plot: bool = True) -> dict: + """ + 运行两种方案并对比 + :return: {'stat': result_a, 'lgb': result_b, 'xgb': result_c, 'comparison': DataFrame} + """ + if period is None: + period = PRIMARY_PERIOD + + logger.info("=" * 70) + logger.info(" 开始策略对比评估") + logger.info("=" * 70) + + results = {} + + # 方案A:统计筛选 + logger.info("\n>>> 运行方案A: 统计筛选策略") + stat = StatStrategy() + results['stat'] = stat.run(period, start_date, end_date) + + # 方案B-1:LightGBM + logger.info("\n>>> 运行方案B-1: LightGBM AI策略") + lgb_strategy = AIStrategy(model_type='lightgbm') + results['lgb'] = lgb_strategy.run(period, start_date, end_date) + + # 方案B-2:XGBoost + logger.info("\n>>> 运行方案B-2: XGBoost AI策略") + xgb_strategy = AIStrategy(model_type='xgboost') + results['xgb'] = xgb_strategy.run(period, start_date, end_date) + + # 对比表格 + comparison = _build_comparison_table(results) + results['comparison'] = comparison + + # 打印对比 + logger.info("\n" + "=" * 70) + logger.info(" 策略对比总结") + logger.info("=" * 70) + logger.info(f"\n{comparison.to_string()}") + + # 每月盈利(U) + _log_monthly_pnl(results) + + # 保存图表 + if save_plot: + _save_equity_plot(results) + + return results + + +def _build_comparison_table(results: dict) -> pd.DataFrame: + """构建对比表格""" + rows = {} + name_map = { + 'stat': '方案A: 统计筛选', + 'lgb': '方案B-1: LightGBM', + 'xgb': '方案B-2: XGBoost', + } + + for key, name in name_map.items(): + if key in results and 'metrics' in results[key]: + rows[name] = results[key]['metrics'] + + if not rows: + return pd.DataFrame() + + df = pd.DataFrame(rows).T + return df + + +def _log_monthly_pnl(results: dict): + """打印各策略每月盈利(USDT)""" + name_map = { + 'stat': '方案A', + 'lgb': 'LightGBM', + 'xgb': 'XGBoost', + } + cols = [] + for key, name in name_map.items(): + if key not in results or results[key].get('monthly_pnl') is None: + continue + s = results[key]['monthly_pnl'] + if s is None or s.empty: + continue + s = s.astype(float).round(2) + s.name = name + cols.append(s) + if not cols: + return + monthly_df = pd.concat(cols, axis=1) + monthly_df = monthly_df.fillna(0) + logger.info("\n" + "-" * 70) + logger.info(" 每月盈利 (USDT)") + logger.info("-" * 70) + logger.info(f"\n{monthly_df.to_string()}") + logger.info("-" * 70) + + +def _save_equity_plot(results: dict): + """保存权益曲线对比图""" + REPORT_DIR.mkdir(parents=True, exist_ok=True) + + fig, axes = plt.subplots(2, 1, figsize=(14, 10)) + + # 上图:权益曲线 + ax1 = axes[0] + name_map = { + 'stat': '方案A: 统计筛选', + 'lgb': '方案B-1: LightGBM', + 'xgb': '方案B-2: XGBoost', + } + colors = {'stat': '#2196F3', 'lgb': '#4CAF50', 'xgb': '#FF9800'} + + for key, name in name_map.items(): + if key in results and 'equity_curve' in results[key]: + eq = results[key]['equity_curve'] + if not eq.empty: + ax1.plot(eq.index, eq.values, label=name, color=colors.get(key, 'gray'), linewidth=1) + + ax1.set_title('权益曲线对比', fontsize=14) + ax1.set_ylabel('资金 (USDT)') + ax1.legend() + ax1.grid(True, alpha=0.3) + + # 下图:回撤曲线 + ax2 = axes[1] + for key, name in name_map.items(): + if key in results and 'equity_curve' in results[key]: + eq = results[key]['equity_curve'] + if not eq.empty: + cummax = eq.cummax() + drawdown = (eq - cummax) / cummax * 100 + ax2.fill_between(drawdown.index, drawdown.values, 0, + alpha=0.3, label=name, color=colors.get(key, 'gray')) + + ax2.set_title('回撤对比', fontsize=14) + ax2.set_ylabel('回撤 (%)') + ax2.legend() + ax2.grid(True, alpha=0.3) + + plt.tight_layout() + plot_path = REPORT_DIR / 'strategy_comparison.png' + plt.savefig(str(plot_path), dpi=150) + plt.close() + logger.info(f"对比图表已保存: {plot_path}") + + +def run_full_comparison(period: int = None, start_date: str = None, end_date: str = None, save_plot: bool = True): + """完整对比入口(可直接调用)""" + results = compare_strategies(period, start_date, end_date, save_plot=save_plot) + + # 推荐最优方案 + best_key = None + best_return = -float('inf') + + for key in ['stat', 'lgb', 'xgb']: + if key in results and 'metrics' in results[key]: + ret_str = results[key]['metrics'].get('总收益率', '0%') + ret_val = float(ret_str.strip('%')) / 100 + if ret_val > best_return: + best_return = ret_val + best_key = key + + name_map = {'stat': '方案A: 统计筛选', 'lgb': '方案B-1: LightGBM', 'xgb': '方案B-2: XGBoost'} + if best_key: + logger.info(f"\n{'='*50}") + logger.info(f" 推荐方案: {name_map.get(best_key, best_key)}") + logger.info(f" 总收益率: {best_return:.2%}") + logger.info(f"{'='*50}") + + return results + + +if __name__ == '__main__': + import argparse + p = argparse.ArgumentParser(description='运行策略对比:方案A(统计) vs 方案B(LightGBM/XGBoost)') + p.add_argument('--period', type=int, default=None, help='K线周期分钟') + p.add_argument('--start', type=str, default=None, help='开始日期,如 2024-01-01') + p.add_argument('--end', type=str, default=None, help='结束日期,如 2024-12-31') + p.add_argument('--no-plot', action='store_true', help='不保存权益/回撤图') + args = p.parse_args() + run_full_comparison( + period=args.period, + start_date=args.start, + end_date=args.end, + save_plot=not args.no_plot, + ) diff --git a/strategy/config.py b/strategy/config.py new file mode 100644 index 0000000..30fb207 --- /dev/null +++ b/strategy/config.py @@ -0,0 +1,151 @@ +""" +全局配置 — 交易参数、指标参数、模型参数 +""" +from pathlib import Path + +# ============ 路径 ============ +PROJECT_ROOT = Path(__file__).parent.parent +DB_PATH = PROJECT_ROOT / 'models' / 'database.db' + +# ============ 交易参数 ============ +TRADE_CONFIG = { + 'symbol': 'ETHUSDT', + 'commission': 0.0006, # 名义手续费 0.06% + 'commission_rebate': 0.90, # 90% 返佣(次日 8 点结算),实付 = commission * (1 - rebate) + 'slippage': 0.0001, # 滑点 0.01% + 'initial_capital': 10000, # 本金 USDT + 'leverage': 100, # 杠杆倍数 + 'position_notional_usd': 500, # 每笔名义 500U(开 100 倍),目标月均收益约 500U + 'max_drawdown_limit': 300, # 最大回撤硬约束:权益从峰值回落超过 300U 则不再开新仓 + # 兼容旧逻辑(若用比例算仓则用此项) + 'position_size': 0.95, +} + +# ============ K线周期 ============ +KLINE_PERIODS = { + 1: '1m', + 3: '3m', + 5: '5m', + 15: '15m', + 30: '30m', + 60: '1h', +} + +# 主周期(用于生成信号) +PRIMARY_PERIOD = 15 # 15分钟 +# 辅助周期(用于多周期融合特征) +AUX_PERIODS = [5, 60] + +# ============ 指标参数 ============ +INDICATOR_PARAMS = { + # 趋势类 + 'sma_windows': [5, 10, 20, 50, 200], + 'ema_windows': [12, 26], + 'macd_fast': 12, + 'macd_slow': 26, + 'macd_signal': 9, + 'adx_window': 14, + 'ichimoku_conversion': 9, + 'ichimoku_base': 26, + 'ichimoku_span_b': 52, + 'trix_window': 15, + 'aroon_window': 25, + 'cci_window': 20, + 'dpo_window': 20, + 'kst_roc1': 10, + 'kst_roc2': 15, + 'kst_roc3': 20, + 'kst_roc4': 30, + 'vortex_window': 14, + + # 动量类 + 'rsi_window': 14, + 'stoch_window': 14, + 'stoch_smooth': 3, + 'williams_window': 14, + 'roc_window': 12, + 'mfi_window': 14, + 'tsi_slow': 25, + 'tsi_fast': 13, + 'uo_short': 7, + 'uo_medium': 14, + 'uo_long': 28, + 'ao_short': 5, + 'ao_long': 34, + 'kama_window': 10, + 'ppo_slow': 26, + 'ppo_fast': 12, + 'stoch_rsi_window': 14, + 'stoch_rsi_smooth': 3, + + # 波动率类 + 'bb_window': 20, + 'bb_std': 2, + 'atr_window': 14, + 'kc_window': 20, + 'dc_window': 20, + + # 成交量类(部分指标需要volume,K线数据可能无volume则跳过) + 'obv_enabled': True, + 'cmf_window': 20, + 'emv_window': 14, + 'fi_window': 13, +} + +# ============ 特征工程参数 ============ +FEATURE_CONFIG = { + 'label_forward_periods': 10, # 未来N根K线用于生成标签 + 'label_threshold': 0.002, # 涨跌阈值(0.2%以内算震荡) + 'lookback_lags': [1, 3, 5], # 滞后特征的lag值 + 'normalize': True, # 是否标准化 +} + +# ============ 模型参数 ============ +MODEL_CONFIG = { + 'walk_forward_train_size': 20000, # Walk-Forward 训练窗口大小 + 'walk_forward_test_size': 2000, # Walk-Forward 测试窗口大小 + 'walk_forward_step': 2000, # 滚动步长 + 'warm_up_months': 12, # 指定回测区间时向前加载的月数,使区间首月即有预测 + + 'lightgbm': { + 'n_estimators': 300, + 'max_depth': 4, + 'learning_rate': 0.03, + 'num_leaves': 15, + 'min_child_samples': 50, + 'subsample': 0.7, + 'colsample_bytree': 0.6, + 'reg_alpha': 1.0, + 'reg_lambda': 1.0, + 'objective': 'multiclass', + 'num_class': 3, + 'verbose': -1, + }, + + 'xgboost': { + 'n_estimators': 300, + 'max_depth': 4, + 'learning_rate': 0.03, + 'subsample': 0.7, + 'colsample_bytree': 0.6, + 'reg_alpha': 1.0, + 'reg_lambda': 1.0, + 'objective': 'multi:softprob', + 'num_class': 3, + 'verbosity': 0, + }, +} + +# ============ 统计筛选参数 ============ +STAT_CONFIG = { + 'top_n_features': 15, # 筛选Top N个指标 + 'correlation_threshold': 0.9, # 去除高相关特征的阈值 + 'grid_search_cv': 3, # 网格搜索交叉验证折数 +} + +# ============ 信号标签映射 ============ +SIGNAL_MAP = { + 0: '观望', + 1: '做多', + 2: '做空', +} diff --git a/strategy/data_loader.py b/strategy/data_loader.py new file mode 100644 index 0000000..b01e09e --- /dev/null +++ b/strategy/data_loader.py @@ -0,0 +1,119 @@ +""" +数据加载器 — 从SQLite加载K线数据为pandas DataFrame +""" +import sqlite3 +import pandas as pd +from loguru import logger +from .config import DB_PATH, KLINE_PERIODS + + +def load_kline(period: int = 15, start_date: str = None, end_date: str = None) -> pd.DataFrame: + """ + 加载指定周期的K线数据 + :param period: K线周期(分钟),如 1, 3, 5, 15, 30, 60 + :param start_date: 起始日期 'YYYY-MM-DD'(可选) + :param end_date: 结束日期 'YYYY-MM-DD'(可选) + :return: DataFrame,列: timestamp, open, high, low, close + """ + suffix = KLINE_PERIODS.get(period) + if suffix is None: + raise ValueError(f"不支持的周期: {period},可选: {list(KLINE_PERIODS.keys())}") + + table_name = f'bitmart_eth_{suffix}' + conn = sqlite3.connect(str(DB_PATH)) + + query = f"SELECT id as timestamp, open, high, low, close FROM {table_name} ORDER BY id" + df = pd.read_sql_query(query, conn) + conn.close() + + if df.empty: + logger.warning(f"[{suffix}] 表中无数据") + return df + + # id 是毫秒时间戳,转为 datetime 索引 + df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms') + df.set_index('datetime', inplace=True) + + # 按日期过滤 + if start_date: + df = df[df.index >= start_date] + if end_date: + df = df[df.index <= end_date] + + logger.info(f"[{suffix}] 加载 {len(df)} 条K线 | {df.index[0]} ~ {df.index[-1]}") + return df + + +def load_multi_period(periods: list = None, start_date: str = None, end_date: str = None) -> dict: + """ + 加载多个周期的K线数据 + :param periods: 周期列表,如 [5, 15, 60],默认全部 + :param start_date: 起始日期 + :param end_date: 结束日期 + :return: {period: DataFrame} 字典 + """ + if periods is None: + periods = list(KLINE_PERIODS.keys()) + + result = {} + for p in periods: + try: + df = load_kline(p, start_date, end_date) + if not df.empty: + result[p] = df + except Exception as e: + logger.error(f"加载 {p}分钟 K线失败: {e}") + + return result + + +def load_trades(start_date: str = None, end_date: str = None, limit: int = None) -> pd.DataFrame: + """ + 加载原始成交记录 + :return: DataFrame,列: id, timestamp, price, volume, side + """ + conn = sqlite3.connect(str(DB_PATH)) + query = "SELECT id, timestamp, price, volume, side FROM bitmart_eth_trades ORDER BY timestamp" + df = pd.read_sql_query(query, conn) + conn.close() + + if df.empty: + logger.warning("成交记录表中无数据") + return df + + df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms') + df.set_index('datetime', inplace=True) + + if start_date: + df = df[df.index >= start_date] + if end_date: + df = df[df.index <= end_date] + if limit: + df = df.head(limit) + + logger.info(f"加载 {len(df)} 条成交记录") + return df + + +def get_available_tables() -> list: + """列出数据库中所有可用的表""" + conn = sqlite3.connect(str(DB_PATH)) + cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") + tables = [row[0] for row in cursor.fetchall()] + conn.close() + return tables + + +def get_table_stats() -> dict: + """获取各表的数据统计""" + conn = sqlite3.connect(str(DB_PATH)) + tables = get_available_tables() + stats = {} + for table in tables: + try: + count = pd.read_sql_query(f"SELECT COUNT(*) as cnt FROM {table}", conn).iloc[0]['cnt'] + stats[table] = count + except Exception: + stats[table] = 0 + conn.close() + return stats diff --git a/strategy/feature_engine.py b/strategy/feature_engine.py new file mode 100644 index 0000000..433c475 --- /dev/null +++ b/strategy/feature_engine.py @@ -0,0 +1,204 @@ +""" +特征工程 — 标准化、多周期融合、滞后特征、标签生成 +""" +import pandas as pd +import numpy as np +from sklearn.preprocessing import StandardScaler +from loguru import logger + +from .config import FEATURE_CONFIG as FC, PRIMARY_PERIOD, AUX_PERIODS +from .data_loader import load_kline +from .indicators import compute_all_indicators, get_indicator_names + + +def build_features(primary_df: pd.DataFrame, + aux_dfs: dict = None, + has_volume: bool = False) -> pd.DataFrame: + """ + 构建完整特征矩阵 + :param primary_df: 主周期K线 DataFrame + :param aux_dfs: {period: DataFrame} 辅助周期数据(可选) + :param has_volume: 是否有成交量 + :return: 特征矩阵 DataFrame + """ + # 1. 主周期指标 + logger.info("计算主周期指标...") + df = compute_all_indicators(primary_df, has_volume=has_volume) + + # 2. 滞后特征 + logger.info("生成滞后特征...") + indicator_cols = get_indicator_names(has_volume) + existing_cols = [col for col in indicator_cols if col in df.columns] + + lag_frames = [] + for lag in FC['lookback_lags']: + lagged = df[existing_cols].shift(lag).add_suffix(f'_lag{lag}') + lag_frames.append(lagged) + if lag_frames: + df = pd.concat([df] + lag_frames, axis=1) + + # 3. 多周期融合 + if aux_dfs: + aux_frames = [] + for period, aux_df in aux_dfs.items(): + logger.info(f"融合 {period}分钟 辅助周期特征...") + aux_with_ind = compute_all_indicators(aux_df, has_volume=has_volume) + key_indicators = ['rsi', 'macd', 'adx', 'bb_pband', 'atr', 'cci'] + for ind in key_indicators: + if ind in aux_with_ind.columns: + aligned = aux_with_ind[ind].reindex(df.index, method='ffill') + aux_frames.append(aligned.rename(f'{ind}_{period}m')) + if aux_frames: + df = pd.concat([df] + aux_frames, axis=1) + + # 4. 去除全NaN列 + before_cols = len(df.columns) + df.dropna(axis=1, how='all', inplace=True) + after_cols = len(df.columns) + if before_cols != after_cols: + logger.info(f"移除 {before_cols - after_cols} 个全NaN列") + + logger.info(f"特征矩阵: {df.shape[0]} 行 x {df.shape[1]} 列") + return df + + +def generate_labels(df: pd.DataFrame, forward_periods: int = None, + threshold: float = None) -> pd.Series: + """ + 生成交易标签 + :param df: 包含 close 列的 DataFrame + :param forward_periods: 未来N根K线 + :param threshold: 涨跌阈值 + :return: Series,值为 0=观望, 1=做多, 2=做空 + """ + if forward_periods is None: + forward_periods = FC['label_forward_periods'] + if threshold is None: + threshold = FC['label_threshold'] + + future_return = df['close'].shift(-forward_periods) / df['close'] - 1 + + labels = pd.Series(0, index=df.index, name='label') # 默认观望 + labels[future_return > threshold] = 1 # 做多 + labels[future_return < -threshold] = 2 # 做空 + + # 最后 forward_periods 行无法计算标签,设为NaN + labels.iloc[-forward_periods:] = np.nan + + dist = labels.value_counts().to_dict() + logger.info(f"标签分布: 观望={dist.get(0, 0)}, 做多={dist.get(1, 0)}, 做空={dist.get(2, 0)}") + return labels + + +def prepare_dataset(period: int = None, start_date: str = None, end_date: str = None, + has_volume: bool = False) -> tuple: + """ + 一键准备训练数据集 + :return: (X, y, feature_names) — 已去NaN、已标准化 + """ + if period is None: + period = PRIMARY_PERIOD + + # 加载主周期 + primary_df = load_kline(period, start_date, end_date) + if primary_df.empty: + raise ValueError(f"{period}分钟 K线数据为空") + + # 加载辅助周期 + aux_dfs = {} + for aux_p in AUX_PERIODS: + try: + aux_df = load_kline(aux_p, start_date, end_date) + if not aux_df.empty: + aux_dfs[aux_p] = aux_df + except Exception as e: + logger.warning(f"加载 {aux_p}分钟 辅助数据失败: {e}") + + # 构建特征 + df = build_features(primary_df, aux_dfs, has_volume=has_volume) + + # 生成标签 + labels = generate_labels(df) + df = df.copy() + df['label'] = labels + + # 去除NaN行 + df.dropna(inplace=True) + logger.info(f"去NaN后剩余 {len(df)} 行") + + # 分离 X, y + exclude_cols = ['open', 'high', 'low', 'close', 'timestamp', 'label'] + if 'volume' in df.columns: + exclude_cols.append('volume') + + # 排除价格级别特征(会泄露绝对价格信息导致过拟合) + price_level_patterns = [ + 'sma_', 'ema_', 'bb_upper', 'bb_mid', 'bb_lower', + 'kc_upper', 'kc_mid', 'kc_lower', 'dc_upper', 'dc_mid', 'dc_lower', + 'ichimoku_conv', 'ichimoku_base', 'ichimoku_a', 'ichimoku_b', + 'kama', 'vwap', 'obv', 'adi', 'vpt', 'nvi', + 'momentum_3', 'momentum_5', + ] + feature_cols = [] + for c in df.columns: + if c in exclude_cols: + continue + base_name = c.split('_lag')[0] + # 去掉周期后缀 + for suffix in ['_5m', '_60m', '_3m', '_15m', '_30m', '_1m']: + if base_name.endswith(suffix): + base_name = base_name[:-len(suffix)] + break + if any(base_name.startswith(p) or base_name == p.rstrip('_') for p in price_level_patterns): + continue + feature_cols.append(c) + + logger.info(f"排除价格级别特征后剩余 {len(feature_cols)} 个特征") + + X = df[feature_cols].copy() + y = df['label'].astype(int) + + # 标准化 + scaler = None + if FC['normalize']: + scaler = StandardScaler() + X = pd.DataFrame(scaler.fit_transform(X), columns=feature_cols, index=X.index) + logger.info("特征已标准化") + + logger.info(f"最终数据集: X={X.shape}, y={y.shape}, 特征数={len(feature_cols)}") + return X, y, feature_cols, scaler + + +def get_latest_feature_row(period: int = None, feature_cols: list = None, + start_date: str = None, end_date: str = None) -> pd.DataFrame: + """ + 构建特征并返回最后一行的特征矩阵(用于实盘/模拟盘预测)。 + 需传入已保存的 feature_cols,保证与训练时一致。 + :return: 1 行 DataFrame,列为 feature_cols;若缺列则返回空 DataFrame + """ + if period is None: + period = PRIMARY_PERIOD + if not feature_cols: + return pd.DataFrame() + + primary_df = load_kline(period, start_date, end_date) + if primary_df.empty: + return pd.DataFrame() + aux_dfs = {} + for aux_p in AUX_PERIODS: + try: + aux_df = load_kline(aux_p, start_date, end_date) + if not aux_df.empty: + aux_dfs[aux_p] = aux_df + except Exception: + pass + df = build_features(primary_df, aux_dfs, has_volume=False) + labels = generate_labels(df) + df = df.copy() + df['label'] = labels + df.dropna(inplace=True) + missing = [c for c in feature_cols if c not in df.columns] + if missing: + logger.warning(f"get_latest_feature_row: 缺少特征列 {missing[:5]}{'...' if len(missing) > 5 else ''}") + return pd.DataFrame() + return df[feature_cols].iloc[-1:].copy() diff --git a/strategy/indicators.py b/strategy/indicators.py new file mode 100644 index 0000000..3ca8521 --- /dev/null +++ b/strategy/indicators.py @@ -0,0 +1,278 @@ +""" +52个技术指标计算引擎 — 基于 ta 库 +覆盖趋势、动量、波动率、成交量、自定义衍生特征五大类 +""" +import pandas as pd +import numpy as np +import ta +from .config import INDICATOR_PARAMS as P + + +def compute_all_indicators(df: pd.DataFrame, has_volume: bool = False) -> pd.DataFrame: + """ + 计算全部52个技术指标,返回拼接后的DataFrame + :param df: 必须包含 open, high, low, close 列;可选 volume 列 + :param has_volume: 是否有成交量数据 + :return: 原始列 + 52个指标列 + """ + out = df.copy() + o, h, l, c = out['open'], out['high'], out['low'], out['close'] + v = out['volume'] if has_volume and 'volume' in out.columns else None + + # ========== 趋势类 (14) ========== + out = _add_trend(out, o, h, l, c) + + # ========== 动量类 (12) ========== + out = _add_momentum(out, h, l, c, v) + + # ========== 波动率类 (8) ========== + out = _add_volatility(out, h, l, c) + + # ========== 成交量类 (8) ========== + if has_volume and v is not None: + out = _add_volume(out, h, l, c, v) + + # ========== 自定义衍生特征 (10) ========== + out = _add_custom(out, o, h, l, c) + + return out + + +def _add_trend(out, o, h, l, c): + """趋势类指标 (14个特征)""" + # SMA (5个) + for w in P['sma_windows']: + out[f'sma_{w}'] = ta.trend.sma_indicator(c, window=w) + + # EMA (2个) + for w in P['ema_windows']: + out[f'ema_{w}'] = ta.trend.ema_indicator(c, window=w) + + # MACD (3个) + macd = ta.trend.MACD(c, window_slow=P['macd_slow'], window_fast=P['macd_fast'], + window_sign=P['macd_signal']) + out['macd'] = macd.macd() + out['macd_signal'] = macd.macd_signal() + out['macd_hist'] = macd.macd_diff() + + # ADX + DI (3个) + adx = ta.trend.ADXIndicator(h, l, c, window=P['adx_window']) + out['adx'] = adx.adx() + out['di_plus'] = adx.adx_pos() + out['di_minus'] = adx.adx_neg() + + # Ichimoku (4个) + ichi = ta.trend.IchimokuIndicator(h, l, + window1=P['ichimoku_conversion'], + window2=P['ichimoku_base'], + window3=P['ichimoku_span_b']) + out['ichimoku_conv'] = ichi.ichimoku_conversion_line() + out['ichimoku_base'] = ichi.ichimoku_base_line() + out['ichimoku_a'] = ichi.ichimoku_a() + out['ichimoku_b'] = ichi.ichimoku_b() + + # TRIX + out['trix'] = ta.trend.trix(c, window=P['trix_window']) + + # Aroon (2个) + aroon = ta.trend.AroonIndicator(h, l, window=P['aroon_window']) + out['aroon_up'] = aroon.aroon_up() + out['aroon_down'] = aroon.aroon_down() + + # CCI + out['cci'] = ta.trend.cci(h, l, c, window=P['cci_window']) + + # DPO + out['dpo'] = ta.trend.dpo(c, window=P['dpo_window']) + + # KST + kst = ta.trend.KSTIndicator(c, roc1=P['kst_roc1'], roc2=P['kst_roc2'], + roc3=P['kst_roc3'], roc4=P['kst_roc4']) + out['kst'] = kst.kst() + + # Vortex (2个) + vortex = ta.trend.VortexIndicator(h, l, c, window=P['vortex_window']) + out['vortex_pos'] = vortex.vortex_indicator_pos() + out['vortex_neg'] = vortex.vortex_indicator_neg() + + return out + + +def _add_momentum(out, h, l, c, v): + """动量类指标 (12个特征)""" + # RSI + out['rsi'] = ta.momentum.rsi(c, window=P['rsi_window']) + + # Stochastic %K / %D + stoch = ta.momentum.StochasticOscillator(h, l, c, + window=P['stoch_window'], + smooth_window=P['stoch_smooth']) + out['stoch_k'] = stoch.stoch() + out['stoch_d'] = stoch.stoch_signal() + + # Williams %R + out['williams_r'] = ta.momentum.williams_r(h, l, c, lbp=P['williams_window']) + + # ROC + out['roc'] = ta.momentum.roc(c, window=P['roc_window']) + + # MFI(需要volume) + if v is not None: + out['mfi'] = ta.volume.money_flow_index(h, l, c, v, window=P['mfi_window']) + + # TSI + out['tsi'] = ta.momentum.tsi(c, window_slow=P['tsi_slow'], window_fast=P['tsi_fast']) + + # Ultimate Oscillator + out['uo'] = ta.momentum.ultimate_oscillator(h, l, c, + window1=P['uo_short'], + window2=P['uo_medium'], + window3=P['uo_long']) + + # Awesome Oscillator + out['ao'] = ta.momentum.awesome_oscillator(h, l, + window1=P['ao_short'], + window2=P['ao_long']) + + # KAMA + out['kama'] = ta.momentum.kama(c, window=P['kama_window']) + + # PPO + out['ppo'] = ta.momentum.ppo(c, window_slow=P['ppo_slow'], window_fast=P['ppo_fast']) + + # Stochastic RSI %K / %D + stoch_rsi = ta.momentum.StochRSIIndicator(c, + window=P['stoch_rsi_window'], + smooth1=P['stoch_rsi_smooth'], + smooth2=P['stoch_rsi_smooth']) + out['stoch_rsi_k'] = stoch_rsi.stochrsi_k() + out['stoch_rsi_d'] = stoch_rsi.stochrsi_d() + + return out + + +def _add_volatility(out, h, l, c): + """波动率类指标 (8个特征 — 含子指标共12列)""" + # Bollinger Bands (5个) + bb = ta.volatility.BollingerBands(c, window=P['bb_window'], window_dev=P['bb_std']) + out['bb_upper'] = bb.bollinger_hband() + out['bb_mid'] = bb.bollinger_mavg() + out['bb_lower'] = bb.bollinger_lband() + out['bb_width'] = bb.bollinger_wband() + out['bb_pband'] = bb.bollinger_pband() + + # ATR + out['atr'] = ta.volatility.average_true_range(h, l, c, window=P['atr_window']) + + # Keltner Channel (3个) + kc = ta.volatility.KeltnerChannel(h, l, c, window=P['kc_window']) + out['kc_upper'] = kc.keltner_channel_hband() + out['kc_mid'] = kc.keltner_channel_mband() + out['kc_lower'] = kc.keltner_channel_lband() + + # Donchian Channel (3个) + dc = ta.volatility.DonchianChannel(h, l, c, window=P['dc_window']) + out['dc_upper'] = dc.donchian_channel_hband() + out['dc_mid'] = dc.donchian_channel_mband() + out['dc_lower'] = dc.donchian_channel_lband() + + return out + + +def _add_volume(out, h, l, c, v): + """成交量类指标 (8个特征)""" + # OBV + out['obv'] = ta.volume.on_balance_volume(c, v) + + # VWAP + out['vwap'] = ta.volume.volume_weighted_average_price(h, l, c, v) + + # CMF + out['cmf'] = ta.volume.chaikin_money_flow(h, l, c, v, window=P['cmf_window']) + + # ADI (Accumulation/Distribution Index) + out['adi'] = ta.volume.acc_dist_index(h, l, c, v) + + # EMV (Ease of Movement) + out['emv'] = ta.volume.ease_of_movement(h, l, v, window=P['emv_window']) + + # Force Index + out['fi'] = ta.volume.force_index(c, v, window=P['fi_window']) + + # VPT (Volume Price Trend) + out['vpt'] = ta.volume.volume_price_trend(c, v) + + # NVI (Negative Volume Index) + out['nvi'] = ta.volume.negative_volume_index(c, v) + + return out + + +def _add_custom(out, o, h, l, c): + """自定义衍生特征 (10个)""" + # 价格变化率 + out['price_change_pct'] = c.pct_change() + + # 振幅(High-Low范围 / Close) + out['high_low_range'] = (h - l) / c + + # 实体比率(|Close-Open| / (High-Low)) + body = (c - o).abs() + hl_range = (h - l).replace(0, np.nan) + out['body_ratio'] = body / hl_range + + # 上影线比率 + upper_shadow = h - pd.concat([o, c], axis=1).max(axis=1) + out['upper_shadow'] = upper_shadow / hl_range + + # 下影线比率 + lower_shadow = pd.concat([o, c], axis=1).min(axis=1) - l + out['lower_shadow'] = lower_shadow / hl_range + + # 波动率比率(ATR / Close 的滚动比值) + atr = ta.volatility.average_true_range(h, l, c, window=14) + out['volatility_ratio'] = atr / c + + # Close / SMA20 比率 + sma20 = ta.trend.sma_indicator(c, window=20) + out['close_sma20_ratio'] = c / sma20.replace(0, np.nan) + + # Close / EMA12 比率 + ema12 = ta.trend.ema_indicator(c, window=12) + out['close_ema12_ratio'] = c / ema12.replace(0, np.nan) + + # 动量 3周期 + out['momentum_3'] = c - c.shift(3) + + # 动量 5周期 + out['momentum_5'] = c - c.shift(5) + + return out + + +def get_indicator_names(has_volume: bool = False) -> list: + """返回所有指标列名""" + names = [] + # 趋势 + for w in P['sma_windows']: + names.append(f'sma_{w}') + for w in P['ema_windows']: + names.append(f'ema_{w}') + names += ['macd', 'macd_signal', 'macd_hist', 'adx', 'di_plus', 'di_minus'] + names += ['ichimoku_conv', 'ichimoku_base', 'ichimoku_a', 'ichimoku_b'] + names += ['trix', 'aroon_up', 'aroon_down', 'cci', 'dpo', 'kst', 'vortex_pos', 'vortex_neg'] + # 动量 + names += ['rsi', 'stoch_k', 'stoch_d', 'williams_r', 'roc', 'tsi', 'uo', 'ao', 'kama', 'ppo', + 'stoch_rsi_k', 'stoch_rsi_d'] + if has_volume: + names.append('mfi') + # 波动率 + names += ['bb_upper', 'bb_mid', 'bb_lower', 'bb_width', 'bb_pband', 'atr', + 'kc_upper', 'kc_mid', 'kc_lower', 'dc_upper', 'dc_mid', 'dc_lower'] + # 成交量 + if has_volume: + names += ['obv', 'vwap', 'cmf', 'adi', 'emv', 'fi', 'vpt', 'nvi'] + # 自定义 + names += ['price_change_pct', 'high_low_range', 'body_ratio', 'upper_shadow', 'lower_shadow', + 'volatility_ratio', 'close_sma20_ratio', 'close_ema12_ratio', 'momentum_3', 'momentum_5'] + return names diff --git a/strategy/stat_strategy.py b/strategy/stat_strategy.py new file mode 100644 index 0000000..05528b1 --- /dev/null +++ b/strategy/stat_strategy.py @@ -0,0 +1,256 @@ +""" +方案A:统计筛选 + 规则组合策略 +1. 从52个指标中用统计方法筛选最有效的指标 +2. 用经典规则组合生成交易信号 +3. 网格搜索优化参数 +""" +import pandas as pd +import numpy as np +from sklearn.ensemble import RandomForestClassifier +from sklearn.feature_selection import mutual_info_classif +from loguru import logger + +from .config import STAT_CONFIG as SC, PRIMARY_PERIOD, AUX_PERIODS +from .feature_engine import prepare_dataset +from .backtest import BacktestEngine, print_metrics + + +class StatStrategy: + """统计筛选策略""" + + def __init__(self): + self.top_features = [] + self.feature_scores = {} + self.best_params = {} + + def select_features(self, X: pd.DataFrame, y: pd.Series) -> list: + """ + 用多种统计方法筛选有效指标 + :return: Top N 特征名列表 + """ + logger.info("=" * 50) + logger.info("开始特征筛选...") + scores = {} + + # 1. 皮尔逊相关系数 + logger.info("计算皮尔逊相关系数...") + corr_scores = X.corrwith(y).abs().fillna(0) + for col in X.columns: + scores[col] = scores.get(col, 0) + corr_scores.get(col, 0) + + # 2. 互信息 + logger.info("计算互信息...") + mi = mutual_info_classif(X.fillna(0), y, random_state=42) + mi_series = pd.Series(mi, index=X.columns) + mi_norm = mi_series / mi_series.max() if mi_series.max() > 0 else mi_series + for col in X.columns: + scores[col] = scores.get(col, 0) + mi_norm.get(col, 0) + + # 3. 随机森林特征重要性 + logger.info("训练随机森林评估特征重要性...") + rf = RandomForestClassifier(n_estimators=200, max_depth=8, random_state=42, n_jobs=-1) + rf.fit(X.fillna(0), y) + rf_imp = pd.Series(rf.feature_importances_, index=X.columns) + rf_norm = rf_imp / rf_imp.max() if rf_imp.max() > 0 else rf_imp + for col in X.columns: + scores[col] = scores.get(col, 0) + rf_norm.get(col, 0) + + # 综合排名 + score_series = pd.Series(scores).sort_values(ascending=False) + self.feature_scores = score_series.to_dict() + + # 去除高相关特征 + top_candidates = score_series.head(SC['top_n_features'] * 2).index.tolist() + selected = self._remove_correlated(X[top_candidates], SC['correlation_threshold']) + self.top_features = selected[:SC['top_n_features']] + + logger.info(f"筛选出 Top {len(self.top_features)} 特征:") + for i, feat in enumerate(self.top_features): + logger.info(f" {i+1}. {feat} (综合得分: {score_series[feat]:.4f})") + + return self.top_features + + def _remove_correlated(self, X: pd.DataFrame, threshold: float) -> list: + """去除高度相关的冗余特征""" + corr_matrix = X.corr().abs() + selected = list(X.columns) + to_remove = set() + + for i in range(len(selected)): + if selected[i] in to_remove: + continue + for j in range(i + 1, len(selected)): + if selected[j] in to_remove: + continue + if corr_matrix.loc[selected[i], selected[j]] > threshold: + to_remove.add(selected[j]) + + result = [c for c in selected if c not in to_remove] + if to_remove: + logger.info(f"移除 {len(to_remove)} 个高相关冗余特征") + return result + + def generate_signals(self, df: pd.DataFrame) -> pd.Series: + """ + 基于筛选出的指标,用规则组合生成交易信号 + :param df: 包含指标列的 DataFrame(原始值,非标准化) + :return: 信号 Series (0=观望, 1=做多, 2=做空) + """ + signals = pd.Series(0, index=df.index) + long_score = pd.Series(0.0, index=df.index) + short_score = pd.Series(0.0, index=df.index) + matched = 0 + + for feat in self.top_features: + if feat not in df.columns: + continue + + col = df[feat] + base = feat.split('_lag')[0] # 去掉 _lagN 后缀 + # 去掉辅助周期后缀 _5m / _60m + for suffix in ['_5m', '_60m', '_3m', '_15m', '_30m', '_1m']: + if base.endswith(suffix): + base = base[:-len(suffix)] + break + + if 'rsi' in base: + long_score += (col < 35).astype(float) + short_score += (col > 65).astype(float) + matched += 1 + elif base == 'macd_hist': + long_score += (col > 0).astype(float) + short_score += (col < 0).astype(float) + matched += 1 + elif base == 'macd': + long_score += (col > 0).astype(float) + short_score += (col < 0).astype(float) + matched += 1 + elif 'bb_pband' in base: + long_score += (col < 0.2).astype(float) + short_score += (col > 0.8).astype(float) + matched += 1 + elif 'adx' in base: + long_score += (col > 25).astype(float) + short_score += (col > 25).astype(float) + matched += 1 + elif 'cci' in base: + long_score += (col < -100).astype(float) + short_score += (col > 100).astype(float) + matched += 1 + elif 'stoch_k' in base or 'stoch_rsi_k' in base: + long_score += (col < 25).astype(float) + short_score += (col > 75).astype(float) + matched += 1 + elif 'williams_r' in base: + long_score += (col < -80).astype(float) + short_score += (col > -20).astype(float) + matched += 1 + elif 'ao' in base or 'tsi' in base or 'roc' in base or 'ppo' in base: + long_score += (col > 0).astype(float) + short_score += (col < 0).astype(float) + matched += 1 + elif 'atr' in base or 'volatility_ratio' in base: + # 波动率类:高波动时趋势更强,用均值分界 + median = col.rolling(200, min_periods=50).median() + long_score += (col > median).astype(float) * 0.5 + short_score += (col > median).astype(float) * 0.5 + matched += 1 + elif 'high_low_range' in base or 'body_ratio' in base: + median = col.rolling(200, min_periods=50).median() + long_score += (col > median).astype(float) * 0.3 + short_score += (col > median).astype(float) * 0.3 + matched += 1 + elif 'bb_width' in base: + # 布林带宽度收窄后扩张 = 突破信号,结合价格方向 + median = col.rolling(200, min_periods=50).median() + expanding = col > col.shift(1) # 宽度在扩张 + was_narrow = col.shift(1) < median # 之前是收窄的 + breakout = expanding & was_narrow + if 'close' in df.columns: + price_up = df['close'] > df['close'].shift(1) + long_score += (breakout & price_up).astype(float) + short_score += (breakout & ~price_up).astype(float) + else: + long_score += breakout.astype(float) * 0.5 + short_score += breakout.astype(float) * 0.5 + matched += 1 + elif 'close_sma20_ratio' in base or 'close_ema12_ratio' in base: + # 价格在均线上方=多头,下方=空头 + long_score += (col > 1.0).astype(float) + short_score += (col < 1.0).astype(float) + matched += 1 + elif 'ichimoku' in base: + if 'close' in df.columns: + long_score += (df['close'] > col).astype(float) + short_score += (df['close'] < col).astype(float) + matched += 1 + elif 'momentum' in base or 'price_change' in base: + long_score += (col > 0).astype(float) + short_score += (col < 0).astype(float) + matched += 1 + + logger.info(f"规则匹配: {matched}/{len(self.top_features)} 个特征有对应规则") + + # 阈值:至少50%的匹配特征同时确认(更严格) + threshold = max(3, matched * 0.5) + logger.info(f"信号阈值: {threshold:.1f} (需要至少这么多指标同时确认)") + signals[long_score >= threshold] = 1 + signals[short_score >= threshold] = 2 + # 多空同时满足时取更强的 + both = (long_score >= threshold) & (short_score >= threshold) + signals[both & (long_score > short_score)] = 1 + signals[both & (short_score > long_score)] = 2 + signals[both & (long_score == short_score)] = 0 + + dist = signals.value_counts().to_dict() + logger.info(f"规则信号分布: 观望={dist.get(0, 0)}, 做多={dist.get(1, 0)}, 做空={dist.get(2, 0)}") + return signals + + def run(self, period: int = None, start_date: str = None, end_date: str = None) -> dict: + """ + 完整运行方案A + :return: 回测结果 + """ + if period is None: + period = PRIMARY_PERIOD + + logger.info("=" * 60) + logger.info("方案A:统计筛选 + 规则组合策略") + logger.info("=" * 60) + + # 1. 准备数据(标准化版本,用于特征筛选) + X, y, feature_names, _ = prepare_dataset(period, start_date, end_date) + + # 2. 筛选特征 + self.select_features(X, y) + + # 3. 构建完整特征矩阵(原始值,非标准化,用于规则判断) + from .data_loader import load_kline, load_multi_period + from .feature_engine import build_features + primary_df = load_kline(period, start_date, end_date) + aux_dfs = {} + for aux_p in AUX_PERIODS: + try: + aux_df = load_kline(aux_p, start_date, end_date) + if not aux_df.empty: + aux_dfs[aux_p] = aux_df + except Exception: + pass + df = build_features(primary_df, aux_dfs) + df.dropna(inplace=True) + + # 4. 生成信号 + signals = self.generate_signals(df) + + # 5. 回测 + engine = BacktestEngine() + result = engine.run(df['close'], signals) + + print_metrics(result['metrics'], "方案A: 统计筛选策略") + return result + + +def run_stat_strategy(period: int = None, start_date: str = None, end_date: str = None) -> dict: + """方案A快捷入口""" + strategy = StatStrategy() + return strategy.run(period, start_date, end_date) diff --git a/test.py b/test.py new file mode 100644 index 0000000..842179d --- /dev/null +++ b/test.py @@ -0,0 +1,2 @@ +from strategy.compare import run_full_comparison +results = run_full_comparison(period=15) \ No newline at end of file diff --git a/四分之一,五分钟,反手条件充足.py b/四分之一,五分钟,反手条件充足.py new file mode 100644 index 0000000..66c5888 --- /dev/null +++ b/四分之一,五分钟,反手条件充足.py @@ -0,0 +1,694 @@ +import sys +import time +from pathlib import Path + +from tqdm import tqdm +from loguru import logger +from bit_tools import openBrowser +from DrissionPage import ChromiumPage +from DrissionPage import ChromiumOptions + +from bitmart.api_contract import APIContract + +# 方案B:从 strategy 模块获取实盘信号(需先运行 AI 策略训练保存模型,并保持 models/database.db 有最新 15m/5m/1h K 线,例如运行 抓取多周期K线.py) +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from strategy.ai_strategy import get_live_signal + + +class BitmartFuturesTransaction: + def __init__(self, bit_id): + + self.page: ChromiumPage | None = None + + self.api_key = "a0fb7b98464fd9bcce67e7c519d58ec10d0c38a8" + self.secret_key = "4eaeba78e77aeaab1c2027f846a276d164f264a44c2c1bb1c5f3be50c8de1ca5" + self.memo = "合约交易" + + self.contract_symbol = "ETHUSDT" + + self.contractAPI = APIContract(self.api_key, self.secret_key, self.memo, timeout=(5, 15)) + + self.start = 0 # 持仓状态: -1 空, 0 无, 1 多 + + self.pbar = tqdm(total=30, desc="等待K线", ncols=80) # 可选:用于长时间等待时展示进度 + + self.last_kline_time = None # 上一次出信号的 15 分钟 K 线 id(方案B 每根 15m 只出一次信号) + + # 反手频率控制 + self.reverse_cooldown_seconds = 1.5 * 60 # 反手冷却时间(秒) + self.reverse_min_move_pct = 0.05 # 反手最小价差过滤(百分比) + self.last_reverse_time = None # 上次反手时间 + + # 开仓频率控制 + self.open_cooldown_seconds = 60 # 开仓冷却时间(秒),两次开仓至少间隔此时长 + self.last_open_time = None # 上次开仓时间 + self.last_open_kline_id = None # 上次开仓所在 K 线 id,同一根 K 线只允许开仓一次 + + self.leverage = "100" # 高杠杆(全仓模式下可开更大仓位) + self.open_type = "cross" # 全仓模式 + self.risk_percent = 0 # 未使用;若启用则可为每次开仓占可用余额的百分比 + self.take_profit_usd = 5 # 仓位盈利达到此金额(美元)时平仓止盈 + self.stop_loss_usd = -3 # 固定止损:亏损达到 3 美元平仓 + self.trailing_activation_usd = 2 # 盈利达到此金额后启动移动止损 + self.trailing_distance_usd = 1.5 # 从最高盈利回撤此金额则平仓 + self.max_unrealized_pnl_seen = None # 持仓期间见过的最大盈利(用于移动止损) + + self.open_avg_price = None # 开仓价格 + self.current_amount = None # 持仓量 + + self.bit_id = bit_id + self.default_order_size = 25 # 开仓/反手张数,统一在此修改 + + # 策略相关变量 + self.prev_kline = None # 上一根K线 + self.current_kline = None # 当前K线 + self.prev_entity = None # 上一根K线实体大小 + self.current_open = None # 当前K线开盘价 + + def get_klines(self): + """获取最近2根K线(当前K线和上一根K线)""" + try: + end_time = int(time.time()) + # 获取足够多的条目确保有最新的K线 + response = self.contractAPI.get_kline( + contract_symbol=self.contract_symbol, + step=5, # 5分钟 + start_time=end_time - 3600 * 3, # 取最近3小时 + end_time=end_time + )[0]["data"] + + # 每根: [timestamp, open, high, low, close, volume] + formatted = [] + for k in response: + formatted.append({ + 'id': int(k["timestamp"]), + 'open': float(k["open_price"]), + 'high': float(k["high_price"]), + 'low': float(k["low_price"]), + 'close': float(k["close_price"]) + }) + formatted.sort(key=lambda x: x['id']) + + # 返回最近2根K线:倒数第二根(上一根)和最后一根(当前) + if len(formatted) >= 2: + return formatted[-2], formatted[-1] + return None, None + except Exception as e: + logger.error(f"获取K线异常: {e}") + self.ding(text="获取K线异常", error=True) + return None, None + + def get_current_price(self): + """获取当前最新价格""" + try: + end_time = int(time.time()) + response = self.contractAPI.get_kline( + contract_symbol=self.contract_symbol, + step=1, # 1分钟 + start_time=end_time - 3600 * 1, # 取最近1小时 + end_time=end_time + )[0] + if response['code'] == 1000: + return float(response['data'][-1]["close_price"]) + return None + except Exception as e: + logger.error(f"获取价格异常: {e}") + return None + + def get_available_balance(self): + """获取合约账户可用USDT余额""" + try: + response = self.contractAPI.get_assets_detail()[0] + if response['code'] == 1000: + data = response['data'] + if isinstance(data, dict): + return float(data.get('available_balance', 0)) + elif isinstance(data, list): + for asset in data: + if asset.get('currency') == 'USDT': + return float(asset.get('available_balance', 0)) + return None + except Exception as e: + logger.error(f"余额查询异常: {e}") + return None + + def get_position_status(self): + """获取当前持仓方向""" + try: + response = self.contractAPI.get_position(contract_symbol=self.contract_symbol)[0] + if response['code'] == 1000: + positions = response['data'] + if not positions: + self.start = 0 + self.open_avg_price = None + self.current_amount = None + self.unrealized_pnl = None + return True + pos = positions[0] + self.start = 1 if pos['position_type'] == 1 else -1 + self.open_avg_price = float(pos['open_avg_price']) + self.current_amount = float(pos['current_amount']) + self.position_cross = pos["position_cross"] + # 直接从API获取未实现盈亏(Bitmart返回的是 unrealized_value 字段) + self.unrealized_pnl = float(pos.get('unrealized_value', 0)) + logger.debug(f"持仓详情: 方向={self.start}, 开仓均价={self.open_avg_price}, " + f"持仓量={self.current_amount}, 未实现盈亏={self.unrealized_pnl:.2f}") + return True + else: + return False + except Exception as e: + logger.error(f"持仓查询异常: {e}") + return False + + def get_unrealized_pnl_usd(self): + """ + 获取当前持仓未实现盈亏(美元),直接使用API返回值 + """ + if self.start == 0 or self.unrealized_pnl is None: + return None + return self.unrealized_pnl + + def set_leverage(self): + """程序启动时设置全仓 + 高杠杆""" + try: + response = self.contractAPI.post_submit_leverage( + contract_symbol=self.contract_symbol, + leverage=self.leverage, + open_type=self.open_type + )[0] + if response['code'] == 1000: + logger.success(f"全仓模式 + {self.leverage}x 杠杆设置成功") + return True + else: + logger.error(f"杠杆设置失败: {response}") + return False + except Exception as e: + logger.error(f"设置杠杆异常: {e}") + return False + + def openBrowser(self): + """打开 TGE 对应浏览器实例""" + try: + bit_port = openBrowser(id=self.bit_id) + co = ChromiumOptions() + co.set_local_port(port=bit_port) + self.page = ChromiumPage(addr_or_opts=co) + return True + except: + return False + + def click_safe(self, xpath, sleep=0.5): + """安全点击""" + try: + ele = self.page.ele(xpath) + if not ele: + return False + # ele.scroll.to_see(center=True) + # time.sleep(sleep) + ele.click(by_js=True) + return True + except: + return False + + def 平仓(self): + """平仓操作""" + self.click_safe('x://span[normalize-space(text()) ="市价"]') + + def 开单(self, marketPriceLongOrder=0, limitPriceShortOrder=0, size=None, price=None): + """ + marketPriceLongOrder 市价做多或者做空,1是做多,-1是做空 + limitPriceShortOrder 限价做多或者做空 + """ + if marketPriceLongOrder == -1: + # self.click_safe('x://button[normalize-space(text()) ="市价"]') + # self.page.ele('x://*[@id="size_0"]').input(vals=size, clear=True) + self.click_safe('x://span[normalize-space(text()) ="卖出/做空"]') + elif marketPriceLongOrder == 1: + # self.click_safe('x://button[normalize-space(text()) ="市价"]') + # self.page.ele('x://*[@id="size_0"]').input(vals=size, clear=True) + self.click_safe('x://span[normalize-space(text()) ="买入/做多"]') + + if limitPriceShortOrder == -1: + self.click_safe('x://button[normalize-space(text()) ="限价"]') + self.page.ele('x://*[@id="price_0"]').input(vals=price, clear=True) + time.sleep(1) + self.page.ele('x://*[@id="size_0"]').input(1) + self.click_safe('x://span[normalize-space(text()) ="卖出/做空"]') + elif limitPriceShortOrder == 1: + self.click_safe('x://button[normalize-space(text()) ="限价"]') + self.page.ele('x://*[@id="price_0"]').input(vals=price, clear=True) + time.sleep(1) + self.page.ele('x://*[@id="size_0"]').input(1) + self.click_safe('x://span[normalize-space(text()) ="买入/做多"]') + + def ding(self, text, error=False): + """日志通知""" + if error: + logger.error(text) + else: + logger.info(text) + + def calculate_entity(self, kline): + """计算K线实体大小(绝对值)""" + return abs(kline['close'] - kline['open']) + + def calculate_upper_shadow(self, kline): + """计算上阴线(上影线)涨幅百分比""" + # 上阴线 = (最高价 - max(开盘价, 收盘价)) / max(开盘价, 收盘价) + body_top = max(kline['open'], kline['close']) + if body_top == 0: + return 0 + return (kline['high'] - body_top) / body_top * 100 + + def calculate_lower_shadow(self, kline): + """计算下阴线(下影线)跌幅百分比""" + # 下阴线 = (min(开盘价, 收盘价) - 最低价) / min(开盘价, 收盘价) + body_bottom = min(kline['open'], kline['close']) + if body_bottom == 0: + return 0 + return (body_bottom - kline['low']) / body_bottom * 100 + + def get_entity_edge(self, kline): + """获取K线实体边(收盘价或开盘价,取决于是阳线还是阴线)""" + # 阳线(收盘>开盘):实体上边=收盘价,实体下边=开盘价 + # 阴线(收盘<开盘):实体上边=开盘价,实体下边=收盘价 + return { + 'upper': max(kline['open'], kline['close']), # 实体上边 + 'lower': min(kline['open'], kline['close']) # 实体下边 + } + + def check_signal(self, current_price, prev_kline, current_kline): + """ + 检查交易信号 + 返回: ('long', trigger_price) / ('short', trigger_price) / None + """ + # 计算上一根K线实体 + prev_entity = self.calculate_entity(prev_kline) + + # 实体过小不交易(实体 < 0.1) + if prev_entity < 0.1: + logger.info(f"上一根K线实体过小: {prev_entity:.4f},跳过信号检测") + return None + + # 获取上一根K线的实体上下边 + prev_entity_edge = self.get_entity_edge(prev_kline) + prev_entity_upper = prev_entity_edge['upper'] # 实体上边 + prev_entity_lower = prev_entity_edge['lower'] # 实体下边 + + # 优化:以下两种情况以当前这根的开盘价作为计算基准 + # 1) 上一根阳线 且 当前开盘价 > 上一根收盘价(跳空高开) + # 2) 上一根阴线 且 当前开盘价 < 上一根收盘价(跳空低开) + prev_is_bullish_for_calc = prev_kline['close'] > prev_kline['open'] + prev_is_bearish_for_calc = prev_kline['close'] < prev_kline['open'] + current_open_above_prev_close = current_kline['open'] > prev_kline['close'] + current_open_below_prev_close = current_kline['open'] < prev_kline['close'] + use_current_open_as_base = (prev_is_bullish_for_calc and current_open_above_prev_close) or (prev_is_bearish_for_calc and current_open_below_prev_close) + + if use_current_open_as_base: + # 以当前K线开盘价为基准计算(跳空时用当前开盘价参与计算) + calc_lower = current_kline['open'] + calc_upper = current_kline['open'] # 同一基准,上下四分之一对称 + long_trigger = calc_lower + prev_entity / 4 + short_trigger = calc_upper - prev_entity / 4 + long_breakout = calc_upper + prev_entity / 4 + short_breakout = calc_lower - prev_entity / 4 + else: + # 原有计算方式 + long_trigger = prev_entity_lower + prev_entity / 4 # 做多触发价 = 实体下边 + 实体/4(下四分之一处) + short_trigger = prev_entity_upper - prev_entity / 4 # 做空触发价 = 实体上边 - 实体/4(上四分之一处) + long_breakout = prev_entity_upper + prev_entity / 4 # 做多突破价 = 实体上边 + 实体/4 + short_breakout = prev_entity_lower - prev_entity / 4 # 做空突破价 = 实体下边 - 实体/4 + + # 上一根阴线 + 当前阳线:做多形态,不按上一根K线上三分之一做空 + prev_is_bearish = prev_kline['close'] < prev_kline['open'] + current_is_bullish = current_kline['close'] > current_kline['open'] + skip_short_by_upper_third = prev_is_bearish and current_is_bullish + # 上一根阳线 + 当前阴线:做空形态,不按上一根K线下三分之一做多 + prev_is_bullish = prev_kline['close'] > prev_kline['open'] + current_is_bearish = current_kline['close'] < current_kline['open'] + skip_long_by_lower_third = prev_is_bullish and current_is_bearish + + if use_current_open_as_base: + if prev_is_bullish_for_calc and current_open_above_prev_close: + logger.info(f"上一根阳线且当前开盘价({current_kline['open']:.2f})>上一根收盘价({prev_kline['close']:.2f}),以当前开盘价为基准计算") + else: + logger.info(f"上一根阴线且当前开盘价({current_kline['open']:.2f})<上一根收盘价({prev_kline['close']:.2f}),以当前开盘价为基准计算") + logger.info(f"当前价格: {current_price:.2f}, 上一根实体: {prev_entity:.4f}") + logger.info(f"上一根实体上边: {prev_entity_upper:.2f}, 下边: {prev_entity_lower:.2f}") + logger.info(f"做多触发价(下1/4): {long_trigger:.2f}, 做空触发价(上1/4): {short_trigger:.2f}") + logger.info(f"突破做多价(上1/4外): {long_breakout:.2f}, 突破做空价(下1/4外): {short_breakout:.2f}") + if skip_short_by_upper_third: + logger.info("上一根阴线+当前阳线(做多形态),不按上四分之一做空") + if skip_long_by_lower_third: + logger.info("上一根阳线+当前阴线(做空形态),不按下四分之一做多") + + # 无持仓时检查开仓信号 + if self.start == 0: + if current_price >= long_breakout and not skip_long_by_lower_third: + logger.info(f"触发做多信号!价格 {current_price:.2f} >= 突破价(上1/4外) {long_breakout:.2f}") + return ('long', long_breakout) + elif current_price <= short_breakout and not skip_short_by_upper_third: + logger.info(f"触发做空信号!价格 {current_price:.2f} <= 突破价(下1/4外) {short_breakout:.2f}") + return ('short', short_breakout) + + # 持仓时检查反手信号 + elif self.start == 1: # 持多仓 + # 反手条件1: 价格跌到上一根K线的上三分之一处(做空触发价);上一根阴线+当前阳线做多时跳过 + if current_price <= short_trigger and not skip_short_by_upper_third: + logger.info(f"持多反手做空!价格 {current_price:.2f} <= 触发价(上1/4) {short_trigger:.2f}") + return ('reverse_short', short_trigger) + + # 反手条件2: 上一根K线上阴线涨幅>0.01%,当前跌到上一根实体下边 + upper_shadow_pct = self.calculate_upper_shadow(prev_kline) + if upper_shadow_pct > 0.01 and current_price <= prev_entity_lower: + logger.info(f"持多反手做空!上阴线涨幅 {upper_shadow_pct:.4f}% > 0.01%," + f"价格 {current_price:.2f} <= 实体下边 {prev_entity_lower:.2f}") + return ('reverse_short', prev_entity_lower) + + elif self.start == -1: # 持空仓 + # 反手条件1: 价格涨到上一根K线的下三分之一处(做多触发价);上一根阳线+当前阴线做空时跳过 + if current_price >= long_trigger and not skip_long_by_lower_third: + logger.info(f"持空反手做多!价格 {current_price:.2f} >= 触发价(下1/4) {long_trigger:.2f}") + return ('reverse_long', long_trigger) + + # 反手条件2: 上一根K线下阴线跌幅>0.01%,当前涨到上一根实体上边 + lower_shadow_pct = self.calculate_lower_shadow(prev_kline) + if lower_shadow_pct > 0.01 and current_price >= prev_entity_upper: + logger.info(f"持空反手做多!下阴线跌幅 {lower_shadow_pct:.4f}% > 0.01%," + f"价格 {current_price:.2f} >= 实体上边 {prev_entity_upper:.2f}") + return ('reverse_long', prev_entity_upper) + + return None + + def can_open(self, current_kline_id): + """开仓前过滤:同一根 K 线只开一次 + 开仓冷却时间。仅用于 long/short 新开仓。""" + now = time.time() + if self.last_open_kline_id is not None and self.last_open_kline_id == current_kline_id: + logger.info(f"开仓频率控制:本 K 线({current_kline_id})已开过仓,跳过") + return False + if self.last_open_time is not None and now - self.last_open_time < self.open_cooldown_seconds: + remain = self.open_cooldown_seconds - (now - self.last_open_time) + logger.info(f"开仓冷却中,剩余 {remain:.0f} 秒") + return False + return True + + def can_reverse(self, current_price, trigger_price): + """反手前过滤:冷却时间 + 最小价差""" + now = time.time() + if self.last_reverse_time and now - self.last_reverse_time < self.reverse_cooldown_seconds: + remain = self.reverse_cooldown_seconds - (now - self.last_reverse_time) + logger.info(f"反手冷却中,剩余 {remain:.0f} 秒") + return False + + if trigger_price and trigger_price > 0: + move_pct = abs(current_price - trigger_price) / trigger_price * 100 + if move_pct < self.reverse_min_move_pct: + logger.info(f"反手价差不足: {move_pct:.4f}% < {self.reverse_min_move_pct}%") + return False + + return True + + def verify_no_position(self, max_retries=5, retry_interval=3): + """ + 验证当前无持仓 + 返回: True 表示无持仓可以开仓,False 表示有持仓不能开仓 + """ + for i in range(max_retries): + if self.get_position_status(): + if self.start == 0: + logger.info(f"确认无持仓,可以开仓") + return True + else: + logger.warning( + f"仍有持仓 (方向: {self.start}),等待 {retry_interval} 秒后重试 ({i + 1}/{max_retries})") + time.sleep(retry_interval) + else: + logger.warning(f"查询持仓状态失败,等待 {retry_interval} 秒后重试 ({i + 1}/{max_retries})") + time.sleep(retry_interval) + + logger.error(f"经过 {max_retries} 次重试仍有持仓或查询失败,放弃开仓") + return False + + def verify_position_direction(self, expected_direction): + """ + 验证当前持仓方向是否与预期一致 + expected_direction: 1 多仓, -1 空仓 + 返回: True 表示持仓方向正确,False 表示不正确 + """ + if self.get_position_status(): + if self.start == expected_direction: + logger.info(f"持仓方向验证成功: {self.start}") + return True + else: + logger.warning(f"持仓方向不符: 期望 {expected_direction}, 实际 {self.start}") + return False + else: + logger.error("查询持仓状态失败") + return False + + def execute_trade(self, signal, size=None): + """执行交易。size 不传或为 None 时使用 default_order_size。""" + signal_type, trigger_price = signal + size = self.default_order_size if size is None else size + + if signal_type == 'long': + # 开多前先确认无持仓 + logger.info(f"准备开多,触发价: {trigger_price:.2f}") + if not self.get_position_status(): + logger.error("开仓前查询持仓状态失败,放弃开仓") + return False + if self.start != 0: + logger.warning(f"开多前发现已有持仓 (方向: {self.start}),放弃开仓避免双向持仓") + return False + + logger.info(f"确认无持仓,执行开多") + self.开单(marketPriceLongOrder=1, size=size) + time.sleep(3) # 等待订单执行 + + # 验证开仓是否成功 + if self.verify_position_direction(1): + self.max_unrealized_pnl_seen = None # 新仓位重置移动止损记录 + self.last_open_time = time.time() + self.last_open_kline_id = getattr(self, "_current_kline_id_for_open", None) + logger.success("开多成功") + return True + else: + logger.error("开多后持仓验证失败") + return False + + elif signal_type == 'short': + # 开空前先确认无持仓 + logger.info(f"准备开空,触发价: {trigger_price:.2f}") + if not self.get_position_status(): + logger.error("开仓前查询持仓状态失败,放弃开仓") + return False + if self.start != 0: + logger.warning(f"开空前发现已有持仓 (方向: {self.start}),放弃开仓避免双向持仓") + return False + + logger.info(f"确认无持仓,执行开空") + self.开单(marketPriceLongOrder=-1, size=size) + time.sleep(3) # 等待订单执行 + + # 验证开仓是否成功 + if self.verify_position_direction(-1): + self.max_unrealized_pnl_seen = None # 新仓位重置移动止损记录 + self.last_open_time = time.time() + self.last_open_kline_id = getattr(self, "_current_kline_id_for_open", None) + logger.success("开空成功") + return True + else: + logger.error("开空后持仓验证失败") + return False + + elif signal_type == 'reverse_long': + # 平空 + 开多(反手做多):先平仓,确认无仓后再开多,避免双向持仓 + logger.info(f"执行反手做多,触发价: {trigger_price:.2f}") + self.平仓() + time.sleep(1) # 给交易所处理平仓的时间 + # 轮询确认已无持仓再开多(最多等约 10 秒) + for _ in range(10): + if self.get_position_status() and self.start == 0: + break + time.sleep(1) + if self.start != 0: + logger.warning("反手做多:平仓后仍有持仓,放弃本次开多") + return False + logger.info("已确认无持仓,执行开多") + self.开单(marketPriceLongOrder=1, size=size) + time.sleep(3) + + if self.verify_position_direction(1): + self.max_unrealized_pnl_seen = None + logger.success("反手做多成功") + self.last_reverse_time = time.time() + time.sleep(20) + return True + else: + logger.error("反手做多后持仓验证失败") + return False + + elif signal_type == 'reverse_short': + # 平多 + 开空(反手做空):先平仓,确认无仓后再开空 + logger.info(f"执行反手做空,触发价: {trigger_price:.2f}") + self.平仓() + time.sleep(1) + for _ in range(10): + if self.get_position_status() and self.start == 0: + break + time.sleep(1) + if self.start != 0: + logger.warning("反手做空:平仓后仍有持仓,放弃本次开空") + return False + logger.info("已确认无持仓,执行开空") + self.开单(marketPriceLongOrder=-1, size=size) + time.sleep(3) + + if self.verify_position_direction(-1): + self.max_unrealized_pnl_seen = None + logger.success("反手做空成功") + self.last_reverse_time = time.time() + time.sleep(20) + return True + else: + logger.error("反手做空后持仓验证失败") + return False + + return False + + def action(self): + """主循环""" + + logger.info("开始运行方案B(AI 策略)交易...") + + # 启动时设置全仓高杠杆 + if not self.set_leverage(): + logger.error("杠杆设置失败,程序继续运行但可能下单失败") + return + + page_start = True + + while True: + + if page_start: + # 打开浏览器 + for i in range(5): + if self.openBrowser(): + logger.info("浏览器打开成功") + break + else: + self.ding("打开浏览器失败!", error=True) + return + + # 进入交易页面 + self.page.get("https://derivatives.bitmart.com/zh-CN/futures/ETHUSDT") + self.click_safe('x://button[normalize-space(text()) ="市价"]') + + self.page.ele('x://*[@id="size_0"]').input(vals=25, clear=True) + + page_start = False + + try: + # 1. 获取当前价格 + current_price = self.get_current_price() + if not current_price: + logger.warning("获取价格失败,等待重试...") + time.sleep(2) + continue + + # 2. 每次循环都通过SDK获取真实持仓状态(避免状态不同步导致双向持仓) + if not self.get_position_status(): + logger.warning("获取持仓状态失败,等待重试...") + time.sleep(2) + continue + + logger.debug(f"当前持仓状态: {self.start} (0=无, 1=多, -1=空)") + + # 3. 止损/止盈/移动止损 + if self.start != 0: + pnl_usd = self.get_unrealized_pnl_usd() + if pnl_usd is not None: + # 固定止损:亏损达到 3 美元平仓 + if pnl_usd <= self.stop_loss_usd: + logger.info(f"仓位亏损 {pnl_usd:.2f} 美元 <= 止损 {self.stop_loss_usd} 美元,执行止损平仓") + self.平仓() + self.max_unrealized_pnl_seen = None + time.sleep(3) + continue + # 更新持仓期间最大盈利(用于移动止损) + if self.max_unrealized_pnl_seen is None: + self.max_unrealized_pnl_seen = pnl_usd + else: + self.max_unrealized_pnl_seen = max(self.max_unrealized_pnl_seen, pnl_usd) + # 移动止损:盈利曾达到 activation 后,从最高盈利回撤 trailing_distance 则平仓 + if self.max_unrealized_pnl_seen >= self.trailing_activation_usd: + if pnl_usd < self.max_unrealized_pnl_seen - self.trailing_distance_usd: + logger.info(f"移动止损:当前盈利 {pnl_usd:.2f} 从最高 {self.max_unrealized_pnl_seen:.2f} 回撤 >= {self.trailing_distance_usd} 美元,平仓") + self.平仓() + self.max_unrealized_pnl_seen = None + time.sleep(3) + continue + # 止盈:盈利达到 take_profit_usd 平仓 + if pnl_usd >= self.take_profit_usd: + logger.info(f"仓位盈利 {pnl_usd:.2f} 美元 >= {self.take_profit_usd} 美元,执行止盈平仓") + self.平仓() + self.max_unrealized_pnl_seen = None + time.sleep(3) + continue + + # 4. 方案B:仅在新的 15 分钟 K 线时取一次信号(0=观望, 1=做多, 2=做空) + current_15m_id = int(time.time() // 900) * 900 # 15 分钟 bar 起始时间戳 + signal = None + if current_15m_id != self.last_kline_time: + self.last_kline_time = current_15m_id + logger.info(f"进入新 15m K 线: {current_15m_id}") + raw = get_live_signal(period=15) + if raw == 1: + if self.start == 0: + signal = ('long', current_price) + elif self.start == -1: + signal = ('reverse_long', current_price) + elif raw == 2: + if self.start == 0: + signal = ('short', current_price) + elif self.start == 1: + signal = ('reverse_short', current_price) + + # 5. 反手过滤:冷却时间 + 最小价差 + if signal and signal[0].startswith('reverse_'): + if not self.can_reverse(current_price, signal[1]): + signal = None + + # 5.5 开仓频率过滤:同一根 15m K 线只开一次 + 开仓冷却 + if signal and signal[0] in ('long', 'short'): + if not self.can_open(current_15m_id): + signal = None + else: + self._current_kline_id_for_open = current_15m_id # 供 execute_trade 成功后记录 + + # 6. 有信号则执行交易 + if signal: + trade_success = self.execute_trade(signal) + if trade_success: + logger.success(f"交易执行完成: {signal[0]}, 当前持仓状态: {self.start}") + page_start = True + else: + logger.warning(f"交易执行失败或被阻止: {signal[0]}") + + # 短暂等待后继续循环(同一根K线遇到信号就操作) + time.sleep(0.1) + + if page_start: + self.page.close() + time.sleep(5) + + except KeyboardInterrupt: + logger.info("用户中断,程序退出") + break + except Exception as e: + logger.error(f"主循环异常: {e}") + time.sleep(5) + + +if __name__ == '__main__': + BitmartFuturesTransaction(bit_id="f2320f57e24c45529a009e1541e25961").action()