Files
jyx_code4/strategy/data_loader.py

120 lines
3.6 KiB
Python
Raw Normal View History

2026-02-20 20:57:25 +08:00
"""
数据加载器 从SQLite加载K线数据为pandas DataFrame
"""
import sqlite3
import pandas as pd
from loguru import logger
from .config import DB_PATH, KLINE_PERIODS
def load_kline(period: int = 15, start_date: str = None, end_date: str = None) -> pd.DataFrame:
"""
加载指定周期的K线数据
:param period: K线周期分钟 1, 3, 5, 15, 30, 60
:param start_date: 起始日期 'YYYY-MM-DD'可选
:param end_date: 结束日期 'YYYY-MM-DD'可选
:return: DataFrame: timestamp, open, high, low, close
"""
suffix = KLINE_PERIODS.get(period)
if suffix is None:
raise ValueError(f"不支持的周期: {period},可选: {list(KLINE_PERIODS.keys())}")
table_name = f'bitmart_eth_{suffix}'
conn = sqlite3.connect(str(DB_PATH))
query = f"SELECT id as timestamp, open, high, low, close FROM {table_name} ORDER BY id"
df = pd.read_sql_query(query, conn)
conn.close()
if df.empty:
logger.warning(f"[{suffix}] 表中无数据")
return df
# id 是毫秒时间戳,转为 datetime 索引
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
df.set_index('datetime', inplace=True)
# 按日期过滤
if start_date:
df = df[df.index >= start_date]
if end_date:
df = df[df.index <= end_date]
logger.info(f"[{suffix}] 加载 {len(df)} 条K线 | {df.index[0]} ~ {df.index[-1]}")
return df
def load_multi_period(periods: list = None, start_date: str = None, end_date: str = None) -> dict:
"""
加载多个周期的K线数据
:param periods: 周期列表 [5, 15, 60]默认全部
:param start_date: 起始日期
:param end_date: 结束日期
:return: {period: DataFrame} 字典
"""
if periods is None:
periods = list(KLINE_PERIODS.keys())
result = {}
for p in periods:
try:
df = load_kline(p, start_date, end_date)
if not df.empty:
result[p] = df
except Exception as e:
logger.error(f"加载 {p}分钟 K线失败: {e}")
return result
def load_trades(start_date: str = None, end_date: str = None, limit: int = None) -> pd.DataFrame:
"""
加载原始成交记录
:return: DataFrame: id, timestamp, price, volume, side
"""
conn = sqlite3.connect(str(DB_PATH))
query = "SELECT id, timestamp, price, volume, side FROM bitmart_eth_trades ORDER BY timestamp"
df = pd.read_sql_query(query, conn)
conn.close()
if df.empty:
logger.warning("成交记录表中无数据")
return df
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
df.set_index('datetime', inplace=True)
if start_date:
df = df[df.index >= start_date]
if end_date:
df = df[df.index <= end_date]
if limit:
df = df.head(limit)
logger.info(f"加载 {len(df)} 条成交记录")
return df
def get_available_tables() -> list:
"""列出数据库中所有可用的表"""
conn = sqlite3.connect(str(DB_PATH))
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
tables = [row[0] for row in cursor.fetchall()]
conn.close()
return tables
def get_table_stats() -> dict:
"""获取各表的数据统计"""
conn = sqlite3.connect(str(DB_PATH))
tables = get_available_tables()
stats = {}
for table in tables:
try:
count = pd.read_sql_query(f"SELECT COUNT(*) as cnt FROM {table}", conn).iloc[0]['cnt']
stats[table] = count
except Exception:
stats[table] = 0
conn.close()
return stats