Files
jyx_code4/strategy/data_loader.py
ddrwode 21f2adc4a4 哈哈
2026-02-20 20:57:25 +08:00

120 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据加载器 — 从SQLite加载K线数据为pandas DataFrame
"""
import sqlite3
import pandas as pd
from loguru import logger
from .config import DB_PATH, KLINE_PERIODS
def load_kline(period: int = 15, start_date: str = None, end_date: str = None) -> pd.DataFrame:
"""
加载指定周期的K线数据
:param period: K线周期分钟如 1, 3, 5, 15, 30, 60
:param start_date: 起始日期 'YYYY-MM-DD'(可选)
:param end_date: 结束日期 'YYYY-MM-DD'(可选)
:return: DataFrame列: timestamp, open, high, low, close
"""
suffix = KLINE_PERIODS.get(period)
if suffix is None:
raise ValueError(f"不支持的周期: {period},可选: {list(KLINE_PERIODS.keys())}")
table_name = f'bitmart_eth_{suffix}'
conn = sqlite3.connect(str(DB_PATH))
query = f"SELECT id as timestamp, open, high, low, close FROM {table_name} ORDER BY id"
df = pd.read_sql_query(query, conn)
conn.close()
if df.empty:
logger.warning(f"[{suffix}] 表中无数据")
return df
# id 是毫秒时间戳,转为 datetime 索引
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
df.set_index('datetime', inplace=True)
# 按日期过滤
if start_date:
df = df[df.index >= start_date]
if end_date:
df = df[df.index <= end_date]
logger.info(f"[{suffix}] 加载 {len(df)} 条K线 | {df.index[0]} ~ {df.index[-1]}")
return df
def load_multi_period(periods: list = None, start_date: str = None, end_date: str = None) -> dict:
"""
加载多个周期的K线数据
:param periods: 周期列表,如 [5, 15, 60],默认全部
:param start_date: 起始日期
:param end_date: 结束日期
:return: {period: DataFrame} 字典
"""
if periods is None:
periods = list(KLINE_PERIODS.keys())
result = {}
for p in periods:
try:
df = load_kline(p, start_date, end_date)
if not df.empty:
result[p] = df
except Exception as e:
logger.error(f"加载 {p}分钟 K线失败: {e}")
return result
def load_trades(start_date: str = None, end_date: str = None, limit: int = None) -> pd.DataFrame:
"""
加载原始成交记录
:return: DataFrame列: id, timestamp, price, volume, side
"""
conn = sqlite3.connect(str(DB_PATH))
query = "SELECT id, timestamp, price, volume, side FROM bitmart_eth_trades ORDER BY timestamp"
df = pd.read_sql_query(query, conn)
conn.close()
if df.empty:
logger.warning("成交记录表中无数据")
return df
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
df.set_index('datetime', inplace=True)
if start_date:
df = df[df.index >= start_date]
if end_date:
df = df[df.index <= end_date]
if limit:
df = df.head(limit)
logger.info(f"加载 {len(df)} 条成交记录")
return df
def get_available_tables() -> list:
"""列出数据库中所有可用的表"""
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
tables = [row[0] for row in cursor.fetchall()]
conn.close()
return tables
def get_table_stats() -> dict:
"""获取各表的数据统计"""
conn = sqlite3.connect(str(DB_PATH))
tables = get_available_tables()
stats = {}
for table in tables:
try:
count = pd.read_sql_query(f"SELECT COUNT(*) as cnt FROM {table}", conn).iloc[0]['cnt']
stats[table] = count
except Exception:
stats[table] = 0
conn.close()
return stats