2026-02-23 04:09:34 +08:00
|
|
|
|
"""
|
|
|
|
|
|
数据加载模块 - 从 SQLite 加载多周期K线数据为 DataFrame
|
|
|
|
|
|
"""
|
2026-02-26 19:05:17 +08:00
|
|
|
|
import numpy as np
|
2026-02-23 04:09:34 +08:00
|
|
|
|
import pandas as pd
|
|
|
|
|
|
from peewee import SqliteDatabase
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
DB_PATH = Path(__file__).parent.parent / 'models' / 'database.db'
|
|
|
|
|
|
|
2026-02-26 19:05:17 +08:00
|
|
|
|
# 周期 -> 表名 (bitmart / binance)
|
2026-02-23 04:09:34 +08:00
|
|
|
|
PERIOD_MAP = {
|
2026-02-26 19:05:17 +08:00
|
|
|
|
'1s': 'bitmart_eth_1s',
|
2026-02-23 04:09:34 +08:00
|
|
|
|
'1m': 'bitmart_eth_1m',
|
|
|
|
|
|
'3m': 'bitmart_eth_3m',
|
|
|
|
|
|
'5m': 'bitmart_eth_5m',
|
|
|
|
|
|
'15m': 'bitmart_eth_15m',
|
|
|
|
|
|
'30m': 'bitmart_eth_30m',
|
|
|
|
|
|
'1h': 'bitmart_eth_1h',
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-02-26 19:05:17 +08:00
|
|
|
|
BINANCE_PERIOD_MAP = {
|
|
|
|
|
|
'1s': 'binance_eth_1s',
|
|
|
|
|
|
'1m': 'binance_eth_1m',
|
|
|
|
|
|
'3m': 'binance_eth_3m',
|
|
|
|
|
|
'5m': 'binance_eth_5m',
|
|
|
|
|
|
'15m': 'binance_eth_15m',
|
|
|
|
|
|
'30m': 'binance_eth_30m',
|
|
|
|
|
|
'1h': 'binance_eth_1h',
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2026-02-23 04:09:34 +08:00
|
|
|
|
|
2026-02-26 19:05:17 +08:00
|
|
|
|
def load_klines(period: str, start_date: str, end_date: str, tz: str | None = None,
|
|
|
|
|
|
source: str = "bitmart") -> pd.DataFrame:
|
2026-02-23 04:09:34 +08:00
|
|
|
|
"""
|
|
|
|
|
|
加载指定周期、指定日期范围的K线数据
|
2026-02-26 19:05:17 +08:00
|
|
|
|
:param period: '1s','1m','3m','5m','15m','30m','1h'
|
2026-02-23 04:09:34 +08:00
|
|
|
|
:param start_date: 'YYYY-MM-DD'
|
|
|
|
|
|
:param end_date: 'YYYY-MM-DD' (不包含该日)
|
2026-02-26 16:34:30 +08:00
|
|
|
|
:param tz: 日期解释的时区,如 'Asia/Shanghai' 表示按北京时间;None 则用本地时区
|
2026-02-23 04:09:34 +08:00
|
|
|
|
:return: DataFrame with columns: datetime, open, high, low, close
|
|
|
|
|
|
"""
|
2026-02-26 19:05:17 +08:00
|
|
|
|
period_map = BINANCE_PERIOD_MAP if source == "binance" else PERIOD_MAP
|
|
|
|
|
|
table = period_map.get(period)
|
2026-02-23 04:09:34 +08:00
|
|
|
|
if not table:
|
|
|
|
|
|
raise ValueError(f"不支持的周期: {period}, 可选: {list(PERIOD_MAP.keys())}")
|
|
|
|
|
|
|
2026-02-26 16:34:30 +08:00
|
|
|
|
if tz:
|
|
|
|
|
|
start_ts = int(pd.Timestamp(start_date, tz=tz).timestamp() * 1000)
|
|
|
|
|
|
end_ts = int(pd.Timestamp(end_date, tz=tz).timestamp() * 1000)
|
|
|
|
|
|
else:
|
|
|
|
|
|
start_ts = int(pd.Timestamp(start_date).timestamp() * 1000)
|
|
|
|
|
|
end_ts = int(pd.Timestamp(end_date).timestamp() * 1000)
|
2026-02-23 04:09:34 +08:00
|
|
|
|
|
|
|
|
|
|
db = SqliteDatabase(str(DB_PATH))
|
|
|
|
|
|
db.connect()
|
|
|
|
|
|
cursor = db.execute_sql(
|
|
|
|
|
|
f'SELECT id, open, high, low, close FROM [{table}] '
|
|
|
|
|
|
f'WHERE id >= ? AND id < ? ORDER BY id',
|
|
|
|
|
|
(start_ts, end_ts)
|
|
|
|
|
|
)
|
|
|
|
|
|
rows = cursor.fetchall()
|
|
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
|
|
df = pd.DataFrame(rows, columns=['timestamp_ms', 'open', 'high', 'low', 'close'])
|
|
|
|
|
|
df['datetime'] = pd.to_datetime(df['timestamp_ms'], unit='ms')
|
|
|
|
|
|
df.set_index('datetime', inplace=True)
|
|
|
|
|
|
df.drop(columns=['timestamp_ms'], inplace=True)
|
|
|
|
|
|
df = df.astype(float)
|
|
|
|
|
|
return df
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-02-26 19:05:17 +08:00
|
|
|
|
def load_multi_period(periods: list, start_date: str, end_date: str, source: str = "bitmart") -> dict:
|
2026-02-23 04:09:34 +08:00
|
|
|
|
"""
|
|
|
|
|
|
加载多个周期的数据
|
|
|
|
|
|
:return: {period: DataFrame}
|
|
|
|
|
|
"""
|
|
|
|
|
|
result = {}
|
|
|
|
|
|
for p in periods:
|
2026-02-26 19:05:17 +08:00
|
|
|
|
result[p] = load_klines(p, start_date, end_date, source=source)
|
2026-02-23 04:09:34 +08:00
|
|
|
|
print(f" 加载 {p}: {len(result[p])} 条 ({start_date} ~ {end_date})")
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-02-26 19:05:17 +08:00
|
|
|
|
def get_1m_touch_direction(df_5m: pd.DataFrame, df_1m: pd.DataFrame,
|
|
|
|
|
|
arr_mid: np.ndarray, kline_step_min: int = 5) -> np.ndarray:
|
|
|
|
|
|
"""
|
|
|
|
|
|
根据 1 分钟线判断每根 5m K 线「先涨碰到均线」还是「先跌碰到均线」。
|
|
|
|
|
|
返回: 1=先涨碰到(可开多), -1=先跌碰到(可开空), 0=未碰到或无法判断
|
|
|
|
|
|
"""
|
|
|
|
|
|
df_1m = df_1m.copy()
|
|
|
|
|
|
df_1m["_bucket"] = df_1m.index.floor(f"{kline_step_min}min")
|
|
|
|
|
|
|
|
|
|
|
|
# 5m 索引与 mid 对齐
|
|
|
|
|
|
mid_sr = pd.Series(arr_mid, index=df_5m.index)
|
|
|
|
|
|
|
|
|
|
|
|
touch_map: dict[pd.Timestamp, int] = {}
|
|
|
|
|
|
for bucket, grp in df_1m.groupby("_bucket", sort=True):
|
|
|
|
|
|
mid = mid_sr.get(bucket, np.nan)
|
|
|
|
|
|
if pd.isna(mid):
|
|
|
|
|
|
touch_map[bucket] = 0
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
o = grp["open"].to_numpy(dtype=float)
|
|
|
|
|
|
h = grp["high"].to_numpy(dtype=float)
|
|
|
|
|
|
l_ = grp["low"].to_numpy(dtype=float)
|
|
|
|
|
|
|
|
|
|
|
|
touch = 0
|
|
|
|
|
|
for j in range(len(grp)):
|
|
|
|
|
|
if l_[j] <= mid <= h[j]:
|
|
|
|
|
|
touch = 1 if o[j] < mid else -1
|
|
|
|
|
|
break
|
|
|
|
|
|
touch_map[bucket] = touch
|
|
|
|
|
|
|
|
|
|
|
|
# 对齐到主周期 index
|
|
|
|
|
|
out = np.zeros(len(df_5m), dtype=np.int32)
|
|
|
|
|
|
for i, t in enumerate(df_5m.index):
|
|
|
|
|
|
out[i] = touch_map.get(t, 0)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-02-23 04:09:34 +08:00
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
|
data = load_multi_period(['5m', '15m', '1h'], '2020-01-01', '2024-01-01')
|
|
|
|
|
|
for k, v in data.items():
|
|
|
|
|
|
print(f"{k}: {v.shape}, {v.index[0]} ~ {v.index[-1]}")
|