Files
jyx_code4/strategy/compare.py

223 lines
7.0 KiB
Python
Raw Normal View History

2026-02-20 20:57:25 +08:00
"""
方案对比评估 方案A(统计筛选) vs 方案B(AI模型) 并排对比 + 报告输出
"""
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg') # 非交互式后端
import matplotlib.pyplot as plt
from matplotlib import font_manager
from pathlib import Path
from loguru import logger
# 设置中文字体
_zh_font = None
for fname in ['PingFang SC', 'Heiti SC', 'STHeiti', 'SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei']:
try:
_zh_font = font_manager.FontProperties(family=fname)
# 验证字体存在
font_manager.findfont(_zh_font, fallback_to_default=False)
plt.rcParams['font.sans-serif'] = [fname, 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
break
except Exception:
_zh_font = None
continue
from .config import PRIMARY_PERIOD, PROJECT_ROOT
from .stat_strategy import StatStrategy
from .ai_strategy import AIStrategy
from .backtest import print_metrics
REPORT_DIR = PROJECT_ROOT / 'reports'
def compare_strategies(period: int = None, start_date: str = None, end_date: str = None,
save_plot: bool = True) -> dict:
"""
运行两种方案并对比
:return: {'stat': result_a, 'lgb': result_b, 'xgb': result_c, 'comparison': DataFrame}
"""
if period is None:
period = PRIMARY_PERIOD
logger.info("=" * 70)
logger.info(" 开始策略对比评估")
logger.info("=" * 70)
results = {}
# 方案A统计筛选
logger.info("\n>>> 运行方案A: 统计筛选策略")
stat = StatStrategy()
results['stat'] = stat.run(period, start_date, end_date)
# 方案B-1LightGBM
logger.info("\n>>> 运行方案B-1: LightGBM AI策略")
lgb_strategy = AIStrategy(model_type='lightgbm')
results['lgb'] = lgb_strategy.run(period, start_date, end_date)
# 方案B-2XGBoost
logger.info("\n>>> 运行方案B-2: XGBoost AI策略")
xgb_strategy = AIStrategy(model_type='xgboost')
results['xgb'] = xgb_strategy.run(period, start_date, end_date)
# 对比表格
comparison = _build_comparison_table(results)
results['comparison'] = comparison
# 打印对比
logger.info("\n" + "=" * 70)
logger.info(" 策略对比总结")
logger.info("=" * 70)
logger.info(f"\n{comparison.to_string()}")
# 每月盈利U
_log_monthly_pnl(results)
# 保存图表
if save_plot:
_save_equity_plot(results)
return results
def _build_comparison_table(results: dict) -> pd.DataFrame:
"""构建对比表格"""
rows = {}
name_map = {
'stat': '方案A: 统计筛选',
'lgb': '方案B-1: LightGBM',
'xgb': '方案B-2: XGBoost',
}
for key, name in name_map.items():
if key in results and 'metrics' in results[key]:
rows[name] = results[key]['metrics']
if not rows:
return pd.DataFrame()
df = pd.DataFrame(rows).T
return df
def _log_monthly_pnl(results: dict):
"""打印各策略每月盈利USDT"""
name_map = {
'stat': '方案A',
'lgb': 'LightGBM',
'xgb': 'XGBoost',
}
cols = []
for key, name in name_map.items():
if key not in results or results[key].get('monthly_pnl') is None:
continue
s = results[key]['monthly_pnl']
if s is None or s.empty:
continue
s = s.astype(float).round(2)
s.name = name
cols.append(s)
if not cols:
return
monthly_df = pd.concat(cols, axis=1)
monthly_df = monthly_df.fillna(0)
logger.info("\n" + "-" * 70)
logger.info(" 每月盈利 (USDT)")
logger.info("-" * 70)
logger.info(f"\n{monthly_df.to_string()}")
logger.info("-" * 70)
def _save_equity_plot(results: dict):
"""保存权益曲线对比图"""
REPORT_DIR.mkdir(parents=True, exist_ok=True)
fig, axes = plt.subplots(2, 1, figsize=(14, 10))
# 上图:权益曲线
ax1 = axes[0]
name_map = {
'stat': '方案A: 统计筛选',
'lgb': '方案B-1: LightGBM',
'xgb': '方案B-2: XGBoost',
}
colors = {'stat': '#2196F3', 'lgb': '#4CAF50', 'xgb': '#FF9800'}
for key, name in name_map.items():
if key in results and 'equity_curve' in results[key]:
eq = results[key]['equity_curve']
if not eq.empty:
ax1.plot(eq.index, eq.values, label=name, color=colors.get(key, 'gray'), linewidth=1)
ax1.set_title('权益曲线对比', fontsize=14)
ax1.set_ylabel('资金 (USDT)')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 下图:回撤曲线
ax2 = axes[1]
for key, name in name_map.items():
if key in results and 'equity_curve' in results[key]:
eq = results[key]['equity_curve']
if not eq.empty:
cummax = eq.cummax()
drawdown = (eq - cummax) / cummax * 100
ax2.fill_between(drawdown.index, drawdown.values, 0,
alpha=0.3, label=name, color=colors.get(key, 'gray'))
ax2.set_title('回撤对比', fontsize=14)
ax2.set_ylabel('回撤 (%)')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plot_path = REPORT_DIR / 'strategy_comparison.png'
plt.savefig(str(plot_path), dpi=150)
plt.close()
logger.info(f"对比图表已保存: {plot_path}")
def run_full_comparison(period: int = None, start_date: str = None, end_date: str = None, save_plot: bool = True):
"""完整对比入口(可直接调用)"""
results = compare_strategies(period, start_date, end_date, save_plot=save_plot)
# 推荐最优方案
best_key = None
best_return = -float('inf')
for key in ['stat', 'lgb', 'xgb']:
if key in results and 'metrics' in results[key]:
ret_str = results[key]['metrics'].get('总收益率', '0%')
ret_val = float(ret_str.strip('%')) / 100
if ret_val > best_return:
best_return = ret_val
best_key = key
name_map = {'stat': '方案A: 统计筛选', 'lgb': '方案B-1: LightGBM', 'xgb': '方案B-2: XGBoost'}
if best_key:
logger.info(f"\n{'='*50}")
logger.info(f" 推荐方案: {name_map.get(best_key, best_key)}")
logger.info(f" 总收益率: {best_return:.2%}")
logger.info(f"{'='*50}")
return results
if __name__ == '__main__':
import argparse
p = argparse.ArgumentParser(description='运行策略对比方案A(统计) vs 方案B(LightGBM/XGBoost)')
p.add_argument('--period', type=int, default=None, help='K线周期分钟')
p.add_argument('--start', type=str, default=None, help='开始日期,如 2024-01-01')
p.add_argument('--end', type=str, default=None, help='结束日期,如 2024-12-31')
p.add_argument('--no-plot', action='store_true', help='不保存权益/回撤图')
args = p.parse_args()
run_full_comparison(
period=args.period,
start_date=args.start,
end_date=args.end,
save_plot=not args.no_plot,
)