223 lines
7.0 KiB
Python
223 lines
7.0 KiB
Python
|
|
"""
|
|||
|
|
方案对比评估 — 方案A(统计筛选) vs 方案B(AI模型) 并排对比 + 报告输出
|
|||
|
|
"""
|
|||
|
|
import pandas as pd
|
|||
|
|
import numpy as np
|
|||
|
|
import matplotlib
|
|||
|
|
matplotlib.use('Agg') # 非交互式后端
|
|||
|
|
import matplotlib.pyplot as plt
|
|||
|
|
from matplotlib import font_manager
|
|||
|
|
from pathlib import Path
|
|||
|
|
from loguru import logger
|
|||
|
|
|
|||
|
|
# 设置中文字体
|
|||
|
|
_zh_font = None
|
|||
|
|
for fname in ['PingFang SC', 'Heiti SC', 'STHeiti', 'SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei']:
|
|||
|
|
try:
|
|||
|
|
_zh_font = font_manager.FontProperties(family=fname)
|
|||
|
|
# 验证字体存在
|
|||
|
|
font_manager.findfont(_zh_font, fallback_to_default=False)
|
|||
|
|
plt.rcParams['font.sans-serif'] = [fname, 'DejaVu Sans']
|
|||
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|||
|
|
break
|
|||
|
|
except Exception:
|
|||
|
|
_zh_font = None
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
from .config import PRIMARY_PERIOD, PROJECT_ROOT
|
|||
|
|
from .stat_strategy import StatStrategy
|
|||
|
|
from .ai_strategy import AIStrategy
|
|||
|
|
from .backtest import print_metrics
|
|||
|
|
|
|||
|
|
|
|||
|
|
REPORT_DIR = PROJECT_ROOT / 'reports'
|
|||
|
|
|
|||
|
|
|
|||
|
|
def compare_strategies(period: int = None, start_date: str = None, end_date: str = None,
|
|||
|
|
save_plot: bool = True) -> dict:
|
|||
|
|
"""
|
|||
|
|
运行两种方案并对比
|
|||
|
|
:return: {'stat': result_a, 'lgb': result_b, 'xgb': result_c, 'comparison': DataFrame}
|
|||
|
|
"""
|
|||
|
|
if period is None:
|
|||
|
|
period = PRIMARY_PERIOD
|
|||
|
|
|
|||
|
|
logger.info("=" * 70)
|
|||
|
|
logger.info(" 开始策略对比评估")
|
|||
|
|
logger.info("=" * 70)
|
|||
|
|
|
|||
|
|
results = {}
|
|||
|
|
|
|||
|
|
# 方案A:统计筛选
|
|||
|
|
logger.info("\n>>> 运行方案A: 统计筛选策略")
|
|||
|
|
stat = StatStrategy()
|
|||
|
|
results['stat'] = stat.run(period, start_date, end_date)
|
|||
|
|
|
|||
|
|
# 方案B-1:LightGBM
|
|||
|
|
logger.info("\n>>> 运行方案B-1: LightGBM AI策略")
|
|||
|
|
lgb_strategy = AIStrategy(model_type='lightgbm')
|
|||
|
|
results['lgb'] = lgb_strategy.run(period, start_date, end_date)
|
|||
|
|
|
|||
|
|
# 方案B-2:XGBoost
|
|||
|
|
logger.info("\n>>> 运行方案B-2: XGBoost AI策略")
|
|||
|
|
xgb_strategy = AIStrategy(model_type='xgboost')
|
|||
|
|
results['xgb'] = xgb_strategy.run(period, start_date, end_date)
|
|||
|
|
|
|||
|
|
# 对比表格
|
|||
|
|
comparison = _build_comparison_table(results)
|
|||
|
|
results['comparison'] = comparison
|
|||
|
|
|
|||
|
|
# 打印对比
|
|||
|
|
logger.info("\n" + "=" * 70)
|
|||
|
|
logger.info(" 策略对比总结")
|
|||
|
|
logger.info("=" * 70)
|
|||
|
|
logger.info(f"\n{comparison.to_string()}")
|
|||
|
|
|
|||
|
|
# 每月盈利(U)
|
|||
|
|
_log_monthly_pnl(results)
|
|||
|
|
|
|||
|
|
# 保存图表
|
|||
|
|
if save_plot:
|
|||
|
|
_save_equity_plot(results)
|
|||
|
|
|
|||
|
|
return results
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _build_comparison_table(results: dict) -> pd.DataFrame:
|
|||
|
|
"""构建对比表格"""
|
|||
|
|
rows = {}
|
|||
|
|
name_map = {
|
|||
|
|
'stat': '方案A: 统计筛选',
|
|||
|
|
'lgb': '方案B-1: LightGBM',
|
|||
|
|
'xgb': '方案B-2: XGBoost',
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for key, name in name_map.items():
|
|||
|
|
if key in results and 'metrics' in results[key]:
|
|||
|
|
rows[name] = results[key]['metrics']
|
|||
|
|
|
|||
|
|
if not rows:
|
|||
|
|
return pd.DataFrame()
|
|||
|
|
|
|||
|
|
df = pd.DataFrame(rows).T
|
|||
|
|
return df
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _log_monthly_pnl(results: dict):
|
|||
|
|
"""打印各策略每月盈利(USDT)"""
|
|||
|
|
name_map = {
|
|||
|
|
'stat': '方案A',
|
|||
|
|
'lgb': 'LightGBM',
|
|||
|
|
'xgb': 'XGBoost',
|
|||
|
|
}
|
|||
|
|
cols = []
|
|||
|
|
for key, name in name_map.items():
|
|||
|
|
if key not in results or results[key].get('monthly_pnl') is None:
|
|||
|
|
continue
|
|||
|
|
s = results[key]['monthly_pnl']
|
|||
|
|
if s is None or s.empty:
|
|||
|
|
continue
|
|||
|
|
s = s.astype(float).round(2)
|
|||
|
|
s.name = name
|
|||
|
|
cols.append(s)
|
|||
|
|
if not cols:
|
|||
|
|
return
|
|||
|
|
monthly_df = pd.concat(cols, axis=1)
|
|||
|
|
monthly_df = monthly_df.fillna(0)
|
|||
|
|
logger.info("\n" + "-" * 70)
|
|||
|
|
logger.info(" 每月盈利 (USDT)")
|
|||
|
|
logger.info("-" * 70)
|
|||
|
|
logger.info(f"\n{monthly_df.to_string()}")
|
|||
|
|
logger.info("-" * 70)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def _save_equity_plot(results: dict):
|
|||
|
|
"""保存权益曲线对比图"""
|
|||
|
|
REPORT_DIR.mkdir(parents=True, exist_ok=True)
|
|||
|
|
|
|||
|
|
fig, axes = plt.subplots(2, 1, figsize=(14, 10))
|
|||
|
|
|
|||
|
|
# 上图:权益曲线
|
|||
|
|
ax1 = axes[0]
|
|||
|
|
name_map = {
|
|||
|
|
'stat': '方案A: 统计筛选',
|
|||
|
|
'lgb': '方案B-1: LightGBM',
|
|||
|
|
'xgb': '方案B-2: XGBoost',
|
|||
|
|
}
|
|||
|
|
colors = {'stat': '#2196F3', 'lgb': '#4CAF50', 'xgb': '#FF9800'}
|
|||
|
|
|
|||
|
|
for key, name in name_map.items():
|
|||
|
|
if key in results and 'equity_curve' in results[key]:
|
|||
|
|
eq = results[key]['equity_curve']
|
|||
|
|
if not eq.empty:
|
|||
|
|
ax1.plot(eq.index, eq.values, label=name, color=colors.get(key, 'gray'), linewidth=1)
|
|||
|
|
|
|||
|
|
ax1.set_title('权益曲线对比', fontsize=14)
|
|||
|
|
ax1.set_ylabel('资金 (USDT)')
|
|||
|
|
ax1.legend()
|
|||
|
|
ax1.grid(True, alpha=0.3)
|
|||
|
|
|
|||
|
|
# 下图:回撤曲线
|
|||
|
|
ax2 = axes[1]
|
|||
|
|
for key, name in name_map.items():
|
|||
|
|
if key in results and 'equity_curve' in results[key]:
|
|||
|
|
eq = results[key]['equity_curve']
|
|||
|
|
if not eq.empty:
|
|||
|
|
cummax = eq.cummax()
|
|||
|
|
drawdown = (eq - cummax) / cummax * 100
|
|||
|
|
ax2.fill_between(drawdown.index, drawdown.values, 0,
|
|||
|
|
alpha=0.3, label=name, color=colors.get(key, 'gray'))
|
|||
|
|
|
|||
|
|
ax2.set_title('回撤对比', fontsize=14)
|
|||
|
|
ax2.set_ylabel('回撤 (%)')
|
|||
|
|
ax2.legend()
|
|||
|
|
ax2.grid(True, alpha=0.3)
|
|||
|
|
|
|||
|
|
plt.tight_layout()
|
|||
|
|
plot_path = REPORT_DIR / 'strategy_comparison.png'
|
|||
|
|
plt.savefig(str(plot_path), dpi=150)
|
|||
|
|
plt.close()
|
|||
|
|
logger.info(f"对比图表已保存: {plot_path}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
def run_full_comparison(period: int = None, start_date: str = None, end_date: str = None, save_plot: bool = True):
|
|||
|
|
"""完整对比入口(可直接调用)"""
|
|||
|
|
results = compare_strategies(period, start_date, end_date, save_plot=save_plot)
|
|||
|
|
|
|||
|
|
# 推荐最优方案
|
|||
|
|
best_key = None
|
|||
|
|
best_return = -float('inf')
|
|||
|
|
|
|||
|
|
for key in ['stat', 'lgb', 'xgb']:
|
|||
|
|
if key in results and 'metrics' in results[key]:
|
|||
|
|
ret_str = results[key]['metrics'].get('总收益率', '0%')
|
|||
|
|
ret_val = float(ret_str.strip('%')) / 100
|
|||
|
|
if ret_val > best_return:
|
|||
|
|
best_return = ret_val
|
|||
|
|
best_key = key
|
|||
|
|
|
|||
|
|
name_map = {'stat': '方案A: 统计筛选', 'lgb': '方案B-1: LightGBM', 'xgb': '方案B-2: XGBoost'}
|
|||
|
|
if best_key:
|
|||
|
|
logger.info(f"\n{'='*50}")
|
|||
|
|
logger.info(f" 推荐方案: {name_map.get(best_key, best_key)}")
|
|||
|
|
logger.info(f" 总收益率: {best_return:.2%}")
|
|||
|
|
logger.info(f"{'='*50}")
|
|||
|
|
|
|||
|
|
return results
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
import argparse
|
|||
|
|
p = argparse.ArgumentParser(description='运行策略对比:方案A(统计) vs 方案B(LightGBM/XGBoost)')
|
|||
|
|
p.add_argument('--period', type=int, default=None, help='K线周期分钟')
|
|||
|
|
p.add_argument('--start', type=str, default=None, help='开始日期,如 2024-01-01')
|
|||
|
|
p.add_argument('--end', type=str, default=None, help='结束日期,如 2024-12-31')
|
|||
|
|
p.add_argument('--no-plot', action='store_true', help='不保存权益/回撤图')
|
|||
|
|
args = p.parse_args()
|
|||
|
|
run_full_comparison(
|
|||
|
|
period=args.period,
|
|||
|
|
start_date=args.start,
|
|||
|
|
end_date=args.end,
|
|||
|
|
save_plot=not args.no_plot,
|
|||
|
|
)
|