diff --git a/README.md b/README.md index 14028e1..23f8ace 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,28 @@ Worker 启动后会自动: --- +## 内网穿透(隧道) + +项目内置 **独立 Python 隧道**,与 Worker 集成。线下电脑无需单独安装 frp 等工具,启动 Worker 时会同时建立隧道,将本机端口(如比特浏览器 API 54345)暴露到云服务器,便于从公网访问该机器上的服务。 + +- **云服务器**:随 `python -m server.main` 自动启动隧道服务端(控制端口 8001、流端口 8003、代理端口从 8010 起按 Worker 分配)。 +- **线下电脑**:Worker 启动时默认开启隧道客户端,连接云服务器并注册 `worker_id`,云上会为该 Worker 分配一个代理端口(如 8010、8011…)。访问 `云服务器IP:8010` 即等价于访问该 Worker 本机的 `127.0.0.1:54345`。 + +**环境变量(可选)** + +| 位置 | 变量 | 默认值 | 说明 | +|------|------|--------|------| +| 服务器 | `TUNNEL_CONTROL_PORT` | 8001 | 隧道控制端口 | +| 服务器 | `TUNNEL_STREAM_PORT` | 8003 | 隧道流端口 | +| 服务器 | `TUNNEL_PROXY_BASE_PORT` | 8010 | 代理端口起始 | +| Worker | `TUNNEL_ENABLED` | 1 | 是否启用隧道 | +| Worker | `TUNNEL_SERVER` | (从 SERVER_WS_URL 解析) | 隧道服务端地址 | +| Worker | `TUNNEL_LOCAL_PORT` | 54345 | 暴露的本地端口 | + +**禁用隧道**:`python -m worker.main --no-tunnel ...` + +--- + ## API 接口 ### 查看在线 Worker @@ -155,6 +177,11 @@ boss_dp/ │ └── boss_recruit.py # BOSS 直聘招聘任务 ├── common/ │ └── protocol.py # 共享消息协议 +├── tunnel/ # 内网穿透(独立隧道,与 Worker 集成) +│ ├── __init__.py +│ ├── protocol.py # 隧道控制协议 +│ ├── server.py # 隧道服务端(云上) +│ └── client.py # 隧道客户端(线下,随 Worker 启动) ├── requirements.txt └── README.md ``` diff --git a/server/config.py b/server/config.py index e88a440..696105f 100644 --- a/server/config.py +++ b/server/config.py @@ -16,3 +16,8 @@ HEARTBEAT_TIMEOUT: int = 90 # 超时未收到心跳视为离 # ─── 安全(可选) ─── API_TOKEN: str = os.getenv("API_TOKEN", "") # 非空时校验 Header: Authorization: Bearer + +# ─── 隧道(内网穿透) ─── +TUNNEL_CONTROL_PORT: int = int(os.getenv("TUNNEL_CONTROL_PORT", "8001")) # 隧道客户端连接 +TUNNEL_STREAM_PORT: int = int(os.getenv("TUNNEL_STREAM_PORT", "8003")) # 流连接(客户端连此端口桥接) +TUNNEL_PROXY_BASE_PORT: int = int(os.getenv("TUNNEL_PROXY_BASE_PORT", "8010")) # 代理端口起始(8010=worker1, 8011=worker2...) diff --git a/server/main.py b/server/main.py index d5be494..3e8e640 100644 --- a/server/main.py +++ b/server/main.py @@ -19,6 +19,7 @@ from server.api.workers import router as workers_router from server.api.tasks import router as tasks_router from server.core.worker_manager import worker_manager from server.core.task_dispatcher import task_dispatcher +from tunnel.server import TunnelServer logging.basicConfig( level=logging.INFO, @@ -128,11 +129,33 @@ async def ws_endpoint(ws: WebSocket): # ────────────────────────── 生命周期 ────────────────────────── +_tunnel_server: TunnelServer | None = None + + @app.on_event("startup") async def startup(): - # 启动心跳巡检后台任务 + global _tunnel_server asyncio.create_task(worker_manager.check_heartbeats_loop()) - logger.info("服务器启动: http://%s:%s", config.HOST, config.PORT) + _tunnel_server = TunnelServer( + control_port=config.TUNNEL_CONTROL_PORT, + stream_port=config.TUNNEL_STREAM_PORT, + proxy_base_port=config.TUNNEL_PROXY_BASE_PORT, + host=config.HOST, + ) + await _tunnel_server.start() + logger.info( + "服务器启动: http://%s:%s | 隧道: 控制 %s, 流 %s, 代理起始 %s", + config.HOST, config.PORT, + config.TUNNEL_CONTROL_PORT, config.TUNNEL_STREAM_PORT, config.TUNNEL_PROXY_BASE_PORT, + ) + + +@app.on_event("shutdown") +async def shutdown(): + global _tunnel_server + if _tunnel_server: + await _tunnel_server.stop() + _tunnel_server = None # ────────────────────────── 入口 ────────────────────────── diff --git a/tunnel/__init__.py b/tunnel/__init__.py new file mode 100644 index 0000000..404cc2d --- /dev/null +++ b/tunnel/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +""" +独立 Python 隧道模块:内网穿透,与 Worker 集成。 +- 隧道服务端运行在云服务器,接受隧道客户端连接,并为每个 Worker 分配代理端口。 +- 隧道客户端运行在每台线下电脑(与 Worker 同进程),将本地端口暴露到云服务器。 +""" diff --git a/tunnel/client.py b/tunnel/client.py new file mode 100644 index 0000000..e430820 --- /dev/null +++ b/tunnel/client.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +""" +隧道客户端:运行在每台线下电脑,与 Worker 同进程集成。 +连接隧道服务端,注册 worker_id,收到 open_stream 时连接本地端口并桥接到服务端流端口。 +""" +from __future__ import annotations + +import asyncio +import logging +from typing import Optional + +from tunnel.protocol import decode_control, encode_control, register_message + +logger = logging.getLogger("tunnel.client") + + +class TunnelClient: + def __init__( + self, + server_host: str, + control_port: int, + stream_port: int, + worker_id: str, + local_port: int, + ) -> None: + self.server_host = server_host + self.control_port = control_port + self.stream_port = stream_port + self.worker_id = worker_id + self.local_port = local_port + self._running = False + self._reconnect_delay = 5 + self._reconnect_max_delay = 60 + + async def run(self) -> None: + self._running = True + while self._running: + try: + await self._connect_and_loop() + except Exception as e: + logger.error("隧道连接异常: %s", e) + if self._running: + logger.info("隧道将在 %d 秒后重连...", self._reconnect_delay) + await asyncio.sleep(self._reconnect_delay) + self._reconnect_delay = min( + self._reconnect_delay * 2, + self._reconnect_max_delay, + ) + + async def stop(self) -> None: + self._running = False + + async def _connect_and_loop(self) -> None: + reader, writer = await asyncio.open_connection( + self.server_host, self.control_port + ) + self._reconnect_delay = 5 + logger.info("隧道控制连接已建立: %s:%s", self.server_host, self.control_port) + + msg = register_message(self.worker_id, self.local_port) + writer.write(encode_control(msg)) + await writer.drain() + + line = await asyncio.wait_for(reader.readline(), timeout=15) + if not line: + raise RuntimeError("隧道服务端未响应") + ack = decode_control(line) + if ack.get("type") == "error": + raise RuntimeError("隧道注册失败: %s" % ack.get("detail", ack)) + if ack.get("type") != "register_ack": + raise RuntimeError("隧道注册异常: %s" % ack) + proxy_port = ack.get("proxy_port") + logger.info("隧道已注册: worker_id=%s, 代理端口 %s:%s", self.worker_id, self.server_host, proxy_port) + + try: + while self._running: + line = await asyncio.wait_for(reader.readline(), timeout=60) + if not line: + break + data = decode_control(line) + if data.get("type") == "open_stream": + stream_id = int(data.get("stream_id", 0)) + local_port = int(data.get("local_port", self.local_port)) + asyncio.create_task( + self._bridge_stream(stream_id, local_port) + ) + except asyncio.TimeoutError: + pass + finally: + writer.close() + await writer.wait_closed() + + async def _bridge_stream(self, stream_id: int, local_port: int) -> None: + """连接本地 local_port,再连服务端 stream 端口,发送 stream_id 后双向桥接。""" + try: + local_reader, local_writer = await asyncio.open_connection( + "127.0.0.1", local_port + ) + except Exception as e: + logger.warning("隧道打开本地 %s 失败: %s", local_port, e) + return + + try: + stream_reader, stream_writer = await asyncio.open_connection( + self.server_host, self.stream_port + ) + except Exception as e: + logger.warning("隧道连流端口失败: %s", e) + local_writer.close() + await local_writer.wait_closed() + return + + try: + stream_writer.write(stream_id.to_bytes(4, "big")) + await stream_writer.drain() + except Exception as e: + logger.warning("隧道发送 stream_id 失败: %s", e) + stream_writer.close() + await stream_writer.wait_closed() + local_writer.close() + await local_writer.wait_closed() + return + + async def pipe( + src: asyncio.StreamReader, dst: asyncio.StreamWriter + ) -> None: + try: + while True: + data = await src.read(65536) + if not data: + break + dst.write(data) + await dst.drain() + except (ConnectionResetError, BrokenPipeError, asyncio.CancelledError): + pass + finally: + try: + dst.close() + await dst.wait_closed() + except Exception: + pass + + await asyncio.gather( + pipe(local_reader, stream_writer), + pipe(stream_reader, local_writer), + ) + logger.debug("隧道流 %s 已关闭", stream_id) diff --git a/tunnel/protocol.py b/tunnel/protocol.py new file mode 100644 index 0000000..0de0935 --- /dev/null +++ b/tunnel/protocol.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +""" +隧道控制面协议:服务端与客户端之间 JSON 文本行。 +""" +from __future__ import annotations + +import json + + +def register_message(worker_id: str, local_port: int) -> dict: + """客户端注册:暴露的 worker_id 与本地端口。""" + return {"type": "register", "worker_id": worker_id, "local_port": local_port} + + +def register_ack_message(proxy_port: int) -> dict: + """服务端确认:分配给的代理端口。""" + return {"type": "register_ack", "proxy_port": proxy_port} + + +def open_stream_message(stream_id: int, local_port: int) -> dict: + """服务端通知客户端:请为 stream_id 打开本地 local_port 并连到 stream 端口。""" + return {"type": "open_stream", "stream_id": stream_id, "local_port": local_port} + + +def encode_control(msg: dict) -> bytes: + """编码为一行 JSON + 换行。""" + return (json.dumps(msg, ensure_ascii=False) + "\n").encode("utf-8") + + +def decode_control(line: bytes) -> dict: + """解码一行 JSON。""" + return json.loads(line.decode("utf-8").strip()) diff --git a/tunnel/server.py b/tunnel/server.py new file mode 100644 index 0000000..0ba3ae6 --- /dev/null +++ b/tunnel/server.py @@ -0,0 +1,247 @@ +# -*- coding: utf-8 -*- +""" +隧道服务端:运行在云服务器。 +- 控制端口:接受隧道客户端连接,注册 worker_id。 +- 流端口:客户端连此端口做流桥接。 +- 为每个 worker 分配一个代理端口,外部连代理端口时转发到该 worker 的本地端口。 +""" +from __future__ import annotations + +import asyncio +import logging +from typing import Optional + +from tunnel.protocol import ( + decode_control, + encode_control, + open_stream_message, + register_ack_message, +) + +logger = logging.getLogger("tunnel.server") + + +class TunnelServer: + def __init__( + self, + control_port: int, + stream_port: int, + proxy_base_port: int, + host: str = "0.0.0.0", + ) -> None: + self.control_port = control_port + self.stream_port = stream_port + self.proxy_base_port = proxy_base_port + self.host = host + self._worker_control: dict[str, asyncio.StreamWriter] = {} + self._worker_local_port: dict[str, int] = {} + self._worker_proxy_port: dict[str, int] = {} + self._proxy_worker: dict[int, str] = {} + self._next_proxy_port = proxy_base_port + self._stream_queues: dict[int, asyncio.Queue[tuple[asyncio.StreamReader, asyncio.StreamWriter]]] = {} + self._stream_id = 0 + self._lock = asyncio.Lock() + self._control_server: Optional[asyncio.Server] = None + self._stream_server: Optional[asyncio.Server] = None + self._proxy_servers: dict[int, asyncio.Server] = {} + + async def start(self) -> None: + self._control_server = await asyncio.start_server( + self._handle_control_connection, self.host, self.control_port + ) + self._stream_server = await asyncio.start_server( + self._handle_stream_connection, self.host, self.stream_port + ) + logger.info( + "隧道服务端: 控制端口 %s:%s, 流端口 %s:%s", + self.host, self.control_port, self.host, self.stream_port, + ) + + async def stop(self) -> None: + for s in self._proxy_servers.values(): + s.close() + await s.wait_closed() + self._proxy_servers.clear() + if self._stream_server: + self._stream_server.close() + await self._stream_server.wait_closed() + if self._control_server: + self._control_server.close() + await self._control_server.wait_closed() + for w in self._worker_control.values(): + try: + w.close() + await w.drain() + except Exception: + pass + self._worker_control.clear() + + def _alloc_proxy_port(self, worker_id: str, local_port: int) -> int: + port = self._next_proxy_port + self._next_proxy_port += 1 + self._worker_proxy_port[worker_id] = port + self._worker_local_port[worker_id] = local_port + self._proxy_worker[port] = worker_id + return port + + async def _handle_control_connection( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + peer = writer.get_extra_info("peername", ("?", "?"))[0] + worker_id: Optional[str] = None + try: + line = await asyncio.wait_for(reader.readline(), timeout=30) + if not line: + return + msg = decode_control(line) + if msg.get("type") != "register": + writer.write(encode_control({"type": "error", "detail": "首条消息必须是 register"})) + await writer.drain() + return + worker_id = msg.get("worker_id") or "" + local_port = int(msg.get("local_port", 54345)) + if not worker_id: + writer.write(encode_control({"type": "error", "detail": "worker_id 不能为空"})) + await writer.drain() + return + + async with self._lock: + if worker_id in self._worker_control: + try: + self._worker_control[worker_id].close() + await self._worker_control[worker_id].drain() + except Exception: + pass + proxy_port = self._alloc_proxy_port(worker_id, local_port) + self._worker_control[worker_id] = writer + + writer.write(encode_control(register_ack_message(proxy_port))) + await writer.drain() + logger.info("隧道客户端注册: worker_id=%s, local_port=%s, proxy_port=%s", worker_id, local_port, proxy_port) + + # 为该 worker 启动代理监听(若尚未监听) + await self._ensure_proxy_listener(proxy_port, worker_id, local_port) + + # 保持连接并处理 open_stream 的响应(本实现里服务端只发 open_stream,不期待客户端回控制消息) + while True: + line = await reader.readline() + if not line: + break + # 可扩展:心跳、其他控制消息 + except asyncio.TimeoutError: + logger.warning("控制连接超时: %s", peer) + except Exception as e: + logger.exception("控制连接异常: %s", e) + finally: + if worker_id: + async with self._lock: + if self._worker_control.get(worker_id) is writer: + del self._worker_control[worker_id] + if worker_id in self._worker_proxy_port: + port = self._worker_proxy_port.pop(worker_id) + self._proxy_worker.pop(port, None) + self._worker_local_port.pop(worker_id, None) + srv = self._proxy_servers.pop(port, None) + if srv: + srv.close() + await srv.wait_closed() + try: + writer.close() + await writer.wait_closed() + except Exception: + pass + + async def _ensure_proxy_listener(self, proxy_port: int, worker_id: str, local_port: int) -> None: + if proxy_port in self._proxy_servers: + return + + async def on_proxy_connection(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: + await self._bridge_user_to_worker(proxy_port, worker_id, local_port, reader, writer) + + self._proxy_servers[proxy_port] = await asyncio.start_server( + on_proxy_connection, self.host, proxy_port + ) + logger.info("隧道代理监听: %s:%s -> worker %s local:%s", self.host, proxy_port, worker_id, local_port) + + async def _bridge_user_to_worker( + self, + proxy_port: int, + worker_id: str, + local_port: int, + user_reader: asyncio.StreamReader, + user_writer: asyncio.StreamWriter, + ) -> None: + stream_id = await self._alloc_stream_id() + control_writer = self._worker_control.get(worker_id) + if not control_writer or control_writer.is_closing(): + user_writer.close() + await user_writer.wait_closed() + return + + queue: asyncio.Queue[tuple[asyncio.StreamReader, asyncio.StreamWriter]] = asyncio.Queue(maxsize=1) + self._stream_queues[stream_id] = queue + + try: + control_writer.write(encode_control(open_stream_message(stream_id, local_port))) + await control_writer.drain() + except Exception as e: + logger.warning("发送 open_stream 失败: %s", e) + self._stream_queues.pop(stream_id, None) + user_writer.close() + await user_writer.wait_closed() + return + + try: + client_reader, client_writer = await asyncio.wait_for(queue.get(), timeout=15) + except asyncio.TimeoutError: + logger.warning("流 %s 等待客户端连接超时", stream_id) + self._stream_queues.pop(stream_id, None) + user_writer.close() + await user_writer.wait_closed() + return + self._stream_queues.pop(stream_id, None) + + async def pipe(src: asyncio.StreamReader, dst: asyncio.StreamWriter) -> None: + try: + while True: + data = await src.read(65536) + if not data: + break + dst.write(data) + await dst.drain() + except (ConnectionResetError, BrokenPipeError, asyncio.CancelledError): + pass + finally: + try: + dst.close() + await dst.wait_closed() + except Exception: + pass + + await asyncio.gather( + pipe(user_reader, client_writer), + pipe(client_reader, user_writer), + ) + + async def _handle_stream_connection( + self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter + ) -> None: + """客户端连到流端口后发送 stream_id(4 字节大端),然后参与桥接。""" + try: + sid_bytes = await asyncio.wait_for(reader.readexactly(4), timeout=10) + stream_id = int.from_bytes(sid_bytes, "big") + queue = self._stream_queues.get(stream_id) + if not queue: + writer.close() + await writer.wait_closed() + return + queue.put_nowait((reader, writer)) + except (asyncio.TimeoutError, asyncio.IncompleteReadError) as e: + logger.warning("流连接读 stream_id 失败: %s", e) + writer.close() + await writer.wait_closed() + + async def _alloc_stream_id(self) -> int: + async with self._lock: + self._stream_id += 1 + return self._stream_id diff --git a/worker/config.py b/worker/config.py index 901376c..ef9ade4 100644 --- a/worker/config.py +++ b/worker/config.py @@ -18,3 +18,26 @@ BIT_API_BASE: str = os.getenv("BIT_API_BASE", "http://127.0.0.1:54345") HEARTBEAT_INTERVAL: int = 25 # 心跳发送间隔(秒) RECONNECT_DELAY: int = 5 # 断线重连等待(秒) RECONNECT_MAX_DELAY: int = 60 # 重连最大等待(秒,指数退避上限) + +# ─── 隧道(内网穿透,与 Worker 集成) ─── +TUNNEL_ENABLED: bool = os.getenv("TUNNEL_ENABLED", "1").strip().lower() in ("1", "true", "yes") +TUNNEL_SERVER: str = os.getenv("TUNNEL_SERVER", "") # 留空则从 SERVER_WS_URL 解析 host +TUNNEL_CONTROL_PORT: int = int(os.getenv("TUNNEL_CONTROL_PORT", "8001")) +TUNNEL_STREAM_PORT: int = int(os.getenv("TUNNEL_STREAM_PORT", "8003")) +TUNNEL_LOCAL_PORT: int = int(os.getenv("TUNNEL_LOCAL_PORT", "54345")) # 暴露的本地端口(如比特浏览器 API) + + +def get_tunnel_server_host() -> str: + """隧道服务端地址。未设置 TUNNEL_SERVER 时从 SERVER_WS_URL 解析。""" + if TUNNEL_SERVER: + return TUNNEL_SERVER.split(":")[0] if ":" in TUNNEL_SERVER else TUNNEL_SERVER + url = SERVER_WS_URL.strip() + for prefix in ("wss://", "ws://"): + if url.startswith(prefix): + rest = url[len(prefix):] + if "/" in rest: + rest = rest.split("/")[0] + if ":" in rest: + return rest.split(":")[0] + return rest + return "127.0.0.1" diff --git a/worker/main.py b/worker/main.py index fcc5cd3..ee91072 100644 --- a/worker/main.py +++ b/worker/main.py @@ -13,6 +13,8 @@ import sys from worker import config from worker.tasks.registry import register_all_handlers from worker.ws_client import WorkerWSClient +from tunnel.client import TunnelClient +from worker.config import get_tunnel_server_host logging.basicConfig( level=logging.INFO, @@ -44,15 +46,29 @@ def parse_args(): default=config.BIT_API_BASE, help=f"比特浏览器本地 API 地址 (默认: {config.BIT_API_BASE})", ) + parser.add_argument( + "--no-tunnel", + action="store_true", + help="禁用内网穿透隧道(不与隧道服务端连接)", + ) return parser.parse_args() +def _local_port_from_bit_api(bit_api_base: str) -> int: + """从 BIT_API_BASE (e.g. http://127.0.0.1:54345) 解析端口。""" + try: + from urllib.parse import urlparse + p = urlparse(bit_api_base if "://" in bit_api_base else "http://" + bit_api_base) + return p.port or 54345 + except Exception: + return config.TUNNEL_LOCAL_PORT + + async def run(args): # 注册所有任务处理器 register_all_handlers() logger.info("已注册任务处理器") - # 创建 WebSocket 客户端并运行 client = WorkerWSClient( server_url=args.server, worker_id=args.worker_id, @@ -60,15 +76,56 @@ async def run(args): bit_api_base=args.bit_api, ) + tunnel_enabled = config.TUNNEL_ENABLED and not args.no_tunnel + tunnel_client = None + if tunnel_enabled: + tunnel_client = TunnelClient( + server_host=get_tunnel_server_host(), + control_port=config.TUNNEL_CONTROL_PORT, + stream_port=config.TUNNEL_STREAM_PORT, + worker_id=args.worker_id, + local_port=_local_port_from_bit_api(args.bit_api), + ) + logger.info( + "隧道已启用: 暴露本地 %s -> 云服务器 (worker_id=%s)", + _local_port_from_bit_api(args.bit_api), args.worker_id, + ) + logger.info( "Worker 启动: id=%s, name=%s, server=%s", args.worker_id, args.worker_name, args.server, ) - try: + async def run_worker(): await client.run() + + async def run_tunnel(): + if tunnel_client: + await tunnel_client.run() + + worker_task = asyncio.create_task(run_worker()) + tunnel_task = asyncio.create_task(run_tunnel()) if tunnel_client else None + + try: + if tunnel_task is not None: + await asyncio.gather(worker_task, tunnel_task) + else: + await worker_task except KeyboardInterrupt: logger.info("收到中断信号,正在退出...") + worker_task.cancel() + if tunnel_task is not None: + tunnel_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + pass + if tunnel_task is not None: + try: + await tunnel_task + except asyncio.CancelledError: + pass + await tunnel_client.stop() await client.stop()