223 lines
7.0 KiB
Python
223 lines
7.0 KiB
Python
"""
|
||
方案对比评估 — 方案A(统计筛选) vs 方案B(AI模型) 并排对比 + 报告输出
|
||
"""
|
||
import pandas as pd
|
||
import numpy as np
|
||
import matplotlib
|
||
matplotlib.use('Agg') # 非交互式后端
|
||
import matplotlib.pyplot as plt
|
||
from matplotlib import font_manager
|
||
from pathlib import Path
|
||
from loguru import logger
|
||
|
||
# 设置中文字体
|
||
_zh_font = None
|
||
for fname in ['PingFang SC', 'Heiti SC', 'STHeiti', 'SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei']:
|
||
try:
|
||
_zh_font = font_manager.FontProperties(family=fname)
|
||
# 验证字体存在
|
||
font_manager.findfont(_zh_font, fallback_to_default=False)
|
||
plt.rcParams['font.sans-serif'] = [fname, 'DejaVu Sans']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
break
|
||
except Exception:
|
||
_zh_font = None
|
||
continue
|
||
|
||
from .config import PRIMARY_PERIOD, PROJECT_ROOT
|
||
from .stat_strategy import StatStrategy
|
||
from .ai_strategy import AIStrategy
|
||
from .backtest import print_metrics
|
||
|
||
|
||
REPORT_DIR = PROJECT_ROOT / 'reports'
|
||
|
||
|
||
def compare_strategies(period: int = None, start_date: str = None, end_date: str = None,
|
||
save_plot: bool = True) -> dict:
|
||
"""
|
||
运行两种方案并对比
|
||
:return: {'stat': result_a, 'lgb': result_b, 'xgb': result_c, 'comparison': DataFrame}
|
||
"""
|
||
if period is None:
|
||
period = PRIMARY_PERIOD
|
||
|
||
logger.info("=" * 70)
|
||
logger.info(" 开始策略对比评估")
|
||
logger.info("=" * 70)
|
||
|
||
results = {}
|
||
|
||
# 方案A:统计筛选
|
||
logger.info("\n>>> 运行方案A: 统计筛选策略")
|
||
stat = StatStrategy()
|
||
results['stat'] = stat.run(period, start_date, end_date)
|
||
|
||
# 方案B-1:LightGBM
|
||
logger.info("\n>>> 运行方案B-1: LightGBM AI策略")
|
||
lgb_strategy = AIStrategy(model_type='lightgbm')
|
||
results['lgb'] = lgb_strategy.run(period, start_date, end_date)
|
||
|
||
# 方案B-2:XGBoost
|
||
logger.info("\n>>> 运行方案B-2: XGBoost AI策略")
|
||
xgb_strategy = AIStrategy(model_type='xgboost')
|
||
results['xgb'] = xgb_strategy.run(period, start_date, end_date)
|
||
|
||
# 对比表格
|
||
comparison = _build_comparison_table(results)
|
||
results['comparison'] = comparison
|
||
|
||
# 打印对比
|
||
logger.info("\n" + "=" * 70)
|
||
logger.info(" 策略对比总结")
|
||
logger.info("=" * 70)
|
||
logger.info(f"\n{comparison.to_string()}")
|
||
|
||
# 每月盈利(U)
|
||
_log_monthly_pnl(results)
|
||
|
||
# 保存图表
|
||
if save_plot:
|
||
_save_equity_plot(results)
|
||
|
||
return results
|
||
|
||
|
||
def _build_comparison_table(results: dict) -> pd.DataFrame:
|
||
"""构建对比表格"""
|
||
rows = {}
|
||
name_map = {
|
||
'stat': '方案A: 统计筛选',
|
||
'lgb': '方案B-1: LightGBM',
|
||
'xgb': '方案B-2: XGBoost',
|
||
}
|
||
|
||
for key, name in name_map.items():
|
||
if key in results and 'metrics' in results[key]:
|
||
rows[name] = results[key]['metrics']
|
||
|
||
if not rows:
|
||
return pd.DataFrame()
|
||
|
||
df = pd.DataFrame(rows).T
|
||
return df
|
||
|
||
|
||
def _log_monthly_pnl(results: dict):
|
||
"""打印各策略每月盈利(USDT)"""
|
||
name_map = {
|
||
'stat': '方案A',
|
||
'lgb': 'LightGBM',
|
||
'xgb': 'XGBoost',
|
||
}
|
||
cols = []
|
||
for key, name in name_map.items():
|
||
if key not in results or results[key].get('monthly_pnl') is None:
|
||
continue
|
||
s = results[key]['monthly_pnl']
|
||
if s is None or s.empty:
|
||
continue
|
||
s = s.astype(float).round(2)
|
||
s.name = name
|
||
cols.append(s)
|
||
if not cols:
|
||
return
|
||
monthly_df = pd.concat(cols, axis=1)
|
||
monthly_df = monthly_df.fillna(0)
|
||
logger.info("\n" + "-" * 70)
|
||
logger.info(" 每月盈利 (USDT)")
|
||
logger.info("-" * 70)
|
||
logger.info(f"\n{monthly_df.to_string()}")
|
||
logger.info("-" * 70)
|
||
|
||
|
||
def _save_equity_plot(results: dict):
|
||
"""保存权益曲线对比图"""
|
||
REPORT_DIR.mkdir(parents=True, exist_ok=True)
|
||
|
||
fig, axes = plt.subplots(2, 1, figsize=(14, 10))
|
||
|
||
# 上图:权益曲线
|
||
ax1 = axes[0]
|
||
name_map = {
|
||
'stat': '方案A: 统计筛选',
|
||
'lgb': '方案B-1: LightGBM',
|
||
'xgb': '方案B-2: XGBoost',
|
||
}
|
||
colors = {'stat': '#2196F3', 'lgb': '#4CAF50', 'xgb': '#FF9800'}
|
||
|
||
for key, name in name_map.items():
|
||
if key in results and 'equity_curve' in results[key]:
|
||
eq = results[key]['equity_curve']
|
||
if not eq.empty:
|
||
ax1.plot(eq.index, eq.values, label=name, color=colors.get(key, 'gray'), linewidth=1)
|
||
|
||
ax1.set_title('权益曲线对比', fontsize=14)
|
||
ax1.set_ylabel('资金 (USDT)')
|
||
ax1.legend()
|
||
ax1.grid(True, alpha=0.3)
|
||
|
||
# 下图:回撤曲线
|
||
ax2 = axes[1]
|
||
for key, name in name_map.items():
|
||
if key in results and 'equity_curve' in results[key]:
|
||
eq = results[key]['equity_curve']
|
||
if not eq.empty:
|
||
cummax = eq.cummax()
|
||
drawdown = (eq - cummax) / cummax * 100
|
||
ax2.fill_between(drawdown.index, drawdown.values, 0,
|
||
alpha=0.3, label=name, color=colors.get(key, 'gray'))
|
||
|
||
ax2.set_title('回撤对比', fontsize=14)
|
||
ax2.set_ylabel('回撤 (%)')
|
||
ax2.legend()
|
||
ax2.grid(True, alpha=0.3)
|
||
|
||
plt.tight_layout()
|
||
plot_path = REPORT_DIR / 'strategy_comparison.png'
|
||
plt.savefig(str(plot_path), dpi=150)
|
||
plt.close()
|
||
logger.info(f"对比图表已保存: {plot_path}")
|
||
|
||
|
||
def run_full_comparison(period: int = None, start_date: str = None, end_date: str = None, save_plot: bool = True):
|
||
"""完整对比入口(可直接调用)"""
|
||
results = compare_strategies(period, start_date, end_date, save_plot=save_plot)
|
||
|
||
# 推荐最优方案
|
||
best_key = None
|
||
best_return = -float('inf')
|
||
|
||
for key in ['stat', 'lgb', 'xgb']:
|
||
if key in results and 'metrics' in results[key]:
|
||
ret_str = results[key]['metrics'].get('总收益率', '0%')
|
||
ret_val = float(ret_str.strip('%')) / 100
|
||
if ret_val > best_return:
|
||
best_return = ret_val
|
||
best_key = key
|
||
|
||
name_map = {'stat': '方案A: 统计筛选', 'lgb': '方案B-1: LightGBM', 'xgb': '方案B-2: XGBoost'}
|
||
if best_key:
|
||
logger.info(f"\n{'='*50}")
|
||
logger.info(f" 推荐方案: {name_map.get(best_key, best_key)}")
|
||
logger.info(f" 总收益率: {best_return:.2%}")
|
||
logger.info(f"{'='*50}")
|
||
|
||
return results
|
||
|
||
|
||
if __name__ == '__main__':
|
||
import argparse
|
||
p = argparse.ArgumentParser(description='运行策略对比:方案A(统计) vs 方案B(LightGBM/XGBoost)')
|
||
p.add_argument('--period', type=int, default=None, help='K线周期分钟')
|
||
p.add_argument('--start', type=str, default=None, help='开始日期,如 2024-01-01')
|
||
p.add_argument('--end', type=str, default=None, help='结束日期,如 2024-12-31')
|
||
p.add_argument('--no-plot', action='store_true', help='不保存权益/回撤图')
|
||
args = p.parse_args()
|
||
run_full_comparison(
|
||
period=args.period,
|
||
start_date=args.start,
|
||
end_date=args.end,
|
||
save_plot=not args.no_plot,
|
||
)
|