226 lines
8.8 KiB
Python
226 lines
8.8 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
Worker WebSocket 客户端。
|
||
负责:连接服务器、注册、心跳、接收任务、上报进度/结果、断线重连。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import time
|
||
from typing import List, Optional
|
||
|
||
import websockets
|
||
from websockets.exceptions import ConnectionClosed
|
||
|
||
from common.protocol import MsgType, TaskStatus, make_msg
|
||
from worker import config
|
||
from worker.bit_browser import BitBrowserAPI
|
||
from worker.tasks.registry import get_handler
|
||
|
||
logger = logging.getLogger("worker.ws_client")
|
||
|
||
|
||
class WorkerWSClient:
|
||
"""Worker WebSocket 客户端。"""
|
||
|
||
def __init__(
|
||
self,
|
||
server_url: str,
|
||
worker_id: str,
|
||
worker_name: str,
|
||
bit_api_base: str,
|
||
) -> None:
|
||
self.server_url = server_url
|
||
self.worker_id = worker_id
|
||
self.worker_name = worker_name
|
||
self.bit_api = BitBrowserAPI(bit_api_base)
|
||
self._ws: Optional[websockets.WebSocketClientProtocol] = None
|
||
self._running = False
|
||
self._reconnect_delay = config.RECONNECT_DELAY
|
||
self._last_browsers: List[dict] = []
|
||
self._heartbeat_count = 0
|
||
|
||
# ────────────────────────── 主循环 ──────────────────────────
|
||
|
||
async def run(self) -> None:
|
||
"""主循环:连接 → 注册 → 收发消息;断线则重连。"""
|
||
self._running = True
|
||
while self._running:
|
||
try:
|
||
await self._connect_and_loop()
|
||
except Exception as e:
|
||
logger.error("连接异常: %s", e)
|
||
if self._running:
|
||
logger.info("将在 %d 秒后重连...", self._reconnect_delay)
|
||
await asyncio.sleep(self._reconnect_delay)
|
||
# 指数退避
|
||
self._reconnect_delay = min(
|
||
self._reconnect_delay * 2,
|
||
config.RECONNECT_MAX_DELAY,
|
||
)
|
||
|
||
async def stop(self) -> None:
|
||
self._running = False
|
||
if self._ws:
|
||
await self._ws.close()
|
||
|
||
# ────────────────────────── 连接流程 ──────────────────────────
|
||
|
||
async def _connect_and_loop(self) -> None:
|
||
logger.info("正在连接服务器: %s", self.server_url)
|
||
async with websockets.connect(self.server_url) as ws:
|
||
self._ws = ws
|
||
self._reconnect_delay = config.RECONNECT_DELAY # 连接成功重置退避
|
||
logger.info("WebSocket 已连接")
|
||
|
||
# 注册
|
||
await self._register(ws)
|
||
|
||
# 启动心跳协程
|
||
heartbeat_task = asyncio.create_task(self._heartbeat_loop(ws))
|
||
|
||
try:
|
||
# 消息接收循环
|
||
async for raw in ws:
|
||
data = json.loads(raw)
|
||
await self._handle_message(ws, data)
|
||
except ConnectionClosed as e:
|
||
logger.warning("WebSocket 连接关闭: %s", e)
|
||
finally:
|
||
heartbeat_task.cancel()
|
||
self._ws = None
|
||
|
||
async def _register(self, ws) -> None:
|
||
"""发送注册消息。"""
|
||
browsers = self._fetch_browser_list()
|
||
self._last_browsers = browsers
|
||
msg = make_msg(
|
||
MsgType.REGISTER,
|
||
worker_id=self.worker_id,
|
||
worker_name=self.worker_name,
|
||
browsers=browsers,
|
||
)
|
||
await ws.send(json.dumps(msg))
|
||
# 等待 ACK
|
||
ack_raw = await asyncio.wait_for(ws.recv(), timeout=10)
|
||
ack = json.loads(ack_raw)
|
||
if ack.get("type") == MsgType.REGISTER_ACK.value:
|
||
logger.info("注册成功: worker_id=%s", self.worker_id)
|
||
else:
|
||
logger.error("注册失败: %s", ack)
|
||
raise RuntimeError(f"注册失败: {ack}")
|
||
|
||
# ────────────────────────── 心跳 ──────────────────────────
|
||
|
||
async def _heartbeat_loop(self, ws) -> None:
|
||
"""定期发送心跳;每 3 次心跳检查并上报浏览器列表变更。"""
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(config.HEARTBEAT_INTERVAL)
|
||
msg = make_msg(MsgType.HEARTBEAT, worker_id=self.worker_id)
|
||
await ws.send(json.dumps(msg))
|
||
logger.debug("心跳已发送")
|
||
|
||
self._heartbeat_count += 1
|
||
if self._heartbeat_count >= 3:
|
||
self._heartbeat_count = 0
|
||
await self._maybe_send_browser_list_update(ws)
|
||
except Exception:
|
||
break
|
||
|
||
async def _maybe_send_browser_list_update(self, ws) -> None:
|
||
"""拉取比特浏览器列表,若有变化则发送 BROWSER_LIST_UPDATE。"""
|
||
try:
|
||
new_list = self._fetch_browser_list()
|
||
key = lambda b: (b.get("id", ""), b.get("name", ""))
|
||
if sorted(new_list, key=key) != sorted(self._last_browsers, key=key):
|
||
self._last_browsers = new_list
|
||
msg = make_msg(MsgType.BROWSER_LIST_UPDATE, browsers=new_list)
|
||
await ws.send(json.dumps(msg))
|
||
logger.info("浏览器列表变更,已上报 %d 个环境", len(new_list))
|
||
except Exception as e:
|
||
logger.debug("检查浏览器列表变更失败: %s", e)
|
||
|
||
# ────────────────────────── 消息处理 ──────────────────────────
|
||
|
||
async def _handle_message(self, ws, data: dict) -> None:
|
||
msg_type = data.get("type", "")
|
||
|
||
if msg_type == MsgType.HEARTBEAT_ACK.value:
|
||
logger.debug("收到心跳 ACK")
|
||
|
||
elif msg_type == MsgType.TASK_ASSIGN.value:
|
||
await self._handle_task(ws, data)
|
||
|
||
elif msg_type == MsgType.TASK_CANCEL.value:
|
||
task_id = data.get("task_id", "")
|
||
logger.info("收到任务取消: %s(暂不支持中途取消)", task_id)
|
||
|
||
elif msg_type == MsgType.ERROR.value:
|
||
logger.error("服务器错误: %s", data.get("detail", ""))
|
||
|
||
else:
|
||
logger.warning("未知消息: %s", msg_type)
|
||
|
||
async def _handle_task(self, ws, data: dict) -> None:
|
||
"""接收并执行任务。"""
|
||
task_id = data.get("task_id", "")
|
||
task_type = data.get("task_type", "")
|
||
account_name = data.get("account_name", "")
|
||
params = data.get("params", {})
|
||
|
||
# 将 account_name 注入 params(供 handler 使用)
|
||
if account_name:
|
||
params.setdefault("account_name", account_name)
|
||
params.setdefault("bit_api_base", self.bit_api.base_url)
|
||
|
||
logger.info("收到任务: %s (type=%s)", task_id, task_type)
|
||
|
||
handler = get_handler(task_type)
|
||
if not handler:
|
||
error_msg = f"不支持的任务类型: {task_type}"
|
||
logger.error(error_msg)
|
||
await self._send_result(ws, task_id, error=error_msg)
|
||
return
|
||
|
||
# 上报进度的回调
|
||
async def progress_cb(tid: str, progress: str):
|
||
msg = make_msg(MsgType.TASK_PROGRESS, task_id=tid, progress=progress)
|
||
try:
|
||
await ws.send(json.dumps(msg))
|
||
except Exception:
|
||
pass
|
||
|
||
# 执行任务
|
||
try:
|
||
result = await handler.execute(task_id, params, progress_cb)
|
||
await self._send_result(ws, task_id, result=result)
|
||
except Exception as e:
|
||
logger.error("任务 %s 执行失败: %s", task_id, e, exc_info=True)
|
||
await self._send_result(ws, task_id, error=str(e))
|
||
|
||
async def _send_result(self, ws, task_id: str, result=None, error: str = None) -> None:
|
||
"""上报任务最终结果。"""
|
||
msg = make_msg(MsgType.TASK_RESULT, task_id=task_id, result=result, error=error)
|
||
try:
|
||
await ws.send(json.dumps(msg))
|
||
logger.info("任务 %s 结果已上报 (error=%s)", task_id, error)
|
||
except Exception as e:
|
||
logger.error("上报结果失败: %s", e)
|
||
|
||
# ────────────────────────── 比特浏览器列表 ──────────────────────────
|
||
|
||
def _fetch_browser_list(self) -> List[dict]:
|
||
"""获取本机比特浏览器窗口列表(用于注册时上报)。"""
|
||
try:
|
||
items = self.bit_api.list_browsers()
|
||
return [
|
||
{"id": b.get("id", ""), "name": b.get("name", ""), "remark": b.get("remark", "")}
|
||
for b in items
|
||
]
|
||
except Exception as e:
|
||
logger.warning("获取比特浏览器列表失败(比特浏览器可能未启动): %s", e)
|
||
return []
|