Files
boss_dp/server/core/worker_manager.py
2026-02-12 16:27:43 +08:00

131 lines
4.2 KiB
Python

# -*- coding: utf-8 -*-
"""
Worker 注册、状态管理、账号 → Worker 映射。
全部在内存中,服务重启后 Worker 重新连接即恢复。
"""
from __future__ import annotations
import asyncio
import logging
import time
from typing import Dict, Optional
from fastapi import WebSocket
from server.config import HEARTBEAT_TIMEOUT
from server.models import BrowserProfile, WorkerInfo
logger = logging.getLogger("server.worker_manager")
class WorkerManager:
"""管理所有已连接的 Worker。"""
def __init__(self) -> None:
# worker_id → WorkerInfo
self._workers: Dict[str, WorkerInfo] = {}
# worker_id → WebSocket 实例
self._connections: Dict[str, WebSocket] = {}
# account_name(lower) → worker_id 快速路由表
self._account_map: Dict[str, str] = {}
# ─── 注册 / 注销 ───
def register(
self,
ws: WebSocket,
worker_id: str,
worker_name: str,
browsers: list[dict],
) -> WorkerInfo:
profiles = [BrowserProfile(**b) for b in browsers]
info = WorkerInfo(
worker_id=worker_id,
worker_name=worker_name,
browsers=profiles,
online=True,
last_heartbeat=time.time(),
connected_at=time.time(),
)
self._workers[worker_id] = info
self._connections[worker_id] = ws
self._rebuild_account_map()
logger.info("Worker 注册: %s (%s), 浏览器 %d", worker_id, worker_name, len(profiles))
return info
def unregister(self, worker_id: str) -> None:
self._workers.pop(worker_id, None)
self._connections.pop(worker_id, None)
self._rebuild_account_map()
logger.info("Worker 注销: %s", worker_id)
# ─── 心跳 ───
def heartbeat(self, worker_id: str) -> None:
info = self._workers.get(worker_id)
if info:
info.last_heartbeat = time.time()
info.online = True
# ─── 浏览器列表更新 ───
def update_browsers(self, worker_id: str, browsers: list[dict]) -> None:
info = self._workers.get(worker_id)
if info:
info.browsers = [BrowserProfile(**b) for b in browsers]
self._rebuild_account_map()
logger.info("Worker %s 浏览器列表更新: %d", worker_id, len(info.browsers))
# ─── 查询 ───
def get_worker(self, worker_id: str) -> Optional[WorkerInfo]:
return self._workers.get(worker_id)
def get_all_workers(self) -> list[WorkerInfo]:
return list(self._workers.values())
def get_ws(self, worker_id: str) -> Optional[WebSocket]:
return self._connections.get(worker_id)
def find_worker_by_account(self, account_name: str) -> Optional[str]:
"""按浏览器窗口 name 查找对应的 worker_id。"""
return self._account_map.get(account_name.lower())
def is_online(self, worker_id: str) -> bool:
info = self._workers.get(worker_id)
return info is not None and info.online
# ─── 任务占用 ───
def set_current_task(self, worker_id: str, task_id: Optional[str]) -> None:
info = self._workers.get(worker_id)
if info:
info.current_task_id = task_id
# ─── 定时巡检(检测超时离线) ───
async def check_heartbeats_loop(self, interval: int = 15) -> None:
"""后台协程,定期检查心跳超时的 Worker。"""
while True:
await asyncio.sleep(interval)
now = time.time()
for wid, info in list(self._workers.items()):
if info.online and (now - info.last_heartbeat) > HEARTBEAT_TIMEOUT:
info.online = False
logger.warning("Worker %s 心跳超时,标记为离线", wid)
# ─── 内部 ───
def _rebuild_account_map(self) -> None:
self._account_map.clear()
for wid, info in self._workers.items():
if not info.online:
continue
for b in info.browsers:
if b.name:
self._account_map[b.name.lower()] = wid
# 全局单例
worker_manager = WorkerManager()