338 lines
13 KiB
Python
338 lines
13 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
Worker WebSocket 客户端。
|
||
负责:连接服务器、注册、心跳、接收任务、上报进度/结果、断线重连。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import threading
|
||
import time
|
||
from typing import Dict, List, Optional, Set
|
||
|
||
import websockets
|
||
from websockets.exceptions import ConnectionClosed
|
||
|
||
from common.protocol import MsgType, make_msg
|
||
from worker import config
|
||
from worker.bit_browser import BitBrowserAPI
|
||
from worker.tasks.base import TaskCancelledError
|
||
from worker.tasks.registry import get_handler
|
||
|
||
logger = logging.getLogger("worker.ws_client")
|
||
|
||
|
||
class WorkerWSClient:
|
||
"""Worker WebSocket 客户端。"""
|
||
|
||
def __init__(
|
||
self,
|
||
server_url: str,
|
||
worker_id: str,
|
||
worker_name: str,
|
||
bit_api_base: str,
|
||
) -> None:
|
||
self.server_url = server_url
|
||
self.worker_id = worker_id
|
||
self.worker_name = worker_name
|
||
self.bit_api = BitBrowserAPI(bit_api_base)
|
||
self._ws: Optional[websockets.WebSocketClientProtocol] = None
|
||
self._running = False
|
||
self._reconnect_delay = config.RECONNECT_DELAY
|
||
self._last_browsers: List[dict] = []
|
||
self._heartbeat_count = 0
|
||
self._cancel_events: Dict[str, threading.Event] = {}
|
||
self._cancelled_tasks: Set[str] = set()
|
||
self._task_runners: Dict[str, asyncio.Task] = {}
|
||
self._send_lock: Optional[asyncio.Lock] = None
|
||
|
||
# ────────────────────────── 主循环 ──────────────────────────
|
||
|
||
async def run(self) -> None:
|
||
"""主循环:连接 → 注册 → 收发消息;断线则重连。"""
|
||
self._running = True
|
||
while self._running:
|
||
try:
|
||
await self._connect_and_loop()
|
||
except Exception as e:
|
||
logger.error("连接异常: %s", e)
|
||
if self._running:
|
||
logger.info("将在 %d 秒后重连...", self._reconnect_delay)
|
||
await asyncio.sleep(self._reconnect_delay)
|
||
# 指数退避
|
||
self._reconnect_delay = min(
|
||
self._reconnect_delay * 2,
|
||
config.RECONNECT_MAX_DELAY,
|
||
)
|
||
|
||
async def stop(self) -> None:
|
||
self._running = False
|
||
if self._ws:
|
||
await self._ws.close()
|
||
|
||
# ────────────────────────── 连接流程 ──────────────────────────
|
||
|
||
async def _connect_and_loop(self) -> None:
|
||
logger.info("正在连接服务器: %s", self.server_url)
|
||
async with websockets.connect(self.server_url) as ws:
|
||
self._ws = ws
|
||
self._send_lock = asyncio.Lock()
|
||
self._reconnect_delay = config.RECONNECT_DELAY # 连接成功重置退避
|
||
logger.info("WebSocket 已连接")
|
||
|
||
# 注册
|
||
await self._register(ws)
|
||
|
||
# 启动心跳协程
|
||
heartbeat_task = asyncio.create_task(self._heartbeat_loop(ws))
|
||
|
||
try:
|
||
# 消息接收循环
|
||
async for raw in ws:
|
||
data = json.loads(raw)
|
||
await self._handle_message(ws, data)
|
||
except ConnectionClosed as e:
|
||
logger.warning("WebSocket 连接关闭: %s", e)
|
||
finally:
|
||
heartbeat_task.cancel()
|
||
self._send_lock = None
|
||
for event in list(self._cancel_events.values()):
|
||
event.set()
|
||
for runner in list(self._task_runners.values()):
|
||
runner.cancel()
|
||
if self._task_runners:
|
||
await asyncio.gather(*self._task_runners.values(), return_exceptions=True)
|
||
self._task_runners.clear()
|
||
self._ws = None
|
||
|
||
async def _register(self, ws) -> None:
|
||
"""发送注册消息。"""
|
||
browsers = self._fetch_browser_list()
|
||
self._last_browsers = browsers
|
||
msg = make_msg(
|
||
MsgType.REGISTER,
|
||
worker_id=self.worker_id,
|
||
worker_name=self.worker_name,
|
||
browsers=browsers,
|
||
)
|
||
await self._safe_send(ws, msg)
|
||
# 等待 ACK
|
||
ack_raw = await asyncio.wait_for(ws.recv(), timeout=10)
|
||
ack = json.loads(ack_raw)
|
||
if ack.get("type") == MsgType.REGISTER_ACK.value:
|
||
logger.info("注册成功: worker_id=%s", self.worker_id)
|
||
else:
|
||
logger.error("注册失败: %s", ack)
|
||
raise RuntimeError(f"注册失败: {ack}")
|
||
|
||
# ────────────────────────── 心跳 ──────────────────────────
|
||
|
||
async def _heartbeat_loop(self, ws) -> None:
|
||
"""定期发送心跳;每 3 次心跳检查并上报浏览器列表变更。"""
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(config.HEARTBEAT_INTERVAL)
|
||
msg = make_msg(MsgType.HEARTBEAT, worker_id=self.worker_id)
|
||
await self._safe_send(ws, msg)
|
||
logger.debug("心跳已发送")
|
||
|
||
self._heartbeat_count += 1
|
||
if self._heartbeat_count >= 3:
|
||
self._heartbeat_count = 0
|
||
await self._maybe_send_browser_list_update(ws)
|
||
except Exception:
|
||
break
|
||
|
||
async def _maybe_send_browser_list_update(self, ws) -> None:
|
||
"""拉取比特浏览器列表,若有变化则发送 BROWSER_LIST_UPDATE。"""
|
||
try:
|
||
new_list = self._fetch_browser_list()
|
||
key = lambda b: (b.get("id", ""), b.get("name", ""))
|
||
if sorted(new_list, key=key) != sorted(self._last_browsers, key=key):
|
||
self._last_browsers = new_list
|
||
msg = make_msg(MsgType.BROWSER_LIST_UPDATE, browsers=new_list)
|
||
await self._safe_send(ws, msg)
|
||
logger.info("浏览器列表变更,已上报 %d 个环境", len(new_list))
|
||
except Exception as e:
|
||
logger.debug("检查浏览器列表变更失败: %s", e)
|
||
|
||
# ────────────────────────── 消息处理 ──────────────────────────
|
||
|
||
async def _handle_message(self, ws, data: dict) -> None:
|
||
msg_type = data.get("type", "")
|
||
|
||
if msg_type == MsgType.HEARTBEAT_ACK.value:
|
||
logger.debug("收到心跳 ACK")
|
||
|
||
elif msg_type == MsgType.TASK_ASSIGN.value:
|
||
await self._start_task(ws, data)
|
||
|
||
elif msg_type == MsgType.TASK_CANCEL.value:
|
||
task_id = data.get("task_id", "")
|
||
await self._handle_task_cancel(task_id)
|
||
|
||
elif msg_type == MsgType.TASK_STATUS_QUERY.value:
|
||
await self._handle_task_status_query(ws, data)
|
||
|
||
elif msg_type == MsgType.ERROR.value:
|
||
logger.error("服务器错误: %s", data.get("detail", ""))
|
||
|
||
else:
|
||
logger.warning("未知消息: %s", msg_type)
|
||
|
||
async def _start_task(self, ws, data: dict) -> None:
|
||
"""启动任务执行协程,不阻塞消息接收循环。"""
|
||
task_id = str(data.get("task_id", "")).strip()
|
||
if not task_id:
|
||
logger.warning("收到无 task_id 的任务消息,已忽略")
|
||
return
|
||
|
||
runner = self._task_runners.get(task_id)
|
||
if runner and not runner.done():
|
||
logger.warning("任务 %s 已在执行中,忽略重复派发", task_id)
|
||
return
|
||
|
||
self._task_runners[task_id] = asyncio.create_task(self._handle_task(ws, data))
|
||
|
||
async def _handle_task(self, ws, data: dict) -> None:
|
||
"""接收并执行任务。"""
|
||
task_id = str(data.get("task_id", "")).strip()
|
||
task_type = str(data.get("task_type", "")).strip()
|
||
account_name = str(data.get("account_name", "")).strip()
|
||
params = data.get("params", {})
|
||
if not isinstance(params, dict):
|
||
params = {}
|
||
|
||
pre_cancelled = task_id in self._cancelled_tasks
|
||
cancel_event = threading.Event()
|
||
if pre_cancelled:
|
||
cancel_event.set()
|
||
self._cancelled_tasks.discard(task_id)
|
||
self._cancel_events[task_id] = cancel_event
|
||
|
||
# 将 account_name 注入 params(供 handler 使用)
|
||
if account_name:
|
||
params.setdefault("account_name", account_name)
|
||
params.setdefault("worker_id", self.worker_id)
|
||
params.setdefault("bit_api_base", self.bit_api.base_url)
|
||
params["_cancel_event"] = cancel_event
|
||
|
||
if task_id:
|
||
logger.info("收到任务: %s (type=%s)", task_id, task_type)
|
||
|
||
handler = get_handler(task_type)
|
||
if not handler:
|
||
error_msg = f"不支持的任务类型: {task_type}"
|
||
logger.error(error_msg)
|
||
await self._send_result(ws, task_id, error=error_msg)
|
||
params.pop("_cancel_event", None)
|
||
self._cancel_events.pop(task_id, None)
|
||
self._task_runners.pop(task_id, None)
|
||
return
|
||
|
||
# 上报进度的回调
|
||
async def progress_cb(tid: str, progress: str):
|
||
msg = make_msg(MsgType.TASK_PROGRESS, task_id=tid, progress=progress)
|
||
try:
|
||
await self._safe_send(ws, msg)
|
||
except Exception:
|
||
pass
|
||
|
||
try:
|
||
# 执行任务
|
||
if cancel_event.is_set():
|
||
raise TaskCancelledError("任务已取消")
|
||
result = await handler.execute(task_id, params, progress_cb)
|
||
if cancel_event.is_set():
|
||
raise TaskCancelledError("任务已取消")
|
||
await self._send_result(ws, task_id, result=result)
|
||
except TaskCancelledError:
|
||
logger.info("任务 %s 已取消", task_id)
|
||
await self._send_result(ws, task_id, error="任务已取消")
|
||
except asyncio.CancelledError:
|
||
logger.info("任务 %s 执行协程已取消", task_id)
|
||
raise
|
||
except Exception as e:
|
||
logger.error("任务 %s 执行失败: %s", task_id, e, exc_info=True)
|
||
await self._send_result(ws, task_id, error=str(e))
|
||
finally:
|
||
params.pop("_cancel_event", None)
|
||
self._cancel_events.pop(task_id, None)
|
||
self._task_runners.pop(task_id, None)
|
||
|
||
async def _handle_task_cancel(self, task_id: str) -> None:
|
||
"""处理服务端下发的任务取消。"""
|
||
if not task_id:
|
||
return
|
||
|
||
cancel_event = self._cancel_events.get(task_id)
|
||
if cancel_event:
|
||
cancel_event.set()
|
||
logger.info("收到任务取消: %s,已标记取消", task_id)
|
||
else:
|
||
self._cancelled_tasks.add(task_id)
|
||
logger.info("收到任务取消: %s,任务尚未执行,已记录预取消", task_id)
|
||
|
||
async def _handle_task_status_query(self, ws, data: dict) -> None:
|
||
"""响应服务端任务状态探测。"""
|
||
task_id = str(data.get("task_id", "")).strip()
|
||
request_id = str(data.get("request_id", "")).strip()
|
||
|
||
running = False
|
||
detail = ""
|
||
if task_id:
|
||
runner = self._task_runners.get(task_id)
|
||
running = bool(runner and not runner.done())
|
||
if not running:
|
||
detail = "task_not_running"
|
||
else:
|
||
running = any(r and not r.done() for r in self._task_runners.values())
|
||
if not running:
|
||
detail = "no_running_task"
|
||
|
||
msg = make_msg(
|
||
MsgType.TASK_STATUS_REPORT,
|
||
request_id=request_id,
|
||
task_id=task_id,
|
||
running=running,
|
||
detail=detail,
|
||
worker_id=self.worker_id,
|
||
)
|
||
try:
|
||
await self._safe_send(ws, msg)
|
||
except Exception as e:
|
||
logger.debug("发送任务状态回报失败: %s", e)
|
||
|
||
async def _send_result(self, ws, task_id: str, result=None, error: str = None) -> None:
|
||
"""上报任务最终结果。"""
|
||
msg = make_msg(MsgType.TASK_RESULT, task_id=task_id, result=result, error=error)
|
||
try:
|
||
await self._safe_send(ws, msg)
|
||
logger.info("任务 %s 结果已上报 (error=%s)", task_id, error)
|
||
except Exception as e:
|
||
logger.error("上报结果失败: %s", e)
|
||
|
||
async def _safe_send(self, ws, msg: dict) -> None:
|
||
payload = json.dumps(msg)
|
||
if self._send_lock is None:
|
||
await ws.send(payload)
|
||
return
|
||
async with self._send_lock:
|
||
await ws.send(payload)
|
||
|
||
# ────────────────────────── 比特浏览器列表 ──────────────────────────
|
||
|
||
def _fetch_browser_list(self) -> List[dict]:
|
||
"""获取本机比特浏览器窗口列表(用于注册时上报)。"""
|
||
try:
|
||
items = self.bit_api.list_browsers()
|
||
return [
|
||
{"id": b.get("id", ""), "name": b.get("name", ""), "remark": b.get("remark", "")}
|
||
for b in items
|
||
]
|
||
except Exception as e:
|
||
logger.warning("获取比特浏览器列表失败(比特浏览器可能未启动): %s", e)
|
||
return []
|