This commit is contained in:
ddrwode
2026-02-16 22:03:12 +08:00
commit 78d972b59b
21 changed files with 16720 additions and 0 deletions

10
.idea/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,10 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 已忽略包含查询文件的默认文件夹
/queries/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/

View File

@@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

8
.idea/jyx_code.iml generated Normal file
View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="jyx_code" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

7
.idea/misc.xml generated Normal file
View File

@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="jyx_code" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="jyx_code" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated Normal file
View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/jyx_code.iml" filepath="$PROJECT_DIR$/.idea/jyx_code.iml" />
</modules>
</component>
</project>

16
main.py Normal file
View File

@@ -0,0 +1,16 @@
# 这是一个示例 Python 脚本。
# 按 ⌃R 执行或将其替换为您的代码。
# 按 双击 ⇧ 在所有地方搜索类、文件、工具窗口、操作和设置。
def print_hi(name):
# 在下面的代码行中使用断点来调试脚本。
print(f'Hi, {name}') # 按 ⌘F8 切换断点。
# 按装订区域中的绿色按钮以运行脚本。
if __name__ == '__main__':
print_hi('PyCharm')
# 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助

58
models/__init__.py Normal file
View File

@@ -0,0 +1,58 @@
from pathlib import Path
from peewee import *
# 连接到 SQLite 数据库,如果文件不存在会自动创建
db = SqliteDatabase(fr'{Path(__file__).parent}/database.db')
import pymysql
from peewee import *
from playhouse.pool import PooledMySQLDatabase
pymysql.install_as_MySQLdb()
# 数据库配置
db_config = {
'database': 'lm',
'user': 'lm',
'password': 'HhyAsGbrrbsJfpyy',
'host': '192.168.1.87',
'port': 3306
}
# 全局数据库实例
db1 = MySQLDatabase(
db_config['database'],
user=db_config['user'],
password=db_config['password'],
host=db_config['host'],
port=db_config['port']
)
# class BaseModel(Model):
# class Meta:
# database = db1
#
# def save(self, *args, **kwargs):
# """在调用 save 时自动连接和关闭(若无事务)"""
# db.connect(reuse_if_open=True)
# try:
# result = super().save(*args, **kwargs)
# finally:
# # 若当前没有事务且连接仍然打开,则关闭连接
# if not db.in_transaction() and not db.is_closed():
# db.close()
# return result
#
# @classmethod
# def get_or_create(cls, defaults=None, **kwargs):
# """在调用 get_or_create 时自动连接和关闭(若无事务)"""
# db.connect(reuse_if_open=True)
# try:
# obj, created = super().get_or_create(defaults=defaults, **kwargs)
# finally:
# # 若当前没有事务且连接仍然打开,则关闭连接
# if not db.in_transaction() and not db.is_closed():
# db.close()
# return obj, created

21
models/bitmart.py Normal file
View File

@@ -0,0 +1,21 @@
from peewee import *
from models import db
class BitMart30(Model):
id = IntegerField(primary_key=True) # 时间戳(毫秒级)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'bitmart_30'
# 连接到数据库
db.connect()
# 创建表(如果表不存在)
db.create_tables([BitMart30])

21
models/bitmart_15.py Normal file
View File

@@ -0,0 +1,21 @@
from peewee import *
from models import db
class BitMart15(Model):
id = IntegerField(primary_key=True) # 时间戳(毫秒级)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'bitmart_15'
# 连接到数据库
db.connect()
# 创建表(如果表不存在)
db.create_tables([BitMart15])

97
models/bitmart_klines.py Normal file
View File

@@ -0,0 +1,97 @@
"""
BitMart 多周期K线数据模型
包含 1分钟、3分钟、5分钟、15分钟、30分钟、1小时 K线数据表
"""
from peewee import *
from models import db
# ==================== 1分钟 K线 ====================
class BitMartETH1M(Model):
id = BigIntegerField(primary_key=True) # 时间戳(毫秒级)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'bitmart_eth_1m'
# ==================== 3分钟 K线 ====================
class BitMartETH3M(Model):
id = BigIntegerField(primary_key=True)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'bitmart_eth_3m'
# ==================== 5分钟 K线 ====================
class BitMartETH5M(Model):
id = BigIntegerField(primary_key=True)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'bitmart_eth_5m'
# ==================== 15分钟 K线 ====================
class BitMartETH15M(Model):
id = BigIntegerField(primary_key=True)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'bitmart_eth_15m'
# ==================== 30分钟 K线 ====================
class BitMartETH30M(Model):
id = BigIntegerField(primary_key=True)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'bitmart_eth_30m'
# ==================== 1小时 K线 ====================
class BitMartETH1H(Model):
id = BigIntegerField(primary_key=True)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'bitmart_eth_1h'
# 连接数据库并创建表
db.connect(reuse_if_open=True)
db.create_tables([
BitMartETH1M,
BitMartETH3M,
BitMartETH5M,
BitMartETH15M,
BitMartETH30M,
BitMartETH1H,
], safe=True)

BIN
models/database.db Normal file

Binary file not shown.

View File

