Files
boss_dp/tunnel/client.py
ddrwode 4a520306b1 ha'ha
2026-02-12 16:46:01 +08:00

148 lines
5.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
隧道客户端:运行在每台线下电脑,与 Worker 同进程集成。
连接隧道服务端,注册 worker_id收到 open_stream 时连接本地端口并桥接到服务端流端口。
"""
from __future__ import annotations
import asyncio
import logging
from typing import Optional
from tunnel.protocol import decode_control, encode_control, register_message
logger = logging.getLogger("tunnel.client")
class TunnelClient:
def __init__(
self,
server_host: str,
control_port: int,
stream_port: int,
worker_id: str,
local_port: int,
) -> None:
self.server_host = server_host
self.control_port = control_port
self.stream_port = stream_port
self.worker_id = worker_id
self.local_port = local_port
self._running = False
self._reconnect_delay = 5
self._reconnect_max_delay = 60
async def run(self) -> None:
self._running = True
while self._running:
try:
await self._connect_and_loop()
except Exception as e:
logger.error("隧道连接异常: %s", e)
if self._running:
logger.info("隧道将在 %d 秒后重连...", self._reconnect_delay)
await asyncio.sleep(self._reconnect_delay)
self._reconnect_delay = min(
self._reconnect_delay * 2,
self._reconnect_max_delay,
)
async def stop(self) -> None:
self._running = False
async def _connect_and_loop(self) -> None:
reader, writer = await asyncio.open_connection(
self.server_host, self.control_port
)
self._reconnect_delay = 5
logger.info("隧道控制连接已建立: %s:%s", self.server_host, self.control_port)
msg = register_message(self.worker_id, self.local_port)
writer.write(encode_control(msg))
await writer.drain()
line = await asyncio.wait_for(reader.readline(), timeout=15)
if not line:
raise RuntimeError("隧道服务端未响应")
ack = decode_control(line)
if ack.get("type") == "error":
raise RuntimeError("隧道注册失败: %s" % ack.get("detail", ack))
if ack.get("type") != "register_ack":
raise RuntimeError("隧道注册异常: %s" % ack)
proxy_port = ack.get("proxy_port")
logger.info("隧道已注册: worker_id=%s, 代理端口 %s:%s", self.worker_id, self.server_host, proxy_port)
try:
while self._running:
line = await asyncio.wait_for(reader.readline(), timeout=60)
if not line:
break
data = decode_control(line)
if data.get("type") == "open_stream":
stream_id = int(data.get("stream_id", 0))
local_port = int(data.get("local_port", self.local_port))
asyncio.create_task(
self._bridge_stream(stream_id, local_port)
)
except asyncio.TimeoutError:
pass
finally:
writer.close()
await writer.wait_closed()
async def _bridge_stream(self, stream_id: int, local_port: int) -> None:
"""连接本地 local_port再连服务端 stream 端口,发送 stream_id 后双向桥接。"""
try:
local_reader, local_writer = await asyncio.open_connection(
"127.0.0.1", local_port
)
except Exception as e:
logger.warning("隧道打开本地 %s 失败: %s", local_port, e)
return
try:
stream_reader, stream_writer = await asyncio.open_connection(
self.server_host, self.stream_port
)
except Exception as e:
logger.warning("隧道连流端口失败: %s", e)
local_writer.close()
await local_writer.wait_closed()
return
try:
stream_writer.write(stream_id.to_bytes(4, "big"))
await stream_writer.drain()
except Exception as e:
logger.warning("隧道发送 stream_id 失败: %s", e)
stream_writer.close()
await stream_writer.wait_closed()
local_writer.close()
await local_writer.wait_closed()
return
async def pipe(
src: asyncio.StreamReader, dst: asyncio.StreamWriter
) -> None:
try:
while True:
data = await src.read(65536)
if not data:
break
dst.write(data)
await dst.drain()
except (ConnectionResetError, BrokenPipeError, asyncio.CancelledError):
pass
finally:
try:
dst.close()
await dst.wait_closed()
except Exception:
pass
await asyncio.gather(
pipe(local_reader, stream_writer),
pipe(stream_reader, local_writer),
)
logger.debug("隧道流 %s 已关闭", stream_id)