120 lines
3.6 KiB
Python
120 lines
3.6 KiB
Python
"""
|
||
数据加载器 — 从SQLite加载K线数据为pandas DataFrame
|
||
"""
|
||
import sqlite3
|
||
import pandas as pd
|
||
from loguru import logger
|
||
from .config import DB_PATH, KLINE_PERIODS
|
||
|
||
|
||
def load_kline(period: int = 15, start_date: str = None, end_date: str = None) -> pd.DataFrame:
|
||
"""
|
||
加载指定周期的K线数据
|
||
:param period: K线周期(分钟),如 1, 3, 5, 15, 30, 60
|
||
:param start_date: 起始日期 'YYYY-MM-DD'(可选)
|
||
:param end_date: 结束日期 'YYYY-MM-DD'(可选)
|
||
:return: DataFrame,列: timestamp, open, high, low, close
|
||
"""
|
||
suffix = KLINE_PERIODS.get(period)
|
||
if suffix is None:
|
||
raise ValueError(f"不支持的周期: {period},可选: {list(KLINE_PERIODS.keys())}")
|
||
|
||
table_name = f'bitmart_eth_{suffix}'
|
||
conn = sqlite3.connect(str(DB_PATH))
|
||
|
||
query = f"SELECT id as timestamp, open, high, low, close FROM {table_name} ORDER BY id"
|
||
df = pd.read_sql_query(query, conn)
|
||
conn.close()
|
||
|
||
if df.empty:
|
||
logger.warning(f"[{suffix}] 表中无数据")
|
||
return df
|
||
|
||
# id 是毫秒时间戳,转为 datetime 索引
|
||
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||
df.set_index('datetime', inplace=True)
|
||
|
||
# 按日期过滤
|
||
if start_date:
|
||
df = df[df.index >= start_date]
|
||
if end_date:
|
||
df = df[df.index <= end_date]
|
||
|
||
logger.info(f"[{suffix}] 加载 {len(df)} 条K线 | {df.index[0]} ~ {df.index[-1]}")
|
||
return df
|
||
|
||
|
||
def load_multi_period(periods: list = None, start_date: str = None, end_date: str = None) -> dict:
|
||
"""
|
||
加载多个周期的K线数据
|
||
:param periods: 周期列表,如 [5, 15, 60],默认全部
|
||
:param start_date: 起始日期
|
||
:param end_date: 结束日期
|
||
:return: {period: DataFrame} 字典
|
||
"""
|
||
if periods is None:
|
||
periods = list(KLINE_PERIODS.keys())
|
||
|
||
result = {}
|
||
for p in periods:
|
||
try:
|
||
df = load_kline(p, start_date, end_date)
|
||
if not df.empty:
|
||
result[p] = df
|
||
except Exception as e:
|
||
logger.error(f"加载 {p}分钟 K线失败: {e}")
|
||
|
||
return result
|
||
|
||
|
||
def load_trades(start_date: str = None, end_date: str = None, limit: int = None) -> pd.DataFrame:
|
||
"""
|
||
加载原始成交记录
|
||
:return: DataFrame,列: id, timestamp, price, volume, side
|
||
"""
|
||
conn = sqlite3.connect(str(DB_PATH))
|
||
query = "SELECT id, timestamp, price, volume, side FROM bitmart_eth_trades ORDER BY timestamp"
|
||
df = pd.read_sql_query(query, conn)
|
||
conn.close()
|
||
|
||
if df.empty:
|
||
logger.warning("成交记录表中无数据")
|
||
return df
|
||
|
||
df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||
df.set_index('datetime', inplace=True)
|
||
|
||
if start_date:
|
||
df = df[df.index >= start_date]
|
||
if end_date:
|
||
df = df[df.index <= end_date]
|
||
if limit:
|
||
df = df.head(limit)
|
||
|
||
logger.info(f"加载 {len(df)} 条成交记录")
|
||
return df
|
||
|
||
|
||
def get_available_tables() -> list:
|
||
"""列出数据库中所有可用的表"""
|
||
conn = sqlite3.connect(str(DB_PATH))
|
||
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
|
||
tables = [row[0] for row in cursor.fetchall()]
|
||
conn.close()
|
||
return tables
|
||
|
||
|
||
def get_table_stats() -> dict:
|
||
"""获取各表的数据统计"""
|
||
conn = sqlite3.connect(str(DB_PATH))
|
||
tables = get_available_tables()
|
||
stats = {}
|
||
for table in tables:
|
||
try:
|
||
count = pd.read_sql_query(f"SELECT COUNT(*) as cnt FROM {table}", conn).iloc[0]['cnt']
|
||
stats[table] = count
|
||
except Exception:
|
||
stats[table] = 0
|
||
conn.close()
|
||
return stats
|