148 lines
5.1 KiB
Python
148 lines
5.1 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
隧道客户端:运行在每台线下电脑,与 Worker 同进程集成。
|
||
连接隧道服务端,注册 worker_id,收到 open_stream 时连接本地端口并桥接到服务端流端口。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
from typing import Optional
|
||
|
||
from tunnel.protocol import decode_control, encode_control, register_message
|
||
|
||
logger = logging.getLogger("tunnel.client")
|
||
|
||
|
||
class TunnelClient:
|
||
def __init__(
|
||
self,
|
||
server_host: str,
|
||
control_port: int,
|
||
stream_port: int,
|
||
worker_id: str,
|
||
local_port: int,
|
||
) -> None:
|
||
self.server_host = server_host
|
||
self.control_port = control_port
|
||
self.stream_port = stream_port
|
||
self.worker_id = worker_id
|
||
self.local_port = local_port
|
||
self._running = False
|
||
self._reconnect_delay = 5
|
||
self._reconnect_max_delay = 60
|
||
|
||
async def run(self) -> None:
|
||
self._running = True
|
||
while self._running:
|
||
try:
|
||
await self._connect_and_loop()
|
||
except Exception as e:
|
||
logger.error("隧道连接异常: %s", e)
|
||
if self._running:
|
||
logger.info("隧道将在 %d 秒后重连...", self._reconnect_delay)
|
||
await asyncio.sleep(self._reconnect_delay)
|
||
self._reconnect_delay = min(
|
||
self._reconnect_delay * 2,
|
||
self._reconnect_max_delay,
|
||
)
|
||
|
||
async def stop(self) -> None:
|
||
self._running = False
|
||
|
||
async def _connect_and_loop(self) -> None:
|
||
reader, writer = await asyncio.open_connection(
|
||
self.server_host, self.control_port
|
||
)
|
||
self._reconnect_delay = 5
|
||
logger.info("隧道控制连接已建立: %s:%s", self.server_host, self.control_port)
|
||
|
||
msg = register_message(self.worker_id, self.local_port)
|
||
writer.write(encode_control(msg))
|
||
await writer.drain()
|
||
|
||
line = await asyncio.wait_for(reader.readline(), timeout=15)
|
||
if not line:
|
||
raise RuntimeError("隧道服务端未响应")
|
||
ack = decode_control(line)
|
||
if ack.get("type") == "error":
|
||
raise RuntimeError("隧道注册失败: %s" % ack.get("detail", ack))
|
||
if ack.get("type") != "register_ack":
|
||
raise RuntimeError("隧道注册异常: %s" % ack)
|
||
proxy_port = ack.get("proxy_port")
|
||
logger.info("隧道已注册: worker_id=%s, 代理端口 %s:%s", self.worker_id, self.server_host, proxy_port)
|
||
|
||
try:
|
||
while self._running:
|
||
line = await asyncio.wait_for(reader.readline(), timeout=60)
|
||
if not line:
|
||
break
|
||
data = decode_control(line)
|
||
if data.get("type") == "open_stream":
|
||
stream_id = int(data.get("stream_id", 0))
|
||
local_port = int(data.get("local_port", self.local_port))
|
||
asyncio.create_task(
|
||
self._bridge_stream(stream_id, local_port)
|
||
)
|
||
except asyncio.TimeoutError:
|
||
pass
|
||
finally:
|
||
writer.close()
|
||
await writer.wait_closed()
|
||
|
||
async def _bridge_stream(self, stream_id: int, local_port: int) -> None:
|
||
"""连接本地 local_port,再连服务端 stream 端口,发送 stream_id 后双向桥接。"""
|
||
try:
|
||
local_reader, local_writer = await asyncio.open_connection(
|
||
"127.0.0.1", local_port
|
||
)
|
||
except Exception as e:
|
||
logger.warning("隧道打开本地 %s 失败: %s", local_port, e)
|
||
return
|
||
|
||
try:
|
||
stream_reader, stream_writer = await asyncio.open_connection(
|
||
self.server_host, self.stream_port
|
||
)
|
||
except Exception as e:
|
||
logger.warning("隧道连流端口失败: %s", e)
|
||
local_writer.close()
|
||
await local_writer.wait_closed()
|
||
return
|
||
|
||
try:
|
||
stream_writer.write(stream_id.to_bytes(4, "big"))
|
||
await stream_writer.drain()
|
||
except Exception as e:
|
||
logger.warning("隧道发送 stream_id 失败: %s", e)
|
||
stream_writer.close()
|
||
await stream_writer.wait_closed()
|
||
local_writer.close()
|
||
await local_writer.wait_closed()
|
||
return
|
||
|
||
async def pipe(
|
||
src: asyncio.StreamReader, dst: asyncio.StreamWriter
|
||
) -> None:
|
||
try:
|
||
while True:
|
||
data = await src.read(65536)
|
||
if not data:
|
||
break
|
||
dst.write(data)
|
||
await dst.drain()
|
||
except (ConnectionResetError, BrokenPipeError, asyncio.CancelledError):
|
||
pass
|
||
finally:
|
||
try:
|
||
dst.close()
|
||
await dst.wait_closed()
|
||
except Exception:
|
||
pass
|
||
|
||
await asyncio.gather(
|
||
pipe(local_reader, stream_writer),
|
||
pipe(stream_reader, local_writer),
|
||
)
|
||
logger.debug("隧道流 %s 已关闭", stream_id)
|