134 lines
4.5 KiB
Python
134 lines
4.5 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
Worker 注册、状态管理、账号 → Worker 映射。
|
|
全部在内存中,服务重启后 Worker 重新连接即恢复。
|
|
框架无关:存储的是异步 send_json 可调用对象,不依赖具体 WebSocket 实现。
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
from typing import Callable, Coroutine, Dict, Optional
|
|
|
|
from server.config import HEARTBEAT_TIMEOUT
|
|
from server.models import BrowserProfile, WorkerInfo
|
|
|
|
logger = logging.getLogger("server.worker_manager")
|
|
|
|
# send_json 的类型:接受一个 dict 参数,返回协程
|
|
SendJsonFn = Callable[[dict], Coroutine]
|
|
|
|
|
|
class WorkerManager:
|
|
"""管理所有已连接的 Worker。"""
|
|
|
|
def __init__(self) -> None:
|
|
# worker_id → WorkerInfo
|
|
self._workers: Dict[str, WorkerInfo] = {}
|
|
# worker_id → 异步 send_json 可调用
|
|
self._connections: Dict[str, SendJsonFn] = {}
|
|
# account_name(lower) → worker_id 快速路由表
|
|
self._account_map: Dict[str, str] = {}
|
|
|
|
# ─── 注册 / 注销 ───
|
|
|
|
def register(
|
|
self,
|
|
send_json: SendJsonFn,
|
|
worker_id: str,
|
|
worker_name: str,
|
|
browsers: list[dict],
|
|
) -> WorkerInfo:
|
|
profiles = [BrowserProfile(**b) for b in browsers]
|
|
info = WorkerInfo(
|
|
worker_id=worker_id,
|
|
worker_name=worker_name,
|
|
browsers=profiles,
|
|
online=True,
|
|
last_heartbeat=time.time(),
|
|
connected_at=time.time(),
|
|
)
|
|
self._workers[worker_id] = info
|
|
self._connections[worker_id] = send_json
|
|
self._rebuild_account_map()
|
|
logger.info("Worker 注册: %s (%s), 浏览器 %d 个", worker_id, worker_name, len(profiles))
|
|
return info
|
|
|
|
def unregister(self, worker_id: str) -> None:
|
|
self._workers.pop(worker_id, None)
|
|
self._connections.pop(worker_id, None)
|
|
self._rebuild_account_map()
|
|
logger.info("Worker 注销: %s", worker_id)
|
|
|
|
# ─── 心跳 ───
|
|
|
|
def heartbeat(self, worker_id: str) -> None:
|
|
info = self._workers.get(worker_id)
|
|
if info:
|
|
info.last_heartbeat = time.time()
|
|
info.online = True
|
|
|
|
# ─── 浏览器列表更新 ───
|
|
|
|
def update_browsers(self, worker_id: str, browsers: list[dict]) -> None:
|
|
info = self._workers.get(worker_id)
|
|
if info:
|
|
info.browsers = [BrowserProfile(**b) for b in browsers]
|
|
self._rebuild_account_map()
|
|
logger.info("Worker %s 浏览器列表更新: %d 个", worker_id, len(info.browsers))
|
|
|
|
# ─── 查询 ───
|
|
|
|
def get_worker(self, worker_id: str) -> Optional[WorkerInfo]:
|
|
return self._workers.get(worker_id)
|
|
|
|
def get_all_workers(self) -> list[WorkerInfo]:
|
|
return list(self._workers.values())
|
|
|
|
def get_send_fn(self, worker_id: str) -> Optional[SendJsonFn]:
|
|
"""获取指定 Worker 的 send_json 可调用对象。"""
|
|
return self._connections.get(worker_id)
|
|
|
|
def find_worker_by_account(self, account_name: str) -> Optional[str]:
|
|
"""按浏览器窗口 name 查找对应的 worker_id。"""
|
|
return self._account_map.get(account_name.lower())
|
|
|
|
def is_online(self, worker_id: str) -> bool:
|
|
info = self._workers.get(worker_id)
|
|
return info is not None and info.online
|
|
|
|
# ─── 任务占用 ───
|
|
|
|
def set_current_task(self, worker_id: str, task_id: Optional[str]) -> None:
|
|
info = self._workers.get(worker_id)
|
|
if info:
|
|
info.current_task_id = task_id
|
|
|
|
# ─── 定时巡检(检测超时离线) ───
|
|
|
|
async def check_heartbeats_loop(self, interval: int = 15) -> None:
|
|
"""后台协程,定期检查心跳超时的 Worker。"""
|
|
while True:
|
|
await asyncio.sleep(interval)
|
|
now = time.time()
|
|
for wid, info in list(self._workers.items()):
|
|
if info.online and (now - info.last_heartbeat) > HEARTBEAT_TIMEOUT:
|
|
info.online = False
|
|
logger.warning("Worker %s 心跳超时,标记为离线", wid)
|
|
|
|
# ─── 内部 ───
|
|
|
|
def _rebuild_account_map(self) -> None:
|
|
self._account_map.clear()
|
|
for wid, info in self._workers.items():
|
|
if not info.online:
|
|
continue
|
|
for b in info.browsers:
|
|
if b.name:
|
|
self._account_map[b.name.lower()] = wid
|
|
|
|
|
|
# 全局单例
|
|
worker_manager = WorkerManager()
|