ha'ha
This commit is contained in:
27
README.md
27
README.md
@@ -65,6 +65,28 @@ Worker 启动后会自动:
|
||||
|
||||
---
|
||||
|
||||
## 内网穿透(隧道)
|
||||
|
||||
项目内置 **独立 Python 隧道**,与 Worker 集成。线下电脑无需单独安装 frp 等工具,启动 Worker 时会同时建立隧道,将本机端口(如比特浏览器 API 54345)暴露到云服务器,便于从公网访问该机器上的服务。
|
||||
|
||||
- **云服务器**:随 `python -m server.main` 自动启动隧道服务端(控制端口 8001、流端口 8003、代理端口从 8010 起按 Worker 分配)。
|
||||
- **线下电脑**:Worker 启动时默认开启隧道客户端,连接云服务器并注册 `worker_id`,云上会为该 Worker 分配一个代理端口(如 8010、8011…)。访问 `云服务器IP:8010` 即等价于访问该 Worker 本机的 `127.0.0.1:54345`。
|
||||
|
||||
**环境变量(可选)**
|
||||
|
||||
| 位置 | 变量 | 默认值 | 说明 |
|
||||
|------|------|--------|------|
|
||||
| 服务器 | `TUNNEL_CONTROL_PORT` | 8001 | 隧道控制端口 |
|
||||
| 服务器 | `TUNNEL_STREAM_PORT` | 8003 | 隧道流端口 |
|
||||
| 服务器 | `TUNNEL_PROXY_BASE_PORT` | 8010 | 代理端口起始 |
|
||||
| Worker | `TUNNEL_ENABLED` | 1 | 是否启用隧道 |
|
||||
| Worker | `TUNNEL_SERVER` | (从 SERVER_WS_URL 解析) | 隧道服务端地址 |
|
||||
| Worker | `TUNNEL_LOCAL_PORT` | 54345 | 暴露的本地端口 |
|
||||
|
||||
**禁用隧道**:`python -m worker.main --no-tunnel ...`
|
||||
|
||||
---
|
||||
|
||||
## API 接口
|
||||
|
||||
### 查看在线 Worker
|
||||
@@ -155,6 +177,11 @@ boss_dp/
|
||||
│ └── boss_recruit.py # BOSS 直聘招聘任务
|
||||
├── common/
|
||||
│ └── protocol.py # 共享消息协议
|
||||
├── tunnel/ # 内网穿透(独立隧道,与 Worker 集成)
|
||||
│ ├── __init__.py
|
||||
│ ├── protocol.py # 隧道控制协议
|
||||
│ ├── server.py # 隧道服务端(云上)
|
||||
│ └── client.py # 隧道客户端(线下,随 Worker 启动)
|
||||
├── requirements.txt
|
||||
└── README.md
|
||||
```
|
||||
|
||||
@@ -16,3 +16,8 @@ HEARTBEAT_TIMEOUT: int = 90 # 超时未收到心跳视为离
|
||||
|
||||
# ─── 安全(可选) ───
|
||||
API_TOKEN: str = os.getenv("API_TOKEN", "") # 非空时校验 Header: Authorization: Bearer <token>
|
||||
|
||||
# ─── 隧道(内网穿透) ───
|
||||
TUNNEL_CONTROL_PORT: int = int(os.getenv("TUNNEL_CONTROL_PORT", "8001")) # 隧道客户端连接
|
||||
TUNNEL_STREAM_PORT: int = int(os.getenv("TUNNEL_STREAM_PORT", "8003")) # 流连接(客户端连此端口桥接)
|
||||
TUNNEL_PROXY_BASE_PORT: int = int(os.getenv("TUNNEL_PROXY_BASE_PORT", "8010")) # 代理端口起始(8010=worker1, 8011=worker2...)
|
||||
|
||||
@@ -19,6 +19,7 @@ from server.api.workers import router as workers_router
|
||||
from server.api.tasks import router as tasks_router
|
||||
from server.core.worker_manager import worker_manager
|
||||
from server.core.task_dispatcher import task_dispatcher
|
||||
from tunnel.server import TunnelServer
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -128,11 +129,33 @@ async def ws_endpoint(ws: WebSocket):
|
||||
|
||||
# ────────────────────────── 生命周期 ──────────────────────────
|
||||
|
||||
_tunnel_server: TunnelServer | None = None
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
# 启动心跳巡检后台任务
|
||||
global _tunnel_server
|
||||
asyncio.create_task(worker_manager.check_heartbeats_loop())
|
||||
logger.info("服务器启动: http://%s:%s", config.HOST, config.PORT)
|
||||
_tunnel_server = TunnelServer(
|
||||
control_port=config.TUNNEL_CONTROL_PORT,
|
||||
stream_port=config.TUNNEL_STREAM_PORT,
|
||||
proxy_base_port=config.TUNNEL_PROXY_BASE_PORT,
|
||||
host=config.HOST,
|
||||
)
|
||||
await _tunnel_server.start()
|
||||
logger.info(
|
||||
"服务器启动: http://%s:%s | 隧道: 控制 %s, 流 %s, 代理起始 %s",
|
||||
config.HOST, config.PORT,
|
||||
config.TUNNEL_CONTROL_PORT, config.TUNNEL_STREAM_PORT, config.TUNNEL_PROXY_BASE_PORT,
|
||||
)
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown():
|
||||
global _tunnel_server
|
||||
if _tunnel_server:
|
||||
await _tunnel_server.stop()
|
||||
_tunnel_server = None
|
||||
|
||||
|
||||
# ────────────────────────── 入口 ──────────────────────────
|
||||
|
||||
6
tunnel/__init__.py
Normal file
6
tunnel/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
独立 Python 隧道模块:内网穿透,与 Worker 集成。
|
||||
- 隧道服务端运行在云服务器,接受隧道客户端连接,并为每个 Worker 分配代理端口。
|
||||
- 隧道客户端运行在每台线下电脑(与 Worker 同进程),将本地端口暴露到云服务器。
|
||||
"""
|
||||
147
tunnel/client.py
Normal file
147
tunnel/client.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
隧道客户端:运行在每台线下电脑,与 Worker 同进程集成。
|
||||
连接隧道服务端,注册 worker_id,收到 open_stream 时连接本地端口并桥接到服务端流端口。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from tunnel.protocol import decode_control, encode_control, register_message
|
||||
|
||||
logger = logging.getLogger("tunnel.client")
|
||||
|
||||
|
||||
class TunnelClient:
|
||||
def __init__(
|
||||
self,
|
||||
server_host: str,
|
||||
control_port: int,
|
||||
stream_port: int,
|
||||
worker_id: str,
|
||||
local_port: int,
|
||||
) -> None:
|
||||
self.server_host = server_host
|
||||
self.control_port = control_port
|
||||
self.stream_port = stream_port
|
||||
self.worker_id = worker_id
|
||||
self.local_port = local_port
|
||||
self._running = False
|
||||
self._reconnect_delay = 5
|
||||
self._reconnect_max_delay = 60
|
||||
|
||||
async def run(self) -> None:
|
||||
self._running = True
|
||||
while self._running:
|
||||
try:
|
||||
await self._connect_and_loop()
|
||||
except Exception as e:
|
||||
logger.error("隧道连接异常: %s", e)
|
||||
if self._running:
|
||||
logger.info("隧道将在 %d 秒后重连...", self._reconnect_delay)
|
||||
await asyncio.sleep(self._reconnect_delay)
|
||||
self._reconnect_delay = min(
|
||||
self._reconnect_delay * 2,
|
||||
self._reconnect_max_delay,
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
|
||||
async def _connect_and_loop(self) -> None:
|
||||
reader, writer = await asyncio.open_connection(
|
||||
self.server_host, self.control_port
|
||||
)
|
||||
self._reconnect_delay = 5
|
||||
logger.info("隧道控制连接已建立: %s:%s", self.server_host, self.control_port)
|
||||
|
||||
msg = register_message(self.worker_id, self.local_port)
|
||||
writer.write(encode_control(msg))
|
||||
await writer.drain()
|
||||
|
||||
line = await asyncio.wait_for(reader.readline(), timeout=15)
|
||||
if not line:
|
||||
raise RuntimeError("隧道服务端未响应")
|
||||
ack = decode_control(line)
|
||||
if ack.get("type") == "error":
|
||||
raise RuntimeError("隧道注册失败: %s" % ack.get("detail", ack))
|
||||
if ack.get("type") != "register_ack":
|
||||
raise RuntimeError("隧道注册异常: %s" % ack)
|
||||
proxy_port = ack.get("proxy_port")
|
||||
logger.info("隧道已注册: worker_id=%s, 代理端口 %s:%s", self.worker_id, self.server_host, proxy_port)
|
||||
|
||||
try:
|
||||
while self._running:
|
||||
line = await asyncio.wait_for(reader.readline(), timeout=60)
|
||||
if not line:
|
||||
break
|
||||
data = decode_control(line)
|
||||
if data.get("type") == "open_stream":
|
||||
stream_id = int(data.get("stream_id", 0))
|
||||
local_port = int(data.get("local_port", self.local_port))
|
||||
asyncio.create_task(
|
||||
self._bridge_stream(stream_id, local_port)
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
async def _bridge_stream(self, stream_id: int, local_port: int) -> None:
|
||||
"""连接本地 local_port,再连服务端 stream 端口,发送 stream_id 后双向桥接。"""
|
||||
try:
|
||||
local_reader, local_writer = await asyncio.open_connection(
|
||||
"127.0.0.1", local_port
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("隧道打开本地 %s 失败: %s", local_port, e)
|
||||
return
|
||||
|
||||
try:
|
||||
stream_reader, stream_writer = await asyncio.open_connection(
|
||||
self.server_host, self.stream_port
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("隧道连流端口失败: %s", e)
|
||||
local_writer.close()
|
||||
await local_writer.wait_closed()
|
||||
return
|
||||
|
||||
try:
|
||||
stream_writer.write(stream_id.to_bytes(4, "big"))
|
||||
await stream_writer.drain()
|
||||
except Exception as e:
|
||||
logger.warning("隧道发送 stream_id 失败: %s", e)
|
||||
stream_writer.close()
|
||||
await stream_writer.wait_closed()
|
||||
local_writer.close()
|
||||
await local_writer.wait_closed()
|
||||
return
|
||||
|
||||
async def pipe(
|
||||
src: asyncio.StreamReader, dst: asyncio.StreamWriter
|
||||
) -> None:
|
||||
try:
|
||||
while True:
|
||||
data = await src.read(65536)
|
||||
if not data:
|
||||
break
|
||||
dst.write(data)
|
||||
await dst.drain()
|
||||
except (ConnectionResetError, BrokenPipeError, asyncio.CancelledError):
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
dst.close()
|
||||
await dst.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.gather(
|
||||
pipe(local_reader, stream_writer),
|
||||
pipe(stream_reader, local_writer),
|
||||
)
|
||||
logger.debug("隧道流 %s 已关闭", stream_id)
|
||||
32
tunnel/protocol.py
Normal file
32
tunnel/protocol.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
隧道控制面协议:服务端与客户端之间 JSON 文本行。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
|
||||
def register_message(worker_id: str, local_port: int) -> dict:
|
||||
"""客户端注册:暴露的 worker_id 与本地端口。"""
|
||||
return {"type": "register", "worker_id": worker_id, "local_port": local_port}
|
||||
|
||||
|
||||
def register_ack_message(proxy_port: int) -> dict:
|
||||
"""服务端确认:分配给的代理端口。"""
|
||||
return {"type": "register_ack", "proxy_port": proxy_port}
|
||||
|
||||
|
||||
def open_stream_message(stream_id: int, local_port: int) -> dict:
|
||||
"""服务端通知客户端:请为 stream_id 打开本地 local_port 并连到 stream 端口。"""
|
||||
return {"type": "open_stream", "stream_id": stream_id, "local_port": local_port}
|
||||
|
||||
|
||||
def encode_control(msg: dict) -> bytes:
|
||||
"""编码为一行 JSON + 换行。"""
|
||||
return (json.dumps(msg, ensure_ascii=False) + "\n").encode("utf-8")
|
||||
|
||||
|
||||
def decode_control(line: bytes) -> dict:
|
||||
"""解码一行 JSON。"""
|
||||
return json.loads(line.decode("utf-8").strip())
|
||||
247
tunnel/server.py
Normal file
247
tunnel/server.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
隧道服务端:运行在云服务器。
|
||||
- 控制端口:接受隧道客户端连接,注册 worker_id。
|
||||
- 流端口:客户端连此端口做流桥接。
|
||||
- 为每个 worker 分配一个代理端口,外部连代理端口时转发到该 worker 的本地端口。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from tunnel.protocol import (
|
||||
decode_control,
|
||||
encode_control,
|
||||
open_stream_message,
|
||||
register_ack_message,
|
||||
)
|
||||
|
||||
logger = logging.getLogger("tunnel.server")
|
||||
|
||||
|
||||
class TunnelServer:
|
||||
def __init__(
|
||||
self,
|
||||
control_port: int,
|
||||
stream_port: int,
|
||||
proxy_base_port: int,
|
||||
host: str = "0.0.0.0",
|
||||
) -> None:
|
||||
self.control_port = control_port
|
||||
self.stream_port = stream_port
|
||||
self.proxy_base_port = proxy_base_port
|
||||
self.host = host
|
||||
self._worker_control: dict[str, asyncio.StreamWriter] = {}
|
||||
self._worker_local_port: dict[str, int] = {}
|
||||
self._worker_proxy_port: dict[str, int] = {}
|
||||
self._proxy_worker: dict[int, str] = {}
|
||||
self._next_proxy_port = proxy_base_port
|
||||
self._stream_queues: dict[int, asyncio.Queue[tuple[asyncio.StreamReader, asyncio.StreamWriter]]] = {}
|
||||
self._stream_id = 0
|
||||
self._lock = asyncio.Lock()
|
||||
self._control_server: Optional[asyncio.Server] = None
|
||||
self._stream_server: Optional[asyncio.Server] = None
|
||||
self._proxy_servers: dict[int, asyncio.Server] = {}
|
||||
|
||||
async def start(self) -> None:
|
||||
self._control_server = await asyncio.start_server(
|
||||
self._handle_control_connection, self.host, self.control_port
|
||||
)
|
||||
self._stream_server = await asyncio.start_server(
|
||||
self._handle_stream_connection, self.host, self.stream_port
|
||||
)
|
||||
logger.info(
|
||||
"隧道服务端: 控制端口 %s:%s, 流端口 %s:%s",
|
||||
self.host, self.control_port, self.host, self.stream_port,
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
for s in self._proxy_servers.values():
|
||||
s.close()
|
||||
await s.wait_closed()
|
||||
self._proxy_servers.clear()
|
||||
if self._stream_server:
|
||||
self._stream_server.close()
|
||||
await self._stream_server.wait_closed()
|
||||
if self._control_server:
|
||||
self._control_server.close()
|
||||
await self._control_server.wait_closed()
|
||||
for w in self._worker_control.values():
|
||||
try:
|
||||
w.close()
|
||||
await w.drain()
|
||||
except Exception:
|
||||
pass
|
||||
self._worker_control.clear()
|
||||
|
||||
def _alloc_proxy_port(self, worker_id: str, local_port: int) -> int:
|
||||
port = self._next_proxy_port
|
||||
self._next_proxy_port += 1
|
||||
self._worker_proxy_port[worker_id] = port
|
||||
self._worker_local_port[worker_id] = local_port
|
||||
self._proxy_worker[port] = worker_id
|
||||
return port
|
||||
|
||||
async def _handle_control_connection(
|
||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
peer = writer.get_extra_info("peername", ("?", "?"))[0]
|
||||
worker_id: Optional[str] = None
|
||||
try:
|
||||
line = await asyncio.wait_for(reader.readline(), timeout=30)
|
||||
if not line:
|
||||
return
|
||||
msg = decode_control(line)
|
||||
if msg.get("type") != "register":
|
||||
writer.write(encode_control({"type": "error", "detail": "首条消息必须是 register"}))
|
||||
await writer.drain()
|
||||
return
|
||||
worker_id = msg.get("worker_id") or ""
|
||||
local_port = int(msg.get("local_port", 54345))
|
||||
if not worker_id:
|
||||
writer.write(encode_control({"type": "error", "detail": "worker_id 不能为空"}))
|
||||
await writer.drain()
|
||||
return
|
||||
|
||||
async with self._lock:
|
||||
if worker_id in self._worker_control:
|
||||
try:
|
||||
self._worker_control[worker_id].close()
|
||||
await self._worker_control[worker_id].drain()
|
||||
except Exception:
|
||||
pass
|
||||
proxy_port = self._alloc_proxy_port(worker_id, local_port)
|
||||
self._worker_control[worker_id] = writer
|
||||
|
||||
writer.write(encode_control(register_ack_message(proxy_port)))
|
||||
await writer.drain()
|
||||
logger.info("隧道客户端注册: worker_id=%s, local_port=%s, proxy_port=%s", worker_id, local_port, proxy_port)
|
||||
|
||||
# 为该 worker 启动代理监听(若尚未监听)
|
||||
await self._ensure_proxy_listener(proxy_port, worker_id, local_port)
|
||||
|
||||
# 保持连接并处理 open_stream 的响应(本实现里服务端只发 open_stream,不期待客户端回控制消息)
|
||||
while True:
|
||||
line = await reader.readline()
|
||||
if not line:
|
||||
break
|
||||
# 可扩展:心跳、其他控制消息
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("控制连接超时: %s", peer)
|
||||
except Exception as e:
|
||||
logger.exception("控制连接异常: %s", e)
|
||||
finally:
|
||||
if worker_id:
|
||||
async with self._lock:
|
||||
if self._worker_control.get(worker_id) is writer:
|
||||
del self._worker_control[worker_id]
|
||||
if worker_id in self._worker_proxy_port:
|
||||
port = self._worker_proxy_port.pop(worker_id)
|
||||
self._proxy_worker.pop(port, None)
|
||||
self._worker_local_port.pop(worker_id, None)
|
||||
srv = self._proxy_servers.pop(port, None)
|
||||
if srv:
|
||||
srv.close()
|
||||
await srv.wait_closed()
|
||||
try:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _ensure_proxy_listener(self, proxy_port: int, worker_id: str, local_port: int) -> None:
|
||||
if proxy_port in self._proxy_servers:
|
||||
return
|
||||
|
||||
async def on_proxy_connection(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
|
||||
await self._bridge_user_to_worker(proxy_port, worker_id, local_port, reader, writer)
|
||||
|
||||
self._proxy_servers[proxy_port] = await asyncio.start_server(
|
||||
on_proxy_connection, self.host, proxy_port
|
||||
)
|
||||
logger.info("隧道代理监听: %s:%s -> worker %s local:%s", self.host, proxy_port, worker_id, local_port)
|
||||
|
||||
async def _bridge_user_to_worker(
|
||||
self,
|
||||
proxy_port: int,
|
||||
worker_id: str,
|
||||
local_port: int,
|
||||
user_reader: asyncio.StreamReader,
|
||||
user_writer: asyncio.StreamWriter,
|
||||
) -> None:
|
||||
stream_id = await self._alloc_stream_id()
|
||||
control_writer = self._worker_control.get(worker_id)
|
||||
if not control_writer or control_writer.is_closing():
|
||||
user_writer.close()
|
||||
await user_writer.wait_closed()
|
||||
return
|
||||
|
||||
queue: asyncio.Queue[tuple[asyncio.StreamReader, asyncio.StreamWriter]] = asyncio.Queue(maxsize=1)
|
||||
self._stream_queues[stream_id] = queue
|
||||
|
||||
try:
|
||||
control_writer.write(encode_control(open_stream_message(stream_id, local_port)))
|
||||
await control_writer.drain()
|
||||
except Exception as e:
|
||||
logger.warning("发送 open_stream 失败: %s", e)
|
||||
self._stream_queues.pop(stream_id, None)
|
||||
user_writer.close()
|
||||
await user_writer.wait_closed()
|
||||
return
|
||||
|
||||
try:
|
||||
client_reader, client_writer = await asyncio.wait_for(queue.get(), timeout=15)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("流 %s 等待客户端连接超时", stream_id)
|
||||
self._stream_queues.pop(stream_id, None)
|
||||
user_writer.close()
|
||||
await user_writer.wait_closed()
|
||||
return
|
||||
self._stream_queues.pop(stream_id, None)
|
||||
|
||||
async def pipe(src: asyncio.StreamReader, dst: asyncio.StreamWriter) -> None:
|
||||
try:
|
||||
while True:
|
||||
data = await src.read(65536)
|
||||
if not data:
|
||||
break
|
||||
dst.write(data)
|
||||
await dst.drain()
|
||||
except (ConnectionResetError, BrokenPipeError, asyncio.CancelledError):
|
||||
pass
|
||||
finally:
|
||||
try:
|
||||
dst.close()
|
||||
await dst.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.gather(
|
||||
pipe(user_reader, client_writer),
|
||||
pipe(client_reader, user_writer),
|
||||
)
|
||||
|
||||
async def _handle_stream_connection(
|
||||
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
||||
) -> None:
|
||||
"""客户端连到流端口后发送 stream_id(4 字节大端),然后参与桥接。"""
|
||||
try:
|
||||
sid_bytes = await asyncio.wait_for(reader.readexactly(4), timeout=10)
|
||||
stream_id = int.from_bytes(sid_bytes, "big")
|
||||
queue = self._stream_queues.get(stream_id)
|
||||
if not queue:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
return
|
||||
queue.put_nowait((reader, writer))
|
||||
except (asyncio.TimeoutError, asyncio.IncompleteReadError) as e:
|
||||
logger.warning("流连接读 stream_id 失败: %s", e)
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
|
||||
async def _alloc_stream_id(self) -> int:
|
||||
async with self._lock:
|
||||
self._stream_id += 1
|
||||
return self._stream_id
|
||||
@@ -18,3 +18,26 @@ BIT_API_BASE: str = os.getenv("BIT_API_BASE", "http://127.0.0.1:54345")
|
||||
HEARTBEAT_INTERVAL: int = 25 # 心跳发送间隔(秒)
|
||||
RECONNECT_DELAY: int = 5 # 断线重连等待(秒)
|
||||
RECONNECT_MAX_DELAY: int = 60 # 重连最大等待(秒,指数退避上限)
|
||||
|
||||
# ─── 隧道(内网穿透,与 Worker 集成) ───
|
||||
TUNNEL_ENABLED: bool = os.getenv("TUNNEL_ENABLED", "1").strip().lower() in ("1", "true", "yes")
|
||||
TUNNEL_SERVER: str = os.getenv("TUNNEL_SERVER", "") # 留空则从 SERVER_WS_URL 解析 host
|
||||
TUNNEL_CONTROL_PORT: int = int(os.getenv("TUNNEL_CONTROL_PORT", "8001"))
|
||||
TUNNEL_STREAM_PORT: int = int(os.getenv("TUNNEL_STREAM_PORT", "8003"))
|
||||
TUNNEL_LOCAL_PORT: int = int(os.getenv("TUNNEL_LOCAL_PORT", "54345")) # 暴露的本地端口(如比特浏览器 API)
|
||||
|
||||
|
||||
def get_tunnel_server_host() -> str:
|
||||
"""隧道服务端地址。未设置 TUNNEL_SERVER 时从 SERVER_WS_URL 解析。"""
|
||||
if TUNNEL_SERVER:
|
||||
return TUNNEL_SERVER.split(":")[0] if ":" in TUNNEL_SERVER else TUNNEL_SERVER
|
||||
url = SERVER_WS_URL.strip()
|
||||
for prefix in ("wss://", "ws://"):
|
||||
if url.startswith(prefix):
|
||||
rest = url[len(prefix):]
|
||||
if "/" in rest:
|
||||
rest = rest.split("/")[0]
|
||||
if ":" in rest:
|
||||
return rest.split(":")[0]
|
||||
return rest
|
||||
return "127.0.0.1"
|
||||
|
||||
@@ -13,6 +13,8 @@ import sys
|
||||
from worker import config
|
||||
from worker.tasks.registry import register_all_handlers
|
||||
from worker.ws_client import WorkerWSClient
|
||||
from tunnel.client import TunnelClient
|
||||
from worker.config import get_tunnel_server_host
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@@ -44,15 +46,29 @@ def parse_args():
|
||||
default=config.BIT_API_BASE,
|
||||
help=f"比特浏览器本地 API 地址 (默认: {config.BIT_API_BASE})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-tunnel",
|
||||
action="store_true",
|
||||
help="禁用内网穿透隧道(不与隧道服务端连接)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _local_port_from_bit_api(bit_api_base: str) -> int:
|
||||
"""从 BIT_API_BASE (e.g. http://127.0.0.1:54345) 解析端口。"""
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
p = urlparse(bit_api_base if "://" in bit_api_base else "http://" + bit_api_base)
|
||||
return p.port or 54345
|
||||
except Exception:
|
||||
return config.TUNNEL_LOCAL_PORT
|
||||
|
||||
|
||||
async def run(args):
|
||||
# 注册所有任务处理器
|
||||
register_all_handlers()
|
||||
logger.info("已注册任务处理器")
|
||||
|
||||
# 创建 WebSocket 客户端并运行
|
||||
client = WorkerWSClient(
|
||||
server_url=args.server,
|
||||
worker_id=args.worker_id,
|
||||
@@ -60,15 +76,56 @@ async def run(args):
|
||||
bit_api_base=args.bit_api,
|
||||
)
|
||||
|
||||
tunnel_enabled = config.TUNNEL_ENABLED and not args.no_tunnel
|
||||
tunnel_client = None
|
||||
if tunnel_enabled:
|
||||
tunnel_client = TunnelClient(
|
||||
server_host=get_tunnel_server_host(),
|
||||
control_port=config.TUNNEL_CONTROL_PORT,
|
||||
stream_port=config.TUNNEL_STREAM_PORT,
|
||||
worker_id=args.worker_id,
|
||||
local_port=_local_port_from_bit_api(args.bit_api),
|
||||
)
|
||||
logger.info(
|
||||
"隧道已启用: 暴露本地 %s -> 云服务器 (worker_id=%s)",
|
||||
_local_port_from_bit_api(args.bit_api), args.worker_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Worker 启动: id=%s, name=%s, server=%s",
|
||||
args.worker_id, args.worker_name, args.server,
|
||||
)
|
||||
|
||||
try:
|
||||
async def run_worker():
|
||||
await client.run()
|
||||
|
||||
async def run_tunnel():
|
||||
if tunnel_client:
|
||||
await tunnel_client.run()
|
||||
|
||||
worker_task = asyncio.create_task(run_worker())
|
||||
tunnel_task = asyncio.create_task(run_tunnel()) if tunnel_client else None
|
||||
|
||||
try:
|
||||
if tunnel_task is not None:
|
||||
await asyncio.gather(worker_task, tunnel_task)
|
||||
else:
|
||||
await worker_task
|
||||
except KeyboardInterrupt:
|
||||
logger.info("收到中断信号,正在退出...")
|
||||
worker_task.cancel()
|
||||
if tunnel_task is not None:
|
||||
tunnel_task.cancel()
|
||||
try:
|
||||
await worker_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if tunnel_task is not None:
|
||||
try:
|
||||
await tunnel_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await tunnel_client.stop()
|
||||
await client.stop()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user