""" 数据加载器 — 从SQLite加载K线数据为pandas DataFrame """ import sqlite3 import pandas as pd from loguru import logger from .config import DB_PATH, KLINE_PERIODS def load_kline(period: int = 15, start_date: str = None, end_date: str = None) -> pd.DataFrame: """ 加载指定周期的K线数据 :param period: K线周期(分钟),如 1, 3, 5, 15, 30, 60 :param start_date: 起始日期 'YYYY-MM-DD'(可选) :param end_date: 结束日期 'YYYY-MM-DD'(可选) :return: DataFrame,列: timestamp, open, high, low, close """ suffix = KLINE_PERIODS.get(period) if suffix is None: raise ValueError(f"不支持的周期: {period},可选: {list(KLINE_PERIODS.keys())}") table_name = f'bitmart_eth_{suffix}' conn = sqlite3.connect(str(DB_PATH)) query = f"SELECT id as timestamp, open, high, low, close FROM {table_name} ORDER BY id" df = pd.read_sql_query(query, conn) conn.close() if df.empty: logger.warning(f"[{suffix}] 表中无数据") return df # id 是毫秒时间戳,转为 datetime 索引 df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms') df.set_index('datetime', inplace=True) # 按日期过滤 if start_date: df = df[df.index >= start_date] if end_date: df = df[df.index <= end_date] logger.info(f"[{suffix}] 加载 {len(df)} 条K线 | {df.index[0]} ~ {df.index[-1]}") return df def load_multi_period(periods: list = None, start_date: str = None, end_date: str = None) -> dict: """ 加载多个周期的K线数据 :param periods: 周期列表,如 [5, 15, 60],默认全部 :param start_date: 起始日期 :param end_date: 结束日期 :return: {period: DataFrame} 字典 """ if periods is None: periods = list(KLINE_PERIODS.keys()) result = {} for p in periods: try: df = load_kline(p, start_date, end_date) if not df.empty: result[p] = df except Exception as e: logger.error(f"加载 {p}分钟 K线失败: {e}") return result def load_trades(start_date: str = None, end_date: str = None, limit: int = None) -> pd.DataFrame: """ 加载原始成交记录 :return: DataFrame,列: id, timestamp, price, volume, side """ conn = sqlite3.connect(str(DB_PATH)) query = "SELECT id, timestamp, price, volume, side FROM bitmart_eth_trades ORDER BY timestamp" df = pd.read_sql_query(query, conn) conn.close() if df.empty: logger.warning("成交记录表中无数据") return df df['datetime'] = pd.to_datetime(df['timestamp'], unit='ms') df.set_index('datetime', inplace=True) if start_date: df = df[df.index >= start_date] if end_date: df = df[df.index <= end_date] if limit: df = df.head(limit) logger.info(f"加载 {len(df)} 条成交记录") return df def get_available_tables() -> list: """列出数据库中所有可用的表""" conn = sqlite3.connect(str(DB_PATH)) cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") tables = [row[0] for row in cursor.fetchall()] conn.close() return tables def get_table_stats() -> dict: """获取各表的数据统计""" conn = sqlite3.connect(str(DB_PATH)) tables = get_available_tables() stats = {} for table in tables: try: count = pd.read_sql_query(f"SELECT COUNT(*) as cnt FROM {table}", conn).iloc[0]['cnt'] stats[table] = count except Exception: stats[table] = 0 conn.close() return stats