154 lines
4.5 KiB
Python
154 lines
4.5 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
Worker 启动入口。
|
||
启动方式: python -m worker.main [--server ws://IP:8000/ws] [--worker-id pc-a] [--worker-name 电脑A]
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import asyncio
|
||
import logging
|
||
import sys
|
||
|
||
from worker import config
|
||
from worker.tasks.registry import register_all_handlers
|
||
from worker.ws_client import WorkerWSClient
|
||
from tunnel.client import TunnelClient
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s %(name)-28s %(levelname)-5s %(message)s",
|
||
datefmt="%Y-%m-%d %H:%M:%S",
|
||
)
|
||
logger = logging.getLogger("worker.main")
|
||
|
||
|
||
def parse_args():
|
||
parser = argparse.ArgumentParser(description="Browser Control Worker Agent")
|
||
parser.add_argument("--worker", action="store_true", help=argparse.SUPPRESS)
|
||
parser.add_argument(
|
||
"--server",
|
||
default=config.SERVER_WS_URL,
|
||
help=f"中央服务器 WebSocket 地址 (默认: {config.SERVER_WS_URL})",
|
||
)
|
||
parser.add_argument(
|
||
"--worker-id",
|
||
default=config.WORKER_ID,
|
||
help=f"Worker ID (默认: {config.WORKER_ID})",
|
||
)
|
||
parser.add_argument(
|
||
"--worker-name",
|
||
default=config.WORKER_NAME,
|
||
help=f"Worker 名称 (默认: {config.WORKER_NAME})",
|
||
)
|
||
parser.add_argument(
|
||
"--bit-api",
|
||
default=config.BIT_API_BASE,
|
||
help=f"比特浏览器本地 API 地址 (默认: {config.BIT_API_BASE})",
|
||
)
|
||
parser.add_argument(
|
||
"--no-tunnel",
|
||
action="store_true",
|
||
help="禁用内网穿透隧道(不与隧道服务端连接)",
|
||
)
|
||
return parser.parse_args()
|
||
|
||
|
||
def _local_port_from_bit_api(bit_api_base: str) -> int:
|
||
"""从 BIT_API_BASE (e.g. http://127.0.0.1:54345) 解析端口。"""
|
||
try:
|
||
from urllib.parse import urlparse
|
||
p = urlparse(bit_api_base if "://" in bit_api_base else "http://" + bit_api_base)
|
||
return p.port or 54345
|
||
except Exception:
|
||
return config.TUNNEL_LOCAL_PORT
|
||
|
||
|
||
def _extract_host_from_ws_url(ws_url: str) -> str:
|
||
"""从 WebSocket URL (如 ws://8.137.99.82:9000/ws) 中提取 host。"""
|
||
try:
|
||
from urllib.parse import urlparse
|
||
p = urlparse(ws_url)
|
||
return p.hostname or "127.0.0.1"
|
||
except Exception:
|
||
return "127.0.0.1"
|
||
|
||
|
||
async def run(args):
|
||
# 注册所有任务处理器
|
||
register_all_handlers()
|
||
logger.info("已注册任务处理器")
|
||
|
||
client = WorkerWSClient(
|
||
server_url=args.server,
|
||
worker_id=args.worker_id,
|
||
worker_name=args.worker_name,
|
||
bit_api_base=args.bit_api,
|
||
)
|
||
|
||
tunnel_enabled = config.TUNNEL_ENABLED and not args.no_tunnel
|
||
tunnel_client = None
|
||
if tunnel_enabled:
|
||
# 从命令行 --server 参数中提取云服务器 host(而非配置文件默认值)
|
||
tunnel_host = _extract_host_from_ws_url(args.server)
|
||
tunnel_client = TunnelClient(
|
||
server_host=tunnel_host,
|
||
control_port=config.TUNNEL_CONTROL_PORT,
|
||
stream_port=config.TUNNEL_STREAM_PORT,
|
||
worker_id=args.worker_id,
|
||
local_port=_local_port_from_bit_api(args.bit_api),
|
||
)
|
||
logger.info(
|
||
"隧道已启用: 暴露本地 %s -> %s (worker_id=%s)",
|
||
_local_port_from_bit_api(args.bit_api), tunnel_host, args.worker_id,
|
||
)
|
||
|
||
logger.info(
|
||
"Worker 启动: id=%s, name=%s, server=%s",
|
||
args.worker_id, args.worker_name, args.server,
|
||
)
|
||
|
||
async def run_worker():
|
||
await client.run()
|
||
|
||
async def run_tunnel():
|
||
if tunnel_client:
|
||
await tunnel_client.run()
|
||
|
||
worker_task = asyncio.create_task(run_worker())
|
||
tunnel_task = asyncio.create_task(run_tunnel()) if tunnel_client else None
|
||
|
||
try:
|
||
if tunnel_task is not None:
|
||
await asyncio.gather(worker_task, tunnel_task)
|
||
else:
|
||
await worker_task
|
||
except KeyboardInterrupt:
|
||
logger.info("收到中断信号,正在退出...")
|
||
worker_task.cancel()
|
||
if tunnel_task is not None:
|
||
tunnel_task.cancel()
|
||
try:
|
||
await worker_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
if tunnel_task is not None:
|
||
try:
|
||
await tunnel_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
await tunnel_client.stop()
|
||
await client.stop()
|
||
|
||
|
||
def main():
|
||
args = parse_args()
|
||
try:
|
||
asyncio.run(run(args))
|
||
except KeyboardInterrupt:
|
||
logger.info("Worker 已退出")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|