# -*- coding: utf-8 -*- """ Worker WebSocket 客户端。 负责:连接服务器、注册、心跳、接收任务、上报进度/结果、断线重连。 """ from __future__ import annotations import asyncio import json import logging import time from typing import List, Optional import websockets from websockets.exceptions import ConnectionClosed from common.protocol import MsgType, TaskStatus, make_msg from worker import config from worker.bit_browser import BitBrowserAPI from worker.tasks.registry import get_handler logger = logging.getLogger("worker.ws_client") class WorkerWSClient: """Worker WebSocket 客户端。""" def __init__( self, server_url: str, worker_id: str, worker_name: str, bit_api_base: str, ) -> None: self.server_url = server_url self.worker_id = worker_id self.worker_name = worker_name self.bit_api = BitBrowserAPI(bit_api_base) self._ws: Optional[websockets.WebSocketClientProtocol] = None self._running = False self._reconnect_delay = config.RECONNECT_DELAY self._last_browsers: List[dict] = [] self._heartbeat_count = 0 # ────────────────────────── 主循环 ────────────────────────── async def run(self) -> None: """主循环:连接 → 注册 → 收发消息;断线则重连。""" self._running = True while self._running: try: await self._connect_and_loop() except Exception as e: logger.error("连接异常: %s", e) if self._running: logger.info("将在 %d 秒后重连...", self._reconnect_delay) await asyncio.sleep(self._reconnect_delay) # 指数退避 self._reconnect_delay = min( self._reconnect_delay * 2, config.RECONNECT_MAX_DELAY, ) async def stop(self) -> None: self._running = False if self._ws: await self._ws.close() # ────────────────────────── 连接流程 ────────────────────────── async def _connect_and_loop(self) -> None: logger.info("正在连接服务器: %s", self.server_url) async with websockets.connect(self.server_url) as ws: self._ws = ws self._reconnect_delay = config.RECONNECT_DELAY # 连接成功重置退避 logger.info("WebSocket 已连接") # 注册 await self._register(ws) # 启动心跳协程 heartbeat_task = asyncio.create_task(self._heartbeat_loop(ws)) try: # 消息接收循环 async for raw in ws: data = json.loads(raw) await self._handle_message(ws, data) except ConnectionClosed as e: logger.warning("WebSocket 连接关闭: %s", e) finally: heartbeat_task.cancel() self._ws = None async def _register(self, ws) -> None: """发送注册消息。""" browsers = self._fetch_browser_list() self._last_browsers = browsers msg = make_msg( MsgType.REGISTER, worker_id=self.worker_id, worker_name=self.worker_name, browsers=browsers, ) await ws.send(json.dumps(msg)) # 等待 ACK ack_raw = await asyncio.wait_for(ws.recv(), timeout=10) ack = json.loads(ack_raw) if ack.get("type") == MsgType.REGISTER_ACK.value: logger.info("注册成功: worker_id=%s", self.worker_id) else: logger.error("注册失败: %s", ack) raise RuntimeError(f"注册失败: {ack}") # ────────────────────────── 心跳 ────────────────────────── async def _heartbeat_loop(self, ws) -> None: """定期发送心跳;每 3 次心跳检查并上报浏览器列表变更。""" while True: try: await asyncio.sleep(config.HEARTBEAT_INTERVAL) msg = make_msg(MsgType.HEARTBEAT, worker_id=self.worker_id) await ws.send(json.dumps(msg)) logger.debug("心跳已发送") self._heartbeat_count += 1 if self._heartbeat_count >= 3: self._heartbeat_count = 0 await self._maybe_send_browser_list_update(ws) except Exception: break async def _maybe_send_browser_list_update(self, ws) -> None: """拉取比特浏览器列表,若有变化则发送 BROWSER_LIST_UPDATE。""" try: new_list = self._fetch_browser_list() key = lambda b: (b.get("id", ""), b.get("name", "")) if sorted(new_list, key=key) != sorted(self._last_browsers, key=key): self._last_browsers = new_list msg = make_msg(MsgType.BROWSER_LIST_UPDATE, browsers=new_list) await ws.send(json.dumps(msg)) logger.info("浏览器列表变更,已上报 %d 个环境", len(new_list)) except Exception as e: logger.debug("检查浏览器列表变更失败: %s", e) # ────────────────────────── 消息处理 ────────────────────────── async def _handle_message(self, ws, data: dict) -> None: msg_type = data.get("type", "") if msg_type == MsgType.HEARTBEAT_ACK.value: logger.debug("收到心跳 ACK") elif msg_type == MsgType.TASK_ASSIGN.value: await self._handle_task(ws, data) elif msg_type == MsgType.TASK_CANCEL.value: task_id = data.get("task_id", "") logger.info("收到任务取消: %s(暂不支持中途取消)", task_id) elif msg_type == MsgType.ERROR.value: logger.error("服务器错误: %s", data.get("detail", "")) else: logger.warning("未知消息: %s", msg_type) async def _handle_task(self, ws, data: dict) -> None: """接收并执行任务。""" task_id = data.get("task_id", "") task_type = data.get("task_type", "") account_name = data.get("account_name", "") params = data.get("params", {}) # 将 account_name 注入 params(供 handler 使用) if account_name: params.setdefault("account_name", account_name) params.setdefault("bit_api_base", self.bit_api.base_url) logger.info("收到任务: %s (type=%s)", task_id, task_type) handler = get_handler(task_type) if not handler: error_msg = f"不支持的任务类型: {task_type}" logger.error(error_msg) await self._send_result(ws, task_id, error=error_msg) return # 上报进度的回调 async def progress_cb(tid: str, progress: str): msg = make_msg(MsgType.TASK_PROGRESS, task_id=tid, progress=progress) try: await ws.send(json.dumps(msg)) except Exception: pass # 执行任务 try: result = await handler.execute(task_id, params, progress_cb) await self._send_result(ws, task_id, result=result) except Exception as e: logger.error("任务 %s 执行失败: %s", task_id, e, exc_info=True) await self._send_result(ws, task_id, error=str(e)) async def _send_result(self, ws, task_id: str, result=None, error: str = None) -> None: """上报任务最终结果。""" msg = make_msg(MsgType.TASK_RESULT, task_id=task_id, result=result, error=error) try: await ws.send(json.dumps(msg)) logger.info("任务 %s 结果已上报 (error=%s)", task_id, error) except Exception as e: logger.error("上报结果失败: %s", e) # ────────────────────────── 比特浏览器列表 ────────────────────────── def _fetch_browser_list(self) -> List[dict]: """获取本机比特浏览器窗口列表(用于注册时上报)。""" try: items = self.bit_api.list_browsers() return [ {"id": b.get("id", ""), "name": b.get("name", ""), "remark": b.get("remark", "")} for b in items ] except Exception as e: logger.warning("获取比特浏览器列表失败(比特浏览器可能未启动): %s", e) return []