@@ -0,0 +1,47 @@
[
"bb_pct",
"bb_width",
"keltner_pct",
"keltner_width",
"donchian_pct",
"donchian_width",
"stoch_k",
"stoch_d",
"cci",
"willr",
"rsi",
"atr_band_pct",
"atr_band_width",
"zscore",
"zscore_abs",
"lr_pct",
"lr_width",
"price_vs_median",
"median_band_pct",
"pct_rank",
"chandelier_pct",
"chandelier_width",
"std_band_pct",
"std_band_width",
"elder_bull",
"elder_bear",
"elder_dist",
"ema_fast_slow",
"price_vs_ema120",
"ema8_slope",
"macd",
"macd_signal",
"macd_hist",
"atr_pct",
"ret_1",
"ret_3",
"ret_5",
"ret_10",
"ret_20",
"vol_5",
"vol_20",
"body_pct",
"price_position_20",
"hour_sin",
"hour_cos"
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,40 @@
{
"bb_period": 20,
"bb_std": 2.0,
"keltner_period": 20,
"keltner_atr_mult": 2.0,
"donchian_period": 20,
"stoch_k_period": 14,
"stoch_d_period": 3,
"cci_period": 20,
"willr_period": 14,
"rsi_period": 14,
"atr_band_period": 20,
"atr_band_mult": 2.0,
"zscore_period": 20,
"lr_period": 20,
"lr_std_mult": 2.0,
"median_band_period": 20,
"pct_rank_period": 20,
"chandelier_period": 22,
"chandelier_mult": 3.0,
"std_band_period": 14,
"std_band_mult": 2.0,
"elder_period": 13,
"forward_bars": 10,
"label_threshold": 0.002,
"prob_threshold": 0.45,
"sl_pct": 0.004,
"tp_pct": 0.006,
"min_hold_seconds": 180,
"max_hold_seconds": 1800,
"margin_per_trade": 100.0,
"leverage": 100,
"notional": 10000.0,
"rebate_rate": 0.9,
"train_period": "2020-01-01 ~ 2021-01-01",
"test_period": "2021-01-01 ~ 2022-01-01",
"kline_period": "5m",
"initial_balance": 10000.0,
"max_dd_target": 500.0
}

21
models/ips.py Normal file
View File

@@ -0,0 +1,21 @@
from peewee import *
from models import db1
class Ips(Model):
id = IntegerField(primary_key=True)
host = CharField(null=True)
port = CharField(null=True)
username = CharField(null=True)
password = CharField(null=True)
start = IntegerField(null=True)
country = CharField(null=True)
class Meta:
database = db1
table_name = 'ips'
# if __name__ == '__main__':
# Ips.create_table()

57
models/mexc.py Normal file
View File

@@ -0,0 +1,57 @@
from peewee import *
from models import db
class Mexc1(Model):
id = IntegerField(primary_key=True)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'mexc_1'
class Mexc15(Model):
id = IntegerField(primary_key=True)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'mexc_15'
class Mexc30(Model):
id = IntegerField(primary_key=True)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'mexc_30'
class Mexc1Hour(Model):
id = IntegerField(primary_key=True)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'mexc_1_hour'
# 连接到数据库
db.connect()
# 创建表(如果表不存在)
db.create_tables([Mexc1, Mexc15, Mexc30, Mexc1Hour])

71
models/weex.py Normal file
View File

@@ -0,0 +1,71 @@
from peewee import *
from models import db
class Weex15(Model):
id = IntegerField(primary_key=True)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'weex_15'
class Weex1(Model):
id = IntegerField(primary_key=True)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'weex_1'
class Weex1Hour(Model):
id = IntegerField(primary_key=True)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'weex_1_hour'
class Weex30(Model):
id = IntegerField(primary_key=True)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'weex_30'
class Weex30Copy(Model):
id = IntegerField(primary_key=True)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
class Meta:
database = db
table_name = 'weex_30_copy1'
# 连接到数据库
db.connect()
#
# # 创建表(如果表不存在)
# db.create_tables([Weex15])
db.create_tables([Weex30])

22
models/xstart.py Normal file
View File

@@ -0,0 +1,22 @@
from peewee import *
from models import db1
from models.ips import Ips
class Xstart(Model):
id = AutoField(primary_key=True) # 自增主键
bit_id = CharField(null=True)
start = IntegerField(null=True)
x_id = IntegerField(null=True)
ip_id = IntegerField(null=True)
url_id = CharField(null=True)
class Meta:
database = db1 # 所属数据库
table_name = 'xstart'
if __name__ == '__main__':
Xstart.create_table()

24
models/xtoken.py Normal file
View File

@@ -0,0 +1,24 @@
from peewee import *
# 假设 db 已经在其他地方定义并连接到数据库
from models import db1
class XToken(Model):
id = AutoField(primary_key=True) # 自增主键
hub_id = IntegerField(null=True) # hub_id 字段,整型,可为空
start = IntegerField(null=True) # start 字段,整型,可为空
account_start = IntegerField(null=True) # account_start 字段,整型,可为空
user_name = CharField(max_length=255, null=True) # user_name 字段,最大长度 255可为空
password = CharField(max_length=255, null=True) # password 字段,最大长度 255可为空
email = CharField(max_length=255, null=True) # email 字段,最大长度 255可为空
two_fa = CharField(max_length=255, null=True) # 2fa 字段,由于 2fa 是 Python 中的无效标识符,这里使用 two_fa 替代,最大长度 255可为空
token = CharField(max_length=255, null=True) # token 字段,最大长度 255可为空
email_pwd = CharField(max_length=255, null=True) # token 字段,最大长度 255可为空
class Meta:
database = db1 # 所属数据库
table_name = 'x_token' # 表名
if __name__ == '__main__':
XToken.create_table()

762
抓取多周期K线.py Normal file
View File

@@ -0,0 +1,762 @@
"""
BitMart 多周期K线数据抓取脚本
支持同时获取 1分钟、3分钟、5分钟、15分钟、30分钟、1小时 K线数据
支持秒级价格数据通过成交记录API
支持断点续传,从数据库最新/最早记录继续抓取
"""
import time
import datetime
from pathlib import Path
from loguru import logger
from peewee import *
from bitmart.api_contract import APIContract
# 数据库配置(使用脚本所在项目目录下的 models
DB_PATH = Path(__file__).parent / 'models' / 'database.db'
db = SqliteDatabase(str(DB_PATH))
# K线周期配置step值 -> 表名后缀
KLINE_CONFIGS = {
1: '1m', # 1分钟
3: '3m', # 3分钟
5: '5m', # 5分钟
15: '15m', # 15分钟
30: '30m', # 30分钟
60: '1h', # 1小时
}
class BitMartETHTrades(Model):
"""成交记录模型(秒级/毫秒级原始数据)"""
id = BigIntegerField(primary_key=True) # 成交ID
timestamp = BigIntegerField(index=True) # 成交时间戳(毫秒)
price = FloatField() # 成交价格
volume = FloatField() # 成交量
side = IntegerField() # 方向: 1=买, -1=卖
class Meta:
database = db
table_name = 'bitmart_eth_trades'
class BitMartETHSecond(Model):
"""秒级K线模型由成交记录聚合而来"""
id = BigIntegerField(primary_key=True) # 时间戳(毫秒,取整到秒)
open = FloatField(null=True)
high = FloatField(null=True)
low = FloatField(null=True)
close = FloatField(null=True)
volume = FloatField(null=True)
trade_count = IntegerField(null=True) # 该秒内成交笔数
class Meta:
database = db
table_name = 'bitmart_eth_1s'
def create_kline_model(step: int):
"""
动态创建K线数据模型
:param step: K线周期分钟
:return: Model类
"""
suffix = KLINE_CONFIGS.get(step, f'{step}m')
tbl_name = f'bitmart_eth_{suffix}'
# 使用 type() 动态创建类,避免闭包问题
attrs = {
'id': BigIntegerField(primary_key=True),
'open': FloatField(null=True),
'high': FloatField(null=True),
'low': FloatField(null=True),
'close': FloatField(null=True),
}
# 创建 Meta 类
meta_attrs = {
'database': db,
'table_name': tbl_name,
}
Meta = type('Meta', (), meta_attrs)
attrs['Meta'] = Meta
# 动态创建 Model 类
model_name = f'BitMartETH{suffix.upper()}'
KlineModel = type(model_name, (Model,), attrs)
return KlineModel
class BitMartMultiKlineCollector:
"""多周期K线数据抓取器"""
def __init__(self):
self.api_key = "a0fb7b98464fd9bcce67e7c519d58ec10d0c38a8"
self.secret_key = "4eaeba78e77aeaab1c2027f846a276d164f264a44c2c1bb1c5f3be50c8de1ca5"
self.memo = "数据抓取"
self.contract_symbol = "ETHUSDT"
self.contractAPI = APIContract(self.api_key, self.secret_key, self.memo, timeout=(5, 15))
# 存储各周期的模型
self.models = {}
# 初始化数据库连接和表
self._init_database()
def _init_database(self):
"""初始化数据库,创建所有周期的表"""
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
db.connect(reuse_if_open=True)
for step in KLINE_CONFIGS.keys():
model = create_kline_model(step)
self.models[step] = model
# 创建表(如果不存在)
db.create_tables([model], safe=True)
logger.info(f"初始化表: {model._meta.table_name}")
# 创建成交记录表和秒级K线表
db.create_tables([BitMartETHTrades, BitMartETHSecond], safe=True)
logger.info(f"初始化表: bitmart_eth_trades (成交记录)")
logger.info(f"初始化表: bitmart_eth_1s (秒级K线)")
def get_db_time_range(self, step: int):
"""
获取数据库中已有数据的时间范围
:param step: K线周期
:return: (earliest_ts, latest_ts) 毫秒时间戳,无数据返回 (None, None)
"""
model = self.models.get(step)
if not model:
return None, None
try:
# 获取最早记录
earliest = model.select(fn.MIN(model.id)).scalar()
# 获取最新记录
latest = model.select(fn.MAX(model.id)).scalar()
return earliest, latest
except Exception as e:
logger.error(f"查询数据库时间范围异常: {e}")
return None, None
def get_klines(self, step: int, start_time: int, end_time: int, max_retries: int = 3):
"""
获取K线数据带重试
:param step: K线周期分钟
:param start_time: 开始时间戳(秒级)
:param end_time: 结束时间戳(秒级)
:param max_retries: 最大重试次数
:return: K线数据列表
"""
for attempt in range(max_retries):
try:
start_time = int(start_time)
end_time = int(end_time)
response = self.contractAPI.get_kline(
contract_symbol=self.contract_symbol,
step=step,
start_time=start_time,
end_time=end_time
)[0]
if response['code'] != 1000:
logger.warning(f"API返回错误 (尝试 {attempt+1}/{max_retries}): {response}")
if attempt < max_retries - 1:
time.sleep(1)
continue
return []
klines = response.get('data', [])
formatted = []
for k in klines:
timestamp_ms = int(k["timestamp"]) * 1000
formatted.append({
'id': timestamp_ms,
'open': float(k["open_price"]),
'high': float(k["high_price"]),
'low': float(k["low_price"]),
'close': float(k["close_price"])
})
formatted.sort(key=lambda x: x['id'])
return formatted
except Exception as e:
logger.error(f"获取K线异常 (尝试 {attempt+1}/{max_retries}): {e}")
if attempt < max_retries - 1:
time.sleep(2)
continue
return []
return []
def save_klines(self, step: int, klines: list):
"""
保存K线数据到数据库
:param step: K线周期
:param klines: K线数据列表
:return: 新保存的数量
"""
model = self.models.get(step)
if not model:
logger.error(f"未找到 {step}分钟 的数据模型")
return 0
new_count = 0
for kline in klines:
try:
_, created = model.get_or_create(
id=kline['id'],
defaults={
'open': kline['open'],
'high': kline['high'],
'low': kline['low'],
'close': kline['close'],
}
)
if created:
new_count += 1
except Exception as e:
logger.error(f"保存K线数据失败 {kline['id']}: {e}")
return new_count
def get_batch_seconds(self, step: int):
"""根据周期获取合适的批次大小"""
if step == 1:
return 3600 * 4 # 1分钟: 每次4小时
elif step == 3:
return 3600 * 8 # 3分钟: 每次8小时
elif step == 5:
return 3600 * 12 # 5分钟: 每次12小时
elif step == 15:
return 3600 * 24 # 15分钟: 每次1天
elif step == 30:
return 3600 * 48 # 30分钟: 每次2天
else:
return 3600 * 72 # 1小时: 每次3天
def collect_period_range(self, step: int, target_start: int, target_end: int):
"""
抓取指定时间范围的K线数据支持断点续传
:param step: K线周期分钟
:param target_start: 目标开始时间戳(秒)
:param target_end: 目标结束时间戳(秒)
:return: 保存的总数量
"""
suffix = KLINE_CONFIGS.get(step, f'{step}m')
batch_seconds = self.get_batch_seconds(step)
# 获取数据库已有数据范围
db_earliest, db_latest = self.get_db_time_range(step)
if db_earliest and db_latest:
db_earliest_sec = db_earliest // 1000
db_latest_sec = db_latest // 1000
logger.info(f"[{suffix}] 数据库已有数据: "
f"{time.strftime('%Y-%m-%d %H:%M', time.localtime(db_earliest_sec))} ~ "
f"{time.strftime('%Y-%m-%d %H:%M', time.localtime(db_latest_sec))}")
else:
db_earliest_sec = None
db_latest_sec = None
logger.info(f"[{suffix}] 数据库暂无数据")
total_saved = 0
# === 第一阶段:向前抓取历史数据(从数据库最早记录向前,直到 target_start===
if db_earliest_sec:
backward_end = db_earliest_sec
else:
backward_end = target_end
if backward_end > target_start:
logger.info(f"[{suffix}] === 开始向前抓取历史数据 ===")
total_backward = backward_end - target_start
current_end = backward_end
fail_count = 0
max_fail = 5
while current_end > target_start and fail_count < max_fail:
current_start = max(current_end - batch_seconds, target_start)
# 计算进度
progress = (backward_end - current_end) / total_backward * 100 if total_backward > 0 else 0
start_str = time.strftime('%Y-%m-%d %H:%M', time.localtime(current_start))
end_str = time.strftime('%Y-%m-%d %H:%M', time.localtime(current_end))
klines = self.get_klines(step, current_start, current_end)
if klines:
saved = self.save_klines(step, klines)
total_saved += saved
logger.info(f"[{suffix}] ← 历史 {start_str} ~ {end_str} | "
f"获取 {len(klines)} 条, 新增 {saved} 条 | 进度 {progress:.1f}%")
fail_count = 0
else:
fail_count += 1
logger.warning(f"[{suffix}] ← 历史 {start_str} 无数据 (连续失败 {fail_count}/{max_fail})")
if fail_count >= max_fail:
earliest_date = time.strftime('%Y-%m-%d', time.localtime(current_end))
logger.warning(f"[{suffix}] 已达到API历史数据限制最早可获取: {earliest_date}")
break
current_end = current_start
time.sleep(0.3)
# === 第二阶段:向后抓取最新数据(从数据库最新记录向后,直到 target_end===
if db_latest_sec:
forward_start = db_latest_sec
else:
# 如果没有数据,从第一阶段结束的地方开始
forward_start = target_start
if forward_start < target_end:
logger.info(f"[{suffix}] === 开始向后抓取最新数据 ===")
total_forward = target_end - forward_start
current_start = forward_start
fail_count = 0
max_fail = 3
while current_start < target_end and fail_count < max_fail:
current_end = min(current_start + batch_seconds, target_end)
# 计算进度
progress = (current_start - forward_start) / total_forward * 100 if total_forward > 0 else 0
start_str = time.strftime('%Y-%m-%d %H:%M', time.localtime(current_start))
end_str = time.strftime('%Y-%m-%d %H:%M', time.localtime(current_end))
klines = self.get_klines(step, current_start, current_end)
if klines:
saved = self.save_klines(step, klines)
total_saved += saved
logger.info(f"[{suffix}] → 最新 {start_str} ~ {end_str} | "
f"获取 {len(klines)} 条, 新增 {saved} 条 | 进度 {progress:.1f}%")
fail_count = 0
else:
fail_count += 1
logger.warning(f"[{suffix}] → 最新 {start_str} 无数据 (失败 {fail_count}/{max_fail})")
current_start = current_end
time.sleep(0.3)
# 统计最终数据范围
final_earliest, final_latest = self.get_db_time_range(step)
if final_earliest and final_latest:
logger.success(f"[{suffix}] 抓取完成!本次新增 {total_saved} 条 | 数据范围: "
f"{time.strftime('%Y-%m-%d', time.localtime(final_earliest//1000))} ~ "
f"{time.strftime('%Y-%m-%d', time.localtime(final_latest//1000))}")
else:
logger.success(f"[{suffix}] 抓取完成!本次新增 {total_saved}")
return total_saved
def collect_from_date(self, start_date: str, periods: list = None):
"""
从指定日期抓取到当前时间
:param start_date: 起始日期 'YYYY-MM-DD'
:param periods: 要抓取的周期列表,如 [1, 5, 15],默认全部
"""
if periods is None:
periods = list(KLINE_CONFIGS.keys())
# 计算时间范围
start_dt = datetime.datetime.strptime(start_date, '%Y-%m-%d')
target_start = int(start_dt.timestamp())
target_end = int(time.time())
start_str = start_dt.strftime('%Y-%m-%d')
end_str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
logger.info(f"{'='*60}")
logger.info(f"目标时间范围: {start_str} ~ {end_str}")
logger.info(f"抓取周期: {[KLINE_CONFIGS[p] for p in periods]}")
logger.info(f"{'='*60}")
results = {}
for step in periods:
if step not in KLINE_CONFIGS:
logger.warning(f"不支持的周期: {step}分钟,跳过")
continue
logger.info(f"\n{'='*60}")
logger.info(f"开始抓取 {KLINE_CONFIGS[step]} K线")
logger.info(f"{'='*60}")
saved = self.collect_period_range(step, target_start, target_end)
results[KLINE_CONFIGS[step]] = saved
time.sleep(1) # 不同周期之间间隔
# 打印总结
logger.info(f"\n{'='*60}")
logger.info("所有周期抓取完成!统计:")
for period, count in results.items():
logger.info(f" {period}: 新增 {count}")
logger.info(f"{'='*60}")
return results
def get_stats(self):
"""获取各周期数据统计"""
logger.info(f"\n{'='*60}")
logger.info("数据库统计:")
logger.info(f"{'='*60}")
for step, model in self.models.items():
suffix = KLINE_CONFIGS.get(step, f'{step}m')
try:
count = model.select().count()
earliest, latest = self.get_db_time_range(step)
if earliest and latest:
earliest_str = time.strftime('%Y-%m-%d %H:%M', time.localtime(earliest//1000))
latest_str = time.strftime('%Y-%m-%d %H:%M', time.localtime(latest//1000))
logger.info(f" {suffix:>4}: {count:>8} 条 | {earliest_str} ~ {latest_str}")
else:
logger.info(f" {suffix:>4}: {count:>8}")
except Exception as e:
logger.error(f" {suffix}: 查询失败 - {e}")
# 成交记录统计
try:
trades_count = BitMartETHTrades.select().count()
if trades_count > 0:
earliest_trade = BitMartETHTrades.select(fn.MIN(BitMartETHTrades.timestamp)).scalar()
latest_trade = BitMartETHTrades.select(fn.MAX(BitMartETHTrades.timestamp)).scalar()
earliest_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(earliest_trade//1000))
latest_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(latest_trade//1000))
logger.info(f"trades: {trades_count:>8} 条 | {earliest_str} ~ {latest_str}")
else:
logger.info(f"trades: {trades_count:>8}")
except Exception as e:
logger.error(f"trades: 查询失败 - {e}")
# 秒级K线统计
try:
second_count = BitMartETHSecond.select().count()
if second_count > 0:
earliest_sec = BitMartETHSecond.select(fn.MIN(BitMartETHSecond.id)).scalar()
latest_sec = BitMartETHSecond.select(fn.MAX(BitMartETHSecond.id)).scalar()
earliest_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(earliest_sec//1000))
latest_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(latest_sec//1000))
logger.info(f" 1s: {second_count:>8} 条 | {earliest_str} ~ {latest_str}")
else:
logger.info(f" 1s: {second_count:>8}")
except Exception as e:
logger.error(f" 1s: 查询失败 - {e}")
logger.info(f"{'='*60}")
# ==================== 秒级数据相关方法 ====================
def get_trades(self, limit: int = 100):
"""
获取最近成交记录
:param limit: 获取条数
:return: 成交记录列表
"""
try:
response = self.contractAPI.get_trades(
contract_symbol=self.contract_symbol,
)[0]
if response['code'] != 1000:
logger.error(f"获取成交记录失败: {response}")
return []
trades = response.get('data', {}).get('trades', [])
formatted = []
for t in trades:
formatted.append({
'id': int(t.get('trade_id', 0)),
'timestamp': int(t.get('create_time', 0)),
'price': float(t.get('deal_price', 0)),
'volume': float(t.get('deal_vol', 0)),
'side': int(t.get('way', 0)),
})
return formatted
except Exception as e:
logger.error(f"获取成交记录异常: {e}")
return []
def save_trades(self, trades: list):
"""保存成交记录到数据库"""
new_count = 0
for trade in trades:
try:
_, created = BitMartETHTrades.get_or_create(
id=trade['id'],
defaults={
'timestamp': trade['timestamp'],
'price': trade['price'],
'volume': trade['volume'],
'side': trade['side'],
}
)
if created:
new_count += 1
except Exception as e:
pass # 忽略重复数据
return new_count
def collect_trades_realtime(self, duration_seconds: int = 3600, interval: float = 0.3):
"""
实时持续采集成交记录(秒级数据源)
:param duration_seconds: 采集时长默认1小时
:param interval: 采集间隔默认0.3秒
"""
logger.info(f"{'='*60}")
logger.info(f"开始实时采集成交记录")
logger.info(f"时长: {duration_seconds}秒 ({duration_seconds/3600:.1f}小时)")
logger.info(f"间隔: {interval}")
logger.info(f"{'='*60}")
start_time = time.time()
end_time = start_time + duration_seconds
total_saved = 0
batch_count = 0
while time.time() < end_time:
trades = self.get_trades(limit=100)
if trades:
saved = self.save_trades(trades)
total_saved += saved
batch_count += 1
# 每10批显示一次进度
if batch_count % 10 == 0:
elapsed = time.time() - start_time
remaining = end_time - time.time()
latest = trades[-1]
ts_str = datetime.datetime.fromtimestamp(
latest['timestamp']/1000
).strftime('%H:%M:%S')
logger.info(f"[{ts_str}] 价格: {latest['price']:.2f} | "
f"本批新增: {saved} | 累计: {total_saved} | "
f"剩余: {remaining/60:.1f}分钟")
time.sleep(interval)
logger.success(f"采集完成!共新增 {total_saved} 条成交记录")
# 自动聚合为秒级K线
logger.info("正在将成交记录聚合为秒级K线...")
self.aggregate_trades_to_seconds()
return total_saved
def aggregate_trades_to_seconds(self, start_ts: int = None, end_ts: int = None):
"""
将成交记录聚合为秒级K线数据
:param start_ts: 开始时间戳(毫秒),默认全部
:param end_ts: 结束时间戳(毫秒),默认全部
:return: 聚合的秒级K线数量
"""
# 构建查询
query = BitMartETHTrades.select().order_by(BitMartETHTrades.timestamp)
if start_ts:
query = query.where(BitMartETHTrades.timestamp >= start_ts)
if end_ts:
query = query.where(BitMartETHTrades.timestamp <= end_ts)
# 按秒聚合
second_data = {}
trade_count = 0
for trade in query:
trade_count += 1
# 取整到秒(毫秒时间戳)
second_ts = (trade.timestamp // 1000) * 1000
if second_ts not in second_data:
second_data[second_ts] = {
'open': trade.price,
'high': trade.price,
'low': trade.price,
'close': trade.price,
'volume': trade.volume,
'trade_count': 1
}
else:
second_data[second_ts]['high'] = max(second_data[second_ts]['high'], trade.price)
second_data[second_ts]['low'] = min(second_data[second_ts]['low'], trade.price)
second_data[second_ts]['close'] = trade.price
second_data[second_ts]['volume'] += trade.volume
second_data[second_ts]['trade_count'] += 1
# 保存到数据库
saved_count = 0
for ts, ohlc in second_data.items():
try:
BitMartETHSecond.insert(
id=ts,
open=ohlc['open'],
high=ohlc['high'],
low=ohlc['low'],
close=ohlc['close'],
volume=ohlc['volume'],
trade_count=ohlc['trade_count']
).on_conflict(
conflict_target=[BitMartETHSecond.id],
update={
BitMartETHSecond.open: ohlc['open'],
BitMartETHSecond.high: ohlc['high'],
BitMartETHSecond.low: ohlc['low'],
BitMartETHSecond.close: ohlc['close'],
BitMartETHSecond.volume: ohlc['volume'],
BitMartETHSecond.trade_count: ohlc['trade_count'],
}
).execute()
saved_count += 1
except Exception as e:
logger.error(f"保存秒级K线失败 {ts}: {e}")
logger.success(f"聚合完成!{trade_count} 条成交记录 → {saved_count} 条秒级K线")
return saved_count
def get_second_klines(self, start_ts: int = None, end_ts: int = None):
"""
获取秒级K线数据
:param start_ts: 开始时间戳(毫秒)
:param end_ts: 结束时间戳(毫秒)
:return: 秒级K线列表
"""
query = BitMartETHSecond.select().order_by(BitMartETHSecond.id)
if start_ts:
query = query.where(BitMartETHSecond.id >= start_ts)
if end_ts:
query = query.where(BitMartETHSecond.id <= end_ts)
return [{
'timestamp': k.id,
'open': k.open,
'high': k.high,
'low': k.low,
'close': k.close,
'volume': k.volume,
'trade_count': k.trade_count
} for k in query]
def aggregate_trades_custom(self, interval_ms: int = 100, start_ts: int = None, end_ts: int = None):
"""
将成交记录聚合为自定义毫秒级K线数据不保存到数据库直接返回
:param interval_ms: 聚合周期(毫秒),如 100=100ms, 500=500ms, 1000=1秒
:param start_ts: 开始时间戳(毫秒)
:param end_ts: 结束时间戳(毫秒)
:return: K线列表 [{'timestamp', 'open', 'high', 'low', 'close', 'volume', 'trade_count'}, ...]
"""
# 构建查询
query = BitMartETHTrades.select().order_by(BitMartETHTrades.timestamp)
if start_ts:
query = query.where(BitMartETHTrades.timestamp >= start_ts)
if end_ts:
query = query.where(BitMartETHTrades.timestamp <= end_ts)
# 按指定间隔聚合
interval_data = {}
trade_count = 0
for trade in query:
trade_count += 1
# 取整到指定间隔
interval_ts = (trade.timestamp // interval_ms) * interval_ms
if interval_ts not in interval_data:
interval_data[interval_ts] = {
'open': trade.price,
'high': trade.price,
'low': trade.price,
'close': trade.price,
'volume': trade.volume,
'trade_count': 1
}
else:
interval_data[interval_ts]['high'] = max(interval_data[interval_ts]['high'], trade.price)
interval_data[interval_ts]['low'] = min(interval_data[interval_ts]['low'], trade.price)
interval_data[interval_ts]['close'] = trade.price
interval_data[interval_ts]['volume'] += trade.volume
interval_data[interval_ts]['trade_count'] += 1
# 转换为列表
result = []
for ts, ohlc in sorted(interval_data.items()):
result.append({
'timestamp': ts,
'datetime': datetime.datetime.fromtimestamp(ts/1000).strftime('%Y-%m-%d %H:%M:%S.%f')[:-3],
'open': ohlc['open'],
'high': ohlc['high'],
'low': ohlc['low'],
'close': ohlc['close'],
'volume': ohlc['volume'],
'trade_count': ohlc['trade_count']
})
logger.info(f"聚合完成: {trade_count} 条成交记录 → {len(result)}{interval_ms}ms K线")
return result
def get_raw_trades(self, start_ts: int = None, end_ts: int = None, limit: int = None):
"""
获取原始成交记录(逐笔数据,毫秒级)
:param start_ts: 开始时间戳(毫秒)
:param end_ts: 结束时间戳(毫秒)
:param limit: 最大返回条数
:return: 成交记录列表
"""
query = BitMartETHTrades.select().order_by(BitMartETHTrades.timestamp)
if start_ts:
query = query.where(BitMartETHTrades.timestamp >= start_ts)
if end_ts:
query = query.where(BitMartETHTrades.timestamp <= end_ts)
if limit:
query = query.limit(limit)
return [{
'id': t.id,
'timestamp': t.timestamp,
'datetime': datetime.datetime.fromtimestamp(t.timestamp/1000).strftime('%Y-%m-%d %H:%M:%S.%f')[:-3],
'price': t.price,
'volume': t.volume,
'side': '' if t.side == 1 else ''
} for t in query]
def close(self):
"""关闭数据库连接"""
if not db.is_closed():
db.close()
if __name__ == '__main__':
collector = BitMartMultiKlineCollector()
try:
# 查看当前数据统计
collector.get_stats()
# ============ 选择要执行的任务 ============
# 任务1: 抓取K线数据1分钟~1小时周期
# 从 2025-01-01 抓取到当前时间(支持断点续传)
collector.collect_from_date(
start_date='2010-01-01',
periods=[1, 3, 5, 15, 30, 60] # 所有周期
)
# 任务2: 实时采集秒级数据(成交记录)
# 注意: 秒级数据只能实时采集,无法获取历史
# collector.collect_trades_realtime(
# duration_seconds=3600, # 采集1小时
# interval=0.3 # 每0.3秒请求一次
# )
# 任务3: 将已采集的成交记录聚合为秒级K线
# collector.aggregate_trades_to_seconds()
# 再次查看统计
collector.get_stats()
finally:
collector.close()

989
训练AI策略_ETH合约.py Normal file
View File

@@ -0,0 +1,989 @@
"""
ETH 合约 AI 策略训练 + 回测
规则:
- 标的: ETH 合约 (BitMart 1分钟K线)
- 同一时间仅 1 个仓位 (多或空)
- 每笔固定: 100U 保证金 × 100 倍 = 10000U 名义价值
- 手续费 90% 返佣,目标: 月均净利 >= 1000 USDT
流程:
1. 训练: 使用 2020 年全年数据训练 LightGBM特征含多种可调参数指标
2. 回测: 使用 2021 年全年数据做严格样本外真实回测(模型从未见过 2021 数据)
3. 参数搜索在「2021 年月均净利」上选最佳并保存
"""
import datetime
import json
import sqlite3
import time as _time
from pathlib import Path
from collections import defaultdict
import numpy as np
import pandas as pd
import lightgbm as lgb
import warnings
warnings.filterwarnings('ignore')
# 常量ETH 合约100U 保证金 100 倍
MARGIN_PER_TRADE = 100.0 # USDT
LEVERAGE = 100
NOTIONAL_PER_TRADE = MARGIN_PER_TRADE * LEVERAGE # 10000 USDT
TAKER_FEE_RATE = 0.0006
REBATE_RATE = 0.90
# 手续费:先扣全额,返佣次日 8 点才到账,回撤按「可用资金」算
FULL_FEE_PER_TRADE = NOTIONAL_PER_TRADE * TAKER_FEE_RATE * 2
REBATE_PER_TRADE = FULL_FEE_PER_TRADE * REBATE_RATE
NET_FEE_PER_TRADE = FULL_FEE_PER_TRADE - REBATE_PER_TRADE # 净手续费
# 本金与风控目标
INITIAL_BALANCE = 10000.0 # 本金 10000 USDT
MAX_DD_TARGET = 500.0 # 目标:可用资金最大回撤 ≤ 500 USDT
# 训练 / 真实回测 时间范围(左闭右开)
TRAIN_START = '2020-01-01'
TRAIN_END_EXCLUSIVE = '2021-01-01' # 训练数据: [2020-01-01, 2021-01-01) = 2020 全年
TEST_START = '2021-01-01'
TEST_END_EXCLUSIVE = '2022-01-01' # 回测数据: [2021-01-01, 2022-01-01) = 2021 全年
def get_db_path():
return Path(__file__).parent / 'models' / 'database.db'
# 多周期:与 抓取多周期K线.py 一致,表名 bitmart_eth_{suffix}
PERIOD_TABLES = {
'1m': 'bitmart_eth_1m',
'3m': 'bitmart_eth_3m',
'5m': 'bitmart_eth_5m',
'15m': 'bitmart_eth_15m',
'30m': 'bitmart_eth_30m',
'1h': 'bitmart_eth_1h',
}
def _normalize_period(period) -> str:
"""5 -> '5m', 15 -> '15m', '5' -> '5m', '15m' -> '15m'"""
if isinstance(period, int):
return f'{period}m' if period != 60 else '1h'
s = str(period).strip().lower()
if s in PERIOD_TABLES:
return s
if s.endswith('m'):
return s
if s == '60' or s == '1h':
return '1h'
return f'{s}m' if s.isdigit() else s
# ==================== 数据加载 ====================
def load_klines(start_date: str = '2025-01-01', end_date: str = '2026-02-01', period: str = '1m'):
"""从 SQLite 读取指定周期 K 线。period: '1m'|'3m'|'5m'|'15m'|'30m'|'1h' 或 1|3|5|15|30|60"""
period = _normalize_period(period)
table = PERIOD_TABLES.get(period)
if not table:
raise ValueError(f"不支持的周期: {period},可选: {list(PERIOD_TABLES.keys())}")
db = get_db_path()
if not db.exists():
raise FileNotFoundError(f"数据库不存在: {db},请先运行 抓取多周期K线.py 拉取数据")
start_ts = int(datetime.datetime.strptime(start_date, '%Y-%m-%d').timestamp()) * 1000
end_ts = int(datetime.datetime.strptime(end_date, '%Y-%m-%d').timestamp()) * 1000
conn = sqlite3.connect(str(db))
df = pd.read_sql_query(
f"SELECT id as ts, open, high, low, close FROM {table} "
"WHERE id >= ? AND id < ? ORDER BY id",
conn, params=(start_ts, end_ts))
conn.close()
if len(df) == 0:
raise FileNotFoundError(f"{table} 中无 {start_date}~{end_date} 数据,请先抓取该周期 K 线")
df['datetime'] = pd.to_datetime(df['ts'], unit='ms')
df.set_index('datetime', inplace=True)
return df
# ==================== 指标参数(类布林带/均值回归,均可训练优化) ====================
def default_indicator_params():
"""默认指标参数均可被搜索优化。含布林、肯特纳、唐奇安、随机、CCI、威廉、RSI、ATR带、Z-Score、线性回归通道、中位数带、百分位排名、Chandelier、标准差带、Elder 射线等"""
return {
'bb_period': 20,
'bb_std': 2.0,
'keltner_period': 20,
'keltner_atr_mult': 2.0,
'donchian_period': 20,
'stoch_k_period': 14,
'stoch_d_period': 3,
'cci_period': 20,
'willr_period': 14,
'rsi_period': 14,
'atr_band_period': 20,
'atr_band_mult': 2.0,
# 均值回归类新增
'zscore_period': 20,
'lr_period': 20,
'lr_std_mult': 2.0,
'median_band_period': 20,
'pct_rank_period': 20,
'chandelier_period': 22,
'chandelier_mult': 3.0,
'std_band_period': 14,
'std_band_mult': 2.0,
'elder_period': 13,
}
# ==================== 特征工程(所有带带/通道类指标参数可调) ====================
def add_features(df: pd.DataFrame, ind: dict = None) -> pd.DataFrame:
"""生成特征。ind 为指标参数字典,缺省用 default_indicator_params()"""
if ind is None:
ind = default_indicator_params()
c = df['close']
h = df['high']
l = df['low']
o = df['open']
cp = c.replace(0, np.nan)
# --- 布林带 ---
bp = ind['bb_period']
bstd = ind['bb_std']
mid = c.rolling(bp).mean()
std = c.rolling(bp).std()
df['bb_upper'] = mid + bstd * std
df['bb_lower'] = mid - bstd * std
df['bb_mid'] = mid
df['bb_pct'] = (c - df['bb_lower']) / (df['bb_upper'] - df['bb_lower']).replace(0, np.nan)
df['bb_width'] = (df['bb_upper'] - df['bb_lower']) / mid.replace(0, np.nan)
# --- 肯特纳通道 (Keltner) ---
kp = ind['keltner_period']
katr = ind['keltner_atr_mult']
tr = pd.concat([h - l, (h - c.shift(1)).abs(), (l - c.shift(1)).abs()], axis=1).max(axis=1)
atr_k = tr.rolling(kp).mean()
k_mid = c.ewm(span=kp, adjust=False).mean()
df['keltner_upper'] = k_mid + katr * atr_k
df['keltner_lower'] = k_mid - katr * atr_k
df['keltner_mid'] = k_mid
df['keltner_pct'] = (c - df['keltner_lower']) / (df['keltner_upper'] - df['keltner_lower']).replace(0, np.nan)
df['keltner_width'] = (df['keltner_upper'] - df['keltner_lower']) / k_mid.replace(0, np.nan)
# --- 唐奇安通道 (Donchian) ---
dp = ind['donchian_period']
du = h.rolling(dp).max()
dd = l.rolling(dp).min()
dm = (du + dd) / 2
df['donchian_upper'] = du
df['donchian_lower'] = dd
df['donchian_mid'] = dm
df['donchian_pct'] = (c - dd) / (du - dd).replace(0, np.nan)
df['donchian_width'] = (du - dd) / dm.replace(0, np.nan)
# --- 随机指标 (Stochastic) ---
sk, sd = ind['stoch_k_period'], ind['stoch_d_period']
low_k = l.rolling(sk).min()
high_k = h.rolling(sk).max()
df['stoch_k'] = (c - low_k) / (high_k - low_k).replace(0, np.nan) * 100
df['stoch_d'] = df['stoch_k'].rolling(sd).mean()
# --- CCI ---
cci_p = ind['cci_period']
typical = (h + l + c) / 3
cci_m = typical.rolling(cci_p).mean()
cci_s = typical.rolling(cci_p).std()
df['cci'] = (typical - cci_m) / (0.015 * cci_s.replace(0, np.nan))
# --- 威廉 %R ---
wp = ind['willr_period']
high_w = h.rolling(wp).max()
low_w = l.rolling(wp).min()
df['willr'] = -100 * (high_w - c) / (high_w - low_w).replace(0, np.nan)
# --- RSI ---
rp = ind['rsi_period']
delta = c.diff()
gain = delta.clip(lower=0)
loss = (-delta).clip(lower=0)
avg_gain = gain.rolling(rp).mean()
avg_loss = loss.rolling(rp).mean()
rs = avg_gain / avg_loss.replace(0, np.nan)
df['rsi'] = 100 - 100 / (1 + rs)
# --- ATR 带 (中轨 SMA带宽 ATR 倍数) ---
abp, abm = ind['atr_band_period'], ind['atr_band_mult']
atr_ab = tr.rolling(abp).mean()
ab_mid = c.rolling(abp).mean()
df['atr_band_upper'] = ab_mid + abm * atr_ab
df['atr_band_lower'] = ab_mid - abm * atr_ab
df['atr_band_mid'] = ab_mid
df['atr_band_pct'] = (c - df['atr_band_lower']) / (df['atr_band_upper'] - df['atr_band_lower']).replace(0, np.nan)
df['atr_band_width'] = (df['atr_band_upper'] - df['atr_band_lower']) / ab_mid.replace(0, np.nan)
# --- Z-Score 均值回归(价格偏离均值的标准差倍数)---
zp = ind['zscore_period']
z_mid = c.rolling(zp).mean()
z_std = c.rolling(zp).std()
df['zscore'] = (c - z_mid) / z_std.replace(0, np.nan)
df['zscore_abs'] = df['zscore'].abs()
# --- 线性回归通道 (Linear Regression Channel),向量化实现 ---
lrp = ind['lr_period']
lr_mult = ind['lr_std_mult']
sum_x = lrp * (lrp - 1) / 2.0
sum_x2 = lrp * (lrp - 1) * (2 * lrp - 1) / 6.0
s_xy = sum(c.shift(j) * j for j in range(lrp))
s_y = c.rolling(lrp).sum()
denom = lrp * sum_x2 - sum_x * sum_x
slope = (lrp * s_xy - sum_x * s_y) / (denom or 1e-10)
intercept = s_y / lrp - slope * (lrp - 1) / 2.0
lr_mid = intercept + slope * (lrp - 1)
lr_resid = c.rolling(lrp).std()
df['lr_upper'] = lr_mid + lr_mult * lr_resid
df['lr_lower'] = lr_mid - lr_mult * lr_resid
df['lr_mid'] = lr_mid
df['lr_pct'] = (c - df['lr_lower']) / (df['lr_upper'] - df['lr_lower']).replace(0, np.nan)
df['lr_width'] = (df['lr_upper'] - df['lr_lower']) / lr_mid.replace(0, np.nan)
# --- 中位数带 (Price vs Median稳健均值回归) ---
mdp = ind['median_band_period']
med = c.rolling(mdp).median()
mstd = c.rolling(mdp).std()
df['median_band_mid'] = med
df['price_vs_median'] = (c - med) / mstd.replace(0, np.nan)
df['median_band_upper'] = med + 2 * mstd
df['median_band_lower'] = med - 2 * mstd
df['median_band_pct'] = (c - df['median_band_lower']) / (df['median_band_upper'] - df['median_band_lower']).replace(0, np.nan)
# --- 百分位排名 (Percent Rank0~1),向量化:价格在区间内位置 ---
prp = ind['pct_rank_period']
rmin = c.rolling(prp).min()
rmax = c.rolling(prp).max()
df['pct_rank'] = (c - rmin) / (rmax - rmin).replace(0, np.nan)
# --- Chandelier Exit 通道 ---
cep, cem = ind['chandelier_period'], ind['chandelier_mult']
atr_ce = tr.rolling(cep).mean()
high_max = h.rolling(cep).max()
low_min = l.rolling(cep).min()
df['chandelier_upper'] = high_max - cem * atr_ce
df['chandelier_lower'] = low_min + cem * atr_ce
df['chandelier_mid'] = (df['chandelier_upper'] + df['chandelier_lower']) / 2
df['chandelier_pct'] = (c - df['chandelier_lower']) / (df['chandelier_upper'] - df['chandelier_lower']).replace(0, np.nan)
df['chandelier_width'] = (df['chandelier_upper'] - df['chandelier_lower']) / df['chandelier_mid'].replace(0, np.nan)
# --- 标准差带 (Std Band与 BB 不同周期/倍数可组合) ---
sbp, sbm = ind['std_band_period'], ind['std_band_mult']
sb_mid = c.rolling(sbp).mean()
sb_std = c.rolling(sbp).std()
df['std_band_upper'] = sb_mid + sbm * sb_std
df['std_band_lower'] = sb_mid - sbm * sb_std
df['std_band_mid'] = sb_mid
df['std_band_pct'] = (c - df['std_band_lower']) / (df['std_band_upper'] - df['std_band_lower']).replace(0, np.nan)
df['std_band_width'] = (df['std_band_upper'] - df['std_band_lower']) / sb_mid.replace(0, np.nan)
# --- Elder 射线 (Elder Ray价格与均线的偏离) ---
ep = ind['elder_period']
ema_elder = c.ewm(span=ep, adjust=False).mean()
df['elder_bull'] = (h - ema_elder) / cp
df['elder_bear'] = (l - ema_elder) / cp
df['elder_dist'] = (c - ema_elder) / cp
# --- EMA固定若干周期作辅助---
for p in [5, 8, 13, 21, 50, 120]:
df[f'ema_{p}'] = c.ewm(span=p, adjust=False).mean()
df['ema_fast_slow'] = (df['ema_8'] - df['ema_21']) / cp
df['price_vs_ema120'] = (c - df['ema_120']) / cp
df['ema8_slope'] = df['ema_8'].pct_change(5)
# --- MACD ---
ema12 = c.ewm(span=12, adjust=False).mean()
ema26 = c.ewm(span=26, adjust=False).mean()
df['macd'] = (ema12 - ema26) / cp
df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
df['macd_hist'] = df['macd'] - df['macd_signal']
# --- ATR 比例 ---
df['atr_14'] = tr.rolling(14).mean()
df['atr_pct'] = df['atr_14'] / cp
# --- 动量与波动 ---
for p in [1, 3, 5, 10, 20]:
df[f'ret_{p}'] = c.pct_change(p)
df['vol_5'] = c.pct_change().rolling(5).std()
df['vol_20'] = c.pct_change().rolling(20).std()
# --- K线形态与时间 ---
body = (c - o).abs()
df['body_pct'] = body / cp
df['price_position_20'] = (c - l.rolling(20).min()) / (h.rolling(20).max() - l.rolling(20).min()).replace(0, np.nan)
df['hour_sin'] = np.sin(2 * np.pi * df.index.hour / 24)
df['hour_cos'] = np.cos(2 * np.pi * df.index.hour / 24)
return df
def get_feature_cols(df: pd.DataFrame):
"""排除原始价格与各通道上下轨/中轨原始列,只保留衍生特征"""
exclude = {
'ts', 'open', 'high', 'low', 'close', 'label',
'bb_upper', 'bb_lower', 'bb_mid', 'atr_14',
'keltner_upper', 'keltner_lower', 'keltner_mid',
'donchian_upper', 'donchian_lower', 'donchian_mid',
'atr_band_upper', 'atr_band_lower', 'atr_band_mid',
'lr_upper', 'lr_lower', 'lr_mid',
'median_band_upper', 'median_band_lower', 'median_band_mid',
'chandelier_upper', 'chandelier_lower', 'chandelier_mid',
'std_band_upper', 'std_band_lower', 'std_band_mid',
'ema_5', 'ema_8', 'ema_13', 'ema_21', 'ema_50', 'ema_120',
}
return [c for c in df.columns if c not in exclude
and df[c].dtype in ('float64', 'float32', 'int64', 'int32')]
# ==================== 标签 ====================
def add_labels(df: pd.DataFrame, forward_bars: int = 10, threshold: float = 0.002) -> pd.DataFrame:
"""未来 N 根收益率 > threshold → 1(多), < -threshold → -1(空), 否则 0"""
future_ret = df['close'].shift(-forward_bars) / df['close'] - 1
df = df.copy()
df['label'] = 0
df.loc[future_ret > threshold, 'label'] = 1
df.loc[future_ret < -threshold, 'label'] = -1
return df
# ==================== 滚动训练 ====================
def train_predict_walkforward(
df: pd.DataFrame,
feature_cols: list,
train_months: int = 3,
lgb_rounds: int = 250,
):
"""滚动:用过去 train_months 月训练,预测下一个月"""
df = df.copy()
df['month'] = df.index.to_period('M')
months = sorted(df['month'].unique())
all_proba_long = pd.Series(0.0, index=df.index)
all_proba_short = pd.Series(0.0, index=df.index)
last_model = None
for i in range(train_months, len(months)):
test_month = months[i]
train_start = months[i - train_months]
train_mask = (df['month'] >= train_start) & (df['month'] < test_month)
test_mask = df['month'] == test_month
train_df = df[train_mask].dropna(subset=feature_cols + ['label'])
test_df = df[test_mask].dropna(subset=feature_cols)
if len(train_df) < 1000 or len(test_df) < 100:
continue
X_train = train_df[feature_cols].values
y_train = (train_df['label'].values + 1).astype(int) # -1,0,1 -> 0,1,2
X_test = test_df[feature_cols].values
params = {
'objective': 'multiclass',
'num_class': 3,
'metric': 'multi_logloss',
'learning_rate': 0.05,
'num_leaves': 31,
'max_depth': 6,
'min_child_samples': 50,
'subsample': 0.8,
'colsample_bytree': 0.8,
'reg_alpha': 0.1,
'reg_lambda': 0.1,
'verbose': -1,
'n_jobs': -1,
'seed': 42,
}
dtrain = lgb.Dataset(X_train, label=y_train)
model = lgb.train(params, dtrain, num_boost_round=lgb_rounds)
last_model = model
proba = model.predict(X_test) # (n, 3) -> [P(short), P(neutral), P(long)]
test_idx = test_df.index
all_proba_short.loc[test_idx] = proba[:, 0]
all_proba_long.loc[test_idx] = proba[:, 2]
return all_proba_long, all_proba_short, last_model
def train_on_period_predict_on_other(
df_train: pd.DataFrame,
df_test: pd.DataFrame,
feature_cols: list,
forward_bars: int,
label_threshold: float,
lgb_rounds: int = 250,
):
"""
在 df_train 上训练一个模型,在 df_test 上预测(严格样本外)。
df_train 需含 label 列df_test 只需含 feature_cols。
返回 (proba_long_series, proba_short_series, model)
"""
train_df = df_train.dropna(subset=feature_cols + ['label'])
if len(train_df) < 2000:
return None, None, None
X_train = train_df[feature_cols].values
y_train = (train_df['label'].values + 1).astype(int)
test_df = df_test.dropna(subset=feature_cols)
if len(test_df) < 100:
return None, None, None
X_test = test_df[feature_cols].values
params = {
'objective': 'multiclass',
'num_class': 3,
'metric': 'multi_logloss',
'learning_rate': 0.05,
'num_leaves': 31,
'max_depth': 6,
'min_child_samples': 50,
'subsample': 0.8,
'colsample_bytree': 0.8,
'reg_alpha': 0.1,
'reg_lambda': 0.1,
'verbose': -1,
'n_jobs': -1,
'seed': 42,
}
dtrain = lgb.Dataset(X_train, label=y_train)
model = lgb.train(params, dtrain, num_boost_round=lgb_rounds)
proba = model.predict(X_test)
all_proba_long = pd.Series(0.0, index=df_test.index)
all_proba_short = pd.Series(0.0, index=df_test.index)
all_proba_short.loc[test_df.index] = proba[:, 0]
all_proba_long.loc[test_df.index] = proba[:, 2]
return all_proba_long, all_proba_short, model
# ==================== 回测单仓位、100U×100倍 ====================
def backtest(
df: pd.DataFrame,
proba_long: pd.Series,
proba_short: pd.Series,
notional: float = NOTIONAL_PER_TRADE,
prob_threshold: float = 0.45,
min_hold_seconds: int = 180,
max_hold_seconds: int = 1800,
sl_pct: float = 0.004,
tp_pct: float = 0.006,
) -> list:
"""同一时间仅 1 仓,开仓即 10000U 名义100U×100倍"""
pos = 0
open_price = 0.0
open_time = None
trades = []
for i in range(len(df)):
dt = df.index[i]
price = df['close'].iloc[i]
pl = proba_long.iloc[i]
ps = proba_short.iloc[i]
if pos != 0 and open_time is not None:
pnl_pct = (price - open_price) / open_price if pos == 1 else (open_price - price) / open_price
hold_sec = (dt - open_time).total_seconds()
# 硬止损
if -pnl_pct >= sl_pct * 1.5:
pnl_usdt = notional * pnl_pct
trades.append((pos, open_price, price, pnl_usdt, hold_sec, '硬止损', open_time, dt))
pos = 0
continue
if hold_sec >= min_hold_seconds:
if -pnl_pct >= sl_pct:
pnl_usdt = notional * pnl_pct
trades.append((pos, open_price, price, pnl_usdt, hold_sec, '止损', open_time, dt))
pos = 0
continue
if pnl_pct >= tp_pct:
pnl_usdt = notional * pnl_pct
trades.append((pos, open_price, price, pnl_usdt, hold_sec, '止盈', open_time, dt))
pos = 0
continue
if hold_sec >= max_hold_seconds:
pnl_usdt = notional * pnl_pct
trades.append((pos, open_price, price, pnl_usdt, hold_sec, '超时', open_time, dt))
pos = 0
continue
# AI 反向信号平仓
if pos == 1 and ps > prob_threshold + 0.05:
pnl_usdt = notional * pnl_pct
trades.append((pos, open_price, price, pnl_usdt, hold_sec, 'AI反转', open_time, dt))
pos = 0
elif pos == -1 and pl > prob_threshold + 0.05:
pnl_usdt = notional * pnl_pct
trades.append((pos, open_price, price, pnl_usdt, hold_sec, 'AI反转', open_time, dt))
pos = 0
if pos == 0:
if pl > prob_threshold and pl > ps:
pos = 1
open_price = price
open_time = dt
elif ps > prob_threshold and ps > pl:
pos = -1
open_price = price
open_time = dt
if pos != 0:
price = df['close'].iloc[-1]
dt = df.index[-1]
pnl_pct = (price - open_price) / open_price if pos == 1 else (open_price - price) / open_price
hold_sec = (dt - open_time).total_seconds()
trades.append((pos, open_price, price, notional * pnl_pct, hold_sec, '结束', open_time, dt))
return trades
# ==================== 结果分析 ====================
def analyze_trades(trades: list, notional: float = NOTIONAL_PER_TRADE, initial_balance: float = INITIAL_BALANCE) -> dict:
"""
统计净利、胜率、回撤、月均净利等。
回撤按「可用资金」计算:手续费当日扣全额,返佣次日 8 点才到账,故 dd_available 更保守。
"""
if not trades:
return {'n': 0, 'net': 0.0, 'wr': 0.0, 'dd': 0.0, 'dd_available': 0.0, 'total_pnl': 0.0, 'monthly_net': 0.0, 'months': 0}
n = len(trades)
total_pnl = sum(t[3] for t in trades)
net = total_pnl - NET_FEE_PER_TRADE * n
wins = len([t for t in trades if t[3] > 0])
wr = wins / n * 100
# 简单回撤(净手续费已扣,不含返佣延迟)
cum = 0.0
peak = 0.0
dd = 0.0
for t in trades:
cum += t[3] - NET_FEE_PER_TRADE
if cum > peak:
peak = cum
if peak - cum > dd:
dd = peak - cum
# 可用资金回撤:平仓时扣全额手续费,返佣次日 8 点才到账(简化为 +1 天)
events = [] # (datetime, delta_balance)
for t in trades:
close_time = t[7]
events.append((close_time, t[3] - FULL_FEE_PER_TRADE))
try:
next_day = close_time + datetime.timedelta(days=1)
except Exception:
next_day = close_time
events.append((next_day, REBATE_PER_TRADE))
events.sort(key=lambda x: x[0])
balance = initial_balance
peak_bal = balance
dd_available = 0.0
for _, delta in events:
balance += delta
if balance > peak_bal:
peak_bal = balance
if peak_bal - balance > dd_available:
dd_available = peak_bal - balance
# 按平仓时间所在月汇总净利
monthly_net = defaultdict(float)
for t in trades:
close_time = t[7]
month_key = close_time.strftime('%Y-%m') if hasattr(close_time, 'strftime') else str(close_time)[:7]
monthly_net[month_key] += t[3] - NET_FEE_PER_TRADE
num_months = len(monthly_net) or 1
avg_monthly_net = sum(monthly_net.values()) / num_months
return {
'n': n,
'net': net,
'wr': wr,
'dd': dd,
'dd_available': dd_available,
'total_pnl': total_pnl,
'avg_pnl': net / n,
'monthly_net': avg_monthly_net,
'months': num_months,
'monthly_detail': dict(monthly_net),
}
TARGET_MONTHLY_NET = 1000.0 # 目标月均净利 (USDT)
def print_report(trades: list, label: str = ""):
"""打印回测报告:本金、可用资金最大回撤(返佣次日到账)、月均净利及双目标"""
if not trades:
print(f" [{label}] 无交易", flush=True)
return
r = analyze_trades(trades)
reasons = defaultdict(int)
for t in trades:
reasons[t[5]] += 1
dd_ok = r['dd_available'] <= MAX_DD_TARGET
monthly_ok = r['monthly_net'] >= TARGET_MONTHLY_NET
print(f"\n === {label} ===", flush=True)
print(f" 本金: {INITIAL_BALANCE:.0f} USDT | 交易: {r['n']} 笔 | 总净利: {r['net']:+.2f} USDT | 胜率: {r['wr']:.1f}%", flush=True)
print(f" 可用资金最大回撤: {r['dd_available']:.2f} USDT (返佣次日到账) | 目标 ≤{MAX_DD_TARGET:.0f}U: {'达标' if dd_ok else '未达标'}", flush=True)
print(f" 月均净利: {r['monthly_net']:+.2f} USDT (共 {r['months']} 个月) | 目标 ≥{TARGET_MONTHLY_NET:.0f}U/月: {'达标' if monthly_ok else '未达标'}", flush=True)
print(f" 平仓原因: {dict(reasons)}", flush=True)
# ==================== 模型保存 ====================
def save_model_and_params(model, feature_cols: list, params: dict, path: Path = None):
"""保存 LightGBM 模型和策略参数"""
if path is None:
path = Path(__file__).parent / 'models' / 'eth_ai_strategy'
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
model.save_model(str(path / 'model.txt'))
with open(path / 'feature_cols.json', 'w', encoding='utf-8') as f:
json.dump(feature_cols, f, ensure_ascii=False, indent=2)
with open(path / 'strategy_params.json', 'w', encoding='utf-8') as f:
json.dump(params, f, ensure_ascii=False, indent=2)
print(f" 模型与参数已保存至: {path}", flush=True)
# ==================== 主流程2020 训练 + 2021 真实回测 ====================
def run_single(
df_train: pd.DataFrame,
df_test: pd.DataFrame,
ind_params: dict,
forward_bars: int,
label_threshold: float,
prob_threshold: float,
sl_pct: float,
tp_pct: float,
min_hold: int,
max_hold: int,
):
"""
用 df_train2020训练模型在 df_test2021上预测并回测严格样本外。
返回 (trades_2021, model, feature_cols)。
"""
df_tr = add_features(df_train.copy(), ind=ind_params)
df_te = add_features(df_test.copy(), ind=ind_params)
feature_cols = get_feature_cols(df_tr)
df_tr = add_labels(df_tr, forward_bars=forward_bars, threshold=label_threshold)
proba_long, proba_short, model = train_on_period_predict_on_other(
df_tr, df_te, feature_cols, forward_bars, label_threshold,
)
if model is None:
return [], None, None
trades = backtest(
df_te, proba_long, proba_short,
notional=NOTIONAL_PER_TRADE,
prob_threshold=prob_threshold,
min_hold_seconds=min_hold,
max_hold_seconds=max_hold,
sl_pct=sl_pct,
tp_pct=tp_pct,
)
return trades, model, feature_cols
def _grid_indicator_and_strategy(max_configs: int = 220, seed: int = 42):
"""随机采样「各类均值回归/带型指标参数 + 策略参数」组合,各种组合、各种指标参数一起调"""
rng = np.random.default_rng(seed)
ind_base = default_indicator_params()
# 布林类
bb_opts = [(20, 2.0), (20, 2.5), (30, 2.0), (15, 2.0), (14, 2.0)]
keltner_opts = [(20, 2.0), (20, 2.5), (14, 2.0), (22, 2.5)]
donchian_opts = [14, 20, 30]
stoch_opts = [(14, 3), (10, 3), (20, 5)]
cci_opts = [14, 20]
rsi_opts = [7, 14, 21]
atr_band_opts = [(20, 2.0), (14, 2.0), (22, 2.5)]
# 新增均值回归类
zscore_opts = [14, 20, 30]
lr_opts = [(20, 2.0), (14, 2.0), (30, 2.5)]
median_band_opts = [14, 20, 30]
pct_rank_opts = [14, 20, 30]
chandelier_opts = [(22, 3.0), (14, 2.5), (30, 3.5)]
std_band_opts = [(14, 2.0), (20, 2.0), (14, 2.5)]
elder_opts = [10, 13, 20]
# 策略:加入更保守选项以压低回撤(紧止损、高置信)
label_opts = [(8, 0.0015), (10, 0.002), (10, 0.003), (15, 0.002), (20, 0.003)]
sl_tp_opts = [
(0.002, 0.004), (0.0025, 0.005), (0.003, 0.005), (0.003, 0.006),
(0.004, 0.006), (0.004, 0.008), (0.005, 0.008), (0.005, 0.010),
]
prob_opts = [0.42, 0.45, 0.48, 0.50]
out = []
for _ in range(max_configs):
bb_p, bb_s = bb_opts[rng.integers(0, len(bb_opts))]
kp, ka = keltner_opts[rng.integers(0, len(keltner_opts))]
dc = donchian_opts[rng.integers(0, len(donchian_opts))]
sk, sd = stoch_opts[rng.integers(0, len(stoch_opts))]
cci_p = cci_opts[rng.integers(0, len(cci_opts))]
rsi_p = rsi_opts[rng.integers(0, len(rsi_opts))]
abp, abm = atr_band_opts[rng.integers(0, len(atr_band_opts))]
zp = zscore_opts[rng.integers(0, len(zscore_opts))]
lrp, lrm = lr_opts[rng.integers(0, len(lr_opts))]
mdp = median_band_opts[rng.integers(0, len(median_band_opts))]
prp = pct_rank_opts[rng.integers(0, len(pct_rank_opts))]
cep, cem = chandelier_opts[rng.integers(0, len(chandelier_opts))]
sbp, sbm = std_band_opts[rng.integers(0, len(std_band_opts))]
ep = elder_opts[rng.integers(0, len(elder_opts))]
fb, th = label_opts[rng.integers(0, len(label_opts))]
sl, tp = sl_tp_opts[rng.integers(0, len(sl_tp_opts))]
prob = prob_opts[rng.integers(0, len(prob_opts))]
ind = dict(ind_base,
bb_period=bb_p, bb_std=bb_s,
keltner_period=kp, keltner_atr_mult=ka,
donchian_period=dc,
stoch_k_period=sk, stoch_d_period=sd,
cci_period=cci_p, rsi_period=rsi_p,
atr_band_period=abp, atr_band_mult=abm,
zscore_period=zp,
lr_period=lrp, lr_std_mult=lrm,
median_band_period=mdp,
pct_rank_period=prp,
chandelier_period=cep, chandelier_mult=cem,
std_band_period=sbp, std_band_mult=sbm,
elder_period=ep,
)
out.append({
'ind': ind,
'forward_bars': fb,
'label_threshold': th,
'sl_pct': sl,
'tp_pct': tp,
'prob_threshold': prob,
'min_hold': 180,
'max_hold': 1800,
})
return out
def run_cycle_compare(periods: list, do_save_best: bool = True):
"""
多周期对比:对每个周期用同一套默认参数做 2020 训练 + 2021 回测,比较 2021 年月均净利等。
periods: 如 [5, 15] 表示 5m、15m或 [1, 3, 5, 15, 30, 60] 表示全部周期。
"""
t0 = _time.time()
period_labels = [_normalize_period(p) for p in periods]
print("=" * 60, flush=True)
print(" ETH 合约 — 多周期回测对比 (2020 训练 / 2021 回测)", flush=True)
print(f" 参与周期: {', '.join(period_labels)}", flush=True)
print("=" * 60, flush=True)
ind = default_indicator_params()
results = []
best_net = -1e9
best_row = None
for i, period in enumerate(period_labels):
print(f"\n [{i+1}/{len(period_labels)}] 周期 {period} ...", flush=True)
try:
df_2020 = load_klines(TRAIN_START, TRAIN_END_EXCLUSIVE, period=period)
df_2021 = load_klines(TEST_START, TEST_END_EXCLUSIVE, period=period)
except Exception as e:
print(f" 跳过 {period}: {e}", flush=True)
results.append({'period': period, 'ok': False, 'error': str(e)})
continue
print(f" 数据: 2020 {len(df_2020):,} 根 | 2021 {len(df_2021):,}", flush=True)
trades, model, feature_cols = run_single(
df_2020, df_2021, ind,
forward_bars=10, label_threshold=0.002,
prob_threshold=0.45, sl_pct=0.004, tp_pct=0.006,
min_hold=180, max_hold=1800,
)
if not trades or model is None:
print(f" {period}: 无有效交易", flush=True)
results.append({'period': period, 'ok': False})
continue
r = analyze_trades(trades)
results.append({
'period': period,
'ok': True,
'n': r['n'],
'net': r['net'],
'monthly_net': r['monthly_net'],
'wr': r['wr'],
'dd': r['dd'],
'dd_available': r['dd_available'],
'trades': trades,
'model': model,
'feature_cols': feature_cols,
})
print(f" {period}: 交易 {r['n']} 笔 | 月均 {r['monthly_net']:+.2f} USDT | 可用回撤 {r['dd_available']:.2f} USDT | 胜率 {r['wr']:.1f}%", flush=True)
if r['monthly_net'] > best_net:
best_net = r['monthly_net']
best_row = results[-1]
# 打印对比表
ok_results = [x for x in results if x.get('ok')]
print("\n" + "=" * 60, flush=True)
print(" 多周期对比结果 (2021 年样本外)", flush=True)
print("=" * 60, flush=True)
if not ok_results:
print(" 无有效回测结果。", flush=True)
return
print(f" {'周期':<6} {'交易数':>8} {'总净利':>12} {'月均净利':>12} {'可用回撤':>10} {'胜率':>8}", flush=True)
print(" " + "-" * 62, flush=True)
for x in ok_results:
print(f" {x['period']:<6} {x['n']:>8} {x['net']:>+12.2f} {x['monthly_net']:>+12.2f} {x.get('dd_available', x['dd']):>10.2f} {x['wr']:>7.1f}%", flush=True)
# 优选回撤≤500 且 月盈≥1000否则月均最高
def _cmp_cycle(a):
dd_ok = a.get('dd_available', a['dd']) <= MAX_DD_TARGET
mo_ok = a['monthly_net'] >= TARGET_MONTHLY_NET
return (dd_ok, mo_ok, a['monthly_net'], -a.get('dd_available', a['dd']))
best = max(ok_results, key=_cmp_cycle)
best_row = best
print(" " + "-" * 62, flush=True)
b_dd = best.get('dd_available', best['dd'])
print(f" 最佳周期: {best['period']} (月均 {best['monthly_net']:+.2f} USDT, 可用回撤 {b_dd:.2f} USDT)", flush=True)
if b_dd <= MAX_DD_TARGET and best['monthly_net'] >= TARGET_MONTHLY_NET:
print(f" 双目标达标: 回撤≤{MAX_DD_TARGET:.0f}U 且 月盈≥{TARGET_MONTHLY_NET:.0f}U", flush=True)
if do_save_best and best_row and best_row.get('model') is not None:
flat_params = {
**ind,
'forward_bars': 10, 'label_threshold': 0.002,
'prob_threshold': 0.45, 'sl_pct': 0.004, 'tp_pct': 0.006,
'min_hold_seconds': 180, 'max_hold_seconds': 1800,
'margin_per_trade': MARGIN_PER_TRADE, 'leverage': LEVERAGE,
'notional': NOTIONAL_PER_TRADE, 'rebate_rate': REBATE_RATE,
'train_period': f'{TRAIN_START} ~ {TRAIN_END_EXCLUSIVE}',
'test_period': f'{TEST_START} ~ {TEST_END_EXCLUSIVE}',
'kline_period': best['period'],
'initial_balance': INITIAL_BALANCE,
'max_dd_target': MAX_DD_TARGET,
}
save_model_and_params(best_row['model'], best_row.get('feature_cols', []), flat_params)
print(f" 已保存最佳周期 {best['period']} 的模型与参数。", flush=True)
print(f"\n 总耗时: {_time.time() - t0:.1f}s", flush=True)
print("=" * 60, flush=True)
def main(do_grid_search: bool = True, period: str = '1m'):
"""
2020 年全年训练2021 年全年真实回测(严格样本外)。
period: K 线周期 '1m'|'5m'|'15m' 等。
参数搜索时在 2021 年月均净利上选最佳并保存。
"""
period = _normalize_period(period)
t0 = _time.time()
print("=" * 60, flush=True)
print(f" ETH 合约 AI 策略 — 2020 训练 / 2021 真实回测 | K线周期 {period} | 单仓 100U×100倍 | 90% 返佣", flush=True)
print(f" 本金: {INITIAL_BALANCE:.0f} USDT | 返佣次日 8 点到账,回撤按可用资金计算", flush=True)
print(f" 目标: 可用资金最大回撤 ≤ {MAX_DD_TARGET:.0f} USDT月均净利 ≥ {TARGET_MONTHLY_NET:.0f} USDT", flush=True)
print("=" * 60, flush=True)
print(f"\n[1/4] 加载 K 线 (周期 {period}, 训练 2020 / 回测 2021)...", flush=True)
df_2020 = load_klines(TRAIN_START, TRAIN_END_EXCLUSIVE, period=period)
df_2021 = load_klines(TEST_START, TEST_END_EXCLUSIVE, period=period)
print(f" 2020 训练: {len(df_2020):,} 根 | 2021 回测: {len(df_2021):,}", flush=True)
if len(df_2020) < 10000:
print(" 警告: 2020 年数据不足 10000 根,请先运行 抓取多周期K线.py 拉取 2020 年数据。", flush=True)
if len(df_2021) < 1000:
print(" 警告: 2021 年数据不足,回测结果可能不可靠。", flush=True)
# 优选回撤≤500 且 月盈≥1000其次回撤≤500再其次月盈≥1000否则取综合最优
def _score(r):
dd_ok = r['dd_available'] <= MAX_DD_TARGET
monthly_ok = r['monthly_net'] >= TARGET_MONTHLY_NET
return (dd_ok, monthly_ok, r['monthly_net'], -r['dd_available'])
best_score = (-1, -1, -1e9, 1e9)
best_result = None
if do_grid_search:
configs = list(_grid_indicator_and_strategy())
print(f"\n[2/4] 参数搜索 (共 {len(configs)} 组): 目标 回撤≤{MAX_DD_TARGET:.0f}U 且 月盈≥{TARGET_MONTHLY_NET:.0f}U ...", flush=True)
for i, cfg in enumerate(configs):
_t_start = _time.time()
trades, model, feature_cols = run_single(
df_2020,
df_2021,
cfg['ind'],
cfg['forward_bars'],
cfg['label_threshold'],
cfg['prob_threshold'],
cfg['sl_pct'],
cfg['tp_pct'],
cfg['min_hold'],
cfg['max_hold'],
)
_elapsed = _time.time() - _t_start
if not trades or model is None:
if (i + 1) % 10 == 0 or i == 0:
print(f"{i+1}/{len(configs)} 完成 (本组 {_elapsed:.0f}s无有效交易)", flush=True)
continue
r = analyze_trades(trades)
score = _score(r)
if score > best_score:
best_score = score
best_result = (trades, model, feature_cols, {
'indicator_params': cfg['ind'],
'forward_bars': cfg['forward_bars'],
'label_threshold': cfg['label_threshold'],
'prob_threshold': cfg['prob_threshold'],
'sl_pct': cfg['sl_pct'],
'tp_pct': cfg['tp_pct'],
'min_hold_seconds': cfg['min_hold'],
'max_hold_seconds': cfg['max_hold'],
'margin_per_trade': MARGIN_PER_TRADE,
'leverage': LEVERAGE,
'notional': NOTIONAL_PER_TRADE,
'rebate_rate': REBATE_RATE,
'train_period': f'{TRAIN_START} ~ {TRAIN_END_EXCLUSIVE}',
'test_period': f'{TEST_START} ~ {TEST_END_EXCLUSIVE}',
'kline_period': period,
'initial_balance': INITIAL_BALANCE,
'max_dd_target': MAX_DD_TARGET,
})
_total_so_far = _time.time() - t0
_avg_per = _total_so_far / (i + 1)
_left = _avg_per * (len(configs) - i - 1)
if (i + 1) % 10 == 0 or i == 0:
print(f"{i+1}/{len(configs)} 完成 | 本组 {_elapsed:.0f}s | 回撤 {r['dd_available']:.0f} 月均 {r['monthly_net']:+.0f} | 预计剩余 ~{_left/60:.0f}min", flush=True)
else:
print("\n[2/4] 使用默认参数: 2020 训练 -> 2021 回测...", flush=True)
ind = default_indicator_params()
trades, model, feature_cols = run_single(
df_2020, df_2021, ind,
forward_bars=10, label_threshold=0.002,
prob_threshold=0.45, sl_pct=0.004, tp_pct=0.006,
min_hold=180, max_hold=1800,
)
if trades and model is not None:
r = analyze_trades(trades)
best_monthly = r['monthly_net']
reached_target = r['monthly_net'] >= TARGET_MONTHLY_NET
best_result = (trades, model, feature_cols, {
'indicator_params': ind,
'forward_bars': 10,
'label_threshold': 0.002,
'prob_threshold': 0.45,
'sl_pct': 0.004,
'tp_pct': 0.006,
'min_hold_seconds': 180,
'max_hold_seconds': 1800,
'margin_per_trade': MARGIN_PER_TRADE,
'leverage': LEVERAGE,
'notional': NOTIONAL_PER_TRADE,
'rebate_rate': REBATE_RATE,
'train_period': f'{TRAIN_START} ~ {TRAIN_END_EXCLUSIVE}',
'test_period': f'{TEST_START} ~ {TEST_END_EXCLUSIVE}',
'kline_period': period,
'initial_balance': INITIAL_BALANCE,
'max_dd_target': MAX_DD_TARGET,
})
if best_result is None:
print(" 未得到有效模型/交易(可能 2020 或 2021 数据不足)。", flush=True)
return
trades, model, feature_cols, full_params = best_result
print("\n[3/4] 2021 年真实回测结果 (最佳参数)...", flush=True)
print_report(trades, "2021 年样本外回测")
flat_params = {**full_params.get('indicator_params', {}), **{k: v for k, v in full_params.items() if k != 'indicator_params'}}
save_model_and_params(model, feature_cols, flat_params)
r = analyze_trades(trades)
elapsed = _time.time() - t0
dd_ok = r['dd_available'] <= MAX_DD_TARGET
monthly_ok = r['monthly_net'] >= TARGET_MONTHLY_NET
print(f"\n[4/4] 总耗时: {elapsed:.1f}s | 可用资金回撤: {r['dd_available']:.2f} USDT (目标≤{MAX_DD_TARGET:.0f}): {'达标' if dd_ok else '未达标'} | 月均净利: {r['monthly_net']:+.2f} USDT (目标≥{TARGET_MONTHLY_NET:.0f}): {'达标' if monthly_ok else '未达标'}", flush=True)
print("=" * 60, flush=True)
if __name__ == '__main__':
import sys
argv = sys.argv
# 多周期对比:默认只比 5m 和 15m加 --all 则比 1m,3m,5m,15m,30m,1h
if '--compare' in argv or '-c' in argv:
periods = [5, 15] if '--all' not in argv else [1, 3, 5, 15, 30, 60]
run_cycle_compare(periods, do_save_best=True)
else:
# 单周期:--period 5m 或 --period 15 指定周期,否则默认 1m
period = '1m'
for i, a in enumerate(argv):
if a in ('--period', '-p') and i + 1 < len(argv):
period = argv[i + 1]
break
do_search = '--no-search' not in argv
main(do_grid_search=do_search, period=period)