Files
boss_dp/tunnel/server.py
ddrwode 4a520306b1 ha'ha
2026-02-12 16:46:01 +08:00

248 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
隧道服务端:运行在云服务器。
- 控制端口:接受隧道客户端连接,注册 worker_id。
- 流端口:客户端连此端口做流桥接。
- 为每个 worker 分配一个代理端口,外部连代理端口时转发到该 worker 的本地端口。
"""
from __future__ import annotations
import asyncio
import logging
from typing import Optional
from tunnel.protocol import (
decode_control,
encode_control,
open_stream_message,
register_ack_message,
)
logger = logging.getLogger("tunnel.server")
class TunnelServer:
def __init__(
self,
control_port: int,
stream_port: int,
proxy_base_port: int,
host: str = "0.0.0.0",
) -> None:
self.control_port = control_port
self.stream_port = stream_port
self.proxy_base_port = proxy_base_port
self.host = host
self._worker_control: dict[str, asyncio.StreamWriter] = {}
self._worker_local_port: dict[str, int] = {}
self._worker_proxy_port: dict[str, int] = {}
self._proxy_worker: dict[int, str] = {}
self._next_proxy_port = proxy_base_port
self._stream_queues: dict[int, asyncio.Queue[tuple[asyncio.StreamReader, asyncio.StreamWriter]]] = {}
self._stream_id = 0
self._lock = asyncio.Lock()
self._control_server: Optional[asyncio.Server] = None
self._stream_server: Optional[asyncio.Server] = None
self._proxy_servers: dict[int, asyncio.Server] = {}
async def start(self) -> None:
self._control_server = await asyncio.start_server(
self._handle_control_connection, self.host, self.control_port
)
self._stream_server = await asyncio.start_server(
self._handle_stream_connection, self.host, self.stream_port
)
logger.info(
"隧道服务端: 控制端口 %s:%s, 流端口 %s:%s",
self.host, self.control_port, self.host, self.stream_port,
)
async def stop(self) -> None:
for s in self._proxy_servers.values():
s.close()
await s.wait_closed()
self._proxy_servers.clear()
if self._stream_server:
self._stream_server.close()
await self._stream_server.wait_closed()
if self._control_server:
self._control_server.close()
await self._control_server.wait_closed()
for w in self._worker_control.values():
try:
w.close()
await w.drain()
except Exception:
pass
self._worker_control.clear()
def _alloc_proxy_port(self, worker_id: str, local_port: int) -> int:
port = self._next_proxy_port
self._next_proxy_port += 1
self._worker_proxy_port[worker_id] = port
self._worker_local_port[worker_id] = local_port
self._proxy_worker[port] = worker_id
return port
async def _handle_control_connection(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
peer = writer.get_extra_info("peername", ("?", "?"))[0]
worker_id: Optional[str] = None
try:
line = await asyncio.wait_for(reader.readline(), timeout=30)
if not line:
return
msg = decode_control(line)
if msg.get("type") != "register":
writer.write(encode_control({"type": "error", "detail": "首条消息必须是 register"}))
await writer.drain()
return
worker_id = msg.get("worker_id") or ""
local_port = int(msg.get("local_port", 54345))
if not worker_id:
writer.write(encode_control({"type": "error", "detail": "worker_id 不能为空"}))
await writer.drain()
return
async with self._lock:
if worker_id in self._worker_control:
try:
self._worker_control[worker_id].close()
await self._worker_control[worker_id].drain()
except Exception:
pass
proxy_port = self._alloc_proxy_port(worker_id, local_port)
self._worker_control[worker_id] = writer
writer.write(encode_control(register_ack_message(proxy_port)))
await writer.drain()
logger.info("隧道客户端注册: worker_id=%s, local_port=%s, proxy_port=%s", worker_id, local_port, proxy_port)
# 为该 worker 启动代理监听(若尚未监听)
await self._ensure_proxy_listener(proxy_port, worker_id, local_port)
# 保持连接并处理 open_stream 的响应(本实现里服务端只发 open_stream不期待客户端回控制消息
while True:
line = await reader.readline()
if not line:
break
# 可扩展:心跳、其他控制消息
except asyncio.TimeoutError:
logger.warning("控制连接超时: %s", peer)
except Exception as e:
logger.exception("控制连接异常: %s", e)
finally:
if worker_id:
async with self._lock:
if self._worker_control.get(worker_id) is writer:
del self._worker_control[worker_id]
if worker_id in self._worker_proxy_port:
port = self._worker_proxy_port.pop(worker_id)
self._proxy_worker.pop(port, None)
self._worker_local_port.pop(worker_id, None)
srv = self._proxy_servers.pop(port, None)
if srv:
srv.close()
await srv.wait_closed()
try:
writer.close()
await writer.wait_closed()
except Exception:
pass
async def _ensure_proxy_listener(self, proxy_port: int, worker_id: str, local_port: int) -> None:
if proxy_port in self._proxy_servers:
return
async def on_proxy_connection(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
await self._bridge_user_to_worker(proxy_port, worker_id, local_port, reader, writer)
self._proxy_servers[proxy_port] = await asyncio.start_server(
on_proxy_connection, self.host, proxy_port
)
logger.info("隧道代理监听: %s:%s -> worker %s local:%s", self.host, proxy_port, worker_id, local_port)
async def _bridge_user_to_worker(
self,
proxy_port: int,
worker_id: str,
local_port: int,
user_reader: asyncio.StreamReader,
user_writer: asyncio.StreamWriter,
) -> None:
stream_id = await self._alloc_stream_id()
control_writer = self._worker_control.get(worker_id)
if not control_writer or control_writer.is_closing():
user_writer.close()
await user_writer.wait_closed()
return
queue: asyncio.Queue[tuple[asyncio.StreamReader, asyncio.StreamWriter]] = asyncio.Queue(maxsize=1)
self._stream_queues[stream_id] = queue
try:
control_writer.write(encode_control(open_stream_message(stream_id, local_port)))
await control_writer.drain()
except Exception as e:
logger.warning("发送 open_stream 失败: %s", e)
self._stream_queues.pop(stream_id, None)
user_writer.close()
await user_writer.wait_closed()
return
try:
client_reader, client_writer = await asyncio.wait_for(queue.get(), timeout=15)
except asyncio.TimeoutError:
logger.warning("%s 等待客户端连接超时", stream_id)
self._stream_queues.pop(stream_id, None)
user_writer.close()
await user_writer.wait_closed()
return
self._stream_queues.pop(stream_id, None)
async def pipe(src: asyncio.StreamReader, dst: asyncio.StreamWriter) -> None:
try:
while True:
data = await src.read(65536)
if not data:
break
dst.write(data)
await dst.drain()
except (ConnectionResetError, BrokenPipeError, asyncio.CancelledError):
pass
finally:
try:
dst.close()
await dst.wait_closed()
except Exception:
pass
await asyncio.gather(
pipe(user_reader, client_writer),
pipe(client_reader, user_writer),
)
async def _handle_stream_connection(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
) -> None:
"""客户端连到流端口后发送 stream_id4 字节大端),然后参与桥接。"""
try:
sid_bytes = await asyncio.wait_for(reader.readexactly(4), timeout=10)
stream_id = int.from_bytes(sid_bytes, "big")
queue = self._stream_queues.get(stream_id)
if not queue:
writer.close()
await writer.wait_closed()
return
queue.put_nowait((reader, writer))
except (asyncio.TimeoutError, asyncio.IncompleteReadError) as e:
logger.warning("流连接读 stream_id 失败: %s", e)
writer.close()
await writer.wait_closed()
async def _alloc_stream_id(self) -> int:
async with self._lock:
self._stream_id += 1
return self._stream_id