Files
boss_dp/worker/main.py
2026-02-26 20:55:59 +08:00

154 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Worker 启动入口。
启动方式: python -m worker.main [--server ws://IP:8000/ws] [--worker-id pc-a] [--worker-name 电脑A]
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import sys
from worker import config
from worker.tasks.registry import register_all_handlers
from worker.ws_client import WorkerWSClient
from tunnel.client import TunnelClient
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(name)-28s %(levelname)-5s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("worker.main")
def parse_args():
parser = argparse.ArgumentParser(description="Browser Control Worker Agent")
parser.add_argument("--worker", action="store_true", help=argparse.SUPPRESS)
parser.add_argument(
"--server",
default=config.SERVER_WS_URL,
help=f"中央服务器 WebSocket 地址 (默认: {config.SERVER_WS_URL})",
)
parser.add_argument(
"--worker-id",
default=config.WORKER_ID,
help=f"Worker ID (默认: {config.WORKER_ID})",
)
parser.add_argument(
"--worker-name",
default=config.WORKER_NAME,
help=f"Worker 名称 (默认: {config.WORKER_NAME})",
)
parser.add_argument(
"--bit-api",
default=config.BIT_API_BASE,
help=f"比特浏览器本地 API 地址 (默认: {config.BIT_API_BASE})",
)
parser.add_argument(
"--no-tunnel",
action="store_true",
help="禁用内网穿透隧道(不与隧道服务端连接)",
)
return parser.parse_args()
def _local_port_from_bit_api(bit_api_base: str) -> int:
"""从 BIT_API_BASE (e.g. http://127.0.0.1:54345) 解析端口。"""
try:
from urllib.parse import urlparse
p = urlparse(bit_api_base if "://" in bit_api_base else "http://" + bit_api_base)
return p.port or 54345
except Exception:
return config.TUNNEL_LOCAL_PORT
def _extract_host_from_ws_url(ws_url: str) -> str:
"""从 WebSocket URL (如 ws://8.137.99.82:9000/ws) 中提取 host。"""
try:
from urllib.parse import urlparse
p = urlparse(ws_url)
return p.hostname or "127.0.0.1"
except Exception:
return "127.0.0.1"
async def run(args):
# 注册所有任务处理器
register_all_handlers()
logger.info("已注册任务处理器")
client = WorkerWSClient(
server_url=args.server,
worker_id=args.worker_id,
worker_name=args.worker_name,
bit_api_base=args.bit_api,
)
tunnel_enabled = config.TUNNEL_ENABLED and not args.no_tunnel
tunnel_client = None
if tunnel_enabled:
# 从命令行 --server 参数中提取云服务器 host而非配置文件默认值
tunnel_host = _extract_host_from_ws_url(args.server)
tunnel_client = TunnelClient(
server_host=tunnel_host,
control_port=config.TUNNEL_CONTROL_PORT,
stream_port=config.TUNNEL_STREAM_PORT,
worker_id=args.worker_id,
local_port=_local_port_from_bit_api(args.bit_api),
)
logger.info(
"隧道已启用: 暴露本地 %s -> %s (worker_id=%s)",
_local_port_from_bit_api(args.bit_api), tunnel_host, args.worker_id,
)
logger.info(
"Worker 启动: id=%s, name=%s, server=%s",
args.worker_id, args.worker_name, args.server,
)
async def run_worker():
await client.run()
async def run_tunnel():
if tunnel_client:
await tunnel_client.run()
worker_task = asyncio.create_task(run_worker())
tunnel_task = asyncio.create_task(run_tunnel()) if tunnel_client else None
try:
if tunnel_task is not None:
await asyncio.gather(worker_task, tunnel_task)
else:
await worker_task
except KeyboardInterrupt:
logger.info("收到中断信号,正在退出...")
worker_task.cancel()
if tunnel_task is not None:
tunnel_task.cancel()
try:
await worker_task
except asyncio.CancelledError:
pass
if tunnel_task is not None:
try:
await tunnel_task
except asyncio.CancelledError:
pass
await tunnel_client.stop()
await client.stop()
def main():
args = parse_args()
try:
asyncio.run(run(args))
except KeyboardInterrupt:
logger.info("Worker 已退出")
if __name__ == "__main__":
main()