Files
lm_code/价格展示/test2.py
2025-09-25 18:29:53 +08:00

713 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import numpy as np
import plotly.graph_objects as go
from datetime import datetime, timezone, timedelta
import warnings
import os
import uuid
# ========== 配置 ==========
KLINE_XLSX = "kline_data.xlsx" # K线数据文件名
ORDERS_XLSX = "做市策略.xls" # 订单数据文件名
OUTPUT_HTML = "kline_with_trades.html"
SYMBOL = "ETH-USDT" # 交易对筛选
# 时间与对齐配置
ORDERS_TIME_IS_LOCAL_ASIA_SH = True # 订单时间是否为东八区时间
SNAP_TRADES_TO_NEAREST_CANDLE = True # 对齐交易点到最近的K线时间
SNAP_TOLERANCE_MULTIPLIER = 1.5 # 对齐容忍度倍数
# 图表尺寸配置 - 更宽更扁
CHART_WIDTH = 2200 # 更宽的图表
CHART_HEIGHT = 600 # 更矮的图表
FONT_SIZE = 12 # 字体大小
ANNOTATION_FONT_SIZE = 10 # 标注字体大小
MARKER_SIZE = 10 # 标记大小
LINE_WIDTH = 1.5 # 连接线宽度
# 颜色配置 - 所有文本使用黑色
TEXT_COLOR = "black" # 所有文本使用黑色
TEXT_OFFSET = 10 # 文本偏移量(像素)
# ========== 工具函数 ==========
def parse_numeric(x):
"""高效解析数值类型,支持多种格式"""
if pd.isna(x):
return np.nan
try:
# 尝试直接转换(大多数情况)
return float(x)
except:
# 处理特殊格式
s = str(x).replace(",", "").replace("USDT", "").replace("", "").strip()
if s.endswith("%"):
s = s[:-1]
return float(s) if s else np.nan
def epoch_to_dt(x):
"""将时间戳转换为上海时区时间"""
try:
return pd.to_datetime(int(x), unit="s", utc=True).tz_convert("Asia/Shanghai")
except:
return pd.NaT
def zh_side(row):
"""解析交易方向"""
direction = str(row.get("方向", "")).strip()
if "开多" in direction: return "long_open"
if "平多" in direction: return "long_close"
if "开空" in direction: return "short_open"
if "平空" in direction: return "short_close"
return "unknown"
# ========== 数据加载与预处理 ==========
def load_kline_data():
"""加载并预处理K线数据"""
if not os.path.exists(KLINE_XLSX):
raise FileNotFoundError(f"K线数据文件不存在: {KLINE_XLSX}")
kdf = pd.read_excel(KLINE_XLSX, dtype=str)
kdf.columns = [str(c).strip().lower() for c in kdf.columns]
# 验证必要列
required_cols = {"id", "open", "close", "low", "high"}
missing = required_cols - set(kdf.columns)
if missing:
raise ValueError(f"K线表缺少列: {missing}")
# 时间转换 - 确保id是秒级时间戳
kdf["time"] = kdf["id"].apply(epoch_to_dt)
# 数值转换(向量化操作提升性能)
for col in ["open", "close", "low", "high"]:
kdf[col] = pd.to_numeric(kdf[col].apply(parse_numeric), errors="coerce")
# 清理无效数据
kdf = kdf.dropna(subset=["time", "open", "close", "low", "high"])
kdf = kdf.sort_values("time").reset_index(drop=True)
# 计算K线周期用于交易点对齐
if len(kdf) >= 3:
median_step = kdf["time"].diff().median()
else:
median_step = pd.Timedelta(minutes=1)
return kdf, median_step
def load_order_data():
"""加载并预处理订单数据"""
if not os.path.exists(ORDERS_XLSX):
raise FileNotFoundError(f"订单数据文件不存在: {ORDERS_XLSX}")
odf = pd.read_excel(ORDERS_XLSX, dtype=str)
# 验证必要列
need_order_cols = ["时间", "交易对", "方向", "模式", "数量(张)", "成交价", "交易额", "消耗手续费", "用户盈亏"]
missing = set(need_order_cols) - set(odf.columns)
if missing:
raise ValueError(f"订单表缺少列: {missing}")
# 筛选交易对
if SYMBOL and "交易对" in odf.columns:
odf = odf[odf["交易对"].astype(str).str.strip() == SYMBOL]
# 时间处理 - 确保时间格式正确
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if ORDERS_TIME_IS_LOCAL_ASIA_SH:
# 尝试多种格式解析时间
odf["时间"] = pd.to_datetime(odf["时间"], errors="coerce", format="mixed")
# 本地化为上海时区
odf["时间"] = odf["时间"].dt.tz_localize("Asia/Shanghai", ambiguous="NaT", nonexistent="shift_forward")
else:
# 如果Excel时间已经是UTC
odf["时间"] = pd.to_datetime(odf["时间"], utc=True, errors="coerce").dt.tz_convert("Asia/Shanghai")
# 数值转换
numeric_cols = {
"数量(张)": "数量",
"成交价": "价格",
"交易额": "交易额_num",
"消耗手续费": "手续费",
"用户盈亏": "盈亏"
}
for src, dest in numeric_cols.items():
odf[dest] = pd.to_numeric(odf[src].apply(parse_numeric), errors="coerce")
# 解析交易方向
odf["side"] = odf.apply(zh_side, axis=1)
# 为每个订单生成唯一ID
odf["order_id"] = [str(uuid.uuid4()) for _ in range(len(odf))]
# 计算本金(数量 * 价格)
odf["本金"] = odf["数量"] * odf["价格"]
# 清理无效数据
odf = odf.dropna(subset=["时间", "价格"])
odf = odf.sort_values("时间").reset_index(drop=True)
return odf
def align_trades_to_candles(kdf, odf, median_step):
"""将交易点对齐到最近的K线时间"""
if not SNAP_TRADES_TO_NEAREST_CANDLE or kdf.empty or odf.empty:
return odf.assign(时间_x=odf["时间"])
snap_tolerance = pd.Timedelta(seconds=max(1, int(median_step.total_seconds() * SNAP_TOLERANCE_MULTIPLIER)))
# 使用merge_asof高效对齐 - 使用方向为'backward'确保交易点对齐到前一个K线
anchor = kdf[["time"]].copy().rename(columns={"time": "k_time"})
odf_sorted = odf.sort_values("时间")
# 关键优化:使用'backward'方向确保交易点对齐到前一个K线
aligned = pd.merge_asof(
odf_sorted,
anchor,
left_on="时间",
right_on="k_time",
direction="backward", # 使用'backward'确保交易点对齐到前一个K线
tolerance=snap_tolerance
)
# 保留原始时间作为参考
aligned["原始时间"] = aligned["时间"]
aligned["时间_x"] = aligned["k_time"].fillna(aligned["时间"])
return aligned
# ========== 持仓跟踪与盈亏计算 ==========
class PositionTracker:
"""FIFO持仓跟踪器支持订单走向可视化"""
def __init__(self):
self.long_lots = [] # (数量, 价格, 时间, 手续费, 订单ID)
self.short_lots = [] # (数量, 价格, 时间, 手续费, 订单ID)
self.realized_pnl = 0.0
self.history = [] # 记录所有交易历史
self.trade_connections = [] # 记录开平仓连接关系
def open_long(self, qty, price, time, fee, order_id):
"""开多仓"""
if qty > 1e-9:
self.long_lots.append((qty, price, time, fee, order_id))
def close_long(self, qty, price, time, fee, order_id):
"""平多仓"""
remaining = qty
local_pnl = 0.0
connections = [] # 本次平仓的连接关系
while remaining > 1e-9 and self.long_lots:
lot_qty, lot_price, lot_time, lot_fee, open_order_id = self.long_lots[0]
take = min(lot_qty, remaining)
pnl = (price - lot_price) * take
local_pnl += pnl
lot_qty -= take
remaining -= take
# 记录开平仓连接
connection = {
"open_time": lot_time,
"close_time": time,
"open_price": lot_price,
"close_price": price,
"qty": take,
"pnl": pnl,
"type": "long",
"open_order_id": open_order_id,
"close_order_id": order_id
}
self.trade_connections.append(connection)
connections.append(connection)
# 记录平仓详情
self.history.append({
"开仓时间": lot_time,
"平仓时间": time,
"数量": take,
"开仓价": lot_price,
"平仓价": price,
"盈亏": pnl,
"类型": "平多",
"开仓订单ID": open_order_id,
"平仓订单ID": order_id
})
if lot_qty <= 1e-9:
self.long_lots.pop(0)
else:
self.long_lots[0] = (lot_qty, lot_price, lot_time, lot_fee, open_order_id)
local_pnl -= fee
self.realized_pnl += local_pnl
return local_pnl, connections
def open_short(self, qty, price, time, fee, order_id):
"""开空仓"""
if qty > 1e-9:
self.short_lots.append((qty, price, time, fee, order_id))
def close_short(self, qty, price, time, fee, order_id):
"""平空仓"""
remaining = qty
local_pnl = 0.0
connections = [] # 本次平仓的连接关系
while remaining > 1e-9 and self.short_lots:
lot_qty, lot_price, lot_time, lot_fee, open_order_id = self.short_lots[0]
take = min(lot_qty, remaining)
pnl = (lot_price - price) * take
local_pnl += pnl
lot_qty -= take
remaining -= take
# 记录开平仓连接
connection = {
"open_time": lot_time,
"close_time": time,
"open_price": lot_price,
"close_price": price,
"q极": take,
"pnl": pnl,
"type": "short",
"open_order_id": open_order_id,
"close_order_id": order_id
}
self.trade_connections.append(connection)
connections.append(connection)
# 记录平仓详情
self.history.append({
"开仓时间": lot_time,
"平仓时间": time,
"数量": take,
"开仓价": lot_price,
"平仓价": price,
"盈亏": pnl,
"类型": "平空",
"开仓订单ID": open_order_id,
"平仓订单ID": order_id
})
if lot_qty <= 1e-9:
self.short_lots.pop(0)
else:
self.short_lots[0] = (lot_qty, lot_price, lot_time, lot_fee, open_order_id)
local_pnl -= fee
self.realized_pnl += local_pnl
return local_pnl, connections
def calculate_pnl(odf):
"""计算持仓盈亏和订单连接关系"""
tracker = PositionTracker()
all_connections = []
for idx, r in odf.iterrows():
qty = r["数量"]
price = r["价格"]
ts = r["时间"]
fee = r["手续费"]
side = r["side"]
order_id = r["order_id"]
if side == "long_open":
tracker.open_long(qty, price, ts, fee, order_id)
elif side == "long_close":
_, connections = tracker.close_long(qty, price, ts, fee, order_id)
all_connections.extend(connections)
elif side == "short_open":
tracker.open_short(qty, price, ts, fee, order_id)
elif side == "short_close":
_, connections = tracker.close_short(qty, price, ts, fee, order_id)
all_connections.extend(connections)
# 创建盈亏DataFrame
if tracker.history:
pnl_df = pd.DataFrame(tracker.history)
# 添加对齐后的时间
pnl_df["时间_x"] = pnl_df["平仓时间"].apply(
lambda x: odf.loc[odf["时间"] == x, "时间_x"].values[0] if not odf.empty else x
)
else:
pnl_df = pd.DataFrame()
# 创建连接关系DataFrame
connections_df = pd.DataFrame(all_connections) if all_connections else pd.DataFrame()
return pnl_df, tracker.realized_pnl, connections_df
# ========== 可视化 ==========
def create_trade_scatter(df, name, color, symbol):
"""创建交易点散点图"""
if df.empty:
return None
# 为不同类型的交易点创建不同的文本标签
if name == "开多":
text = "开多\n" + df["价格"].apply(lambda x: f"{x:.2f}") + "\n" + df["本金"].apply(lambda x: f"{x:.0f}")
elif name == "平多":
text = "平多\n" + df["价格"].apply(lambda x: f"{x:.2f}") + "\n" + df["盈亏"].apply(lambda x: f"{x:.0f}")
elif name == "开空":
text = "开空\n" + df["价格"].apply(lambda x: f"{x:.2f}") + "\n" + df["本金"].apply(lambda x: f"{x:.0f}")
elif name == "平空":
text = "平空\n" + df["价格"].apply(lambda x: f"{x:.2f}") + "\n" + df["盈亏"].apply(lambda x: f"{x:.0f}")
else:
text = name
return go.Scatter(
x=df["时间_x"],
y=df["价格"],
mode="markers+text",
name=name,
text=text,
textposition="middle right", # 文本放在右侧中间位置
textfont=dict(size=ANNOTATION_FONT_SIZE, color=TEXT_COLOR), # 使用黑色文本
marker=dict(
size=MARKER_SIZE,
color=color,
symbol=symbol,
line=dict(width=1.5, color="black")
),
customdata=np.stack([
df["数量"].to_numpy(),
df["价格"].to_numpy(),
df["手续费"].to_numpy(),
df.get("盈亏", np.nan).to_numpy(),
df.get("原始时间", df["时间"]).dt.strftime("%Y-%m-%d %H:%M:%S").to_numpy(),
df["order_id"].to_numpy(),
df["本金"].to_numpy()
], axis=-1),
hovertemplate=(
f"<b>{name}</b><br>"
"数量: %{customdata[0]:.0f}张<br>"
"价格: %{customdata[1]:.2f}<br>"
"手续费: %{customdata[2]:.6f}<br>"
"盈亏: %{customdata[3]:.4f}<br>"
"本金: %{customdata[6]:.0f}<br>"
"时间: %{customdata[4]}<br>"
"订单ID: %{customdata[5]}<extra></extra>"
)
)
def add_trade_connections(fig, connections_df, odf):
"""添加开平仓连接线"""
if connections_df.empty:
return
# 为盈利和亏损的连接线分别创建轨迹
profit_lines = []
loss_lines = []
for _, conn in connections_df.iterrows():
# 获取开仓点和平仓点的坐标
open_point = odf[odf["order_id"] == conn["open_order_id"]].iloc[0]
close_point = odf[odf["order_id"] == conn["close_order_id"]].iloc[0]
line_data = {
"x": [open_point["时间_x"], close_point["时间_x"]],
"y": [open_point["价格"], close_point["价格"]],
"pnl": conn["pnl"],
"type": conn["type"],
"open_order_id": conn["open_order_id"],
"close_order_id": conn["close_order_id"]
}
if conn["pnl"] >= 0:
profit_lines.append(line_data)
else:
loss_lines.append(line_data)
# 添加盈利连接线(绿色)
if profit_lines:
x_profit = []
y_profit = []
customdata_profit = []
for line in profit_lines:
x_profit.extend(line["x"])
y_profit.extend(line["y"])
x_profit.append(None)
y_profit.append(None)
# 为每个点添加自定义数据
customdata_profit.append([
line["open_order_id"],
line["close_order_id"],
line["pnl"],
line["type"]
])
customdata_profit.append([
line["open_order_id"],
line["close_order_id"],
line["pnl"],
line["type"]
])
customdata_profit.append(None)
fig.add_trace(go.Scatter(
x=x_profit,
y=y_profit,
mode="lines",
name="盈利订单",
line=dict(color="rgba(46, 204, 113, 0.7)", width=LINE_WIDTH),
hoverinfo="text",
text=[f"盈利: {d[2]:.2f}" if d else None for d in customdata_profit],
customdata=customdata_profit,
hovertemplate=(
"<b>%{text}</b><br>"
"类型: %{customdata[3]}<br>"
"开仓订单ID: %{customdata[0]}<br>"
"平仓订单ID: %{customdata[1]}<extra></extra>"
)
))
# 添加亏损连接线(红色)
if loss_lines:
x_loss = []
y_loss = []
customdata_loss = []
for line in loss_lines:
x_loss.extend(line["x"])
y_loss.extend(line["y"])
x_loss.append(None)
y_loss.append(None)
# 为每个点添加自定义数据
customdata_loss.append([
line["open_order_id"],
line["close_order_id"],
line["pnl"],
line["type"]
])
customdata_loss.append([
line["open_order_id"],
line["close_order_id"],
line["pnl"],
line["type"]
])
customdata_loss.append(None)
fig.add_trace(go.Scatter(
x=x_loss,
y=y_loss,
mode="lines",
name="亏损订单",
line=dict(color="rgba(231, 76, 60, 0.7)", width=LINE_WIDTH),
hoverinfo="text",
text=[f"亏损: {abs(d[2]):.2f}" if d else None for d in customdata_loss],
customdata=customdata_loss,
hovertemplate=(
"<b>%{text}</b><br>"
"类型: %{customdata[3]}<br>"
"开仓订单ID: %{customdata[0]}<极>"
"平仓订单ID: %{customdata[1]}<extra></extra>"
)
))
def generate_chart(kdf, odf, pnl_df, cum_realized, connections_df):
"""生成K线图与交易标注"""
fig = go.Figure()
# K线主图
fig.add_trace(go.Candlestick(
x=kdf["time"],
open=kdf["open"],
high=kdf["high"],
low=kdf["low"],
close=kdf["close"],
name="K线",
increasing_line_color="#2ecc71",
decreasing_line_color="#e74c3c"
))
# 添加交易点
trade_types = [
(odf[odf["side"] == "long_open"], "开多", "#2ecc71", "triangle-up"),
(odf[odf["side"] == "long_close"], "平多", "#27ae60", "circle"),
(odf[odf["side"] == "short_open"], "开空", "#e74c3c", "triangle-down"),
(odf[odf["side"] == "short_close"], "平空", "#c0392b", "x")
]
for data, name, color, symbol in trade_types:
trace = create_trade_scatter(data, name, color, symbol)
if trace:
fig.add_trace(trace)
# 添加开平仓连接线
add_trade_connections(fig, connections_df, odf)
# 计算时间范围,确保所有点都显示在图表中
all_times = pd.concat([kdf["time"], odf["时间_x"]])
min_time = all_times.min() - pd.Timedelta(minutes=10)
max_time = all_times.max() + pd.Timedelta(minutes=10)
# 计算价格范围,确保所有点都显示在图表中
min_price = min(kdf["low"].min(), odf["价格"].min()) * 0.99
max_price = max(kdf["high"].max(), odf["价格"].max()) * 1.01
# 布局配置 - 更宽更扁的图表
fig.update_layout(
xaxis_title="时间",
yaxis_title="价格 (USDT)",
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="left",
x=0,
font=dict(size=FONT_SIZE, color=TEXT_COLOR) # 图例文字使用黑色
),
xaxis=dict(
rangeslider=dict(visible=False),
type="date",
gridcolor="rgba(128, 128, 128, 0.2)",
range=[min_time, max_time], # 设置时间范围
title_font=dict(color=TEXT_COLOR), # 坐标轴标题使用黑色
tickfont=dict(color=TEXT_COLOR) # 刻度标签使用黑色
),
yaxis=dict(
gridcolor="rgba(128, 128, 128, 0.2)",
range=[min_price, max_price], # 设置价格范围
title_font=dict(color=TEXT_COLOR), # 坐标轴标题使用黑色
tickfont=dict(color=TEXT_COLOR) # 刻度标签使用黑色
),
hovermode="x unified",
hoverlabel=dict(
namelength=-1,
bgcolor="rgba(255, 255, 255, 0.9)",
font_size=FONT_SIZE,
font_color=TEXT_COLOR # 悬停标签文字使用黑色
),
margin=dict(l=50, r=50, t=80, b=50),
plot_bgcolor="rgba(240, 240, 240, 1)",
width=CHART_WIDTH, # 使用配置的宽度
height=CHART_HEIGHT, # 使用配置的高度
font=dict(size=FONT_SIZE, color=TEXT_COLOR), # 全局字体大小和颜色
# 增强交互性配置
dragmode="pan", # 默认拖拽模式为平移
clickmode="event+select", # 点击模式
selectdirection="h", # 水平选择方向
modebar=dict(
orientation="h", # 水平方向工具栏
bgcolor="rgba(255, 255, 255, 0.7)", # 半透明背景
color="rgba(0, 0, 0, 0.7)", # 图标颜色
activecolor="rgba(0, 0, 0, 0.9)" # 激活图标颜色
)
)
# 添加模式栏按钮
fig.update_layout(
modebar_add=[
"zoom2d",
"pan2d",
"select2d",
"lasso2d",
"zoomIn2d",
"zoomOut2d",
"autoScale2d",
"resetScale2d",
"toImage"
]
)
# 配置缩放行为 - 确保滚轮缩放正常工作
fig.update_xaxes(
autorange=False,
fixedrange=False, # 允许缩放
constrain="domain", # 约束在域内
rangeslider=dict(visible=False) # 禁用范围滑块
)
fig.update_yaxes(
autorange=False,
fixedrange=False, # 允许缩放
scaleanchor="x", # 保持纵横比
scaleratio=1, # 缩放比例
constrain="domain" # 约束在域内
)
# 保存并打开结果 - 启用滚轮缩放
fig.write_html(
OUTPUT_HTML,
include_plotlyjs="cdn",
auto_open=True,
config={
'scrollZoom': True, # 启用滚轮缩放
'displayModeBar': True, # 显示工具栏
'displaylogo': False, # 隐藏Plotly标志
'responsive': True # 响应式布局
}
)
print(f"图表已生成: {OUTPUT_HTML}")
# 返回盈亏详情
if not pnl_df.empty:
pnl_df.to_csv("pnl_details.csv", index=False)
print(f"盈亏详情已保存: pnl_details.csv")
if not connections_df.empty:
connections_df.to_csv("trade_connections.csv", index=False)
print(f"订单连接关系已保存: trade_connections.csv")
return fig
# ========== 主执行流程 ==========
def main():
print("开始处理数据...")
# 加载数据
kdf, median_step = load_kline_data()
odf = load_order_data()
print(f"加载K线数据: {len(kdf)}")
print(f"加载订单数据: {len(odf)}")
# 对齐交易时间
odf = align_trades_to_candles(kdf, odf, median_step)
# 检查时间范围
kline_min_time = kdf["time"].min()
kline_max_time = kdf["time"].max()
order_min_time = odf["时间"].min()
order_max_time = odf["时间"].max()
print(f"K线时间范围: {kline_min_time}{kline_max_time}")
print(f"订单时间范围: {order_min_time}{order_max_time}")
# 检查是否有订单在K线时间范围外
outside_orders = odf[(odf["时间"] < kline_min_time) | (odf["时间"] > kline_max_time)]
if not outside_orders.empty:
print(f"警告: 有 {len(outside_orders)} 个订单在K线时间范围外")
print(outside_orders[["时间", "方向", "价格"]])
# 计算盈亏和订单连接关系
pnl_df, cum_realized, connections_df = calculate_pnl(odf)
print(f"累计已实现盈亏: {cum_realized:.2f} USDT")
print(f"订单连接关系: {len(connections_df)}")
# 生成图表
generate_chart(kdf, odf, pnl_df, cum_realized, connections_df)
print("处理完成")
if __name__ == "__main__":
with warnings.catch_warnings():
warnings.simplefilter("ignore")
main()