This commit is contained in:
ddrwode
2026-02-12 16:46:01 +08:00
parent 2e9ae4c7d7
commit 4a520306b1
9 changed files with 571 additions and 4 deletions

6
tunnel/__init__.py Normal file
View File

@@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
"""
独立 Python 隧道模块:内网穿透,与 Worker 集成。
- 隧道服务端运行在云服务器,接受隧道客户端连接,并为每个 Worker 分配代理端口。
- 隧道客户端运行在每台线下电脑(与 Worker 同进程),将本地端口暴露到云服务器。
"""

147
tunnel/client.py Normal file
View File

@@ -0,0 +1,147 @@
# -*- coding: utf-8 -*-
"""
隧道客户端:运行在每台线下电脑,与 Worker 同进程集成。
连接隧道服务端,注册 worker_id收到 open_stream 时连接本地端口并桥接到服务端流端口。
"""
from __future__ import annotations
import asyncio
import logging
from typing import Optional
from tunnel.protocol import decode_control, encode_control, register_message
logger = logging.getLogger("tunnel.client")
class TunnelClient:
def __init__(
self,
server_host: str,
control_port: int,
stream_port: int,
worker_id: str,
local_port: int,
) -> None:
self.server_host = server_host
self.control_port = control_port
self.stream_port = stream_port
self.worker_id = worker_id
self.local_port = local_port
self._running = False
self._reconnect_delay = 5
self._reconnect_max_delay = 60
async def run(self) -> None:
self._running = True
while self._running:
try:
await self._connect_and_loop()
except Exception as e:
logger.error("隧道连接异常: %s", e)
if self._running:
logger.info("隧道将在 %d 秒后重连...", self._reconnect_delay)
await asyncio.sleep(self._reconnect_delay)
self._reconnect_delay = min(
self._reconnect_delay * 2,
self._reconnect_max_delay,
)
async def stop(self) -> None:
self._running = False
async def _connect_and_loop(self) -> None:
reader, writer = await asyncio.open_connection(
self.server_host, self.control_port
)
self._reconnect_delay = 5
logger.info("隧道控制连接已建立: %s:%s", self.server_host, self.control_port)
msg = register_message(self.worker_id, self.local_port)
writer.write(encode_control(msg))
await writer.drain()
line = await asyncio.wait_for(reader.readline(), timeout=15)
if not line:
raise RuntimeError("隧道服务端未响应")
ack = decode_control(line)
if ack.get("type") == "error":
raise RuntimeError("隧道注册失败: %s" % ack.get("detail", ack))
if ack.get("type") != "register_ack":
raise RuntimeError("隧道注册异常: %s" % ack)
proxy_port = ack.get("proxy_port")
logger.info("隧道已注册: worker_id=%s, 代理端口 %s:%s", self.worker_id, self.server_host, proxy_port)
try:
while self._running:
line = await asyncio.wait_for(reader.readline(), timeout=60)
if not line:
break
data = decode_control(line)
if data.get("type") == "open_stream":
stream_id = int(data.get("stream_id", 0))
local_port = int(data.get("local_port", self.local_port))
asyncio.create_task(
self._bridge_stream(stream_id, local_port)
)
except asyncio.TimeoutError:
pass
finally:
writer.close()
await writer.wait_closed()
async def _bridge_stream(self, stream_id: int, local_port: int) -> None:
"""连接本地 local_port再连服务端 stream 端口,发送 stream_id 后双向桥接。"""
try:
local_reader, local_writer = await asyncio.open_connection(
"127.0.0.1", local_port
)
except Exception as e:
logger.warning("隧道打开本地 %s 失败: %s", local_port, e)
return
try:
stream_reader, stream_writer = await asyncio.open_connection(
self.server_host, self.stream_port
)
except Exception as e:
logger.warning("隧道连流端口失败: %s", e)
local_writer.close()
await local_writer.wait_closed()
return
try:
stream_writer.write(stream_id.to_bytes(4, "big"))
await stream_writer.drain()
except Exception as e:
logger.warning("隧道发送 stream_id 失败: %s", e)
stream_writer.close()
await stream_writer.wait_closed()
local_writer.close()
await local_writer.wait_closed()
return
async def pipe(
src: asyncio.StreamReader, dst: asyncio.StreamWriter
) -> None:
try:
while True:
data = await src.read(65536)
if not data:
break
dst.write(data)
await dst.drain()
except (ConnectionResetError, BrokenPipeError, asyncio.CancelledError):
pass
finally:
try:
dst.close()
await dst.wait_closed()
except Exception:
pass
await asyncio.gather(
pipe(local_reader, stream_writer),
pipe(stream_reader, local_writer),
)
logger.debug("隧道流 %s 已关闭", stream_id)

