Files
codex_jxs_code/strategy/data_loader.py
2026-02-26 16:34:30 +08:00

75 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据加载模块 - 从 SQLite 加载多周期K线数据为 DataFrame
"""
import pandas as pd
from peewee import SqliteDatabase
from pathlib import Path
DB_PATH = Path(__file__).parent.parent / 'models' / 'database.db'
# 周期 -> 表名
PERIOD_MAP = {
'1m': 'bitmart_eth_1m',
'3m': 'bitmart_eth_3m',
'5m': 'bitmart_eth_5m',
'15m': 'bitmart_eth_15m',
'30m': 'bitmart_eth_30m',
'1h': 'bitmart_eth_1h',
}
def load_klines(period: str, start_date: str, end_date: str, tz: str | None = None) -> pd.DataFrame:
"""
加载指定周期、指定日期范围的K线数据
:param period: '1m','3m','5m','15m','30m','1h'
:param start_date: 'YYYY-MM-DD'
:param end_date: 'YYYY-MM-DD' (不包含该日)
:param tz: 日期解释的时区,如 'Asia/Shanghai' 表示按北京时间None 则用本地时区
:return: DataFrame with columns: datetime, open, high, low, close
"""
table = PERIOD_MAP.get(period)
if not table:
raise ValueError(f"不支持的周期: {period}, 可选: {list(PERIOD_MAP.keys())}")
if tz:
start_ts = int(pd.Timestamp(start_date, tz=tz).timestamp() * 1000)
end_ts = int(pd.Timestamp(end_date, tz=tz).timestamp() * 1000)
else:
start_ts = int(pd.Timestamp(start_date).timestamp() * 1000)
end_ts = int(pd.Timestamp(end_date).timestamp() * 1000)
db = SqliteDatabase(str(DB_PATH))
db.connect()
cursor = db.execute_sql(
f'SELECT id, open, high, low, close FROM [{table}] '
f'WHERE id >= ? AND id < ? ORDER BY id',
(start_ts, end_ts)
)
rows = cursor.fetchall()
db.close()
df = pd.DataFrame(rows, columns=['timestamp_ms', 'open', 'high', 'low', 'close'])
df['datetime'] = pd.to_datetime(df['timestamp_ms'], unit='ms')
df.set_index('datetime', inplace=True)
df.drop(columns=['timestamp_ms'], inplace=True)
df = df.astype(float)
return df
def load_multi_period(periods: list, start_date: str, end_date: str) -> dict:
"""
加载多个周期的数据
:return: {period: DataFrame}
"""
result = {}
for p in periods:
result[p] = load_klines(p, start_date, end_date)
print(f" 加载 {p}: {len(result[p])} 条 ({start_date} ~ {end_date})")
return result
if __name__ == '__main__':
data = load_multi_period(['5m', '15m', '1h'], '2020-01-01', '2024-01-01')
for k, v in data.items():
print(f"{k}: {v.shape}, {v.index[0]} ~ {v.index[-1]}")