Files
jyx_code4/strategy/compare.py
ddrwode 21f2adc4a4 哈哈
2026-02-20 20:57:25 +08:00

223 lines
7.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
方案对比评估 — 方案A(统计筛选) vs 方案B(AI模型) 并排对比 + 报告输出
"""
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use('Agg') # 非交互式后端
import matplotlib.pyplot as plt
from matplotlib import font_manager
from pathlib import Path
from loguru import logger
# 设置中文字体
_zh_font = None
for fname in ['PingFang SC', 'Heiti SC', 'STHeiti', 'SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei']:
try:
_zh_font = font_manager.FontProperties(family=fname)
# 验证字体存在
font_manager.findfont(_zh_font, fallback_to_default=False)
plt.rcParams['font.sans-serif'] = [fname, 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
break
except Exception:
_zh_font = None
continue
from .config import PRIMARY_PERIOD, PROJECT_ROOT
from .stat_strategy import StatStrategy
from .ai_strategy import AIStrategy
from .backtest import print_metrics
REPORT_DIR = PROJECT_ROOT / 'reports'
def compare_strategies(period: int = None, start_date: str = None, end_date: str = None,
save_plot: bool = True) -> dict:
"""
运行两种方案并对比
:return: {'stat': result_a, 'lgb': result_b, 'xgb': result_c, 'comparison': DataFrame}
"""
if period is None:
period = PRIMARY_PERIOD
logger.info("=" * 70)
logger.info(" 开始策略对比评估")
logger.info("=" * 70)
results = {}
# 方案A统计筛选
logger.info("\n>>> 运行方案A: 统计筛选策略")
stat = StatStrategy()
results['stat'] = stat.run(period, start_date, end_date)
# 方案B-1LightGBM
logger.info("\n>>> 运行方案B-1: LightGBM AI策略")
lgb_strategy = AIStrategy(model_type='lightgbm')
results['lgb'] = lgb_strategy.run(period, start_date, end_date)
# 方案B-2XGBoost
logger.info("\n>>> 运行方案B-2: XGBoost AI策略")
xgb_strategy = AIStrategy(model_type='xgboost')
results['xgb'] = xgb_strategy.run(period, start_date, end_date)
# 对比表格
comparison = _build_comparison_table(results)
results['comparison'] = comparison
# 打印对比
logger.info("\n" + "=" * 70)
logger.info(" 策略对比总结")
logger.info("=" * 70)
logger.info(f"\n{comparison.to_string()}")
# 每月盈利U
_log_monthly_pnl(results)
# 保存图表
if save_plot:
_save_equity_plot(results)
return results
def _build_comparison_table(results: dict) -> pd.DataFrame:
"""构建对比表格"""
rows = {}
name_map = {
'stat': '方案A: 统计筛选',
'lgb': '方案B-1: LightGBM',
'xgb': '方案B-2: XGBoost',
}
for key, name in name_map.items():
if key in results and 'metrics' in results[key]:
rows[name] = results[key]['metrics']
if not rows:
return pd.DataFrame()
df = pd.DataFrame(rows).T
return df
def _log_monthly_pnl(results: dict):
"""打印各策略每月盈利USDT"""
name_map = {
'stat': '方案A',
'lgb': 'LightGBM',
'xgb': 'XGBoost',
}
cols = []
for key, name in name_map.items():
if key not in results or results[key].get('monthly_pnl') is None:
continue
s = results[key]['monthly_pnl']
if s is None or s.empty:
continue
s = s.astype(float).round(2)
s.name = name
cols.append(s)
if not cols:
return
monthly_df = pd.concat(cols, axis=1)
monthly_df = monthly_df.fillna(0)
logger.info("\n" + "-" * 70)
logger.info(" 每月盈利 (USDT)")
logger.info("-" * 70)
logger.info(f"\n{monthly_df.to_string()}")
logger.info("-" * 70)
def _save_equity_plot(results: dict):
"""保存权益曲线对比图"""
REPORT_DIR.mkdir(parents=True, exist_ok=True)
fig, axes = plt.subplots(2, 1, figsize=(14, 10))
# 上图:权益曲线
ax1 = axes[0]
name_map = {
'stat': '方案A: 统计筛选',
'lgb': '方案B-1: LightGBM',
'xgb': '方案B-2: XGBoost',
}
colors = {'stat': '#2196F3', 'lgb': '#4CAF50', 'xgb': '#FF9800'}
for key, name in name_map.items():
if key in results and 'equity_curve' in results[key]:
eq = results[key]['equity_curve']
if not eq.empty:
ax1.plot(eq.index, eq.values, label=name, color=colors.get(key, 'gray'), linewidth=1)
ax1.set_title('权益曲线对比', fontsize=14)
ax1.set_ylabel('资金 (USDT)')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 下图:回撤曲线
ax2 = axes[1]
for key, name in name_map.items():
if key in results and 'equity_curve' in results[key]:
eq = results[key]['equity_curve']
if not eq.empty:
cummax = eq.cummax()
drawdown = (eq - cummax) / cummax * 100
ax2.fill_between(drawdown.index, drawdown.values, 0,
alpha=0.3, label=name, color=colors.get(key, 'gray'))
ax2.set_title('回撤对比', fontsize=14)
ax2.set_ylabel('回撤 (%)')
ax2.legend()
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plot_path = REPORT_DIR / 'strategy_comparison.png'
plt.savefig(str(plot_path), dpi=150)
plt.close()
logger.info(f"对比图表已保存: {plot_path}")
def run_full_comparison(period: int = None, start_date: str = None, end_date: str = None, save_plot: bool = True):
"""完整对比入口(可直接调用)"""
results = compare_strategies(period, start_date, end_date, save_plot=save_plot)
# 推荐最优方案
best_key = None
best_return = -float('inf')
for key in ['stat', 'lgb', 'xgb']:
if key in results and 'metrics' in results[key]:
ret_str = results[key]['metrics'].get('总收益率', '0%')
ret_val = float(ret_str.strip('%')) / 100
if ret_val > best_return:
best_return = ret_val
best_key = key
name_map = {'stat': '方案A: 统计筛选', 'lgb': '方案B-1: LightGBM', 'xgb': '方案B-2: XGBoost'}
if best_key:
logger.info(f"\n{'='*50}")
logger.info(f" 推荐方案: {name_map.get(best_key, best_key)}")
logger.info(f" 总收益率: {best_return:.2%}")
logger.info(f"{'='*50}")
return results
if __name__ == '__main__':
import argparse
p = argparse.ArgumentParser(description='运行策略对比方案A(统计) vs 方案B(LightGBM/XGBoost)')
p.add_argument('--period', type=int, default=None, help='K线周期分钟')
p.add_argument('--start', type=str, default=None, help='开始日期,如 2024-01-01')
p.add_argument('--end', type=str, default=None, help='结束日期,如 2024-12-31')
p.add_argument('--no-plot', action='store_true', help='不保存权益/回撤图')
args = p.parse_args()
run_full_comparison(
period=args.period,
start_date=args.start,
end_date=args.end,
save_plot=not args.no_plot,
)