248 lines
9.6 KiB
Python
248 lines
9.6 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
隧道服务端:运行在云服务器。
|
||
- 控制端口:接受隧道客户端连接,注册 worker_id。
|
||
- 流端口:客户端连此端口做流桥接。
|
||
- 为每个 worker 分配一个代理端口,外部连代理端口时转发到该 worker 的本地端口。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
from typing import Optional
|
||
|
||
from tunnel.protocol import (
|
||
decode_control,
|
||
encode_control,
|
||
open_stream_message,
|
||
register_ack_message,
|
||
)
|
||
|
||
logger = logging.getLogger("tunnel.server")
|
||
|
||
|
||
class TunnelServer:
|
||
def __init__(
|
||
self,
|
||
control_port: int,
|
||
stream_port: int,
|
||
proxy_base_port: int,
|
||
host: str = "0.0.0.0",
|
||
) -> None:
|
||
self.control_port = control_port
|
||
self.stream_port = stream_port
|
||
self.proxy_base_port = proxy_base_port
|
||
self.host = host
|
||
self._worker_control: dict[str, asyncio.StreamWriter] = {}
|
||
self._worker_local_port: dict[str, int] = {}
|
||
self._worker_proxy_port: dict[str, int] = {}
|
||
self._proxy_worker: dict[int, str] = {}
|
||
self._next_proxy_port = proxy_base_port
|
||
self._stream_queues: dict[int, asyncio.Queue[tuple[asyncio.StreamReader, asyncio.StreamWriter]]] = {}
|
||
self._stream_id = 0
|
||
self._lock = asyncio.Lock()
|
||
self._control_server: Optional[asyncio.Server] = None
|
||
self._stream_server: Optional[asyncio.Server] = None
|
||
self._proxy_servers: dict[int, asyncio.Server] = {}
|
||
|
||
async def start(self) -> None:
|
||
self._control_server = await asyncio.start_server(
|
||
self._handle_control_connection, self.host, self.control_port
|
||
)
|
||
self._stream_server = await asyncio.start_server(
|
||
self._handle_stream_connection, self.host, self.stream_port
|
||
)
|
||
logger.info(
|
||
"隧道服务端: 控制端口 %s:%s, 流端口 %s:%s",
|
||
self.host, self.control_port, self.host, self.stream_port,
|
||
)
|
||
|
||
async def stop(self) -> None:
|
||
for s in self._proxy_servers.values():
|
||
s.close()
|
||
await s.wait_closed()
|
||
self._proxy_servers.clear()
|
||
if self._stream_server:
|
||
self._stream_server.close()
|
||
await self._stream_server.wait_closed()
|
||
if self._control_server:
|
||
self._control_server.close()
|
||
await self._control_server.wait_closed()
|
||
for w in self._worker_control.values():
|
||
try:
|
||
w.close()
|
||
await w.drain()
|
||
except Exception:
|
||
pass
|
||
self._worker_control.clear()
|
||
|
||
def _alloc_proxy_port(self, worker_id: str, local_port: int) -> int:
|
||
port = self._next_proxy_port
|
||
self._next_proxy_port += 1
|
||
self._worker_proxy_port[worker_id] = port
|
||
self._worker_local_port[worker_id] = local_port
|
||
self._proxy_worker[port] = worker_id
|
||
return port
|
||
|
||
async def _handle_control_connection(
|
||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||
) -> None:
|
||
peer = writer.get_extra_info("peername", ("?", "?"))[0]
|
||
worker_id: Optional[str] = None
|
||
try:
|
||
line = await asyncio.wait_for(reader.readline(), timeout=30)
|
||
if not line:
|
||
return
|
||
msg = decode_control(line)
|
||
if msg.get("type") != "register":
|
||
writer.write(encode_control({"type": "error", "detail": "首条消息必须是 register"}))
|
||
await writer.drain()
|
||
return
|
||
worker_id = msg.get("worker_id") or ""
|
||
local_port = int(msg.get("local_port", 54345))
|
||
if not worker_id:
|
||
writer.write(encode_control({"type": "error", "detail": "worker_id 不能为空"}))
|
||
await writer.drain()
|
||
return
|
||
|
||
async with self._lock:
|
||
if worker_id in self._worker_control:
|
||
try:
|
||
self._worker_control[worker_id].close()
|
||
await self._worker_control[worker_id].drain()
|
||
except Exception:
|
||
pass
|
||
proxy_port = self._alloc_proxy_port(worker_id, local_port)
|
||
self._worker_control[worker_id] = writer
|
||
|
||
writer.write(encode_control(register_ack_message(proxy_port)))
|
||
await writer.drain()
|
||
logger.info("隧道客户端注册: worker_id=%s, local_port=%s, proxy_port=%s", worker_id, local_port, proxy_port)
|
||
|
||
# 为该 worker 启动代理监听(若尚未监听)
|
||
await self._ensure_proxy_listener(proxy_port, worker_id, local_port)
|
||
|
||
# 保持连接并处理 open_stream 的响应(本实现里服务端只发 open_stream,不期待客户端回控制消息)
|
||
while True:
|
||
line = await reader.readline()
|
||
if not line:
|
||
break
|
||
# 可扩展:心跳、其他控制消息
|
||
except asyncio.TimeoutError:
|
||
logger.warning("控制连接超时: %s", peer)
|
||
except Exception as e:
|
||
logger.exception("控制连接异常: %s", e)
|
||
finally:
|
||
if worker_id:
|
||
async with self._lock:
|
||
if self._worker_control.get(worker_id) is writer:
|
||
del self._worker_control[worker_id]
|
||
if worker_id in self._worker_proxy_port:
|
||
port = self._worker_proxy_port.pop(worker_id)
|
||
self._proxy_worker.pop(port, None)
|
||
self._worker_local_port.pop(worker_id, None)
|
||
srv = self._proxy_servers.pop(port, None)
|
||
if srv:
|
||
srv.close()
|
||
await srv.wait_closed()
|
||
try:
|
||
writer.close()
|
||
await writer.wait_closed()
|
||
except Exception:
|
||
pass
|
||
|
||
async def _ensure_proxy_listener(self, proxy_port: int, worker_id: str, local_port: int) -> None:
|
||
if proxy_port in self._proxy_servers:
|
||
return
|
||
|
||
async def on_proxy_connection(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||
await self._bridge_user_to_worker(proxy_port, worker_id, local_port, reader, writer)
|
||
|
||
self._proxy_servers[proxy_port] = await asyncio.start_server(
|
||
on_proxy_connection, self.host, proxy_port
|
||
)
|
||
logger.info("隧道代理监听: %s:%s -> worker %s local:%s", self.host, proxy_port, worker_id, local_port)
|
||
|
||
async def _bridge_user_to_worker(
|
||
self,
|
||
proxy_port: int,
|
||
worker_id: str,
|
||
local_port: int,
|
||
user_reader: asyncio.StreamReader,
|
||
user_writer: asyncio.StreamWriter,
|
||
) -> None:
|
||
stream_id = await self._alloc_stream_id()
|
||
control_writer = self._worker_control.get(worker_id)
|
||
if not control_writer or control_writer.is_closing():
|
||
user_writer.close()
|
||
await user_writer.wait_closed()
|
||
return
|
||
|
||
queue: asyncio.Queue[tuple[asyncio.StreamReader, asyncio.StreamWriter]] = asyncio.Queue(maxsize=1)
|
||
self._stream_queues[stream_id] = queue
|
||
|
||
try:
|
||
control_writer.write(encode_control(open_stream_message(stream_id, local_port)))
|
||
await control_writer.drain()
|
||
except Exception as e:
|
||
logger.warning("发送 open_stream 失败: %s", e)
|
||
self._stream_queues.pop(stream_id, None)
|
||
user_writer.close()
|
||
await user_writer.wait_closed()
|
||
return
|
||
|
||
try:
|
||
client_reader, client_writer = await asyncio.wait_for(queue.get(), timeout=15)
|
||
except asyncio.TimeoutError:
|
||
logger.warning("流 %s 等待客户端连接超时", stream_id)
|
||
self._stream_queues.pop(stream_id, None)
|
||
user_writer.close()
|
||
await user_writer.wait_closed()
|
||
return
|
||
self._stream_queues.pop(stream_id, None)
|
||
|
||
async def pipe(src: asyncio.StreamReader, dst: asyncio.StreamWriter) -> None:
|
||
try:
|
||
while True:
|
||
data = await src.read(65536)
|
||
if not data:
|
||
break
|
||
dst.write(data)
|
||
await dst.drain()
|
||
except (ConnectionResetError, BrokenPipeError, asyncio.CancelledError):
|
||
pass
|
||
finally:
|
||
try:
|
||
dst.close()
|
||
await dst.wait_closed()
|
||
except Exception:
|
||
pass
|
||
|
||
await asyncio.gather(
|
||
pipe(user_reader, client_writer),
|
||
pipe(client_reader, user_writer),
|
||
)
|
||
|
||
async def _handle_stream_connection(
|
||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||
) -> None:
|
||
"""客户端连到流端口后发送 stream_id(4 字节大端),然后参与桥接。"""
|
||
try:
|
||
sid_bytes = await asyncio.wait_for(reader.readexactly(4), timeout=10)
|
||
stream_id = int.from_bytes(sid_bytes, "big")
|
||
queue = self._stream_queues.get(stream_id)
|
||
if not queue:
|
||
writer.close()
|
||
await writer.wait_closed()
|
||
return
|
||
queue.put_nowait((reader, writer))
|
||
except (asyncio.TimeoutError, asyncio.IncompleteReadError) as e:
|
||
logger.warning("流连接读 stream_id 失败: %s", e)
|
||
writer.close()
|
||
await writer.wait_closed()
|
||
|
||
async def _alloc_stream_id(self) -> int:
|
||
async with self._lock:
|
||
self._stream_id += 1
|
||
return self._stream_id
|