32
tunnel/protocol.py Normal file
View File

@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
"""
隧道控制面协议:服务端与客户端之间 JSON 文本行。
"""
from __future__ import annotations
import json
def register_message(worker_id: str, local_port: int) -> dict:
"""客户端注册:暴露的 worker_id 与本地端口。"""
return {"type": "register", "worker_id": worker_id, "local_port": local_port}
def register_ack_message(proxy_port: int) -> dict:
"""服务端确认:分配给的代理端口。"""
return {"type": "register_ack", "proxy_port": proxy_port}
def open_stream_message(stream_id: int, local_port: int) -> dict:
"""服务端通知客户端:请为 stream_id 打开本地 local_port 并连到 stream 端口。"""
return {"type": "open_stream", "stream_id": stream_id, "local_port": local_port}
def encode_control(msg: dict) -> bytes:
"""编码为一行 JSON + 换行。"""
return (json.dumps(msg, ensure_ascii=False) + "\n").encode("utf-8")
def decode_control(line: bytes) -> dict:
"""解码一行 JSON。"""
return json.loads(line.decode("utf-8").strip())

247
tunnel/server.py Normal file
View File

@@ -0,0 +1,247 @@
# -*- coding: utf-8 -*-
"""
隧道服务端:运行在云服务器。
- 控制端口:接受隧道客户端连接,注册 worker_id。
- 流端口:客户端连此端口做流桥接。
- 为每个 worker 分配一个代理端口,外部连代理端口时转发到该 worker 的本地端口。
"""
from __future__ import annotations
import asyncio
import logging
from typing import Optional
from tunnel.protocol import (
decode_control,
encode_control,
open_stream_message,
register_ack_message,
)
logger = logging.getLogger("tunnel.server")
class TunnelServer:
def __init__(
self,
control_port: int,
stream_port: int,
proxy_base_port: int,
host: str = "0.0.0.0",
) -> None:
self.control_port = control_port
self.stream_port = stream_port
self.proxy_base_port = proxy_base_port
self.host = host
self._worker_control: dict[str, asyncio.StreamWriter] = {}
self._worker_local_port: dict[str, int] = {}
self._worker_proxy_port: dict[str, int] = {}
self._proxy_worker: dict[int, str] = {}
self._next_proxy_port = proxy_base_port
self._stream_queues: dict[int, asyncio.Queue[tuple[asyncio.StreamReader, asyncio.StreamWriter]]] = {}
self._stream_id = 0
self._lock = asyncio.Lock()
self._control_server: Optional[asyncio.Server] = None
self._stream_server: Optional[asyncio.Server] = None
self._proxy_servers: dict[int, asyncio.Server] = {}
async def start(self) -> None:
self._control_server = await asyncio.start_server(
self._handle_control_connection, self.host, self.control_port
)
self._stream_server = await asyncio.start_server(
self._handle_stream_connection, self.host, self.stream_port
)
logger.info(
"隧道服务端: 控制端口 %s:%s, 流端口 %s:%s",
self.host, self.control_port, self.host, self.stream_port,
)
async def stop(self) -> None:
for s in self._proxy_servers.values():
s.close()
await s.wait_closed()
self._proxy_servers.clear()
if self._stream_server:
self._stream_server.close()
await self._stream_server.wait_closed()
if self._control_server:
self._control_server.close()
await self._control_server.wait_closed()
for w in self._worker_control.values():
try:
w.close()
await w.drain()
except Exception:
pass
self._worker_control.clear()
def _alloc_proxy_port(self, worker_id: str, local_port: int) -> int:
port = self._next_proxy_port
self._next_proxy_port += 1
self._worker_proxy_port[worker_id] = port
self._worker_local_port[worker_id] = local_port
self._proxy_worker[port] = worker_id
return port
async def _handle_control_connection(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
peer = writer.get_extra_info("peername", ("?", "?"))[0]
worker_id: Optional[str] = None
try:
line = await asyncio.wait_for(reader.readline(), timeout=30)
if not line:
return
msg = decode_control(line)
if msg.get("type") != "register":
writer.write(encode_control({"type": "error", "detail": "首条消息必须是 register"}))
await writer.drain()
return
worker_id = msg.get("worker_id") or ""
local_port = int(msg.get("local_port", 54345))
if not worker_id:
writer.write(encode_control({"type": "error", "detail": "worker_id 不能为空"}))
await writer.drain()
return
async with self._lock:
if worker_id in self._worker_control:
try:
self._worker_control[worker_id].close()
await self._worker_control[worker_id].drain()
except Exception:
pass
proxy_port = self._alloc_proxy_port(worker_id, local_port)
self._worker_control[worker_id] = writer
writer.write(encode_control(register_ack_message(proxy_port)))
await writer.drain()
logger.info("隧道客户端注册: worker_id=%s, local_port=%s, proxy_port=%s", worker_id, local_port, proxy_port)
# 为该 worker 启动代理监听(若尚未监听)
await self._ensure_proxy_listener(proxy_port, worker_id, local_port)
# 保持连接并处理 open_stream 的响应(本实现里服务端只发 open_stream不期待客户端回控制消息
while True:
line = await reader.readline()
if not line:
break
# 可扩展:心跳、其他控制消息
except asyncio.TimeoutError:
logger.warning("控制连接超时: %s", peer)
except Exception as e:
logger.exception("控制连接异常: %s", e)
finally:
if worker_id:
async with self._lock:
if self._worker_control.get(worker_id) is writer:
del self._worker_control[worker_id]
if worker_id in self._worker_proxy_port:
port = self._worker_proxy_port.pop(worker_id)
self._proxy_worker.pop(port, None)
self._worker_local_port.pop(worker_id, None)
srv = self._proxy_servers.pop(port, None)
if srv:
srv.close()
await srv.wait_closed()
try:
writer.close()
await writer.wait_closed()
except Exception:
pass
async def _ensure_proxy_listener(self, proxy_port: int, worker_id: str, local_port: int) -> None:
if proxy_port in self._proxy_servers:
return
async def on_proxy_connection(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
await self._bridge_user_to_worker(proxy_port, worker_id, local_port, reader, writer)
self._proxy_servers[proxy_port] = await asyncio.start_server(
on_proxy_connection, self.host, proxy_port
)
logger.info("隧道代理监听: %s:%s -> worker %s local:%s", self.host, proxy_port, worker_id, local_port)
async def _bridge_user_to_worker(
self,
proxy_port: int,
worker_id: str,
local_port: int,
user_reader: asyncio.StreamReader,
user_writer: asyncio.StreamWriter,
) -> None:
stream_id = await self._alloc_stream_id()
control_writer = self._worker_control.get(worker_id)
if not control_writer or control_writer.is_closing():
user_writer.close()
await user_writer.wait_closed()
return
queue: asyncio.Queue[tuple[asyncio.StreamReader, asyncio.StreamWriter]] = asyncio.Queue(maxsize=1)
self._stream_queues[stream_id] = queue
try:
control_writer.write(encode_control(open_stream_message(stream_id, local_port)))
await control_writer.drain()
except Exception as e:
logger.warning("发送 open_stream 失败: %s", e)
self._stream_queues.pop(stream_id, None)
user_writer.close()
await user_writer.wait_closed()
return
try:
client_reader, client_writer = await asyncio.wait_for(queue.get(), timeout=15)
except asyncio.TimeoutError:
logger.warning("%s 等待客户端连接超时", stream_id)
self._stream_queues.pop(stream_id, None)
user_writer.close()
await user_writer.wait_closed()
return
self._stream_queues.pop(stream_id, None)
async def pipe(src: asyncio.StreamReader, dst: asyncio.StreamWriter) -> None:
try:
while True:
data = await src.read(65536)
if not data:
break
dst.write(data)
await dst.drain()
except (ConnectionResetError, BrokenPipeError, asyncio.CancelledError):
pass
finally:
try:
dst.close()
await dst.wait_closed()
except Exception:
pass
await asyncio.gather(
pipe(user_reader, client_writer),
pipe(client_reader, user_writer),
)
async def _handle_stream_connection(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
"""客户端连到流端口后发送 stream_id4 字节大端),然后参与桥接。"""
try:
sid_bytes = await asyncio.wait_for(reader.readexactly(4), timeout=10)
stream_id = int.from_bytes(sid_bytes, "big")
queue = self._stream_queues.get(stream_id)
if not queue:
writer.close()
await writer.wait_closed()
return
queue.put_nowait((reader, writer))
except (asyncio.TimeoutError, asyncio.IncompleteReadError) as e:
logger.warning("流连接读 stream_id 失败: %s", e)
writer.close()
await writer.wait_closed()
async def _alloc_stream_id(self) -> int:
async with self._lock:
self._stream_id += 1
return self._stream_id