""" 方案对比评估 — 方案A(统计筛选) vs 方案B(AI模型) 并排对比 + 报告输出 """ import pandas as pd import numpy as np import matplotlib matplotlib.use('Agg') # 非交互式后端 import matplotlib.pyplot as plt from matplotlib import font_manager from pathlib import Path from loguru import logger # 设置中文字体 _zh_font = None for fname in ['PingFang SC', 'Heiti SC', 'STHeiti', 'SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei']: try: _zh_font = font_manager.FontProperties(family=fname) # 验证字体存在 font_manager.findfont(_zh_font, fallback_to_default=False) plt.rcParams['font.sans-serif'] = [fname, 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False break except Exception: _zh_font = None continue from .config import PRIMARY_PERIOD, PROJECT_ROOT from .stat_strategy import StatStrategy from .ai_strategy import AIStrategy from .backtest import print_metrics REPORT_DIR = PROJECT_ROOT / 'reports' def compare_strategies(period: int = None, start_date: str = None, end_date: str = None, save_plot: bool = True) -> dict: """ 运行两种方案并对比 :return: {'stat': result_a, 'lgb': result_b, 'xgb': result_c, 'comparison': DataFrame} """ if period is None: period = PRIMARY_PERIOD logger.info("=" * 70) logger.info(" 开始策略对比评估") logger.info("=" * 70) results = {} # 方案A:统计筛选 logger.info("\n>>> 运行方案A: 统计筛选策略") stat = StatStrategy() results['stat'] = stat.run(period, start_date, end_date) # 方案B-1:LightGBM logger.info("\n>>> 运行方案B-1: LightGBM AI策略") lgb_strategy = AIStrategy(model_type='lightgbm') results['lgb'] = lgb_strategy.run(period, start_date, end_date) # 方案B-2:XGBoost logger.info("\n>>> 运行方案B-2: XGBoost AI策略") xgb_strategy = AIStrategy(model_type='xgboost') results['xgb'] = xgb_strategy.run(period, start_date, end_date) # 对比表格 comparison = _build_comparison_table(results) results['comparison'] = comparison # 打印对比 logger.info("\n" + "=" * 70) logger.info(" 策略对比总结") logger.info("=" * 70) logger.info(f"\n{comparison.to_string()}") # 每月盈利(U) _log_monthly_pnl(results) # 保存图表 if save_plot: _save_equity_plot(results) return results def _build_comparison_table(results: dict) -> pd.DataFrame: """构建对比表格""" rows = {} name_map = { 'stat': '方案A: 统计筛选', 'lgb': '方案B-1: LightGBM', 'xgb': '方案B-2: XGBoost', } for key, name in name_map.items(): if key in results and 'metrics' in results[key]: rows[name] = results[key]['metrics'] if not rows: return pd.DataFrame() df = pd.DataFrame(rows).T return df def _log_monthly_pnl(results: dict): """打印各策略每月盈利(USDT)""" name_map = { 'stat': '方案A', 'lgb': 'LightGBM', 'xgb': 'XGBoost', } cols = [] for key, name in name_map.items(): if key not in results or results[key].get('monthly_pnl') is None: continue s = results[key]['monthly_pnl'] if s is None or s.empty: continue s = s.astype(float).round(2) s.name = name cols.append(s) if not cols: return monthly_df = pd.concat(cols, axis=1) monthly_df = monthly_df.fillna(0) logger.info("\n" + "-" * 70) logger.info(" 每月盈利 (USDT)") logger.info("-" * 70) logger.info(f"\n{monthly_df.to_string()}") logger.info("-" * 70) def _save_equity_plot(results: dict): """保存权益曲线对比图""" REPORT_DIR.mkdir(parents=True, exist_ok=True) fig, axes = plt.subplots(2, 1, figsize=(14, 10)) # 上图:权益曲线 ax1 = axes[0] name_map = { 'stat': '方案A: 统计筛选', 'lgb': '方案B-1: LightGBM', 'xgb': '方案B-2: XGBoost', } colors = {'stat': '#2196F3', 'lgb': '#4CAF50', 'xgb': '#FF9800'} for key, name in name_map.items(): if key in results and 'equity_curve' in results[key]: eq = results[key]['equity_curve'] if not eq.empty: ax1.plot(eq.index, eq.values, label=name, color=colors.get(key, 'gray'), linewidth=1) ax1.set_title('权益曲线对比', fontsize=14) ax1.set_ylabel('资金 (USDT)') ax1.legend() ax1.grid(True, alpha=0.3) # 下图:回撤曲线 ax2 = axes[1] for key, name in name_map.items(): if key in results and 'equity_curve' in results[key]: eq = results[key]['equity_curve'] if not eq.empty: cummax = eq.cummax() drawdown = (eq - cummax) / cummax * 100 ax2.fill_between(drawdown.index, drawdown.values, 0, alpha=0.3, label=name, color=colors.get(key, 'gray')) ax2.set_title('回撤对比', fontsize=14) ax2.set_ylabel('回撤 (%)') ax2.legend() ax2.grid(True, alpha=0.3) plt.tight_layout() plot_path = REPORT_DIR / 'strategy_comparison.png' plt.savefig(str(plot_path), dpi=150) plt.close() logger.info(f"对比图表已保存: {plot_path}") def run_full_comparison(period: int = None, start_date: str = None, end_date: str = None, save_plot: bool = True): """完整对比入口(可直接调用)""" results = compare_strategies(period, start_date, end_date, save_plot=save_plot) # 推荐最优方案 best_key = None best_return = -float('inf') for key in ['stat', 'lgb', 'xgb']: if key in results and 'metrics' in results[key]: ret_str = results[key]['metrics'].get('总收益率', '0%') ret_val = float(ret_str.strip('%')) / 100 if ret_val > best_return: best_return = ret_val best_key = key name_map = {'stat': '方案A: 统计筛选', 'lgb': '方案B-1: LightGBM', 'xgb': '方案B-2: XGBoost'} if best_key: logger.info(f"\n{'='*50}") logger.info(f" 推荐方案: {name_map.get(best_key, best_key)}") logger.info(f" 总收益率: {best_return:.2%}") logger.info(f"{'='*50}") return results if __name__ == '__main__': import argparse p = argparse.ArgumentParser(description='运行策略对比:方案A(统计) vs 方案B(LightGBM/XGBoost)') p.add_argument('--period', type=int, default=None, help='K线周期分钟') p.add_argument('--start', type=str, default=None, help='开始日期,如 2024-01-01') p.add_argument('--end', type=str, default=None, help='结束日期,如 2024-12-31') p.add_argument('--no-plot', action='store_true', help='不保存权益/回撤图') args = p.parse_args() run_full_comparison( period=args.period, start_date=args.start, end_date=args.end, save_plot=not args.no_plot, )