# -*- coding: utf-8 -*- """ Worker 注册、状态管理、账号 → Worker 映射。 全部在内存中,服务重启后 Worker 重新连接即恢复。 """ from __future__ import annotations import asyncio import logging import time from typing import Dict, Optional from fastapi import WebSocket from server.config import HEARTBEAT_TIMEOUT from server.models import BrowserProfile, WorkerInfo logger = logging.getLogger("server.worker_manager") class WorkerManager: """管理所有已连接的 Worker。""" def __init__(self) -> None: # worker_id → WorkerInfo self._workers: Dict[str, WorkerInfo] = {} # worker_id → WebSocket 实例 self._connections: Dict[str, WebSocket] = {} # account_name(lower) → worker_id 快速路由表 self._account_map: Dict[str, str] = {} # ─── 注册 / 注销 ─── def register( self, ws: WebSocket, worker_id: str, worker_name: str, browsers: list[dict], ) -> WorkerInfo: profiles = [BrowserProfile(**b) for b in browsers] info = WorkerInfo( worker_id=worker_id, worker_name=worker_name, browsers=profiles, online=True, last_heartbeat=time.time(), connected_at=time.time(), ) self._workers[worker_id] = info self._connections[worker_id] = ws self._rebuild_account_map() logger.info("Worker 注册: %s (%s), 浏览器 %d 个", worker_id, worker_name, len(profiles)) return info def unregister(self, worker_id: str) -> None: self._workers.pop(worker_id, None) self._connections.pop(worker_id, None) self._rebuild_account_map() logger.info("Worker 注销: %s", worker_id) # ─── 心跳 ─── def heartbeat(self, worker_id: str) -> None: info = self._workers.get(worker_id) if info: info.last_heartbeat = time.time() info.online = True # ─── 浏览器列表更新 ─── def update_browsers(self, worker_id: str, browsers: list[dict]) -> None: info = self._workers.get(worker_id) if info: info.browsers = [BrowserProfile(**b) for b in browsers] self._rebuild_account_map() logger.info("Worker %s 浏览器列表更新: %d 个", worker_id, len(info.browsers)) # ─── 查询 ─── def get_worker(self, worker_id: str) -> Optional[WorkerInfo]: return self._workers.get(worker_id) def get_all_workers(self) -> list[WorkerInfo]: return list(self._workers.values()) def get_ws(self, worker_id: str) -> Optional[WebSocket]: return self._connections.get(worker_id) def find_worker_by_account(self, account_name: str) -> Optional[str]: """按浏览器窗口 name 查找对应的 worker_id。""" return self._account_map.get(account_name.lower()) def is_online(self, worker_id: str) -> bool: info = self._workers.get(worker_id) return info is not None and info.online # ─── 任务占用 ─── def set_current_task(self, worker_id: str, task_id: Optional[str]) -> None: info = self._workers.get(worker_id) if info: info.current_task_id = task_id # ─── 定时巡检(检测超时离线) ─── async def check_heartbeats_loop(self, interval: int = 15) -> None: """后台协程,定期检查心跳超时的 Worker。""" while True: await asyncio.sleep(interval) now = time.time() for wid, info in list(self._workers.items()): if info.online and (now - info.last_heartbeat) > HEARTBEAT_TIMEOUT: info.online = False logger.warning("Worker %s 心跳超时,标记为离线", wid) # ─── 内部 ─── def _rebuild_account_map(self) -> None: self._account_map.clear() for wid, info in self._workers.items(): if not info.online: continue for b in info.browsers: if b.name: self._account_map[b.name.lower()] = wid # 全局单例 worker_manager = WorkerManager()