Files
codex_jxs_code/strategy/data_loader.py
2026-02-26 19:05:17 +08:00

126 lines
4.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据加载模块 - 从 SQLite 加载多周期K线数据为 DataFrame
"""
import numpy as np
import pandas as pd
from peewee import SqliteDatabase
from pathlib import Path
DB_PATH = Path(__file__).parent.parent / 'models' / 'database.db'
# 周期 -> 表名 (bitmart / binance)
PERIOD_MAP = {
'1s': 'bitmart_eth_1s',
'1m': 'bitmart_eth_1m',
'3m': 'bitmart_eth_3m',
'5m': 'bitmart_eth_5m',
'15m': 'bitmart_eth_15m',
'30m': 'bitmart_eth_30m',
'1h': 'bitmart_eth_1h',
}
BINANCE_PERIOD_MAP = {
'1s': 'binance_eth_1s',
'1m': 'binance_eth_1m',
'3m': 'binance_eth_3m',
'5m': 'binance_eth_5m',
'15m': 'binance_eth_15m',
'30m': 'binance_eth_30m',
'1h': 'binance_eth_1h',
}
def load_klines(period: str, start_date: str, end_date: str, tz: str | None = None,
source: str = "bitmart") -> pd.DataFrame:
"""
加载指定周期、指定日期范围的K线数据
:param period: '1s','1m','3m','5m','15m','30m','1h'
:param start_date: 'YYYY-MM-DD'
:param end_date: 'YYYY-MM-DD' (不包含该日)
:param tz: 日期解释的时区,如 'Asia/Shanghai' 表示按北京时间None 则用本地时区
:return: DataFrame with columns: datetime, open, high, low, close
"""
period_map = BINANCE_PERIOD_MAP if source == "binance" else PERIOD_MAP
table = period_map.get(period)
if not table:
raise ValueError(f"不支持的周期: {period}, 可选: {list(PERIOD_MAP.keys())}")
if tz:
start_ts = int(pd.Timestamp(start_date, tz=tz).timestamp() * 1000)
end_ts = int(pd.Timestamp(end_date, tz=tz).timestamp() * 1000)
else:
start_ts = int(pd.Timestamp(start_date).timestamp() * 1000)
end_ts = int(pd.Timestamp(end_date).timestamp() * 1000)
db = SqliteDatabase(str(DB_PATH))
db.connect()
cursor = db.execute_sql(
f'SELECT id, open, high, low, close FROM [{table}] '
f'WHERE id >= ? AND id < ? ORDER BY id',
(start_ts, end_ts)
)
rows = cursor.fetchall()
db.close()
df = pd.DataFrame(rows, columns=['timestamp_ms', 'open', 'high', 'low', 'close'])
df['datetime'] = pd.to_datetime(df['timestamp_ms'], unit='ms')
df.set_index('datetime', inplace=True)
df.drop(columns=['timestamp_ms'], inplace=True)
df = df.astype(float)
return df
def load_multi_period(periods: list, start_date: str, end_date: str, source: str = "bitmart") -> dict:
"""
加载多个周期的数据
:return: {period: DataFrame}
"""
result = {}
for p in periods:
result[p] = load_klines(p, start_date, end_date, source=source)
print(f" 加载 {p}: {len(result[p])} 条 ({start_date} ~ {end_date})")
return result
def get_1m_touch_direction(df_5m: pd.DataFrame, df_1m: pd.DataFrame,
arr_mid: np.ndarray, kline_step_min: int = 5) -> np.ndarray:
"""
根据 1 分钟线判断每根 5m K 线「先涨碰到均线」还是「先跌碰到均线」。
返回: 1=先涨碰到(可开多), -1=先跌碰到(可开空), 0=未碰到或无法判断
"""
df_1m = df_1m.copy()
df_1m["_bucket"] = df_1m.index.floor(f"{kline_step_min}min")
# 5m 索引与 mid 对齐
mid_sr = pd.Series(arr_mid, index=df_5m.index)
touch_map: dict[pd.Timestamp, int] = {}
for bucket, grp in df_1m.groupby("_bucket", sort=True):
mid = mid_sr.get(bucket, np.nan)
if pd.isna(mid):
touch_map[bucket] = 0
continue
o = grp["open"].to_numpy(dtype=float)
h = grp["high"].to_numpy(dtype=float)
l_ = grp["low"].to_numpy(dtype=float)
touch = 0
for j in range(len(grp)):
if l_[j] <= mid <= h[j]:
touch = 1 if o[j] < mid else -1
break
touch_map[bucket] = touch
# 对齐到主周期 index
out = np.zeros(len(df_5m), dtype=np.int32)
for i, t in enumerate(df_5m.index):
out[i] = touch_map.get(t, 0)
return out
if __name__ == '__main__':
data = load_multi_period(['5m', '15m', '1h'], '2020-01-01', '2024-01-01')
for k, v in data.items():
print(f"{k}: {v.shape}, {v.index[0]} ~ {v.index[-1]}")