# -*- coding: utf-8 -*- """ Worker 启动入口。 启动方式: python -m worker.main [--server ws://IP:8000/ws] [--worker-id pc-a] [--worker-name 电脑A] """ from __future__ import annotations import argparse import asyncio import logging import sys from worker import config from worker.tasks.registry import register_all_handlers from worker.ws_client import WorkerWSClient from tunnel.client import TunnelClient logging.basicConfig( level=logging.INFO, format="%(asctime)s %(name)-28s %(levelname)-5s %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) logger = logging.getLogger("worker.main") def parse_args(): parser = argparse.ArgumentParser(description="Browser Control Worker Agent") parser.add_argument("--worker", action="store_true", help=argparse.SUPPRESS) parser.add_argument( "--server", default=config.SERVER_WS_URL, help=f"中央服务器 WebSocket 地址 (默认: {config.SERVER_WS_URL})", ) parser.add_argument( "--worker-id", default=config.WORKER_ID, help=f"Worker ID (默认: {config.WORKER_ID})", ) parser.add_argument( "--worker-name", default=config.WORKER_NAME, help=f"Worker 名称 (默认: {config.WORKER_NAME})", ) parser.add_argument( "--bit-api", default=config.BIT_API_BASE, help=f"比特浏览器本地 API 地址 (默认: {config.BIT_API_BASE})", ) parser.add_argument( "--no-tunnel", action="store_true", help="禁用内网穿透隧道(不与隧道服务端连接)", ) return parser.parse_args() def _local_port_from_bit_api(bit_api_base: str) -> int: """从 BIT_API_BASE (e.g. http://127.0.0.1:54345) 解析端口。""" try: from urllib.parse import urlparse p = urlparse(bit_api_base if "://" in bit_api_base else "http://" + bit_api_base) return p.port or 54345 except Exception: return config.TUNNEL_LOCAL_PORT def _extract_host_from_ws_url(ws_url: str) -> str: """从 WebSocket URL (如 ws://8.137.99.82:9000/ws) 中提取 host。""" try: from urllib.parse import urlparse p = urlparse(ws_url) return p.hostname or "127.0.0.1" except Exception: return "127.0.0.1" async def run(args): # 注册所有任务处理器 register_all_handlers() logger.info("已注册任务处理器") client = WorkerWSClient( server_url=args.server, worker_id=args.worker_id, worker_name=args.worker_name, bit_api_base=args.bit_api, ) tunnel_enabled = config.TUNNEL_ENABLED and not args.no_tunnel tunnel_client = None if tunnel_enabled: # 从命令行 --server 参数中提取云服务器 host(而非配置文件默认值) tunnel_host = _extract_host_from_ws_url(args.server) tunnel_client = TunnelClient( server_host=tunnel_host, control_port=config.TUNNEL_CONTROL_PORT, stream_port=config.TUNNEL_STREAM_PORT, worker_id=args.worker_id, local_port=_local_port_from_bit_api(args.bit_api), ) logger.info( "隧道已启用: 暴露本地 %s -> %s (worker_id=%s)", _local_port_from_bit_api(args.bit_api), tunnel_host, args.worker_id, ) logger.info( "Worker 启动: id=%s, name=%s, server=%s", args.worker_id, args.worker_name, args.server, ) async def run_worker(): await client.run() async def run_tunnel(): if tunnel_client: await tunnel_client.run() worker_task = asyncio.create_task(run_worker()) tunnel_task = asyncio.create_task(run_tunnel()) if tunnel_client else None try: if tunnel_task is not None: await asyncio.gather(worker_task, tunnel_task) else: await worker_task except KeyboardInterrupt: logger.info("收到中断信号,正在退出...") worker_task.cancel() if tunnel_task is not None: tunnel_task.cancel() try: await worker_task except asyncio.CancelledError: pass if tunnel_task is not None: try: await tunnel_task except asyncio.CancelledError: pass await tunnel_client.stop() await client.stop() def main(): args = parse_args() try: asyncio.run(run(args)) except KeyboardInterrupt: logger.info("Worker 已退出") if __name__ == "__main__": main()