Files
boss_dp/worker/ws_client.py
2026-03-06 10:05:49 +08:00

338 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Worker WebSocket 客户端。
负责:连接服务器、注册、心跳、接收任务、上报进度/结果、断线重连。
"""
from __future__ import annotations
import asyncio
import json
import logging
import threading
import time
from typing import Dict, List, Optional, Set
import websockets
from websockets.exceptions import ConnectionClosed
from common.protocol import MsgType, make_msg
from worker import config
from worker.bit_browser import BitBrowserAPI
from worker.tasks.base import TaskCancelledError
from worker.tasks.registry import get_handler
logger = logging.getLogger("worker.ws_client")
class WorkerWSClient:
"""Worker WebSocket 客户端。"""
def __init__(
self,
server_url: str,
worker_id: str,
worker_name: str,
bit_api_base: str,
) -> None:
self.server_url = server_url
self.worker_id = worker_id
self.worker_name = worker_name
self.bit_api = BitBrowserAPI(bit_api_base)
self._ws: Optional[websockets.WebSocketClientProtocol] = None
self._running = False
self._reconnect_delay = config.RECONNECT_DELAY
self._last_browsers: List[dict] = []
self._heartbeat_count = 0
self._cancel_events: Dict[str, threading.Event] = {}
self._cancelled_tasks: Set[str] = set()
self._task_runners: Dict[str, asyncio.Task] = {}
self._send_lock: Optional[asyncio.Lock] = None
# ────────────────────────── 主循环 ──────────────────────────
async def run(self) -> None:
"""主循环:连接 → 注册 → 收发消息;断线则重连。"""
self._running = True
while self._running:
try:
await self._connect_and_loop()
except Exception as e:
logger.error("连接异常: %s", e)
if self._running:
logger.info("将在 %d 秒后重连...", self._reconnect_delay)
await asyncio.sleep(self._reconnect_delay)
# 指数退避
self._reconnect_delay = min(
self._reconnect_delay * 2,
config.RECONNECT_MAX_DELAY,
)
async def stop(self) -> None:
self._running = False
if self._ws:
await self._ws.close()
# ────────────────────────── 连接流程 ──────────────────────────
async def _connect_and_loop(self) -> None:
logger.info("正在连接服务器: %s", self.server_url)
async with websockets.connect(self.server_url) as ws:
self._ws = ws
self._send_lock = asyncio.Lock()
self._reconnect_delay = config.RECONNECT_DELAY # 连接成功重置退避
logger.info("WebSocket 已连接")
# 注册
await self._register(ws)
# 启动心跳协程
heartbeat_task = asyncio.create_task(self._heartbeat_loop(ws))
try:
# 消息接收循环
async for raw in ws:
data = json.loads(raw)
await self._handle_message(ws, data)
except ConnectionClosed as e:
logger.warning("WebSocket 连接关闭: %s", e)
finally:
heartbeat_task.cancel()
self._send_lock = None
for event in list(self._cancel_events.values()):
event.set()
for runner in list(self._task_runners.values()):
runner.cancel()
if self._task_runners:
await asyncio.gather(*self._task_runners.values(), return_exceptions=True)
self._task_runners.clear()
self._ws = None
async def _register(self, ws) -> None:
"""发送注册消息。"""
browsers = self._fetch_browser_list()
self._last_browsers = browsers
msg = make_msg(
MsgType.REGISTER,
worker_id=self.worker_id,
worker_name=self.worker_name,
browsers=browsers,
)
await self._safe_send(ws, msg)
# 等待 ACK
ack_raw = await asyncio.wait_for(ws.recv(), timeout=10)
ack = json.loads(ack_raw)
if ack.get("type") == MsgType.REGISTER_ACK.value:
logger.info("注册成功: worker_id=%s", self.worker_id)
else:
logger.error("注册失败: %s", ack)
raise RuntimeError(f"注册失败: {ack}")
# ────────────────────────── 心跳 ──────────────────────────
async def _heartbeat_loop(self, ws) -> None:
"""定期发送心跳;每 3 次心跳检查并上报浏览器列表变更。"""
while True:
try:
await asyncio.sleep(config.HEARTBEAT_INTERVAL)
msg = make_msg(MsgType.HEARTBEAT, worker_id=self.worker_id)
await self._safe_send(ws, msg)
logger.debug("心跳已发送")
self._heartbeat_count += 1
if self._heartbeat_count >= 3:
self._heartbeat_count = 0
await self._maybe_send_browser_list_update(ws)
except Exception:
break
async def _maybe_send_browser_list_update(self, ws) -> None:
"""拉取比特浏览器列表,若有变化则发送 BROWSER_LIST_UPDATE。"""
try:
new_list = self._fetch_browser_list()
key = lambda b: (b.get("id", ""), b.get("name", ""))
if sorted(new_list, key=key) != sorted(self._last_browsers, key=key):
self._last_browsers = new_list
msg = make_msg(MsgType.BROWSER_LIST_UPDATE, browsers=new_list)
await self._safe_send(ws, msg)
logger.info("浏览器列表变更,已上报 %d 个环境", len(new_list))
except Exception as e:
logger.debug("检查浏览器列表变更失败: %s", e)
# ────────────────────────── 消息处理 ──────────────────────────
async def _handle_message(self, ws, data: dict) -> None:
msg_type = data.get("type", "")
if msg_type == MsgType.HEARTBEAT_ACK.value:
logger.debug("收到心跳 ACK")
elif msg_type == MsgType.TASK_ASSIGN.value:
await self._start_task(ws, data)
elif msg_type == MsgType.TASK_CANCEL.value:
task_id = data.get("task_id", "")
await self._handle_task_cancel(task_id)
elif msg_type == MsgType.TASK_STATUS_QUERY.value:
await self._handle_task_status_query(ws, data)
elif msg_type == MsgType.ERROR.value:
logger.error("服务器错误: %s", data.get("detail", ""))
else:
logger.warning("未知消息: %s", msg_type)
async def _start_task(self, ws, data: dict) -> None:
"""启动任务执行协程,不阻塞消息接收循环。"""
task_id = str(data.get("task_id", "")).strip()
if not task_id:
logger.warning("收到无 task_id 的任务消息,已忽略")
return
runner = self._task_runners.get(task_id)
if runner and not runner.done():
logger.warning("任务 %s 已在执行中,忽略重复派发", task_id)
return
self._task_runners[task_id] = asyncio.create_task(self._handle_task(ws, data))
async def _handle_task(self, ws, data: dict) -> None:
"""接收并执行任务。"""
task_id = str(data.get("task_id", "")).strip()
task_type = str(data.get("task_type", "")).strip()
account_name = str(data.get("account_name", "")).strip()
params = data.get("params", {})
if not isinstance(params, dict):
params = {}
pre_cancelled = task_id in self._cancelled_tasks
cancel_event = threading.Event()
if pre_cancelled:
cancel_event.set()
self._cancelled_tasks.discard(task_id)
self._cancel_events[task_id] = cancel_event
# 将 account_name 注入 params供 handler 使用)
if account_name:
params.setdefault("account_name", account_name)
params.setdefault("worker_id", self.worker_id)
params.setdefault("bit_api_base", self.bit_api.base_url)
params["_cancel_event"] = cancel_event
if task_id:
logger.info("收到任务: %s (type=%s)", task_id, task_type)
handler = get_handler(task_type)
if not handler:
error_msg = f"不支持的任务类型: {task_type}"
logger.error(error_msg)
await self._send_result(ws, task_id, error=error_msg)
params.pop("_cancel_event", None)
self._cancel_events.pop(task_id, None)
self._task_runners.pop(task_id, None)
return
# 上报进度的回调
async def progress_cb(tid: str, progress: str):
msg = make_msg(MsgType.TASK_PROGRESS, task_id=tid, progress=progress)
try:
await self._safe_send(ws, msg)
except Exception:
pass
try:
# 执行任务
if cancel_event.is_set():
raise TaskCancelledError("任务已取消")
result = await handler.execute(task_id, params, progress_cb)
if cancel_event.is_set():
raise TaskCancelledError("任务已取消")
await self._send_result(ws, task_id, result=result)
except TaskCancelledError:
logger.info("任务 %s 已取消", task_id)
await self._send_result(ws, task_id, error="任务已取消")
except asyncio.CancelledError:
logger.info("任务 %s 执行协程已取消", task_id)
raise
except Exception as e:
logger.error("任务 %s 执行失败: %s", task_id, e, exc_info=True)
await self._send_result(ws, task_id, error=str(e))
finally:
params.pop("_cancel_event", None)
self._cancel_events.pop(task_id, None)
self._task_runners.pop(task_id, None)
async def _handle_task_cancel(self, task_id: str) -> None:
"""处理服务端下发的任务取消。"""
if not task_id:
return
cancel_event = self._cancel_events.get(task_id)
if cancel_event:
cancel_event.set()
logger.info("收到任务取消: %s,已标记取消", task_id)
else:
self._cancelled_tasks.add(task_id)
logger.info("收到任务取消: %s,任务尚未执行,已记录预取消", task_id)
async def _handle_task_status_query(self, ws, data: dict) -> None:
"""响应服务端任务状态探测。"""
task_id = str(data.get("task_id", "")).strip()
request_id = str(data.get("request_id", "")).strip()
running = False
detail = ""
if task_id:
runner = self._task_runners.get(task_id)
running = bool(runner and not runner.done())
if not running:
detail = "task_not_running"
else:
running = any(r and not r.done() for r in self._task_runners.values())
if not running:
detail = "no_running_task"
msg = make_msg(
MsgType.TASK_STATUS_REPORT,
request_id=request_id,
task_id=task_id,
running=running,
detail=detail,
worker_id=self.worker_id,
)
try:
await self._safe_send(ws, msg)
except Exception as e:
logger.debug("发送任务状态回报失败: %s", e)
async def _send_result(self, ws, task_id: str, result=None, error: str = None) -> None:
"""上报任务最终结果。"""
msg = make_msg(MsgType.TASK_RESULT, task_id=task_id, result=result, error=error)
try:
await self._safe_send(ws, msg)
logger.info("任务 %s 结果已上报 (error=%s)", task_id, error)
except Exception as e:
logger.error("上报结果失败: %s", e)
async def _safe_send(self, ws, msg: dict) -> None:
payload = json.dumps(msg)
if self._send_lock is None:
await ws.send(payload)
return
async with self._send_lock:
await ws.send(payload)
# ────────────────────────── 比特浏览器列表 ──────────────────────────
def _fetch_browser_list(self) -> List[dict]:
"""获取本机比特浏览器窗口列表(用于注册时上报)。"""
try:
items = self.bit_api.list_browsers()
return [
{"id": b.get("id", ""), "name": b.get("name", ""), "remark": b.get("remark", "")}
for b in items
]
except Exception as e:
logger.warning("获取比特浏览器列表失败(比特浏览器可能未启动): %s", e)
return []