哈哈
This commit is contained in:
10
.idea/.gitignore
generated
vendored
Normal file
10
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# 已忽略包含查询文件的默认文件夹
|
||||
/queries/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
# 基于编辑器的 HTTP 客户端请求
|
||||
/httpRequests/
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
8
.idea/jyx_code.iml
generated
Normal file
8
.idea/jyx_code.iml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="jyx_code" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
7
.idea/misc.xml
generated
Normal file
7
.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="jyx_code" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="jyx_code" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/jyx_code.iml" filepath="$PROJECT_DIR$/.idea/jyx_code.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
16
main.py
Normal file
16
main.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# 这是一个示例 Python 脚本。
|
||||
|
||||
# 按 ⌃R 执行或将其替换为您的代码。
|
||||
# 按 双击 ⇧ 在所有地方搜索类、文件、工具窗口、操作和设置。
|
||||
|
||||
|
||||
def print_hi(name):
|
||||
# 在下面的代码行中使用断点来调试脚本。
|
||||
print(f'Hi, {name}') # 按 ⌘F8 切换断点。
|
||||
|
||||
|
||||
# 按装订区域中的绿色按钮以运行脚本。
|
||||
if __name__ == '__main__':
|
||||
print_hi('PyCharm')
|
||||
|
||||
# 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助
|
||||
58
models/__init__.py
Normal file
58
models/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from pathlib import Path
|
||||
|
||||
from peewee import *
|
||||
|
||||
# 连接到 SQLite 数据库,如果文件不存在会自动创建
|
||||
db = SqliteDatabase(fr'{Path(__file__).parent}/database.db')
|
||||
|
||||
import pymysql
|
||||
|
||||
from peewee import *
|
||||
from playhouse.pool import PooledMySQLDatabase
|
||||
|
||||
pymysql.install_as_MySQLdb()
|
||||
|
||||
# 数据库配置
|
||||
db_config = {
|
||||
'database': 'lm',
|
||||
'user': 'lm',
|
||||
'password': 'HhyAsGbrrbsJfpyy',
|
||||
'host': '192.168.1.87',
|
||||
'port': 3306
|
||||
}
|
||||
|
||||
# 全局数据库实例
|
||||
db1 = MySQLDatabase(
|
||||
db_config['database'],
|
||||
user=db_config['user'],
|
||||
password=db_config['password'],
|
||||
host=db_config['host'],
|
||||
port=db_config['port']
|
||||
)
|
||||
|
||||
# class BaseModel(Model):
|
||||
# class Meta:
|
||||
# database = db1
|
||||
#
|
||||
# def save(self, *args, **kwargs):
|
||||
# """在调用 save 时自动连接和关闭(若无事务)"""
|
||||
# db.connect(reuse_if_open=True)
|
||||
# try:
|
||||
# result = super().save(*args, **kwargs)
|
||||
# finally:
|
||||
# # 若当前没有事务且连接仍然打开,则关闭连接
|
||||
# if not db.in_transaction() and not db.is_closed():
|
||||
# db.close()
|
||||
# return result
|
||||
#
|
||||
# @classmethod
|
||||
# def get_or_create(cls, defaults=None, **kwargs):
|
||||
# """在调用 get_or_create 时自动连接和关闭(若无事务)"""
|
||||
# db.connect(reuse_if_open=True)
|
||||
# try:
|
||||
# obj, created = super().get_or_create(defaults=defaults, **kwargs)
|
||||
# finally:
|
||||
# # 若当前没有事务且连接仍然打开,则关闭连接
|
||||
# if not db.in_transaction() and not db.is_closed():
|
||||
# db.close()
|
||||
# return obj, created
|
||||
21
models/bitmart.py
Normal file
21
models/bitmart.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from peewee import *
|
||||
|
||||
from models import db
|
||||
|
||||
|
||||
class BitMart30(Model):
|
||||
id = IntegerField(primary_key=True) # 时间戳(毫秒级)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'bitmart_30'
|
||||
|
||||
|
||||
# 连接到数据库
|
||||
db.connect()
|
||||
# 创建表(如果表不存在)
|
||||
db.create_tables([BitMart30])
|
||||
21
models/bitmart_15.py
Normal file
21
models/bitmart_15.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from peewee import *
|
||||
|
||||
from models import db
|
||||
|
||||
|
||||
class BitMart15(Model):
|
||||
id = IntegerField(primary_key=True) # 时间戳(毫秒级)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'bitmart_15'
|
||||
|
||||
|
||||
# 连接到数据库
|
||||
db.connect()
|
||||
# 创建表(如果表不存在)
|
||||
db.create_tables([BitMart15])
|
||||
97
models/bitmart_klines.py
Normal file
97
models/bitmart_klines.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
BitMart 多周期K线数据模型
|
||||
包含 1分钟、3分钟、5分钟、15分钟、30分钟、1小时 K线数据表
|
||||
"""
|
||||
|
||||
from peewee import *
|
||||
from models import db
|
||||
|
||||
|
||||
# ==================== 1分钟 K线 ====================
|
||||
class BitMartETH1M(Model):
|
||||
id = BigIntegerField(primary_key=True) # 时间戳(毫秒级)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'bitmart_eth_1m'
|
||||
|
||||
|
||||
# ==================== 3分钟 K线 ====================
|
||||
class BitMartETH3M(Model):
|
||||
id = BigIntegerField(primary_key=True)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'bitmart_eth_3m'
|
||||
|
||||
|
||||
# ==================== 5分钟 K线 ====================
|
||||
class BitMartETH5M(Model):
|
||||
id = BigIntegerField(primary_key=True)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'bitmart_eth_5m'
|
||||
|
||||
|
||||
# ==================== 15分钟 K线 ====================
|
||||
class BitMartETH15M(Model):
|
||||
id = BigIntegerField(primary_key=True)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'bitmart_eth_15m'
|
||||
|
||||
|
||||
# ==================== 30分钟 K线 ====================
|
||||
class BitMartETH30M(Model):
|
||||
id = BigIntegerField(primary_key=True)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'bitmart_eth_30m'
|
||||
|
||||
|
||||
# ==================== 1小时 K线 ====================
|
||||
class BitMartETH1H(Model):
|
||||
id = BigIntegerField(primary_key=True)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'bitmart_eth_1h'
|
||||
|
||||
|
||||
# 连接数据库并创建表
|
||||
db.connect(reuse_if_open=True)
|
||||
db.create_tables([
|
||||
BitMartETH1M,
|
||||
BitMartETH3M,
|
||||
BitMartETH5M,
|
||||
BitMartETH15M,
|
||||
BitMartETH30M,
|
||||
BitMartETH1H,
|
||||
], safe=True)
|
||||
BIN
models/database.db
Normal file
BIN
models/database.db
Normal file
Binary file not shown.
47
models/eth_ai_strategy/feature_cols.json
Normal file
47
models/eth_ai_strategy/feature_cols.json
Normal file
@@ -0,0 +1,47 @@
|
||||
[
|
||||
"bb_pct",
|
||||
"bb_width",
|
||||
"keltner_pct",
|
||||
"keltner_width",
|
||||
"donchian_pct",
|
||||
"donchian_width",
|
||||
"stoch_k",
|
||||
"stoch_d",
|
||||
"cci",
|
||||
"willr",
|
||||
"rsi",
|
||||
"atr_band_pct",
|
||||
"atr_band_width",
|
||||
"zscore",
|
||||
"zscore_abs",
|
||||
"lr_pct",
|
||||
"lr_width",
|
||||
"price_vs_median",
|
||||
"median_band_pct",
|
||||
"pct_rank",
|
||||
"chandelier_pct",
|
||||
"chandelier_width",
|
||||
"std_band_pct",
|
||||
"std_band_width",
|
||||
"elder_bull",
|
||||
"elder_bear",
|
||||
"elder_dist",
|
||||
"ema_fast_slow",
|
||||
"price_vs_ema120",
|
||||
"ema8_slope",
|
||||
"macd",
|
||||
"macd_signal",
|
||||
"macd_hist",
|
||||
"atr_pct",
|
||||
"ret_1",
|
||||
"ret_3",
|
||||
"ret_5",
|
||||
"ret_10",
|
||||
"ret_20",
|
||||
"vol_5",
|
||||
"vol_20",
|
||||
"body_pct",
|
||||
"price_position_20",
|
||||
"hour_sin",
|
||||
"hour_cos"
|
||||
]
|
||||
14435
models/eth_ai_strategy/model.txt
Normal file
14435
models/eth_ai_strategy/model.txt
Normal file
File diff suppressed because it is too large
Load Diff
40
models/eth_ai_strategy/strategy_params.json
Normal file
40
models/eth_ai_strategy/strategy_params.json
Normal file
@@ -0,0 +1,40 @@
|
||||
{
|
||||
"bb_period": 20,
|
||||
"bb_std": 2.0,
|
||||
"keltner_period": 20,
|
||||
"keltner_atr_mult": 2.0,
|
||||
"donchian_period": 20,
|
||||
"stoch_k_period": 14,
|
||||
"stoch_d_period": 3,
|
||||
"cci_period": 20,
|
||||
"willr_period": 14,
|
||||
"rsi_period": 14,
|
||||
"atr_band_period": 20,
|
||||
"atr_band_mult": 2.0,
|
||||
"zscore_period": 20,
|
||||
"lr_period": 20,
|
||||
"lr_std_mult": 2.0,
|
||||
"median_band_period": 20,
|
||||
"pct_rank_period": 20,
|
||||
"chandelier_period": 22,
|
||||
"chandelier_mult": 3.0,
|
||||
"std_band_period": 14,
|
||||
"std_band_mult": 2.0,
|
||||
"elder_period": 13,
|
||||
"forward_bars": 10,
|
||||
"label_threshold": 0.002,
|
||||
"prob_threshold": 0.45,
|
||||
"sl_pct": 0.004,
|
||||
"tp_pct": 0.006,
|
||||
"min_hold_seconds": 180,
|
||||
"max_hold_seconds": 1800,
|
||||
"margin_per_trade": 100.0,
|
||||
"leverage": 100,
|
||||
"notional": 10000.0,
|
||||
"rebate_rate": 0.9,
|
||||
"train_period": "2020-01-01 ~ 2021-01-01",
|
||||
"test_period": "2021-01-01 ~ 2022-01-01",
|
||||
"kline_period": "5m",
|
||||
"initial_balance": 10000.0,
|
||||
"max_dd_target": 500.0
|
||||
}
|
||||
21
models/ips.py
Normal file
21
models/ips.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from peewee import *
|
||||
|
||||
from models import db1
|
||||
|
||||
|
||||
class Ips(Model):
|
||||
id = IntegerField(primary_key=True)
|
||||
host = CharField(null=True)
|
||||
port = CharField(null=True)
|
||||
username = CharField(null=True)
|
||||
password = CharField(null=True)
|
||||
start = IntegerField(null=True)
|
||||
country = CharField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db1
|
||||
table_name = 'ips'
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# Ips.create_table()
|
||||
57
models/mexc.py
Normal file
57
models/mexc.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from peewee import *
|
||||
|
||||
from models import db
|
||||
|
||||
|
||||
class Mexc1(Model):
|
||||
id = IntegerField(primary_key=True)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'mexc_1'
|
||||
|
||||
|
||||
class Mexc15(Model):
|
||||
id = IntegerField(primary_key=True)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'mexc_15'
|
||||
|
||||
|
||||
class Mexc30(Model):
|
||||
id = IntegerField(primary_key=True)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'mexc_30'
|
||||
|
||||
|
||||
class Mexc1Hour(Model):
|
||||
id = IntegerField(primary_key=True)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'mexc_1_hour'
|
||||
|
||||
|
||||
# 连接到数据库
|
||||
db.connect()
|
||||
# 创建表(如果表不存在)
|
||||
db.create_tables([Mexc1, Mexc15, Mexc30, Mexc1Hour])
|
||||
71
models/weex.py
Normal file
71
models/weex.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from peewee import *
|
||||
|
||||
from models import db
|
||||
|
||||
|
||||
class Weex15(Model):
|
||||
id = IntegerField(primary_key=True)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'weex_15'
|
||||
|
||||
|
||||
class Weex1(Model):
|
||||
id = IntegerField(primary_key=True)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'weex_1'
|
||||
|
||||
|
||||
class Weex1Hour(Model):
|
||||
id = IntegerField(primary_key=True)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'weex_1_hour'
|
||||
|
||||
class Weex30(Model):
|
||||
id = IntegerField(primary_key=True)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'weex_30'
|
||||
|
||||
class Weex30Copy(Model):
|
||||
id = IntegerField(primary_key=True)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'weex_30_copy1'
|
||||
|
||||
|
||||
# 连接到数据库
|
||||
db.connect()
|
||||
#
|
||||
# # 创建表(如果表不存在)
|
||||
# db.create_tables([Weex15])
|
||||
db.create_tables([Weex30])
|
||||
|
||||
|
||||
22
models/xstart.py
Normal file
22
models/xstart.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from peewee import *
|
||||
|
||||
from models import db1
|
||||
from models.ips import Ips
|
||||
|
||||
|
||||
class Xstart(Model):
|
||||
id = AutoField(primary_key=True) # 自增主键
|
||||
bit_id = CharField(null=True)
|
||||
start = IntegerField(null=True)
|
||||
x_id = IntegerField(null=True)
|
||||
ip_id = IntegerField(null=True)
|
||||
url_id = CharField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = db1 # 所属数据库
|
||||
table_name = 'xstart'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Xstart.create_table()
|
||||
24
models/xtoken.py
Normal file
24
models/xtoken.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from peewee import *
|
||||
# 假设 db 已经在其他地方定义并连接到数据库
|
||||
from models import db1
|
||||
|
||||
|
||||
class XToken(Model):
|
||||
id = AutoField(primary_key=True) # 自增主键
|
||||
hub_id = IntegerField(null=True) # hub_id 字段,整型,可为空
|
||||
start = IntegerField(null=True) # start 字段,整型,可为空
|
||||
account_start = IntegerField(null=True) # account_start 字段,整型,可为空
|
||||
user_name = CharField(max_length=255, null=True) # user_name 字段,最大长度 255,可为空
|
||||
password = CharField(max_length=255, null=True) # password 字段,最大长度 255,可为空
|
||||
email = CharField(max_length=255, null=True) # email 字段,最大长度 255,可为空
|
||||
two_fa = CharField(max_length=255, null=True) # 2fa 字段,由于 2fa 是 Python 中的无效标识符,这里使用 two_fa 替代,最大长度 255,可为空
|
||||
token = CharField(max_length=255, null=True) # token 字段,最大长度 255,可为空
|
||||
email_pwd = CharField(max_length=255, null=True) # token 字段,最大长度 255,可为空
|
||||
|
||||
class Meta:
|
||||
database = db1 # 所属数据库
|
||||
table_name = 'x_token' # 表名
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
XToken.create_table()
|
||||
762
抓取多周期K线.py
Normal file
762
抓取多周期K线.py
Normal file
@@ -0,0 +1,762 @@
|
||||
"""
|
||||
BitMart 多周期K线数据抓取脚本
|
||||
支持同时获取 1分钟、3分钟、5分钟、15分钟、30分钟、1小时 K线数据
|
||||
支持秒级价格数据(通过成交记录API)
|
||||
支持断点续传,从数据库最新/最早记录继续抓取
|
||||
"""
|
||||
|
||||
import time
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from peewee import *
|
||||
from bitmart.api_contract import APIContract
|
||||
|
||||
# 数据库配置(使用脚本所在项目目录下的 models)
|
||||
DB_PATH = Path(__file__).parent / 'models' / 'database.db'
|
||||
db = SqliteDatabase(str(DB_PATH))
|
||||
|
||||
# K线周期配置:step值 -> 表名后缀
|
||||
KLINE_CONFIGS = {
|
||||
1: '1m', # 1分钟
|
||||
3: '3m', # 3分钟
|
||||
5: '5m', # 5分钟
|
||||
15: '15m', # 15分钟
|
||||
30: '30m', # 30分钟
|
||||
60: '1h', # 1小时
|
||||
}
|
||||
|
||||
|
||||
class BitMartETHTrades(Model):
|
||||
"""成交记录模型(秒级/毫秒级原始数据)"""
|
||||
id = BigIntegerField(primary_key=True) # 成交ID
|
||||
timestamp = BigIntegerField(index=True) # 成交时间戳(毫秒)
|
||||
price = FloatField() # 成交价格
|
||||
volume = FloatField() # 成交量
|
||||
side = IntegerField() # 方向: 1=买, -1=卖
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'bitmart_eth_trades'
|
||||
|
||||
|
||||
class BitMartETHSecond(Model):
|
||||
"""秒级K线模型(由成交记录聚合而来)"""
|
||||
id = BigIntegerField(primary_key=True) # 时间戳(毫秒,取整到秒)
|
||||
open = FloatField(null=True)
|
||||
high = FloatField(null=True)
|
||||
low = FloatField(null=True)
|
||||
close = FloatField(null=True)
|
||||
volume = FloatField(null=True)
|
||||
trade_count = IntegerField(null=True) # 该秒内成交笔数
|
||||
|
||||
class Meta:
|
||||
database = db
|
||||
table_name = 'bitmart_eth_1s'
|
||||
|
||||
|
||||
def create_kline_model(step: int):
|
||||
"""
|
||||
动态创建K线数据模型
|
||||
:param step: K线周期(分钟)
|
||||
:return: Model类
|
||||
"""
|
||||
suffix = KLINE_CONFIGS.get(step, f'{step}m')
|
||||
tbl_name = f'bitmart_eth_{suffix}'
|
||||
|
||||
# 使用 type() 动态创建类,避免闭包问题
|
||||
attrs = {
|
||||
'id': BigIntegerField(primary_key=True),
|
||||
'open': FloatField(null=True),
|
||||
'high': FloatField(null=True),
|
||||
'low': FloatField(null=True),
|
||||
'close': FloatField(null=True),
|
||||
}
|
||||
|
||||
# 创建 Meta 类
|
||||
meta_attrs = {
|
||||
'database': db,
|
||||
'table_name': tbl_name,
|
||||
}
|
||||
Meta = type('Meta', (), meta_attrs)
|
||||
attrs['Meta'] = Meta
|
||||
|
||||
# 动态创建 Model 类
|
||||
model_name = f'BitMartETH{suffix.upper()}'
|
||||
KlineModel = type(model_name, (Model,), attrs)
|
||||
|
||||
return KlineModel
|
||||
|
||||
|
||||
class BitMartMultiKlineCollector:
|
||||
"""多周期K线数据抓取器"""
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = "a0fb7b98464fd9bcce67e7c519d58ec10d0c38a8"
|
||||
self.secret_key = "4eaeba78e77aeaab1c2027f846a276d164f264a44c2c1bb1c5f3be50c8de1ca5"
|
||||
self.memo = "数据抓取"
|
||||
self.contract_symbol = "ETHUSDT"
|
||||
self.contractAPI = APIContract(self.api_key, self.secret_key, self.memo, timeout=(5, 15))
|
||||
|
||||
# 存储各周期的模型
|
||||
self.models = {}
|
||||
|
||||
# 初始化数据库连接和表
|
||||
self._init_database()
|
||||
|
||||
def _init_database(self):
|
||||
"""初始化数据库,创建所有周期的表"""
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
db.connect(reuse_if_open=True)
|
||||
|
||||
for step in KLINE_CONFIGS.keys():
|
||||
model = create_kline_model(step)
|
||||
self.models[step] = model
|
||||
# 创建表(如果不存在)
|
||||
db.create_tables([model], safe=True)
|
||||
logger.info(f"初始化表: {model._meta.table_name}")
|
||||
|
||||
# 创建成交记录表和秒级K线表
|
||||
db.create_tables([BitMartETHTrades, BitMartETHSecond], safe=True)
|
||||
logger.info(f"初始化表: bitmart_eth_trades (成交记录)")
|
||||
logger.info(f"初始化表: bitmart_eth_1s (秒级K线)")
|
||||
|
||||
def get_db_time_range(self, step: int):
|
||||
"""
|
||||
获取数据库中已有数据的时间范围
|
||||
:param step: K线周期
|
||||
:return: (earliest_ts, latest_ts) 毫秒时间戳,无数据返回 (None, None)
|
||||
"""
|
||||
model = self.models.get(step)
|
||||
if not model:
|
||||
return None, None
|
||||
|
||||
try:
|
||||
# 获取最早记录
|
||||
earliest = model.select(fn.MIN(model.id)).scalar()
|
||||
# 获取最新记录
|
||||
latest = model.select(fn.MAX(model.id)).scalar()
|
||||
return earliest, latest
|
||||
except Exception as e:
|
||||
logger.error(f"查询数据库时间范围异常: {e}")
|
||||
return None, None
|
||||
|
||||
def get_klines(self, step: int, start_time: int, end_time: int, max_retries: int = 3):
|
||||
"""
|
||||
获取K线数据(带重试)
|
||||
:param step: K线周期(分钟)
|
||||
:param start_time: 开始时间戳(秒级)
|
||||
:param end_time: 结束时间戳(秒级)
|
||||
:param max_retries: 最大重试次数
|
||||
:return: K线数据列表
|
||||
"""
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
start_time = int(start_time)
|
||||
end_time = int(end_time)
|
||||
|
||||
response = self.contractAPI.get_kline(
|
||||
contract_symbol=self.contract_symbol,
|
||||
step=step,
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
)[0]
|
||||
|
||||
if response['code'] != 1000:
|
||||
logger.warning(f"API返回错误 (尝试 {attempt+1}/{max_retries}): {response}")
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(1)
|
||||
continue
|
||||
return []
|
||||
|
||||
klines = response.get('data', [])
|
||||
formatted = []
|
||||
for k in klines:
|
||||
timestamp_ms = int(k["timestamp"]) * 1000
|
||||
formatted.append({
|
||||
'id': timestamp_ms,
|
||||
'open': float(k["open_price"]),
|
||||
'high': float(k["high_price"]),
|
||||
'low': float(k["low_price"]),
|
||||
'close': float(k["close_price"])
|
||||
})
|
||||
|
||||
formatted.sort(key=lambda x: x['id'])
|
||||
return formatted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取K线异常 (尝试 {attempt+1}/{max_retries}): {e}")
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(2)
|
||||
continue
|
||||
return []
|
||||
|
||||
return []
|
||||
|
||||
def save_klines(self, step: int, klines: list):
|
||||
"""
|
||||
保存K线数据到数据库
|
||||
:param step: K线周期
|
||||
:param klines: K线数据列表
|
||||
:return: 新保存的数量
|
||||
"""
|
||||
model = self.models.get(step)
|
||||
if not model:
|
||||
logger.error(f"未找到 {step}分钟 的数据模型")
|
||||
return 0
|
||||
|
||||
new_count = 0
|
||||
for kline in klines:
|
||||
try:
|
||||
_, created = model.get_or_create(
|
||||
id=kline['id'],
|
||||
defaults={
|
||||
'open': kline['open'],
|
||||
'high': kline['high'],
|
||||
'low': kline['low'],
|
||||
'close': kline['close'],
|
||||
}
|
||||
)
|
||||
if created:
|
||||
new_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"保存K线数据失败 {kline['id']}: {e}")
|
||||
|
||||
return new_count
|
||||
|
||||
def get_batch_seconds(self, step: int):
|
||||
"""根据周期获取合适的批次大小"""
|
||||
if step == 1:
|
||||
return 3600 * 4 # 1分钟: 每次4小时
|
||||
elif step == 3:
|
||||
return 3600 * 8 # 3分钟: 每次8小时
|
||||
elif step == 5:
|
||||
return 3600 * 12 # 5分钟: 每次12小时
|
||||
elif step == 15:
|
||||
return 3600 * 24 # 15分钟: 每次1天
|
||||
elif step == 30:
|
||||
return 3600 * 48 # 30分钟: 每次2天
|
||||
else:
|
||||
return 3600 * 72 # 1小时: 每次3天
|
||||
|
||||
def collect_period_range(self, step: int, target_start: int, target_end: int):
|
||||
"""
|
||||
抓取指定时间范围的K线数据(支持断点续传)
|
||||
:param step: K线周期(分钟)
|
||||
:param target_start: 目标开始时间戳(秒)
|
||||
:param target_end: 目标结束时间戳(秒)
|
||||
:return: 保存的总数量
|
||||
"""
|
||||
suffix = KLINE_CONFIGS.get(step, f'{step}m')
|
||||
batch_seconds = self.get_batch_seconds(step)
|
||||
|
||||
# 获取数据库已有数据范围
|
||||
db_earliest, db_latest = self.get_db_time_range(step)
|
||||
|
||||
if db_earliest and db_latest:
|
||||
db_earliest_sec = db_earliest // 1000
|
||||
db_latest_sec = db_latest // 1000
|
||||
logger.info(f"[{suffix}] 数据库已有数据: "
|
||||
f"{time.strftime('%Y-%m-%d %H:%M', time.localtime(db_earliest_sec))} ~ "
|
||||
f"{time.strftime('%Y-%m-%d %H:%M', time.localtime(db_latest_sec))}")
|
||||
else:
|
||||
db_earliest_sec = None
|
||||
db_latest_sec = None
|
||||
logger.info(f"[{suffix}] 数据库暂无数据")
|
||||
|
||||
total_saved = 0
|
||||
|
||||
# === 第一阶段:向前抓取历史数据(从数据库最早记录向前,直到 target_start)===
|
||||
if db_earliest_sec:
|
||||
backward_end = db_earliest_sec
|
||||
else:
|
||||
backward_end = target_end
|
||||
|
||||
if backward_end > target_start:
|
||||
logger.info(f"[{suffix}] === 开始向前抓取历史数据 ===")
|
||||
total_backward = backward_end - target_start
|
||||
|
||||
current_end = backward_end
|
||||
fail_count = 0
|
||||
max_fail = 5
|
||||
|
||||
while current_end > target_start and fail_count < max_fail:
|
||||
current_start = max(current_end - batch_seconds, target_start)
|
||||
|
||||
# 计算进度
|
||||
progress = (backward_end - current_end) / total_backward * 100 if total_backward > 0 else 0
|
||||
start_str = time.strftime('%Y-%m-%d %H:%M', time.localtime(current_start))
|
||||
end_str = time.strftime('%Y-%m-%d %H:%M', time.localtime(current_end))
|
||||
|
||||
klines = self.get_klines(step, current_start, current_end)
|
||||
if klines:
|
||||
saved = self.save_klines(step, klines)
|
||||
total_saved += saved
|
||||
logger.info(f"[{suffix}] ← 历史 {start_str} ~ {end_str} | "
|
||||
f"获取 {len(klines)} 条, 新增 {saved} 条 | 进度 {progress:.1f}%")
|
||||
fail_count = 0
|
||||
else:
|
||||
fail_count += 1
|
||||
logger.warning(f"[{suffix}] ← 历史 {start_str} 无数据 (连续失败 {fail_count}/{max_fail})")
|
||||
if fail_count >= max_fail:
|
||||
earliest_date = time.strftime('%Y-%m-%d', time.localtime(current_end))
|
||||
logger.warning(f"[{suffix}] 已达到API历史数据限制,最早可获取: {earliest_date}")
|
||||
break
|
||||
|
||||
current_end = current_start
|
||||
time.sleep(0.3)
|
||||
|
||||
# === 第二阶段:向后抓取最新数据(从数据库最新记录向后,直到 target_end)===
|
||||
if db_latest_sec:
|
||||
forward_start = db_latest_sec
|
||||
else:
|
||||
# 如果没有数据,从第一阶段结束的地方开始
|
||||
forward_start = target_start
|
||||
|
||||
if forward_start < target_end:
|
||||
logger.info(f"[{suffix}] === 开始向后抓取最新数据 ===")
|
||||
total_forward = target_end - forward_start
|
||||
|
||||
current_start = forward_start
|
||||
fail_count = 0
|
||||
max_fail = 3
|
||||
|
||||
while current_start < target_end and fail_count < max_fail:
|
||||
current_end = min(current_start + batch_seconds, target_end)
|
||||
|
||||
# 计算进度
|
||||
progress = (current_start - forward_start) / total_forward * 100 if total_forward > 0 else 0
|
||||
start_str = time.strftime('%Y-%m-%d %H:%M', time.localtime(current_start))
|
||||
end_str = time.strftime('%Y-%m-%d %H:%M', time.localtime(current_end))
|
||||
|
||||
klines = self.get_klines(step, current_start, current_end)
|
||||
if klines:
|
||||
saved = self.save_klines(step, klines)
|
||||
total_saved += saved
|
||||
logger.info(f"[{suffix}] → 最新 {start_str} ~ {end_str} | "
|
||||
f"获取 {len(klines)} 条, 新增 {saved} 条 | 进度 {progress:.1f}%")
|
||||
fail_count = 0
|
||||
else:
|
||||
fail_count += 1
|
||||
logger.warning(f"[{suffix}] → 最新 {start_str} 无数据 (失败 {fail_count}/{max_fail})")
|
||||
|
||||
current_start = current_end
|
||||
time.sleep(0.3)
|
||||
|
||||
# 统计最终数据范围
|
||||
final_earliest, final_latest = self.get_db_time_range(step)
|
||||
if final_earliest and final_latest:
|
||||
logger.success(f"[{suffix}] 抓取完成!本次新增 {total_saved} 条 | 数据范围: "
|
||||
f"{time.strftime('%Y-%m-%d', time.localtime(final_earliest//1000))} ~ "
|
||||
f"{time.strftime('%Y-%m-%d', time.localtime(final_latest//1000))}")
|
||||
else:
|
||||
logger.success(f"[{suffix}] 抓取完成!本次新增 {total_saved} 条")
|
||||
|
||||
return total_saved
|
||||
|
||||
def collect_from_date(self, start_date: str, periods: list = None):
|
||||
"""
|
||||
从指定日期抓取到当前时间
|
||||
:param start_date: 起始日期 'YYYY-MM-DD'
|
||||
:param periods: 要抓取的周期列表,如 [1, 5, 15],默认全部
|
||||
"""
|
||||
if periods is None:
|
||||
periods = list(KLINE_CONFIGS.keys())
|
||||
|
||||
# 计算时间范围
|
||||
start_dt = datetime.datetime.strptime(start_date, '%Y-%m-%d')
|
||||
target_start = int(start_dt.timestamp())
|
||||
target_end = int(time.time())
|
||||
|
||||
start_str = start_dt.strftime('%Y-%m-%d')
|
||||
end_str = datetime.datetime.now().strftime('%Y-%m-%d %H:%M')
|
||||
|
||||
logger.info(f"{'='*60}")
|
||||
logger.info(f"目标时间范围: {start_str} ~ {end_str}")
|
||||
logger.info(f"抓取周期: {[KLINE_CONFIGS[p] for p in periods]}")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
results = {}
|
||||
for step in periods:
|
||||
if step not in KLINE_CONFIGS:
|
||||
logger.warning(f"不支持的周期: {step}分钟,跳过")
|
||||
continue
|
||||
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info(f"开始抓取 {KLINE_CONFIGS[step]} K线")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
saved = self.collect_period_range(step, target_start, target_end)
|
||||
results[KLINE_CONFIGS[step]] = saved
|
||||
|
||||
time.sleep(1) # 不同周期之间间隔
|
||||
|
||||
# 打印总结
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info("所有周期抓取完成!统计:")
|
||||
for period, count in results.items():
|
||||
logger.info(f" {period}: 新增 {count} 条")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
return results
|
||||
|
||||
def get_stats(self):
|
||||
"""获取各周期数据统计"""
|
||||
logger.info(f"\n{'='*60}")
|
||||
logger.info("数据库统计:")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
for step, model in self.models.items():
|
||||
suffix = KLINE_CONFIGS.get(step, f'{step}m')
|
||||
try:
|
||||
count = model.select().count()
|
||||
earliest, latest = self.get_db_time_range(step)
|
||||
if earliest and latest:
|
||||
earliest_str = time.strftime('%Y-%m-%d %H:%M', time.localtime(earliest//1000))
|
||||
latest_str = time.strftime('%Y-%m-%d %H:%M', time.localtime(latest//1000))
|
||||
logger.info(f" {suffix:>4}: {count:>8} 条 | {earliest_str} ~ {latest_str}")
|
||||
else:
|
||||
logger.info(f" {suffix:>4}: {count:>8} 条")
|
||||
except Exception as e:
|
||||
logger.error(f" {suffix}: 查询失败 - {e}")
|
||||
|
||||
# 成交记录统计
|
||||
try:
|
||||
trades_count = BitMartETHTrades.select().count()
|
||||
if trades_count > 0:
|
||||
earliest_trade = BitMartETHTrades.select(fn.MIN(BitMartETHTrades.timestamp)).scalar()
|
||||
latest_trade = BitMartETHTrades.select(fn.MAX(BitMartETHTrades.timestamp)).scalar()
|
||||
earliest_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(earliest_trade//1000))
|
||||
latest_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(latest_trade//1000))
|
||||
logger.info(f"trades: {trades_count:>8} 条 | {earliest_str} ~ {latest_str}")
|
||||
else:
|
||||
logger.info(f"trades: {trades_count:>8} 条")
|
||||
except Exception as e:
|
||||
logger.error(f"trades: 查询失败 - {e}")
|
||||
|
||||
# 秒级K线统计
|
||||
try:
|
||||
second_count = BitMartETHSecond.select().count()
|
||||
if second_count > 0:
|
||||
earliest_sec = BitMartETHSecond.select(fn.MIN(BitMartETHSecond.id)).scalar()
|
||||
latest_sec = BitMartETHSecond.select(fn.MAX(BitMartETHSecond.id)).scalar()
|
||||
earliest_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(earliest_sec//1000))
|
||||
latest_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(latest_sec//1000))
|
||||
logger.info(f" 1s: {second_count:>8} 条 | {earliest_str} ~ {latest_str}")
|
||||
else:
|
||||
logger.info(f" 1s: {second_count:>8} 条")
|
||||
except Exception as e:
|
||||
logger.error(f" 1s: 查询失败 - {e}")
|
||||
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
# ==================== 秒级数据相关方法 ====================
|
||||
|
||||
def get_trades(self, limit: int = 100):
|
||||
"""
|
||||
获取最近成交记录
|
||||
:param limit: 获取条数
|
||||
:return: 成交记录列表
|
||||
"""
|
||||
try:
|
||||
response = self.contractAPI.get_trades(
|
||||
contract_symbol=self.contract_symbol,
|
||||
)[0]
|
||||
|
||||
if response['code'] != 1000:
|
||||
logger.error(f"获取成交记录失败: {response}")
|
||||
return []
|
||||
|
||||
trades = response.get('data', {}).get('trades', [])
|
||||
formatted = []
|
||||
for t in trades:
|
||||
formatted.append({
|
||||
'id': int(t.get('trade_id', 0)),
|
||||
'timestamp': int(t.get('create_time', 0)),
|
||||
'price': float(t.get('deal_price', 0)),
|
||||
'volume': float(t.get('deal_vol', 0)),
|
||||
'side': int(t.get('way', 0)),
|
||||
})
|
||||
|
||||
return formatted
|
||||
except Exception as e:
|
||||
logger.error(f"获取成交记录异常: {e}")
|
||||
return []
|
||||
|
||||
def save_trades(self, trades: list):
|
||||
"""保存成交记录到数据库"""
|
||||
new_count = 0
|
||||
for trade in trades:
|
||||
try:
|
||||
_, created = BitMartETHTrades.get_or_create(
|
||||
id=trade['id'],
|
||||
defaults={
|
||||
'timestamp': trade['timestamp'],
|
||||
'price': trade['price'],
|
||||
'volume': trade['volume'],
|
||||
'side': trade['side'],
|
||||
}
|
||||
)
|
||||
if created:
|
||||
new_count += 1
|
||||
except Exception as e:
|
||||
pass # 忽略重复数据
|
||||
return new_count
|
||||
|
||||
def collect_trades_realtime(self, duration_seconds: int = 3600, interval: float = 0.3):
|
||||
"""
|
||||
实时持续采集成交记录(秒级数据源)
|
||||
:param duration_seconds: 采集时长(秒),默认1小时
|
||||
:param interval: 采集间隔(秒),默认0.3秒
|
||||
"""
|
||||
logger.info(f"{'='*60}")
|
||||
logger.info(f"开始实时采集成交记录")
|
||||
logger.info(f"时长: {duration_seconds}秒 ({duration_seconds/3600:.1f}小时)")
|
||||
logger.info(f"间隔: {interval}秒")
|
||||
logger.info(f"{'='*60}")
|
||||
|
||||
start_time = time.time()
|
||||
end_time = start_time + duration_seconds
|
||||
total_saved = 0
|
||||
batch_count = 0
|
||||
|
||||
while time.time() < end_time:
|
||||
trades = self.get_trades(limit=100)
|
||||
if trades:
|
||||
saved = self.save_trades(trades)
|
||||
total_saved += saved
|
||||
batch_count += 1
|
||||
|
||||
# 每10批显示一次进度
|
||||
if batch_count % 10 == 0:
|
||||
elapsed = time.time() - start_time
|
||||
remaining = end_time - time.time()
|
||||
latest = trades[-1]
|
||||
ts_str = datetime.datetime.fromtimestamp(
|
||||
latest['timestamp']/1000
|
||||
).strftime('%H:%M:%S')
|
||||
logger.info(f"[{ts_str}] 价格: {latest['price']:.2f} | "
|
||||
f"本批新增: {saved} | 累计: {total_saved} | "
|
||||
f"剩余: {remaining/60:.1f}分钟")
|
||||
|
||||
time.sleep(interval)
|
||||
|
||||
logger.success(f"采集完成!共新增 {total_saved} 条成交记录")
|
||||
|
||||
# 自动聚合为秒级K线
|
||||
logger.info("正在将成交记录聚合为秒级K线...")
|
||||
self.aggregate_trades_to_seconds()
|
||||
|
||||
return total_saved
|
||||
|
||||
def aggregate_trades_to_seconds(self, start_ts: int = None, end_ts: int = None):
|
||||
"""
|
||||
将成交记录聚合为秒级K线数据
|
||||
:param start_ts: 开始时间戳(毫秒),默认全部
|
||||
:param end_ts: 结束时间戳(毫秒),默认全部
|
||||
:return: 聚合的秒级K线数量
|
||||
"""
|
||||
# 构建查询
|
||||
query = BitMartETHTrades.select().order_by(BitMartETHTrades.timestamp)
|
||||
if start_ts:
|
||||
query = query.where(BitMartETHTrades.timestamp >= start_ts)
|
||||
if end_ts:
|
||||
query = query.where(BitMartETHTrades.timestamp <= end_ts)
|
||||
|
||||
# 按秒聚合
|
||||
second_data = {}
|
||||
trade_count = 0
|
||||
|
||||
for trade in query:
|
||||
trade_count += 1
|
||||
# 取整到秒(毫秒时间戳)
|
||||
second_ts = (trade.timestamp // 1000) * 1000
|
||||
|
||||
if second_ts not in second_data:
|
||||
second_data[second_ts] = {
|
||||
'open': trade.price,
|
||||
'high': trade.price,
|
||||
'low': trade.price,
|
||||
'close': trade.price,
|
||||
'volume': trade.volume,
|
||||
'trade_count': 1
|
||||
}
|
||||
else:
|
||||
second_data[second_ts]['high'] = max(second_data[second_ts]['high'], trade.price)
|
||||
second_data[second_ts]['low'] = min(second_data[second_ts]['low'], trade.price)
|
||||
second_data[second_ts]['close'] = trade.price
|
||||
second_data[second_ts]['volume'] += trade.volume
|
||||
second_data[second_ts]['trade_count'] += 1
|
||||
|
||||
# 保存到数据库
|
||||
saved_count = 0
|
||||
for ts, ohlc in second_data.items():
|
||||
try:
|
||||
BitMartETHSecond.insert(
|
||||
id=ts,
|
||||
open=ohlc['open'],
|
||||
high=ohlc['high'],
|
||||
low=ohlc['low'],
|
||||
close=ohlc['close'],
|
||||
volume=ohlc['volume'],
|
||||
trade_count=ohlc['trade_count']
|
||||
).on_conflict(
|
||||
conflict_target=[BitMartETHSecond.id],
|
||||
update={
|
||||
BitMartETHSecond.open: ohlc['open'],
|
||||
BitMartETHSecond.high: ohlc['high'],
|
||||
BitMartETHSecond.low: ohlc['low'],
|
||||
BitMartETHSecond.close: ohlc['close'],
|
||||
BitMartETHSecond.volume: ohlc['volume'],
|
||||
BitMartETHSecond.trade_count: ohlc['trade_count'],
|
||||
}
|
||||
).execute()
|
||||
saved_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"保存秒级K线失败 {ts}: {e}")
|
||||
|
||||
logger.success(f"聚合完成!{trade_count} 条成交记录 → {saved_count} 条秒级K线")
|
||||
return saved_count
|
||||
|
||||
def get_second_klines(self, start_ts: int = None, end_ts: int = None):
|
||||
"""
|
||||
获取秒级K线数据
|
||||
:param start_ts: 开始时间戳(毫秒)
|
||||
:param end_ts: 结束时间戳(毫秒)
|
||||
:return: 秒级K线列表
|
||||
"""
|
||||
query = BitMartETHSecond.select().order_by(BitMartETHSecond.id)
|
||||
if start_ts:
|
||||
query = query.where(BitMartETHSecond.id >= start_ts)
|
||||
if end_ts:
|
||||
query = query.where(BitMartETHSecond.id <= end_ts)
|
||||
|
||||
return [{
|
||||
'timestamp': k.id,
|
||||
'open': k.open,
|
||||
'high': k.high,
|
||||
'low': k.low,
|
||||
'close': k.close,
|
||||
'volume': k.volume,
|
||||
'trade_count': k.trade_count
|
||||
} for k in query]
|
||||
|
||||
def aggregate_trades_custom(self, interval_ms: int = 100, start_ts: int = None, end_ts: int = None):
|
||||
"""
|
||||
将成交记录聚合为自定义毫秒级K线数据(不保存到数据库,直接返回)
|
||||
:param interval_ms: 聚合周期(毫秒),如 100=100ms, 500=500ms, 1000=1秒
|
||||
:param start_ts: 开始时间戳(毫秒)
|
||||
:param end_ts: 结束时间戳(毫秒)
|
||||
:return: K线列表 [{'timestamp', 'open', 'high', 'low', 'close', 'volume', 'trade_count'}, ...]
|
||||
"""
|
||||
# 构建查询
|
||||
query = BitMartETHTrades.select().order_by(BitMartETHTrades.timestamp)
|
||||
if start_ts:
|
||||
query = query.where(BitMartETHTrades.timestamp >= start_ts)
|
||||
if end_ts:
|
||||
query = query.where(BitMartETHTrades.timestamp <= end_ts)
|
||||
|
||||
# 按指定间隔聚合
|
||||
interval_data = {}
|
||||
trade_count = 0
|
||||
|
||||
for trade in query:
|
||||
trade_count += 1
|
||||
# 取整到指定间隔
|
||||
interval_ts = (trade.timestamp // interval_ms) * interval_ms
|
||||
|
||||
if interval_ts not in interval_data:
|
||||
interval_data[interval_ts] = {
|
||||
'open': trade.price,
|
||||
'high': trade.price,
|
||||
'low': trade.price,
|
||||
'close': trade.price,
|
||||
'volume': trade.volume,
|
||||
'trade_count': 1
|
||||
}
|
||||
else:
|
||||
interval_data[interval_ts]['high'] = max(interval_data[interval_ts]['high'], trade.price)
|
||||
interval_data[interval_ts]['low'] = min(interval_data[interval_ts]['low'], trade.price)
|
||||
interval_data[interval_ts]['close'] = trade.price
|
||||
interval_data[interval_ts]['volume'] += trade.volume
|
||||
interval_data[interval_ts]['trade_count'] += 1
|
||||
|
||||
# 转换为列表
|
||||
result = []
|
||||
for ts, ohlc in sorted(interval_data.items()):
|
||||
result.append({
|
||||
'timestamp': ts,
|
||||
'datetime': datetime.datetime.fromtimestamp(ts/1000).strftime('%Y-%m-%d %H:%M:%S.%f')[:-3],
|
||||
'open': ohlc['open'],
|
||||
'high': ohlc['high'],
|
||||
'low': ohlc['low'],
|
||||
'close': ohlc['close'],
|
||||
'volume': ohlc['volume'],
|
||||
'trade_count': ohlc['trade_count']
|
||||
})
|
||||
|
||||
logger.info(f"聚合完成: {trade_count} 条成交记录 → {len(result)} 条 {interval_ms}ms K线")
|
||||
return result
|
||||
|
||||
def get_raw_trades(self, start_ts: int = None, end_ts: int = None, limit: int = None):
|
||||
"""
|
||||
获取原始成交记录(逐笔数据,毫秒级)
|
||||
:param start_ts: 开始时间戳(毫秒)
|
||||
:param end_ts: 结束时间戳(毫秒)
|
||||
:param limit: 最大返回条数
|
||||
:return: 成交记录列表
|
||||
"""
|
||||
query = BitMartETHTrades.select().order_by(BitMartETHTrades.timestamp)
|
||||
if start_ts:
|
||||
query = query.where(BitMartETHTrades.timestamp >= start_ts)
|
||||
if end_ts:
|
||||
query = query.where(BitMartETHTrades.timestamp <= end_ts)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
return [{
|
||||
'id': t.id,
|
||||
'timestamp': t.timestamp,
|
||||
'datetime': datetime.datetime.fromtimestamp(t.timestamp/1000).strftime('%Y-%m-%d %H:%M:%S.%f')[:-3],
|
||||
'price': t.price,
|
||||
'volume': t.volume,
|
||||
'side': '买' if t.side == 1 else '卖'
|
||||
} for t in query]
|
||||
|
||||
def close(self):
|
||||
"""关闭数据库连接"""
|
||||
if not db.is_closed():
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
collector = BitMartMultiKlineCollector()
|
||||
|
||||
try:
|
||||
# 查看当前数据统计
|
||||
collector.get_stats()
|
||||
|
||||
# ============ 选择要执行的任务 ============
|
||||
|
||||
# 任务1: 抓取K线数据(1分钟~1小时周期)
|
||||
# 从 2025-01-01 抓取到当前时间(支持断点续传)
|
||||
collector.collect_from_date(
|
||||
start_date='2010-01-01',
|
||||
periods=[1, 3, 5, 15, 30, 60] # 所有周期
|
||||
)
|
||||
|
||||
# 任务2: 实时采集秒级数据(成交记录)
|
||||
# 注意: 秒级数据只能实时采集,无法获取历史
|
||||
# collector.collect_trades_realtime(
|
||||
# duration_seconds=3600, # 采集1小时
|
||||
# interval=0.3 # 每0.3秒请求一次
|
||||
# )
|
||||
|
||||
# 任务3: 将已采集的成交记录聚合为秒级K线
|
||||
# collector.aggregate_trades_to_seconds()
|
||||
|
||||
# 再次查看统计
|
||||
collector.get_stats()
|
||||
|
||||
finally:
|
||||
collector.close()
|
||||
989
训练AI策略_ETH合约.py
Normal file
989
训练AI策略_ETH合约.py
Normal file
@@ -0,0 +1,989 @@
|
||||
"""
|
||||
ETH 合约 AI 策略训练 + 回测
|
||||
|
||||
规则:
|
||||
- 标的: ETH 合约 (BitMart 1分钟K线)
|
||||
- 同一时间仅 1 个仓位 (多或空)
|
||||
- 每笔固定: 100U 保证金 × 100 倍 = 10000U 名义价值
|
||||
- 手续费 90% 返佣,目标: 月均净利 >= 1000 USDT
|
||||
|
||||
流程:
|
||||
1. 训练: 使用 2020 年全年数据训练 LightGBM(特征含多种可调参数指标)
|
||||
2. 回测: 使用 2021 年全年数据做严格样本外真实回测(模型从未见过 2021 数据)
|
||||
3. 参数搜索在「2021 年月均净利」上选最佳并保存
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import sqlite3
|
||||
import time as _time
|
||||
from pathlib import Path
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import lightgbm as lgb
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
# 常量:ETH 合约,100U 保证金 100 倍
|
||||
MARGIN_PER_TRADE = 100.0 # USDT
|
||||
LEVERAGE = 100
|
||||
NOTIONAL_PER_TRADE = MARGIN_PER_TRADE * LEVERAGE # 10000 USDT
|
||||
TAKER_FEE_RATE = 0.0006
|
||||
REBATE_RATE = 0.90
|
||||
# 手续费:先扣全额,返佣次日 8 点才到账,回撤按「可用资金」算
|
||||
FULL_FEE_PER_TRADE = NOTIONAL_PER_TRADE * TAKER_FEE_RATE * 2
|
||||
REBATE_PER_TRADE = FULL_FEE_PER_TRADE * REBATE_RATE
|
||||
NET_FEE_PER_TRADE = FULL_FEE_PER_TRADE - REBATE_PER_TRADE # 净手续费
|
||||
|
||||
# 本金与风控目标
|
||||
INITIAL_BALANCE = 10000.0 # 本金 10000 USDT
|
||||
MAX_DD_TARGET = 500.0 # 目标:可用资金最大回撤 ≤ 500 USDT
|
||||
|
||||
# 训练 / 真实回测 时间范围(左闭右开)
|
||||
TRAIN_START = '2020-01-01'
|
||||
TRAIN_END_EXCLUSIVE = '2021-01-01' # 训练数据: [2020-01-01, 2021-01-01) = 2020 全年
|
||||
TEST_START = '2021-01-01'
|
||||
TEST_END_EXCLUSIVE = '2022-01-01' # 回测数据: [2021-01-01, 2022-01-01) = 2021 全年
|
||||
|
||||
|
||||
def get_db_path():
|
||||
return Path(__file__).parent / 'models' / 'database.db'
|
||||
|
||||
|
||||
# 多周期:与 抓取多周期K线.py 一致,表名 bitmart_eth_{suffix}
|
||||
PERIOD_TABLES = {
|
||||
'1m': 'bitmart_eth_1m',
|
||||
'3m': 'bitmart_eth_3m',
|
||||
'5m': 'bitmart_eth_5m',
|
||||
'15m': 'bitmart_eth_15m',
|
||||
'30m': 'bitmart_eth_30m',
|
||||
'1h': 'bitmart_eth_1h',
|
||||
}
|
||||
|
||||
|
||||
def _normalize_period(period) -> str:
|
||||
"""5 -> '5m', 15 -> '15m', '5' -> '5m', '15m' -> '15m'"""
|
||||
if isinstance(period, int):
|
||||
return f'{period}m' if period != 60 else '1h'
|
||||
s = str(period).strip().lower()
|
||||
if s in PERIOD_TABLES:
|
||||
return s
|
||||
if s.endswith('m'):
|
||||
return s
|
||||
if s == '60' or s == '1h':
|
||||
return '1h'
|
||||
return f'{s}m' if s.isdigit() else s
|
||||
|
||||
|
||||
# ==================== 数据加载 ====================
|
||||
def load_klines(start_date: str = '2025-01-01', end_date: str = '2026-02-01', period: str = '1m'):
|
||||
"""从 SQLite 读取指定周期 K 线。period: '1m'|'3m'|'5m'|'15m'|'30m'|'1h' 或 1|3|5|15|30|60"""
|
||||
period = _normalize_period(period)
|
||||
table = PERIOD_TABLES.get(period)
|
||||
if not table:
|
||||
raise ValueError(f"不支持的周期: {period},可选: {list(PERIOD_TABLES.keys())}")
|
||||
db = get_db_path()
|
||||
if not db.exists():
|
||||
raise FileNotFoundError(f"数据库不存在: {db},请先运行 抓取多周期K线.py 拉取数据")
|
||||
start_ts = int(datetime.datetime.strptime(start_date, '%Y-%m-%d').timestamp()) * 1000
|
||||
end_ts = int(datetime.datetime.strptime(end_date, '%Y-%m-%d').timestamp()) * 1000
|
||||
conn = sqlite3.connect(str(db))
|
||||
df = pd.read_sql_query(
|
||||
f"SELECT id as ts, open, high, low, close FROM {table} "
|
||||
"WHERE id >= ? AND id < ? ORDER BY id",
|
||||
conn, params=(start_ts, end_ts))
|
||||
conn.close()
|
||||
if len(df) == 0:
|
||||
raise FileNotFoundError(f"表 {table} 中无 {start_date}~{end_date} 数据,请先抓取该周期 K 线")
|
||||
df['datetime'] = pd.to_datetime(df['ts'], unit='ms')
|
||||
df.set_index('datetime', inplace=True)
|
||||
return df
|
||||
|
||||
|
||||
# ==================== 指标参数(类布林带/均值回归,均可训练优化) ====================
|
||||
def default_indicator_params():
|
||||
"""默认指标参数,均可被搜索优化。含:布林、肯特纳、唐奇安、随机、CCI、威廉、RSI、ATR带、Z-Score、线性回归通道、中位数带、百分位排名、Chandelier、标准差带、Elder 射线等"""
|
||||
return {
|
||||
'bb_period': 20,
|
||||
'bb_std': 2.0,
|
||||
'keltner_period': 20,
|
||||
'keltner_atr_mult': 2.0,
|
||||
'donchian_period': 20,
|
||||
'stoch_k_period': 14,
|
||||
'stoch_d_period': 3,
|
||||
'cci_period': 20,
|
||||
'willr_period': 14,
|
||||
'rsi_period': 14,
|
||||
'atr_band_period': 20,
|
||||
'atr_band_mult': 2.0,
|
||||
# 均值回归类新增
|
||||
'zscore_period': 20,
|
||||
'lr_period': 20,
|
||||
'lr_std_mult': 2.0,
|
||||
'median_band_period': 20,
|
||||
'pct_rank_period': 20,
|
||||
'chandelier_period': 22,
|
||||
'chandelier_mult': 3.0,
|
||||
'std_band_period': 14,
|
||||
'std_band_mult': 2.0,
|
||||
'elder_period': 13,
|
||||
}
|
||||
|
||||
|
||||
# ==================== 特征工程(所有带带/通道类指标参数可调) ====================
|
||||
def add_features(df: pd.DataFrame, ind: dict = None) -> pd.DataFrame:
|
||||
"""生成特征。ind 为指标参数字典,缺省用 default_indicator_params()"""
|
||||
if ind is None:
|
||||
ind = default_indicator_params()
|
||||
c = df['close']
|
||||
h = df['high']
|
||||
l = df['low']
|
||||
o = df['open']
|
||||
cp = c.replace(0, np.nan)
|
||||
|
||||
# --- 布林带 ---
|
||||
bp = ind['bb_period']
|
||||
bstd = ind['bb_std']
|
||||
mid = c.rolling(bp).mean()
|
||||
std = c.rolling(bp).std()
|
||||
df['bb_upper'] = mid + bstd * std
|
||||
df['bb_lower'] = mid - bstd * std
|
||||
df['bb_mid'] = mid
|
||||
df['bb_pct'] = (c - df['bb_lower']) / (df['bb_upper'] - df['bb_lower']).replace(0, np.nan)
|
||||
df['bb_width'] = (df['bb_upper'] - df['bb_lower']) / mid.replace(0, np.nan)
|
||||
|
||||
# --- 肯特纳通道 (Keltner) ---
|
||||
kp = ind['keltner_period']
|
||||
katr = ind['keltner_atr_mult']
|
||||
tr = pd.concat([h - l, (h - c.shift(1)).abs(), (l - c.shift(1)).abs()], axis=1).max(axis=1)
|
||||
atr_k = tr.rolling(kp).mean()
|
||||
k_mid = c.ewm(span=kp, adjust=False).mean()
|
||||
df['keltner_upper'] = k_mid + katr * atr_k
|
||||
df['keltner_lower'] = k_mid - katr * atr_k
|
||||
df['keltner_mid'] = k_mid
|
||||
df['keltner_pct'] = (c - df['keltner_lower']) / (df['keltner_upper'] - df['keltner_lower']).replace(0, np.nan)
|
||||
df['keltner_width'] = (df['keltner_upper'] - df['keltner_lower']) / k_mid.replace(0, np.nan)
|
||||
|
||||
# --- 唐奇安通道 (Donchian) ---
|
||||
dp = ind['donchian_period']
|
||||
du = h.rolling(dp).max()
|
||||
dd = l.rolling(dp).min()
|
||||
dm = (du + dd) / 2
|
||||
df['donchian_upper'] = du
|
||||
df['donchian_lower'] = dd
|
||||
df['donchian_mid'] = dm
|
||||
df['donchian_pct'] = (c - dd) / (du - dd).replace(0, np.nan)
|
||||
df['donchian_width'] = (du - dd) / dm.replace(0, np.nan)
|
||||
|
||||
# --- 随机指标 (Stochastic) ---
|
||||
sk, sd = ind['stoch_k_period'], ind['stoch_d_period']
|
||||
low_k = l.rolling(sk).min()
|
||||
high_k = h.rolling(sk).max()
|
||||
df['stoch_k'] = (c - low_k) / (high_k - low_k).replace(0, np.nan) * 100
|
||||
df['stoch_d'] = df['stoch_k'].rolling(sd).mean()
|
||||
|
||||
# --- CCI ---
|
||||
cci_p = ind['cci_period']
|
||||
typical = (h + l + c) / 3
|
||||
cci_m = typical.rolling(cci_p).mean()
|
||||
cci_s = typical.rolling(cci_p).std()
|
||||
df['cci'] = (typical - cci_m) / (0.015 * cci_s.replace(0, np.nan))
|
||||
|
||||
# --- 威廉 %R ---
|
||||
wp = ind['willr_period']
|
||||
high_w = h.rolling(wp).max()
|
||||
low_w = l.rolling(wp).min()
|
||||
df['willr'] = -100 * (high_w - c) / (high_w - low_w).replace(0, np.nan)
|
||||
|
||||
# --- RSI ---
|
||||
rp = ind['rsi_period']
|
||||
delta = c.diff()
|
||||
gain = delta.clip(lower=0)
|
||||
loss = (-delta).clip(lower=0)
|
||||
avg_gain = gain.rolling(rp).mean()
|
||||
avg_loss = loss.rolling(rp).mean()
|
||||
rs = avg_gain / avg_loss.replace(0, np.nan)
|
||||
df['rsi'] = 100 - 100 / (1 + rs)
|
||||
|
||||
# --- ATR 带 (中轨 SMA,带宽 ATR 倍数) ---
|
||||
abp, abm = ind['atr_band_period'], ind['atr_band_mult']
|
||||
atr_ab = tr.rolling(abp).mean()
|
||||
ab_mid = c.rolling(abp).mean()
|
||||
df['atr_band_upper'] = ab_mid + abm * atr_ab
|
||||
df['atr_band_lower'] = ab_mid - abm * atr_ab
|
||||
df['atr_band_mid'] = ab_mid
|
||||
df['atr_band_pct'] = (c - df['atr_band_lower']) / (df['atr_band_upper'] - df['atr_band_lower']).replace(0, np.nan)
|
||||
df['atr_band_width'] = (df['atr_band_upper'] - df['atr_band_lower']) / ab_mid.replace(0, np.nan)
|
||||
|
||||
# --- Z-Score 均值回归(价格偏离均值的标准差倍数)---
|
||||
zp = ind['zscore_period']
|
||||
z_mid = c.rolling(zp).mean()
|
||||
z_std = c.rolling(zp).std()
|
||||
df['zscore'] = (c - z_mid) / z_std.replace(0, np.nan)
|
||||
df['zscore_abs'] = df['zscore'].abs()
|
||||
|
||||
# --- 线性回归通道 (Linear Regression Channel),向量化实现 ---
|
||||
lrp = ind['lr_period']
|
||||
lr_mult = ind['lr_std_mult']
|
||||
sum_x = lrp * (lrp - 1) / 2.0
|
||||
sum_x2 = lrp * (lrp - 1) * (2 * lrp - 1) / 6.0
|
||||
s_xy = sum(c.shift(j) * j for j in range(lrp))
|
||||
s_y = c.rolling(lrp).sum()
|
||||
denom = lrp * sum_x2 - sum_x * sum_x
|
||||
slope = (lrp * s_xy - sum_x * s_y) / (denom or 1e-10)
|
||||
intercept = s_y / lrp - slope * (lrp - 1) / 2.0
|
||||
lr_mid = intercept + slope * (lrp - 1)
|
||||
lr_resid = c.rolling(lrp).std()
|
||||
df['lr_upper'] = lr_mid + lr_mult * lr_resid
|
||||
df['lr_lower'] = lr_mid - lr_mult * lr_resid
|
||||
df['lr_mid'] = lr_mid
|
||||
df['lr_pct'] = (c - df['lr_lower']) / (df['lr_upper'] - df['lr_lower']).replace(0, np.nan)
|
||||
df['lr_width'] = (df['lr_upper'] - df['lr_lower']) / lr_mid.replace(0, np.nan)
|
||||
|
||||
# --- 中位数带 (Price vs Median,稳健均值回归) ---
|
||||
mdp = ind['median_band_period']
|
||||
med = c.rolling(mdp).median()
|
||||
mstd = c.rolling(mdp).std()
|
||||
df['median_band_mid'] = med
|
||||
df['price_vs_median'] = (c - med) / mstd.replace(0, np.nan)
|
||||
df['median_band_upper'] = med + 2 * mstd
|
||||
df['median_band_lower'] = med - 2 * mstd
|
||||
df['median_band_pct'] = (c - df['median_band_lower']) / (df['median_band_upper'] - df['median_band_lower']).replace(0, np.nan)
|
||||
|
||||
# --- 百分位排名 (Percent Rank,0~1),向量化:价格在区间内位置 ---
|
||||
prp = ind['pct_rank_period']
|
||||
rmin = c.rolling(prp).min()
|
||||
rmax = c.rolling(prp).max()
|
||||
df['pct_rank'] = (c - rmin) / (rmax - rmin).replace(0, np.nan)
|
||||
|
||||
# --- Chandelier Exit 通道 ---
|
||||
cep, cem = ind['chandelier_period'], ind['chandelier_mult']
|
||||
atr_ce = tr.rolling(cep).mean()
|
||||
high_max = h.rolling(cep).max()
|
||||
low_min = l.rolling(cep).min()
|
||||
df['chandelier_upper'] = high_max - cem * atr_ce
|
||||
df['chandelier_lower'] = low_min + cem * atr_ce
|
||||
df['chandelier_mid'] = (df['chandelier_upper'] + df['chandelier_lower']) / 2
|
||||
df['chandelier_pct'] = (c - df['chandelier_lower']) / (df['chandelier_upper'] - df['chandelier_lower']).replace(0, np.nan)
|
||||
df['chandelier_width'] = (df['chandelier_upper'] - df['chandelier_lower']) / df['chandelier_mid'].replace(0, np.nan)
|
||||
|
||||
# --- 标准差带 (Std Band,与 BB 不同周期/倍数可组合) ---
|
||||
sbp, sbm = ind['std_band_period'], ind['std_band_mult']
|
||||
sb_mid = c.rolling(sbp).mean()
|
||||
sb_std = c.rolling(sbp).std()
|
||||
df['std_band_upper'] = sb_mid + sbm * sb_std
|
||||
df['std_band_lower'] = sb_mid - sbm * sb_std
|
||||
df['std_band_mid'] = sb_mid
|
||||
df['std_band_pct'] = (c - df['std_band_lower']) / (df['std_band_upper'] - df['std_band_lower']).replace(0, np.nan)
|
||||
df['std_band_width'] = (df['std_band_upper'] - df['std_band_lower']) / sb_mid.replace(0, np.nan)
|
||||
|
||||
# --- Elder 射线 (Elder Ray,价格与均线的偏离) ---
|
||||
ep = ind['elder_period']
|
||||
ema_elder = c.ewm(span=ep, adjust=False).mean()
|
||||
df['elder_bull'] = (h - ema_elder) / cp
|
||||
df['elder_bear'] = (l - ema_elder) / cp
|
||||
df['elder_dist'] = (c - ema_elder) / cp
|
||||
|
||||
# --- EMA(固定若干周期作辅助)---
|
||||
for p in [5, 8, 13, 21, 50, 120]:
|
||||
df[f'ema_{p}'] = c.ewm(span=p, adjust=False).mean()
|
||||
df['ema_fast_slow'] = (df['ema_8'] - df['ema_21']) / cp
|
||||
df['price_vs_ema120'] = (c - df['ema_120']) / cp
|
||||
df['ema8_slope'] = df['ema_8'].pct_change(5)
|
||||
|
||||
# --- MACD ---
|
||||
ema12 = c.ewm(span=12, adjust=False).mean()
|
||||
ema26 = c.ewm(span=26, adjust=False).mean()
|
||||
df['macd'] = (ema12 - ema26) / cp
|
||||
df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
|
||||
df['macd_hist'] = df['macd'] - df['macd_signal']
|
||||
|
||||
# --- ATR 比例 ---
|
||||
df['atr_14'] = tr.rolling(14).mean()
|
||||
df['atr_pct'] = df['atr_14'] / cp
|
||||
|
||||
# --- 动量与波动 ---
|
||||
for p in [1, 3, 5, 10, 20]:
|
||||
df[f'ret_{p}'] = c.pct_change(p)
|
||||
df['vol_5'] = c.pct_change().rolling(5).std()
|
||||
df['vol_20'] = c.pct_change().rolling(20).std()
|
||||
|
||||
# --- K线形态与时间 ---
|
||||
body = (c - o).abs()
|
||||
df['body_pct'] = body / cp
|
||||
df['price_position_20'] = (c - l.rolling(20).min()) / (h.rolling(20).max() - l.rolling(20).min()).replace(0, np.nan)
|
||||
df['hour_sin'] = np.sin(2 * np.pi * df.index.hour / 24)
|
||||
df['hour_cos'] = np.cos(2 * np.pi * df.index.hour / 24)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def get_feature_cols(df: pd.DataFrame):
|
||||
"""排除原始价格与各通道上下轨/中轨原始列,只保留衍生特征"""
|
||||
exclude = {
|
||||
'ts', 'open', 'high', 'low', 'close', 'label',
|
||||
'bb_upper', 'bb_lower', 'bb_mid', 'atr_14',
|
||||
'keltner_upper', 'keltner_lower', 'keltner_mid',
|
||||
'donchian_upper', 'donchian_lower', 'donchian_mid',
|
||||
'atr_band_upper', 'atr_band_lower', 'atr_band_mid',
|
||||
'lr_upper', 'lr_lower', 'lr_mid',
|
||||
'median_band_upper', 'median_band_lower', 'median_band_mid',
|
||||
'chandelier_upper', 'chandelier_lower', 'chandelier_mid',
|
||||
'std_band_upper', 'std_band_lower', 'std_band_mid',
|
||||
'ema_5', 'ema_8', 'ema_13', 'ema_21', 'ema_50', 'ema_120',
|
||||
}
|
||||
return [c for c in df.columns if c not in exclude
|
||||
and df[c].dtype in ('float64', 'float32', 'int64', 'int32')]
|
||||
|
||||
|
||||
# ==================== 标签 ====================
|
||||
def add_labels(df: pd.DataFrame, forward_bars: int = 10, threshold: float = 0.002) -> pd.DataFrame:
|
||||
"""未来 N 根收益率 > threshold → 1(多), < -threshold → -1(空), 否则 0"""
|
||||
future_ret = df['close'].shift(-forward_bars) / df['close'] - 1
|
||||
df = df.copy()
|
||||
df['label'] = 0
|
||||
df.loc[future_ret > threshold, 'label'] = 1
|
||||
df.loc[future_ret < -threshold, 'label'] = -1
|
||||
return df
|
||||
|
||||
|
||||
# ==================== 滚动训练 ====================
|
||||
def train_predict_walkforward(
|
||||
df: pd.DataFrame,
|
||||
feature_cols: list,
|
||||
train_months: int = 3,
|
||||
lgb_rounds: int = 250,
|
||||
):
|
||||
"""滚动:用过去 train_months 月训练,预测下一个月"""
|
||||
df = df.copy()
|
||||
df['month'] = df.index.to_period('M')
|
||||
months = sorted(df['month'].unique())
|
||||
all_proba_long = pd.Series(0.0, index=df.index)
|
||||
all_proba_short = pd.Series(0.0, index=df.index)
|
||||
last_model = None
|
||||
|
||||
for i in range(train_months, len(months)):
|
||||
test_month = months[i]
|
||||
train_start = months[i - train_months]
|
||||
train_mask = (df['month'] >= train_start) & (df['month'] < test_month)
|
||||
test_mask = df['month'] == test_month
|
||||
train_df = df[train_mask].dropna(subset=feature_cols + ['label'])
|
||||
test_df = df[test_mask].dropna(subset=feature_cols)
|
||||
if len(train_df) < 1000 or len(test_df) < 100:
|
||||
continue
|
||||
X_train = train_df[feature_cols].values
|
||||
y_train = (train_df['label'].values + 1).astype(int) # -1,0,1 -> 0,1,2
|
||||
X_test = test_df[feature_cols].values
|
||||
params = {
|
||||
'objective': 'multiclass',
|
||||
'num_class': 3,
|
||||
'metric': 'multi_logloss',
|
||||
'learning_rate': 0.05,
|
||||
'num_leaves': 31,
|
||||
'max_depth': 6,
|
||||
'min_child_samples': 50,
|
||||
'subsample': 0.8,
|
||||
'colsample_bytree': 0.8,
|
||||
'reg_alpha': 0.1,
|
||||
'reg_lambda': 0.1,
|
||||
'verbose': -1,
|
||||
'n_jobs': -1,
|
||||
'seed': 42,
|
||||
}
|
||||
dtrain = lgb.Dataset(X_train, label=y_train)
|
||||
model = lgb.train(params, dtrain, num_boost_round=lgb_rounds)
|
||||
last_model = model
|
||||
proba = model.predict(X_test) # (n, 3) -> [P(short), P(neutral), P(long)]
|
||||
test_idx = test_df.index
|
||||
all_proba_short.loc[test_idx] = proba[:, 0]
|
||||
all_proba_long.loc[test_idx] = proba[:, 2]
|
||||
|
||||
return all_proba_long, all_proba_short, last_model
|
||||
|
||||
|
||||
def train_on_period_predict_on_other(
|
||||
df_train: pd.DataFrame,
|
||||
df_test: pd.DataFrame,
|
||||
feature_cols: list,
|
||||
forward_bars: int,
|
||||
label_threshold: float,
|
||||
lgb_rounds: int = 250,
|
||||
):
|
||||
"""
|
||||
在 df_train 上训练一个模型,在 df_test 上预测(严格样本外)。
|
||||
df_train 需含 label 列;df_test 只需含 feature_cols。
|
||||
返回 (proba_long_series, proba_short_series, model)
|
||||
"""
|
||||
train_df = df_train.dropna(subset=feature_cols + ['label'])
|
||||
if len(train_df) < 2000:
|
||||
return None, None, None
|
||||
X_train = train_df[feature_cols].values
|
||||
y_train = (train_df['label'].values + 1).astype(int)
|
||||
test_df = df_test.dropna(subset=feature_cols)
|
||||
if len(test_df) < 100:
|
||||
return None, None, None
|
||||
X_test = test_df[feature_cols].values
|
||||
params = {
|
||||
'objective': 'multiclass',
|
||||
'num_class': 3,
|
||||
'metric': 'multi_logloss',
|
||||
'learning_rate': 0.05,
|
||||
'num_leaves': 31,
|
||||
'max_depth': 6,
|
||||
'min_child_samples': 50,
|
||||
'subsample': 0.8,
|
||||
'colsample_bytree': 0.8,
|
||||
'reg_alpha': 0.1,
|
||||
'reg_lambda': 0.1,
|
||||
'verbose': -1,
|
||||
'n_jobs': -1,
|
||||
'seed': 42,
|
||||
}
|
||||
dtrain = lgb.Dataset(X_train, label=y_train)
|
||||
model = lgb.train(params, dtrain, num_boost_round=lgb_rounds)
|
||||
proba = model.predict(X_test)
|
||||
all_proba_long = pd.Series(0.0, index=df_test.index)
|
||||
all_proba_short = pd.Series(0.0, index=df_test.index)
|
||||
all_proba_short.loc[test_df.index] = proba[:, 0]
|
||||
all_proba_long.loc[test_df.index] = proba[:, 2]
|
||||
return all_proba_long, all_proba_short, model
|
||||
|
||||
|
||||
# ==================== 回测(单仓位、100U×100倍) ====================
|
||||
def backtest(
|
||||
df: pd.DataFrame,
|
||||
proba_long: pd.Series,
|
||||
proba_short: pd.Series,
|
||||
notional: float = NOTIONAL_PER_TRADE,
|
||||
prob_threshold: float = 0.45,
|
||||
min_hold_seconds: int = 180,
|
||||
max_hold_seconds: int = 1800,
|
||||
sl_pct: float = 0.004,
|
||||
tp_pct: float = 0.006,
|
||||
) -> list:
|
||||
"""同一时间仅 1 仓,开仓即 10000U 名义(100U×100倍)"""
|
||||
pos = 0
|
||||
open_price = 0.0
|
||||
open_time = None
|
||||
trades = []
|
||||
|
||||
for i in range(len(df)):
|
||||
dt = df.index[i]
|
||||
price = df['close'].iloc[i]
|
||||
pl = proba_long.iloc[i]
|
||||
ps = proba_short.iloc[i]
|
||||
|
||||
if pos != 0 and open_time is not None:
|
||||
pnl_pct = (price - open_price) / open_price if pos == 1 else (open_price - price) / open_price
|
||||
hold_sec = (dt - open_time).total_seconds()
|
||||
|
||||
# 硬止损
|
||||
if -pnl_pct >= sl_pct * 1.5:
|
||||
pnl_usdt = notional * pnl_pct
|
||||
trades.append((pos, open_price, price, pnl_usdt, hold_sec, '硬止损', open_time, dt))
|
||||
pos = 0
|
||||
continue
|
||||
if hold_sec >= min_hold_seconds:
|
||||
if -pnl_pct >= sl_pct:
|
||||
pnl_usdt = notional * pnl_pct
|
||||
trades.append((pos, open_price, price, pnl_usdt, hold_sec, '止损', open_time, dt))
|
||||
pos = 0
|
||||
continue
|
||||
if pnl_pct >= tp_pct:
|
||||
pnl_usdt = notional * pnl_pct
|
||||
trades.append((pos, open_price, price, pnl_usdt, hold_sec, '止盈', open_time, dt))
|
||||
pos = 0
|
||||
continue
|
||||
if hold_sec >= max_hold_seconds:
|
||||
pnl_usdt = notional * pnl_pct
|
||||
trades.append((pos, open_price, price, pnl_usdt, hold_sec, '超时', open_time, dt))
|
||||
pos = 0
|
||||
continue
|
||||
# AI 反向信号平仓
|
||||
if pos == 1 and ps > prob_threshold + 0.05:
|
||||
pnl_usdt = notional * pnl_pct
|
||||
trades.append((pos, open_price, price, pnl_usdt, hold_sec, 'AI反转', open_time, dt))
|
||||
pos = 0
|
||||
elif pos == -1 and pl > prob_threshold + 0.05:
|
||||
pnl_usdt = notional * pnl_pct
|
||||
trades.append((pos, open_price, price, pnl_usdt, hold_sec, 'AI反转', open_time, dt))
|
||||
pos = 0
|
||||
|
||||
if pos == 0:
|
||||
if pl > prob_threshold and pl > ps:
|
||||
pos = 1
|
||||
open_price = price
|
||||
open_time = dt
|
||||
elif ps > prob_threshold and ps > pl:
|
||||
pos = -1
|
||||
open_price = price
|
||||
open_time = dt
|
||||
|
||||
if pos != 0:
|
||||
price = df['close'].iloc[-1]
|
||||
dt = df.index[-1]
|
||||
pnl_pct = (price - open_price) / open_price if pos == 1 else (open_price - price) / open_price
|
||||
hold_sec = (dt - open_time).total_seconds()
|
||||
trades.append((pos, open_price, price, notional * pnl_pct, hold_sec, '结束', open_time, dt))
|
||||
|
||||
return trades
|
||||
|
||||
|
||||
# ==================== 结果分析 ====================
|
||||
def analyze_trades(trades: list, notional: float = NOTIONAL_PER_TRADE, initial_balance: float = INITIAL_BALANCE) -> dict:
|
||||
"""
|
||||
统计净利、胜率、回撤、月均净利等。
|
||||
回撤按「可用资金」计算:手续费当日扣全额,返佣次日 8 点才到账,故 dd_available 更保守。
|
||||
"""
|
||||
if not trades:
|
||||
return {'n': 0, 'net': 0.0, 'wr': 0.0, 'dd': 0.0, 'dd_available': 0.0, 'total_pnl': 0.0, 'monthly_net': 0.0, 'months': 0}
|
||||
n = len(trades)
|
||||
total_pnl = sum(t[3] for t in trades)
|
||||
net = total_pnl - NET_FEE_PER_TRADE * n
|
||||
wins = len([t for t in trades if t[3] > 0])
|
||||
wr = wins / n * 100
|
||||
# 简单回撤(净手续费已扣,不含返佣延迟)
|
||||
cum = 0.0
|
||||
peak = 0.0
|
||||
dd = 0.0
|
||||
for t in trades:
|
||||
cum += t[3] - NET_FEE_PER_TRADE
|
||||
if cum > peak:
|
||||
peak = cum
|
||||
if peak - cum > dd:
|
||||
dd = peak - cum
|
||||
# 可用资金回撤:平仓时扣全额手续费,返佣次日 8 点才到账(简化为 +1 天)
|
||||
events = [] # (datetime, delta_balance)
|
||||
for t in trades:
|
||||
close_time = t[7]
|
||||
events.append((close_time, t[3] - FULL_FEE_PER_TRADE))
|
||||
try:
|
||||
next_day = close_time + datetime.timedelta(days=1)
|
||||
except Exception:
|
||||
next_day = close_time
|
||||
events.append((next_day, REBATE_PER_TRADE))
|
||||
events.sort(key=lambda x: x[0])
|
||||
balance = initial_balance
|
||||
peak_bal = balance
|
||||
dd_available = 0.0
|
||||
for _, delta in events:
|
||||
balance += delta
|
||||
if balance > peak_bal:
|
||||
peak_bal = balance
|
||||
if peak_bal - balance > dd_available:
|
||||
dd_available = peak_bal - balance
|
||||
# 按平仓时间所在月汇总净利
|
||||
monthly_net = defaultdict(float)
|
||||
for t in trades:
|
||||
close_time = t[7]
|
||||
month_key = close_time.strftime('%Y-%m') if hasattr(close_time, 'strftime') else str(close_time)[:7]
|
||||
monthly_net[month_key] += t[3] - NET_FEE_PER_TRADE
|
||||
num_months = len(monthly_net) or 1
|
||||
avg_monthly_net = sum(monthly_net.values()) / num_months
|
||||
return {
|
||||
'n': n,
|
||||
'net': net,
|
||||
'wr': wr,
|
||||
'dd': dd,
|
||||
'dd_available': dd_available,
|
||||
'total_pnl': total_pnl,
|
||||
'avg_pnl': net / n,
|
||||
'monthly_net': avg_monthly_net,
|
||||
'months': num_months,
|
||||
'monthly_detail': dict(monthly_net),
|
||||
}
|
||||
|
||||
|
||||
TARGET_MONTHLY_NET = 1000.0 # 目标月均净利 (USDT)
|
||||
|
||||
def print_report(trades: list, label: str = ""):
|
||||
"""打印回测报告:本金、可用资金最大回撤(返佣次日到账)、月均净利及双目标"""
|
||||
if not trades:
|
||||
print(f" [{label}] 无交易", flush=True)
|
||||
return
|
||||
r = analyze_trades(trades)
|
||||
reasons = defaultdict(int)
|
||||
for t in trades:
|
||||
reasons[t[5]] += 1
|
||||
dd_ok = r['dd_available'] <= MAX_DD_TARGET
|
||||
monthly_ok = r['monthly_net'] >= TARGET_MONTHLY_NET
|
||||
print(f"\n === {label} ===", flush=True)
|
||||
print(f" 本金: {INITIAL_BALANCE:.0f} USDT | 交易: {r['n']} 笔 | 总净利: {r['net']:+.2f} USDT | 胜率: {r['wr']:.1f}%", flush=True)
|
||||
print(f" 可用资金最大回撤: {r['dd_available']:.2f} USDT (返佣次日到账) | 目标 ≤{MAX_DD_TARGET:.0f}U: {'达标' if dd_ok else '未达标'}", flush=True)
|
||||
print(f" 月均净利: {r['monthly_net']:+.2f} USDT (共 {r['months']} 个月) | 目标 ≥{TARGET_MONTHLY_NET:.0f}U/月: {'达标' if monthly_ok else '未达标'}", flush=True)
|
||||
print(f" 平仓原因: {dict(reasons)}", flush=True)
|
||||
|
||||
|
||||
# ==================== 模型保存 ====================
|
||||
def save_model_and_params(model, feature_cols: list, params: dict, path: Path = None):
|
||||
"""保存 LightGBM 模型和策略参数"""
|
||||
if path is None:
|
||||
path = Path(__file__).parent / 'models' / 'eth_ai_strategy'
|
||||
path = Path(path)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
model.save_model(str(path / 'model.txt'))
|
||||
with open(path / 'feature_cols.json', 'w', encoding='utf-8') as f:
|
||||
json.dump(feature_cols, f, ensure_ascii=False, indent=2)
|
||||
with open(path / 'strategy_params.json', 'w', encoding='utf-8') as f:
|
||||
json.dump(params, f, ensure_ascii=False, indent=2)
|
||||
print(f" 模型与参数已保存至: {path}", flush=True)
|
||||
|
||||
|
||||
# ==================== 主流程:2020 训练 + 2021 真实回测 ====================
|
||||
def run_single(
|
||||
df_train: pd.DataFrame,
|
||||
df_test: pd.DataFrame,
|
||||
ind_params: dict,
|
||||
forward_bars: int,
|
||||
label_threshold: float,
|
||||
prob_threshold: float,
|
||||
sl_pct: float,
|
||||
tp_pct: float,
|
||||
min_hold: int,
|
||||
max_hold: int,
|
||||
):
|
||||
"""
|
||||
用 df_train(2020)训练模型,在 df_test(2021)上预测并回测,严格样本外。
|
||||
返回 (trades_2021, model, feature_cols)。
|
||||
"""
|
||||
df_tr = add_features(df_train.copy(), ind=ind_params)
|
||||
df_te = add_features(df_test.copy(), ind=ind_params)
|
||||
feature_cols = get_feature_cols(df_tr)
|
||||
df_tr = add_labels(df_tr, forward_bars=forward_bars, threshold=label_threshold)
|
||||
proba_long, proba_short, model = train_on_period_predict_on_other(
|
||||
df_tr, df_te, feature_cols, forward_bars, label_threshold,
|
||||
)
|
||||
if model is None:
|
||||
return [], None, None
|
||||
trades = backtest(
|
||||
df_te, proba_long, proba_short,
|
||||
notional=NOTIONAL_PER_TRADE,
|
||||
prob_threshold=prob_threshold,
|
||||
min_hold_seconds=min_hold,
|
||||
max_hold_seconds=max_hold,
|
||||
sl_pct=sl_pct,
|
||||
tp_pct=tp_pct,
|
||||
)
|
||||
return trades, model, feature_cols
|
||||
|
||||
|
||||
def _grid_indicator_and_strategy(max_configs: int = 220, seed: int = 42):
|
||||
"""随机采样「各类均值回归/带型指标参数 + 策略参数」组合,各种组合、各种指标参数一起调"""
|
||||
rng = np.random.default_rng(seed)
|
||||
ind_base = default_indicator_params()
|
||||
# 布林类
|
||||
bb_opts = [(20, 2.0), (20, 2.5), (30, 2.0), (15, 2.0), (14, 2.0)]
|
||||
keltner_opts = [(20, 2.0), (20, 2.5), (14, 2.0), (22, 2.5)]
|
||||
donchian_opts = [14, 20, 30]
|
||||
stoch_opts = [(14, 3), (10, 3), (20, 5)]
|
||||
cci_opts = [14, 20]
|
||||
rsi_opts = [7, 14, 21]
|
||||
atr_band_opts = [(20, 2.0), (14, 2.0), (22, 2.5)]
|
||||
# 新增均值回归类
|
||||
zscore_opts = [14, 20, 30]
|
||||
lr_opts = [(20, 2.0), (14, 2.0), (30, 2.5)]
|
||||
median_band_opts = [14, 20, 30]
|
||||
pct_rank_opts = [14, 20, 30]
|
||||
chandelier_opts = [(22, 3.0), (14, 2.5), (30, 3.5)]
|
||||
std_band_opts = [(14, 2.0), (20, 2.0), (14, 2.5)]
|
||||
elder_opts = [10, 13, 20]
|
||||
# 策略:加入更保守选项以压低回撤(紧止损、高置信)
|
||||
label_opts = [(8, 0.0015), (10, 0.002), (10, 0.003), (15, 0.002), (20, 0.003)]
|
||||
sl_tp_opts = [
|
||||
(0.002, 0.004), (0.0025, 0.005), (0.003, 0.005), (0.003, 0.006),
|
||||
(0.004, 0.006), (0.004, 0.008), (0.005, 0.008), (0.005, 0.010),
|
||||
]
|
||||
prob_opts = [0.42, 0.45, 0.48, 0.50]
|
||||
out = []
|
||||
for _ in range(max_configs):
|
||||
bb_p, bb_s = bb_opts[rng.integers(0, len(bb_opts))]
|
||||
kp, ka = keltner_opts[rng.integers(0, len(keltner_opts))]
|
||||
dc = donchian_opts[rng.integers(0, len(donchian_opts))]
|
||||
sk, sd = stoch_opts[rng.integers(0, len(stoch_opts))]
|
||||
cci_p = cci_opts[rng.integers(0, len(cci_opts))]
|
||||
rsi_p = rsi_opts[rng.integers(0, len(rsi_opts))]
|
||||
abp, abm = atr_band_opts[rng.integers(0, len(atr_band_opts))]
|
||||
zp = zscore_opts[rng.integers(0, len(zscore_opts))]
|
||||
lrp, lrm = lr_opts[rng.integers(0, len(lr_opts))]
|
||||
mdp = median_band_opts[rng.integers(0, len(median_band_opts))]
|
||||
prp = pct_rank_opts[rng.integers(0, len(pct_rank_opts))]
|
||||
cep, cem = chandelier_opts[rng.integers(0, len(chandelier_opts))]
|
||||
sbp, sbm = std_band_opts[rng.integers(0, len(std_band_opts))]
|
||||
ep = elder_opts[rng.integers(0, len(elder_opts))]
|
||||
fb, th = label_opts[rng.integers(0, len(label_opts))]
|
||||
sl, tp = sl_tp_opts[rng.integers(0, len(sl_tp_opts))]
|
||||
prob = prob_opts[rng.integers(0, len(prob_opts))]
|
||||
ind = dict(ind_base,
|
||||
bb_period=bb_p, bb_std=bb_s,
|
||||
keltner_period=kp, keltner_atr_mult=ka,
|
||||
donchian_period=dc,
|
||||
stoch_k_period=sk, stoch_d_period=sd,
|
||||
cci_period=cci_p, rsi_period=rsi_p,
|
||||
atr_band_period=abp, atr_band_mult=abm,
|
||||
zscore_period=zp,
|
||||
lr_period=lrp, lr_std_mult=lrm,
|
||||
median_band_period=mdp,
|
||||
pct_rank_period=prp,
|
||||
chandelier_period=cep, chandelier_mult=cem,
|
||||
std_band_period=sbp, std_band_mult=sbm,
|
||||
elder_period=ep,
|
||||
)
|
||||
out.append({
|
||||
'ind': ind,
|
||||
'forward_bars': fb,
|
||||
'label_threshold': th,
|
||||
'sl_pct': sl,
|
||||
'tp_pct': tp,
|
||||
'prob_threshold': prob,
|
||||
'min_hold': 180,
|
||||
'max_hold': 1800,
|
||||
})
|
||||
return out
|
||||
|
||||
|
||||
def run_cycle_compare(periods: list, do_save_best: bool = True):
|
||||
"""
|
||||
多周期对比:对每个周期用同一套默认参数做 2020 训练 + 2021 回测,比较 2021 年月均净利等。
|
||||
periods: 如 [5, 15] 表示 5m、15m;或 [1, 3, 5, 15, 30, 60] 表示全部周期。
|
||||
"""
|
||||
t0 = _time.time()
|
||||
period_labels = [_normalize_period(p) for p in periods]
|
||||
print("=" * 60, flush=True)
|
||||
print(" ETH 合约 — 多周期回测对比 (2020 训练 / 2021 回测)", flush=True)
|
||||
print(f" 参与周期: {', '.join(period_labels)}", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
ind = default_indicator_params()
|
||||
results = []
|
||||
best_net = -1e9
|
||||
best_row = None
|
||||
for i, period in enumerate(period_labels):
|
||||
print(f"\n [{i+1}/{len(period_labels)}] 周期 {period} ...", flush=True)
|
||||
try:
|
||||
df_2020 = load_klines(TRAIN_START, TRAIN_END_EXCLUSIVE, period=period)
|
||||
df_2021 = load_klines(TEST_START, TEST_END_EXCLUSIVE, period=period)
|
||||
except Exception as e:
|
||||
print(f" 跳过 {period}: {e}", flush=True)
|
||||
results.append({'period': period, 'ok': False, 'error': str(e)})
|
||||
continue
|
||||
print(f" 数据: 2020 {len(df_2020):,} 根 | 2021 {len(df_2021):,} 根", flush=True)
|
||||
trades, model, feature_cols = run_single(
|
||||
df_2020, df_2021, ind,
|
||||
forward_bars=10, label_threshold=0.002,
|
||||
prob_threshold=0.45, sl_pct=0.004, tp_pct=0.006,
|
||||
min_hold=180, max_hold=1800,
|
||||
)
|
||||
if not trades or model is None:
|
||||
print(f" {period}: 无有效交易", flush=True)
|
||||
results.append({'period': period, 'ok': False})
|
||||
continue
|
||||
r = analyze_trades(trades)
|
||||
results.append({
|
||||
'period': period,
|
||||
'ok': True,
|
||||
'n': r['n'],
|
||||
'net': r['net'],
|
||||
'monthly_net': r['monthly_net'],
|
||||
'wr': r['wr'],
|
||||
'dd': r['dd'],
|
||||
'dd_available': r['dd_available'],
|
||||
'trades': trades,
|
||||
'model': model,
|
||||
'feature_cols': feature_cols,
|
||||
})
|
||||
print(f" {period}: 交易 {r['n']} 笔 | 月均 {r['monthly_net']:+.2f} USDT | 可用回撤 {r['dd_available']:.2f} USDT | 胜率 {r['wr']:.1f}%", flush=True)
|
||||
if r['monthly_net'] > best_net:
|
||||
best_net = r['monthly_net']
|
||||
best_row = results[-1]
|
||||
# 打印对比表
|
||||
ok_results = [x for x in results if x.get('ok')]
|
||||
print("\n" + "=" * 60, flush=True)
|
||||
print(" 多周期对比结果 (2021 年样本外)", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
if not ok_results:
|
||||
print(" 无有效回测结果。", flush=True)
|
||||
return
|
||||
print(f" {'周期':<6} {'交易数':>8} {'总净利':>12} {'月均净利':>12} {'可用回撤':>10} {'胜率':>8}", flush=True)
|
||||
print(" " + "-" * 62, flush=True)
|
||||
for x in ok_results:
|
||||
print(f" {x['period']:<6} {x['n']:>8} {x['net']:>+12.2f} {x['monthly_net']:>+12.2f} {x.get('dd_available', x['dd']):>10.2f} {x['wr']:>7.1f}%", flush=True)
|
||||
# 优选:回撤≤500 且 月盈≥1000;否则月均最高
|
||||
def _cmp_cycle(a):
|
||||
dd_ok = a.get('dd_available', a['dd']) <= MAX_DD_TARGET
|
||||
mo_ok = a['monthly_net'] >= TARGET_MONTHLY_NET
|
||||
return (dd_ok, mo_ok, a['monthly_net'], -a.get('dd_available', a['dd']))
|
||||
best = max(ok_results, key=_cmp_cycle)
|
||||
best_row = best
|
||||
print(" " + "-" * 62, flush=True)
|
||||
b_dd = best.get('dd_available', best['dd'])
|
||||
print(f" 最佳周期: {best['period']} (月均 {best['monthly_net']:+.2f} USDT, 可用回撤 {b_dd:.2f} USDT)", flush=True)
|
||||
if b_dd <= MAX_DD_TARGET and best['monthly_net'] >= TARGET_MONTHLY_NET:
|
||||
print(f" 双目标达标: 回撤≤{MAX_DD_TARGET:.0f}U 且 月盈≥{TARGET_MONTHLY_NET:.0f}U", flush=True)
|
||||
if do_save_best and best_row and best_row.get('model') is not None:
|
||||
flat_params = {
|
||||
**ind,
|
||||
'forward_bars': 10, 'label_threshold': 0.002,
|
||||
'prob_threshold': 0.45, 'sl_pct': 0.004, 'tp_pct': 0.006,
|
||||
'min_hold_seconds': 180, 'max_hold_seconds': 1800,
|
||||
'margin_per_trade': MARGIN_PER_TRADE, 'leverage': LEVERAGE,
|
||||
'notional': NOTIONAL_PER_TRADE, 'rebate_rate': REBATE_RATE,
|
||||
'train_period': f'{TRAIN_START} ~ {TRAIN_END_EXCLUSIVE}',
|
||||
'test_period': f'{TEST_START} ~ {TEST_END_EXCLUSIVE}',
|
||||
'kline_period': best['period'],
|
||||
'initial_balance': INITIAL_BALANCE,
|
||||
'max_dd_target': MAX_DD_TARGET,
|
||||
}
|
||||
save_model_and_params(best_row['model'], best_row.get('feature_cols', []), flat_params)
|
||||
print(f" 已保存最佳周期 {best['period']} 的模型与参数。", flush=True)
|
||||
print(f"\n 总耗时: {_time.time() - t0:.1f}s", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
|
||||
|
||||
def main(do_grid_search: bool = True, period: str = '1m'):
|
||||
"""
|
||||
2020 年全年训练,2021 年全年真实回测(严格样本外)。
|
||||
period: K 线周期 '1m'|'5m'|'15m' 等。
|
||||
参数搜索时在 2021 年月均净利上选最佳并保存。
|
||||
"""
|
||||
period = _normalize_period(period)
|
||||
t0 = _time.time()
|
||||
print("=" * 60, flush=True)
|
||||
print(f" ETH 合约 AI 策略 — 2020 训练 / 2021 真实回测 | K线周期 {period} | 单仓 100U×100倍 | 90% 返佣", flush=True)
|
||||
print(f" 本金: {INITIAL_BALANCE:.0f} USDT | 返佣次日 8 点到账,回撤按可用资金计算", flush=True)
|
||||
print(f" 目标: 可用资金最大回撤 ≤ {MAX_DD_TARGET:.0f} USDT,月均净利 ≥ {TARGET_MONTHLY_NET:.0f} USDT", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
|
||||
print(f"\n[1/4] 加载 K 线 (周期 {period}, 训练 2020 / 回测 2021)...", flush=True)
|
||||
df_2020 = load_klines(TRAIN_START, TRAIN_END_EXCLUSIVE, period=period)
|
||||
df_2021 = load_klines(TEST_START, TEST_END_EXCLUSIVE, period=period)
|
||||
print(f" 2020 训练: {len(df_2020):,} 根 | 2021 回测: {len(df_2021):,} 根", flush=True)
|
||||
if len(df_2020) < 10000:
|
||||
print(" 警告: 2020 年数据不足 10000 根,请先运行 抓取多周期K线.py 拉取 2020 年数据。", flush=True)
|
||||
if len(df_2021) < 1000:
|
||||
print(" 警告: 2021 年数据不足,回测结果可能不可靠。", flush=True)
|
||||
|
||||
# 优选:回撤≤500 且 月盈≥1000;其次回撤≤500;再其次月盈≥1000;否则取综合最优
|
||||
def _score(r):
|
||||
dd_ok = r['dd_available'] <= MAX_DD_TARGET
|
||||
monthly_ok = r['monthly_net'] >= TARGET_MONTHLY_NET
|
||||
return (dd_ok, monthly_ok, r['monthly_net'], -r['dd_available'])
|
||||
|
||||
best_score = (-1, -1, -1e9, 1e9)
|
||||
best_result = None
|
||||
|
||||
if do_grid_search:
|
||||
configs = list(_grid_indicator_and_strategy())
|
||||
print(f"\n[2/4] 参数搜索 (共 {len(configs)} 组): 目标 回撤≤{MAX_DD_TARGET:.0f}U 且 月盈≥{TARGET_MONTHLY_NET:.0f}U ...", flush=True)
|
||||
for i, cfg in enumerate(configs):
|
||||
_t_start = _time.time()
|
||||
trades, model, feature_cols = run_single(
|
||||
df_2020,
|
||||
df_2021,
|
||||
cfg['ind'],
|
||||
cfg['forward_bars'],
|
||||
cfg['label_threshold'],
|
||||
cfg['prob_threshold'],
|
||||
cfg['sl_pct'],
|
||||
cfg['tp_pct'],
|
||||
cfg['min_hold'],
|
||||
cfg['max_hold'],
|
||||
)
|
||||
_elapsed = _time.time() - _t_start
|
||||
if not trades or model is None:
|
||||
if (i + 1) % 10 == 0 or i == 0:
|
||||
print(f" 组 {i+1}/{len(configs)} 完成 (本组 {_elapsed:.0f}s,无有效交易)", flush=True)
|
||||
continue
|
||||
r = analyze_trades(trades)
|
||||
score = _score(r)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_result = (trades, model, feature_cols, {
|
||||
'indicator_params': cfg['ind'],
|
||||
'forward_bars': cfg['forward_bars'],
|
||||
'label_threshold': cfg['label_threshold'],
|
||||
'prob_threshold': cfg['prob_threshold'],
|
||||
'sl_pct': cfg['sl_pct'],
|
||||
'tp_pct': cfg['tp_pct'],
|
||||
'min_hold_seconds': cfg['min_hold'],
|
||||
'max_hold_seconds': cfg['max_hold'],
|
||||
'margin_per_trade': MARGIN_PER_TRADE,
|
||||
'leverage': LEVERAGE,
|
||||
'notional': NOTIONAL_PER_TRADE,
|
||||
'rebate_rate': REBATE_RATE,
|
||||
'train_period': f'{TRAIN_START} ~ {TRAIN_END_EXCLUSIVE}',
|
||||
'test_period': f'{TEST_START} ~ {TEST_END_EXCLUSIVE}',
|
||||
'kline_period': period,
|
||||
'initial_balance': INITIAL_BALANCE,
|
||||
'max_dd_target': MAX_DD_TARGET,
|
||||
})
|
||||
_total_so_far = _time.time() - t0
|
||||
_avg_per = _total_so_far / (i + 1)
|
||||
_left = _avg_per * (len(configs) - i - 1)
|
||||
if (i + 1) % 10 == 0 or i == 0:
|
||||
print(f" 组 {i+1}/{len(configs)} 完成 | 本组 {_elapsed:.0f}s | 回撤 {r['dd_available']:.0f} 月均 {r['monthly_net']:+.0f} | 预计剩余 ~{_left/60:.0f}min", flush=True)
|
||||
else:
|
||||
print("\n[2/4] 使用默认参数: 2020 训练 -> 2021 回测...", flush=True)
|
||||
ind = default_indicator_params()
|
||||
trades, model, feature_cols = run_single(
|
||||
df_2020, df_2021, ind,
|
||||
forward_bars=10, label_threshold=0.002,
|
||||
prob_threshold=0.45, sl_pct=0.004, tp_pct=0.006,
|
||||
min_hold=180, max_hold=1800,
|
||||
)
|
||||
if trades and model is not None:
|
||||
r = analyze_trades(trades)
|
||||
best_monthly = r['monthly_net']
|
||||
reached_target = r['monthly_net'] >= TARGET_MONTHLY_NET
|
||||
best_result = (trades, model, feature_cols, {
|
||||
'indicator_params': ind,
|
||||
'forward_bars': 10,
|
||||
'label_threshold': 0.002,
|
||||
'prob_threshold': 0.45,
|
||||
'sl_pct': 0.004,
|
||||
'tp_pct': 0.006,
|
||||
'min_hold_seconds': 180,
|
||||
'max_hold_seconds': 1800,
|
||||
'margin_per_trade': MARGIN_PER_TRADE,
|
||||
'leverage': LEVERAGE,
|
||||
'notional': NOTIONAL_PER_TRADE,
|
||||
'rebate_rate': REBATE_RATE,
|
||||
'train_period': f'{TRAIN_START} ~ {TRAIN_END_EXCLUSIVE}',
|
||||
'test_period': f'{TEST_START} ~ {TEST_END_EXCLUSIVE}',
|
||||
'kline_period': period,
|
||||
'initial_balance': INITIAL_BALANCE,
|
||||
'max_dd_target': MAX_DD_TARGET,
|
||||
})
|
||||
|
||||
if best_result is None:
|
||||
print(" 未得到有效模型/交易(可能 2020 或 2021 数据不足)。", flush=True)
|
||||
return
|
||||
trades, model, feature_cols, full_params = best_result
|
||||
print("\n[3/4] 2021 年真实回测结果 (最佳参数)...", flush=True)
|
||||
print_report(trades, "2021 年样本外回测")
|
||||
flat_params = {**full_params.get('indicator_params', {}), **{k: v for k, v in full_params.items() if k != 'indicator_params'}}
|
||||
save_model_and_params(model, feature_cols, flat_params)
|
||||
r = analyze_trades(trades)
|
||||
elapsed = _time.time() - t0
|
||||
dd_ok = r['dd_available'] <= MAX_DD_TARGET
|
||||
monthly_ok = r['monthly_net'] >= TARGET_MONTHLY_NET
|
||||
print(f"\n[4/4] 总耗时: {elapsed:.1f}s | 可用资金回撤: {r['dd_available']:.2f} USDT (目标≤{MAX_DD_TARGET:.0f}): {'达标' if dd_ok else '未达标'} | 月均净利: {r['monthly_net']:+.2f} USDT (目标≥{TARGET_MONTHLY_NET:.0f}): {'达标' if monthly_ok else '未达标'}", flush=True)
|
||||
print("=" * 60, flush=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
argv = sys.argv
|
||||
# 多周期对比:默认只比 5m 和 15m;加 --all 则比 1m,3m,5m,15m,30m,1h
|
||||
if '--compare' in argv or '-c' in argv:
|
||||
periods = [5, 15] if '--all' not in argv else [1, 3, 5, 15, 30, 60]
|
||||
run_cycle_compare(periods, do_save_best=True)
|
||||
else:
|
||||
# 单周期:--period 5m 或 --period 15 指定周期,否则默认 1m
|
||||
period = '1m'
|
||||
for i, a in enumerate(argv):
|
||||
if a in ('--period', '-p') and i + 1 < len(argv):
|
||||
period = argv[i + 1]
|
||||
break
|
||||
do_search = '--no-search' not in argv
|
||||
main(do_grid_search=do_search, period=period)
|
||||
Reference in New Issue
Block a user