初始化:分布式浏览器控制后台

This commit is contained in:
Your Name
2026-02-12 16:27:43 +08:00
commit 2e9ae4c7d7
25 changed files with 1930 additions and 0 deletions

37
.gitignore vendored Normal file
View File

@@ -0,0 +1,37 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
*.egg-info/
dist/
build/
*.egg
# 虚拟环境
venv/
.venv/
env/
ENV/
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# 操作系统
Thumbs.db
Desktop.ini
.DS_Store
# 环境变量 / 密钥
.env
.env.*
# 日志
*.log
# Cursor
.cursor/

203
README.md Normal file
View File

@@ -0,0 +1,203 @@
# 分布式浏览器控制后台
通过 **中央服务器 + Worker 代理** 架构,远程控制多台电脑上的比特浏览器,并使用 DrissionPage 执行自动化任务。
---
## 架构概览
```
前端 / Postman 中央服务器(第三台机器) 电脑 A / B
─────────────── ───────────────────── ──────────
REST API ──── HTTP ────→ FastAPI Server Worker Agent
├── Worker Manager ←── WebSocket ──┤
└── Task Dispatcher ──── WebSocket ──→ 比特浏览器
+ DrissionPage
```
- **中央服务器**:接收前端 API 请求,管理 Worker 状态,路由并派发任务
- **Worker 代理**:运行在每台电脑上,通过 WebSocket 连接服务器,接收任务后本地控制比特浏览器 + DrissionPage
---
## 快速开始
### 1. 安装依赖
```bash
pip install -r requirements.txt
```
### 2. 启动中央服务器
在第三台机器(或任一台能被其他机器访问的机器)上运行:
```bash
python -m server.main
```
默认监听 `0.0.0.0:8000`。可通过环境变量修改:
```bash
set SERVER_HOST=0.0.0.0
set SERVER_PORT=8000
python -m server.main
```
### 3. 启动 Worker每台电脑上各运行一个
**电脑 A**
```bash
python -m worker.main --server ws://服务器IP:8000/ws --worker-id pc-a --worker-name "电脑A"
```
**电脑 B**
```bash
python -m worker.main --server ws://服务器IP:8000/ws --worker-id pc-b --worker-name "电脑B"
```
Worker 启动后会自动:
1. 连接中央服务器
2. 上报本机比特浏览器窗口列表
3. 等待接收任务
> 注意:每台电脑需先启动比特浏览器客户端(默认 API 端口 54345
---
## API 接口
### 查看在线 Worker
```
GET /api/workers
```
响应示例:
```json
[
{
"worker_id": "pc-a",
"worker_name": "电脑A",
"browsers": [
{"id": "abc123", "name": "BOSS主号", "remark": "张三"}
],
"online": true,
"current_task_id": null
}
]
```
### 提交任务
```
POST /api/tasks
Content-Type: application/json
{
"task_type": "boss_recruit",
"account_name": "BOSS主号",
"params": {
"job_title": "前端工程师",
"max_greet": 5
}
}
```
路由规则:
- 指定 `worker_id` → 直接发到该 Worker
- 指定 `account_name` → 自动找到拥有该浏览器的 Worker
- 两者都传 → `worker_id` 优先
### 查询任务状态
```
GET /api/tasks/{task_id}
```
### 查询任务列表
```
GET /api/tasks?worker_id=pc-a&status=success&limit=20
```
### 健康检查
```
GET /health
```
---
## 项目结构
```
boss_dp/
├── server/ # 中央服务器
│ ├── main.py # FastAPI 入口(含 WebSocket 端点)
│ ├── config.py # 服务器配置
│ ├── models.py # Pydantic 数据模型
│ ├── api/
│ │ ├── workers.py # Worker 查询 API
│ │ └── tasks.py # 任务提交与查询 API
│ └── core/
│ ├── worker_manager.py # Worker 注册、状态管理、账号映射
│ └── task_dispatcher.py # 任务路由与派发
├── worker/ # Worker 代理(每台电脑部署)
│ ├── main.py # 启动入口
│ ├── config.py # Worker 配置
│ ├── ws_client.py # WebSocket 客户端(心跳、重连)
│ ├── bit_browser.py # 比特浏览器 API 封装
│ ├── browser_control.py # DrissionPage 通用控制封装
│ └── tasks/
│ ├── base.py # 任务处理器基类
│ ├── registry.py # 处理器注册表
│ └── boss_recruit.py # BOSS 直聘招聘任务
├── common/
│ └── protocol.py # 共享消息协议
├── requirements.txt
└── README.md
```
---
## 扩展新任务
1.`worker/tasks/` 下新建文件,继承 `BaseTaskHandler`
2. 实现 `execute` 方法
3.`common/protocol.py``TaskType` 中添加新类型
4.`worker/tasks/registry.py``register_all_handlers()` 中注册
```python
# worker/tasks/my_task.py
from worker.tasks.base import BaseTaskHandler
class MyTaskHandler(BaseTaskHandler):
task_type = "my_task"
async def execute(self, task_id, params, progress_cb):
await progress_cb(task_id, "开始执行...")
# ... 你的自动化逻辑 ...
return {"status": "done"}
```
---
## 配置项
### 服务器 (server/config.py)
| 环境变量 | 默认值 | 说明 |
|---------|--------|------|
| `SERVER_HOST` | `0.0.0.0` | 监听地址 |
| `SERVER_PORT` | `8000` | 监听端口 |
| `API_TOKEN` | (空) | 非空时校验 Authorization Header |
### Worker (worker/config.py)
| 环境变量 | 默认值 | 说明 |
|---------|--------|------|
| `SERVER_WS_URL` | `ws://127.0.0.1:8000/ws` | 服务器 WebSocket 地址 |
| `WORKER_ID` | `worker-1` | Worker 唯一标识 |
| `WORKER_NAME` | `本机` | Worker 显示名称 |
| `BIT_API_BASE` | `http://127.0.0.1:54345` | 比特浏览器本地 API |

1
common/__init__.py Normal file
View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

54
common/protocol.py Normal file
View File

@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
"""
共享消息协议定义。
服务器与 Worker 之间通过 WebSocket 传递 JSON 消息,每条消息包含 type 字段标识消息类型。
"""
from enum import Enum
# ────────────────────────── 消息类型 ──────────────────────────
class MsgType(str, Enum):
"""WebSocket 消息类型枚举str 混入方便 JSON 序列化)。"""
# Worker → Server
REGISTER = "register" # 注册:上报 worker 信息与浏览器列表
HEARTBEAT = "heartbeat" # 心跳
BROWSER_LIST_UPDATE = "browser_list_update" # 浏览器列表变更
TASK_PROGRESS = "task_progress" # 任务进度上报
TASK_RESULT = "task_result" # 任务最终结果
# Server → Worker
REGISTER_ACK = "register_ack" # 注册确认
HEARTBEAT_ACK = "heartbeat_ack" # 心跳确认
TASK_ASSIGN = "task_assign" # 派发任务
TASK_CANCEL = "task_cancel" # 取消任务
# 双向
ERROR = "error" # 错误消息
# ────────────────────────── 任务状态 ──────────────────────────
class TaskStatus(str, Enum):
"""任务生命周期状态。"""
PENDING = "pending" # 已创建,等待派发
DISPATCHED = "dispatched" # 已派发给 Worker
RUNNING = "running" # Worker 正在执行
SUCCESS = "success" # 执行成功
FAILED = "failed" # 执行失败
CANCELLED = "cancelled" # 已取消
# ────────────────────────── 任务类型 ──────────────────────────
class TaskType(str, Enum):
"""可扩展的任务类型。新增任务在此追加即可。"""
BOSS_RECRUIT = "boss_recruit" # BOSS 直聘招聘流程
# ────────────────────────── 辅助函数 ──────────────────────────
def make_msg(msg_type: MsgType, **payload) -> dict:
"""构造一条标准 WebSocket JSON 消息。"""
return {"type": msg_type.value, **payload}

9
requirements.txt Normal file
View File

@@ -0,0 +1,9 @@
# ─── 中央服务器 (server/) ───
fastapi>=0.115.0
uvicorn>=0.34.0
pydantic>=2.0.0
# ─── Worker 代理 (worker/) ───
websockets>=14.0
requests>=2.31.0
DrissionPage>=4.1.0

1
server/__init__.py Normal file
View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

1
server/api/__init__.py Normal file
View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

107
server/api/tasks.py Normal file
View File

@@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
"""
任务提交与查询 API。
"""
from __future__ import annotations
from typing import List, Optional
from fastapi import APIRouter, HTTPException
from common.protocol import TaskStatus
from server.models import TaskCreate, TaskOut
from server.core.worker_manager import worker_manager
from server.core.task_dispatcher import task_dispatcher
router = APIRouter(prefix="/api/tasks", tags=["tasks"])
@router.post("", response_model=TaskOut, status_code=201)
async def create_task(req: TaskCreate):
"""
提交一个新任务。
路由规则(优先级从高到低):
1. 如果指定了 worker_id → 直接发到该 Worker
2. 如果指定了 account_name → 查找拥有该浏览器的 Worker
3. 两者都没有 → 400 错误
"""
# 确定目标 worker_id
target_worker_id = req.worker_id
if not target_worker_id and req.account_name:
target_worker_id = worker_manager.find_worker_by_account(req.account_name)
if not target_worker_id:
raise HTTPException(
status_code=404,
detail=f"未找到拥有浏览器 '{req.account_name}' 的在线 Worker",
)
if not target_worker_id:
raise HTTPException(
status_code=400,
detail="请指定 worker_id 或 account_name",
)
# 检查 Worker 是否在线
if not worker_manager.is_online(target_worker_id):
raise HTTPException(
status_code=503,
detail=f"Worker {target_worker_id} 不在线",
)
# 创建任务
req.worker_id = target_worker_id
task = task_dispatcher.create_task(req)
# 通过 WebSocket 派发
ws = worker_manager.get_ws(target_worker_id)
if not ws:
task.status = TaskStatus.FAILED
task.error = "Worker WebSocket 连接不存在"
raise HTTPException(status_code=503, detail="Worker WebSocket 连接不存在")
success = await task_dispatcher.dispatch(task, ws.send_json)
if not success:
raise HTTPException(status_code=503, detail=f"任务派发失败: {task.error}")
# 更新 Worker 当前任务
worker_manager.set_current_task(target_worker_id, task.task_id)
return _to_out(task)
@router.get("", response_model=List[TaskOut])
async def list_tasks(
worker_id: Optional[str] = None,
status: Optional[TaskStatus] = None,
limit: int = 50,
):
"""查询任务列表,支持按 worker_id / status 过滤。"""
tasks = task_dispatcher.list_tasks(worker_id=worker_id, status=status, limit=limit)
return [_to_out(t) for t in tasks]
@router.get("/{task_id}", response_model=TaskOut)
async def get_task(task_id: str):
"""查询指定任务的状态和结果。"""
task = task_dispatcher.get_task(task_id)
if not task:
raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
return _to_out(task)
def _to_out(t) -> TaskOut:
return TaskOut(
task_id=t.task_id,
task_type=t.task_type,
status=t.status,
worker_id=t.worker_id,
account_name=t.account_name,
params=t.params,
progress=t.progress,
result=t.result,
error=t.error,
created_at=t.created_at,
updated_at=t.updated_at,
)

42
server/api/workers.py Normal file
View File

@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
"""
Worker 查询 API。
"""
from fastapi import APIRouter, HTTPException
from typing import List
from server.models import WorkerOut
from server.core.worker_manager import worker_manager
router = APIRouter(prefix="/api/workers", tags=["workers"])
@router.get("", response_model=List[WorkerOut])
async def list_workers():
"""获取所有已注册的 Worker含在线状态与浏览器列表"""
workers = worker_manager.get_all_workers()
return [
WorkerOut(
worker_id=w.worker_id,
worker_name=w.worker_name,
browsers=w.browsers,
online=w.online,
current_task_id=w.current_task_id,
)
for w in workers
]
@router.get("/{worker_id}", response_model=WorkerOut)
async def get_worker(worker_id: str):
"""获取指定 Worker 的详情。"""
w = worker_manager.get_worker(worker_id)
if not w:
raise HTTPException(status_code=404, detail=f"Worker {worker_id} 不存在")
return WorkerOut(
worker_id=w.worker_id,
worker_name=w.worker_name,
browsers=w.browsers,
online=w.online,
current_task_id=w.current_task_id,
)

18
server/config.py Normal file
View File

@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
"""
服务器配置。
可通过环境变量或直接修改此文件调整。
"""
import os
# ─── 服务 ───
HOST: str = os.getenv("SERVER_HOST", "0.0.0.0")
PORT: int = int(os.getenv("SERVER_PORT", "8000"))
# ─── WebSocket ───
WS_PATH: str = "/ws" # Worker 连接端点
HEARTBEAT_INTERVAL: int = 30 # 期望 Worker 心跳间隔(秒)
HEARTBEAT_TIMEOUT: int = 90 # 超时未收到心跳视为离线(秒)
# ─── 安全(可选) ───
API_TOKEN: str = os.getenv("API_TOKEN", "") # 非空时校验 Header: Authorization: Bearer <token>

1
server/core/__init__.py Normal file
View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@@ -0,0 +1,117 @@
# -*- coding: utf-8 -*-
"""
任务路由与派发。
根据请求中的 worker_id / account_name 找到目标 Worker通过 WebSocket 下发任务。
"""
from __future__ import annotations
import logging
import time
from typing import Dict, List, Optional
from common.protocol import MsgType, TaskStatus, make_msg
from server.models import TaskCreate, TaskInfo
logger = logging.getLogger("server.task_dispatcher")
class TaskDispatcher:
"""管理任务生命周期并派发给 Worker。"""
def __init__(self) -> None:
# task_id → TaskInfo
self._tasks: Dict[str, TaskInfo] = {}
# ─── 创建任务 ───
def create_task(self, req: TaskCreate) -> TaskInfo:
task = TaskInfo(
task_type=req.task_type,
worker_id=req.worker_id,
account_name=req.account_name,
params=req.params,
)
self._tasks[task.task_id] = task
logger.info("任务创建: %s type=%s worker=%s account=%s",
task.task_id, task.task_type, task.worker_id, task.account_name)
return task
# ─── 派发 ───
async def dispatch(self, task: TaskInfo, ws_send) -> bool:
"""
将任务通过 WebSocket 发给目标 Worker。
ws_send: 异步可调用,接受一个 dict 参数。
返回是否发送成功。
"""
msg = make_msg(
MsgType.TASK_ASSIGN,
task_id=task.task_id,
task_type=task.task_type.value,
account_name=task.account_name or "",
params=task.params,
)
try:
await ws_send(msg)
task.status = TaskStatus.DISPATCHED
task.updated_at = time.time()
logger.info("任务 %s 已派发", task.task_id)
return True
except Exception as e:
task.status = TaskStatus.FAILED
task.error = f"派发失败: {e}"
task.updated_at = time.time()
logger.error("任务 %s 派发失败: %s", task.task_id, e)
return False
# ─── 更新状态 ───
def update_progress(self, task_id: str, progress: str) -> None:
task = self._tasks.get(task_id)
if task:
task.status = TaskStatus.RUNNING
task.progress = progress
task.updated_at = time.time()
def complete_task(self, task_id: str, result=None, error: str = None) -> None:
task = self._tasks.get(task_id)
if not task:
return
if error:
task.status = TaskStatus.FAILED
task.error = error
else:
task.status = TaskStatus.SUCCESS
task.result = result
task.updated_at = time.time()
logger.info("任务 %s 完成: status=%s", task_id, task.status)
def cancel_task(self, task_id: str) -> None:
task = self._tasks.get(task_id)
if task and task.status in (TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING):
task.status = TaskStatus.CANCELLED
task.updated_at = time.time()
# ─── 查询 ───
def get_task(self, task_id: str) -> Optional[TaskInfo]:
return self._tasks.get(task_id)
def list_tasks(
self,
worker_id: Optional[str] = None,
status: Optional[TaskStatus] = None,
limit: int = 50,
) -> List[TaskInfo]:
result = list(self._tasks.values())
if worker_id:
result = [t for t in result if t.worker_id == worker_id]
if status:
result = [t for t in result if t.status == status]
# 按创建时间倒序
result.sort(key=lambda t: t.created_at, reverse=True)
return result[:limit]
# 全局单例
task_dispatcher = TaskDispatcher()

View File

@@ -0,0 +1,130 @@
# -*- coding: utf-8 -*-
"""
Worker 注册、状态管理、账号 → Worker 映射。
全部在内存中,服务重启后 Worker 重新连接即恢复。
"""
from __future__ import annotations
import asyncio
import logging
import time
from typing import Dict, Optional
from fastapi import WebSocket
from server.config import HEARTBEAT_TIMEOUT
from server.models import BrowserProfile, WorkerInfo
logger = logging.getLogger("server.worker_manager")
class WorkerManager:
"""管理所有已连接的 Worker。"""
def __init__(self) -> None:
# worker_id → WorkerInfo
self._workers: Dict[str, WorkerInfo] = {}
# worker_id → WebSocket 实例
self._connections: Dict[str, WebSocket] = {}
# account_name(lower) → worker_id 快速路由表
self._account_map: Dict[str, str] = {}
# ─── 注册 / 注销 ───
def register(
self,
ws: WebSocket,
worker_id: str,
worker_name: str,
browsers: list[dict],
) -> WorkerInfo:
profiles = [BrowserProfile(**b) for b in browsers]
info = WorkerInfo(
worker_id=worker_id,
worker_name=worker_name,
browsers=profiles,
online=True,
last_heartbeat=time.time(),
connected_at=time.time(),
)
self._workers[worker_id] = info
self._connections[worker_id] = ws
self._rebuild_account_map()
logger.info("Worker 注册: %s (%s), 浏览器 %d", worker_id, worker_name, len(profiles))
return info
def unregister(self, worker_id: str) -> None:
self._workers.pop(worker_id, None)
self._connections.pop(worker_id, None)
self._rebuild_account_map()
logger.info("Worker 注销: %s", worker_id)
# ─── 心跳 ───
def heartbeat(self, worker_id: str) -> None:
info = self._workers.get(worker_id)
if info:
info.last_heartbeat = time.time()
info.online = True
# ─── 浏览器列表更新 ───
def update_browsers(self, worker_id: str, browsers: list[dict]) -> None:
info = self._workers.get(worker_id)
if info:
info.browsers = [BrowserProfile(**b) for b in browsers]
self._rebuild_account_map()
logger.info("Worker %s 浏览器列表更新: %d", worker_id, len(info.browsers))
# ─── 查询 ───
def get_worker(self, worker_id: str) -> Optional[WorkerInfo]:
return self._workers.get(worker_id)
def get_all_workers(self) -> list[WorkerInfo]:
return list(self._workers.values())
def get_ws(self, worker_id: str) -> Optional[WebSocket]:
return self._connections.get(worker_id)
def find_worker_by_account(self, account_name: str) -> Optional[str]:
"""按浏览器窗口 name 查找对应的 worker_id。"""
return self._account_map.get(account_name.lower())
def is_online(self, worker_id: str) -> bool:
info = self._workers.get(worker_id)
return info is not None and info.online
# ─── 任务占用 ───
def set_current_task(self, worker_id: str, task_id: Optional[str]) -> None:
info = self._workers.get(worker_id)
if info:
info.current_task_id = task_id
# ─── 定时巡检(检测超时离线) ───
async def check_heartbeats_loop(self, interval: int = 15) -> None:
"""后台协程,定期检查心跳超时的 Worker。"""
while True:
await asyncio.sleep(interval)
now = time.time()
for wid, info in list(self._workers.items()):
if info.online and (now - info.last_heartbeat) > HEARTBEAT_TIMEOUT:
info.online = False
logger.warning("Worker %s 心跳超时,标记为离线", wid)
# ─── 内部 ───
def _rebuild_account_map(self) -> None:
self._account_map.clear()
for wid, info in self._workers.items():
if not info.online:
continue
for b in info.browsers:
if b.name:
self._account_map[b.name.lower()] = wid
# 全局单例
worker_manager = WorkerManager()

150
server/main.py Normal file
View File

@@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
"""
中央服务器入口。
启动方式: python -m server.main
"""
from __future__ import annotations
import asyncio
import json
import logging
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from common.protocol import MsgType, make_msg
from server import config
from server.api.workers import router as workers_router
from server.api.tasks import router as tasks_router
from server.core.worker_manager import worker_manager
from server.core.task_dispatcher import task_dispatcher
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(name)-28s %(levelname)-5s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("server.main")
# ────────────────────────── FastAPI App ──────────────────────────
app = FastAPI(title="Browser Control Server", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(workers_router)
app.include_router(tasks_router)
# ────────────────────────── 健康检查 ──────────────────────────
@app.get("/health")
async def health():
return {"status": "ok", "workers_online": len([w for w in worker_manager.get_all_workers() if w.online])}
# ────────────────────────── WebSocket 端点 ──────────────────────────
@app.websocket(config.WS_PATH)
async def ws_endpoint(ws: WebSocket):
"""
Worker 连接流程:
1. 建立连接
2. 等待第一条 register 消息
3. 持续收发消息(心跳、任务进度、任务结果等)
4. 断开时注销
"""
await ws.accept()
worker_id: str | None = None
try:
# ── 等待注册 ──
raw = await asyncio.wait_for(ws.receive_json(), timeout=30)
if raw.get("type") != MsgType.REGISTER.value:
await ws.send_json(make_msg(MsgType.ERROR, detail="首条消息必须是 register"))
await ws.close(code=4001)
return
worker_id = raw.get("worker_id", "")
worker_name = raw.get("worker_name", worker_id)
browsers = raw.get("browsers", [])
if not worker_id:
await ws.send_json(make_msg(MsgType.ERROR, detail="worker_id 不能为空"))
await ws.close(code=4002)
return
worker_manager.register(ws, worker_id, worker_name, browsers)
await ws.send_json(make_msg(MsgType.REGISTER_ACK, worker_id=worker_id))
logger.info("Worker %s 已连接", worker_id)
# ── 消息循环 ──
while True:
data = await ws.receive_json()
msg_type = data.get("type", "")
if msg_type == MsgType.HEARTBEAT.value:
worker_manager.heartbeat(worker_id)
await ws.send_json(make_msg(MsgType.HEARTBEAT_ACK))
elif msg_type == MsgType.BROWSER_LIST_UPDATE.value:
worker_manager.update_browsers(worker_id, data.get("browsers", []))
elif msg_type == MsgType.TASK_PROGRESS.value:
task_id = data.get("task_id", "")
progress = data.get("progress", "")
task_dispatcher.update_progress(task_id, progress)
logger.info("任务 %s 进度: %s", task_id, progress)
elif msg_type == MsgType.TASK_RESULT.value:
task_id = data.get("task_id", "")
result = data.get("result")
error = data.get("error")
task_dispatcher.complete_task(task_id, result=result, error=error)
# 释放 Worker 任务占用
worker_manager.set_current_task(worker_id, None)
logger.info("任务 %s 已完成", task_id)
else:
logger.warning("未知消息类型: %s (from %s)", msg_type, worker_id)
except WebSocketDisconnect:
logger.info("Worker %s WebSocket 断开", worker_id or "unknown")
except asyncio.TimeoutError:
logger.warning("WebSocket 连接超时(未在 30 秒内注册)")
await ws.close(code=4003)
except Exception as e:
logger.error("WebSocket 处理异常: %s", e, exc_info=True)
finally:
if worker_id:
worker_manager.unregister(worker_id)
# ────────────────────────── 生命周期 ──────────────────────────
@app.on_event("startup")
async def startup():
# 启动心跳巡检后台任务
asyncio.create_task(worker_manager.check_heartbeats_loop())
logger.info("服务器启动: http://%s:%s", config.HOST, config.PORT)
# ────────────────────────── 入口 ──────────────────────────
def main():
uvicorn.run(
"server.main:app",
host=config.HOST,
port=config.PORT,
log_level="info",
)
if __name__ == "__main__":
main()

82
server/models.py Normal file
View File

@@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-
"""
Pydantic 数据模型 —— 用于 REST API 请求 / 响应以及内部状态。
"""
from __future__ import annotations
import time
import uuid
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
from common.protocol import TaskStatus, TaskType
# ────────────────────────── Worker ──────────────────────────
class BrowserProfile(BaseModel):
"""比特浏览器窗口信息Worker 上报)。"""
id: str
name: str = ""
remark: str = ""
class WorkerInfo(BaseModel):
"""一台 Worker 的运行时信息(内存中保存)。"""
worker_id: str
worker_name: str = ""
browsers: List[BrowserProfile] = []
online: bool = True
last_heartbeat: float = Field(default_factory=time.time)
connected_at: float = Field(default_factory=time.time)
current_task_id: Optional[str] = None
class WorkerOut(BaseModel):
"""返回给前端的 Worker 信息。"""
worker_id: str
worker_name: str
browsers: List[BrowserProfile]
online: bool
current_task_id: Optional[str] = None
# ────────────────────────── Task ──────────────────────────
class TaskCreate(BaseModel):
"""前端提交任务的请求体。"""
task_type: TaskType
worker_id: Optional[str] = None # 手动指定机器
account_name: Optional[str] = None # 按比特浏览器窗口名自动路由
params: Dict[str, Any] = {} # 任务参数(如 job_title, max_greet 等)
class TaskInfo(BaseModel):
"""任务完整信息(内存 / 返回前端)。"""
task_id: str = Field(default_factory=lambda: uuid.uuid4().hex[:12])
task_type: TaskType
status: TaskStatus = TaskStatus.PENDING
worker_id: Optional[str] = None
account_name: Optional[str] = None
params: Dict[str, Any] = {}
progress: Optional[str] = None # 最新进度描述
result: Any = None # 最终结果
error: Optional[str] = None
created_at: float = Field(default_factory=time.time)
updated_at: float = Field(default_factory=time.time)
class TaskOut(BaseModel):
"""返回给前端的任务信息。"""
task_id: str
task_type: TaskType
status: TaskStatus
worker_id: Optional[str] = None
account_name: Optional[str] = None
params: Dict[str, Any] = {}
progress: Optional[str] = None
result: Any = None
error: Optional[str] = None
created_at: float
updated_at: float

1
worker/__init__.py Normal file
View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

118
worker/bit_browser.py Normal file
View File

@@ -0,0 +1,118 @@
# -*- coding: utf-8 -*-
"""
比特浏览器 API 封装。
提供:列出窗口、打开窗口、关闭窗口,返回 CDP 地址供 DrissionPage 连接。
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional, Tuple
import requests
logger = logging.getLogger("worker.bit_browser")
class BitBrowserAPI:
"""比特浏览器本地 API 客户端。"""
def __init__(self, base_url: str = "http://127.0.0.1:54345") -> None:
self.base_url = base_url.rstrip("/")
def _post(self, path: str, data: dict, timeout: int = 15) -> dict:
url = f"{self.base_url}{path}"
r = requests.post(url, json=data, timeout=timeout)
r.raise_for_status()
return r.json()
# ─── 列出浏览器窗口 ───
def list_browsers(
self,
name: Optional[str] = None,
remark: Optional[str] = None,
page: int = 0,
page_size: int = 100,
) -> List[Dict[str, Any]]:
"""获取比特浏览器窗口列表。可按 name / remark 模糊筛选。"""
data: Dict[str, Any] = {"page": page, "pageSize": page_size}
if name is not None:
data["name"] = name
if remark is not None:
data["remark"] = remark
res = self._post("/browser/list", data)
if not res.get("success"):
raise RuntimeError(f"list 失败: {res.get('msg', res)}")
list_data = res.get("data")
if not list_data:
return []
if isinstance(list_data, list):
return list_data
if isinstance(list_data, dict) and "list" in list_data:
return list_data["list"]
return []
# ─── 打开浏览器窗口 ───
def open_browser(
self,
browser_id: Optional[str] = None,
name: Optional[str] = None,
remark: Optional[str] = None,
) -> Tuple[str, int, str]:
"""
打开指定浏览器窗口。
返回 (cdp_addr, port, browser_id)。
"""
if browser_id is None:
items = self.list_browsers(name=name, remark=remark, page_size=10)
if not items:
raise RuntimeError("没有匹配的浏览器窗口")
browser_id = items[0].get("id")
if not browser_id:
raise RuntimeError("列表项中没有 id")
res = self._post("/browser/open", {"id": browser_id})
if not res.get("success"):
raise RuntimeError(f"open 失败: {res.get('msg', res)}")
data = res.get("data") or {}
http_addr = data.get("http")
if not http_addr:
raise RuntimeError(f"返回中无 http 地址: {data}")
if ":" in str(http_addr):
host, port_str = str(http_addr).strip().rsplit(":", 1)
port = int(port_str)
cdp_addr = f"{host}:{port}"
else:
port = int(http_addr)
cdp_addr = f"127.0.0.1:{port}"
logger.info("已打开浏览器 %s, CDP: %s", browser_id, cdp_addr)
return cdp_addr, port, browser_id
# ─── 关闭浏览器窗口 ───
def close_browser(self, browser_id: str) -> bool:
"""关闭指定浏览器窗口。"""
try:
res = self._post("/browser/close", {"id": browser_id})
return res.get("success", False)
except Exception as e:
logger.warning("关闭浏览器 %s 失败: %s", browser_id, e)
return False
# ─── 便捷方法:获取供 DrissionPage 连接的信息 ───
def get_browser_for_drission(
self,
browser_id: Optional[str] = None,
name: Optional[str] = None,
remark: Optional[str] = None,
) -> Tuple[str, int]:
"""按需打开浏览器,返回 (addr, port) 供 DrissionPage 使用。"""
cdp_addr, port, _ = self.open_browser(browser_id=browser_id, name=name, remark=remark)
return cdp_addr, port

119
worker/browser_control.py Normal file
View File

@@ -0,0 +1,119 @@
# -*- coding: utf-8 -*-
"""
DrissionPage 浏览器控制基础封装。
提供连接浏览器、拟人化操作等通用能力,供各任务处理器复用。
"""
from __future__ import annotations
import logging
import random
import time
from typing import Optional
logger = logging.getLogger("worker.browser_control")
try:
from DrissionPage import Chromium
except ImportError:
Chromium = None # type: ignore
# ─── 拟人化参数 ───
HUMAN_DELAY_CLICK = (0.2, 0.6)
HUMAN_DELAY_BETWEEN = (0.1, 0.4)
HUMAN_MOVE_DURATION = (0.25, 0.7)
HUMAN_CLICK_OFFSET = 12
def ensure_drission():
"""确保 DrissionPage 已安装。"""
if Chromium is None:
raise RuntimeError("请先安装: pip install DrissionPage")
def connect_browser(port: int = None, addr: str = None):
"""使用 DrissionPage 连接浏览器。"""
ensure_drission()
if port is not None:
return Chromium(port)
if addr:
return Chromium(addr)
raise ValueError("需要 port 或 addr 以连接浏览器")
def human_delay(low: float = 0.2, high: float = 0.6):
"""随机延迟,模拟人类操作间隔。"""
time.sleep(random.uniform(low, high))
def human_click(tab, ele, scroll_first: bool = True) -> bool:
"""
拟人化点击:先滚动到可见(可选),再动作链「移动 → 短暂停顿 → 点击」,
移动带随机耗时与元素内随机偏移。
"""
if ele is None:
return False
try:
if scroll_first and hasattr(ele, "scroll") and getattr(ele.scroll, "to_see", None):
ele.scroll.to_see()
human_delay(0.15, 0.4)
ox = random.randint(-HUMAN_CLICK_OFFSET, HUMAN_CLICK_OFFSET)
oy = random.randint(-HUMAN_CLICK_OFFSET, HUMAN_CLICK_OFFSET)
duration = random.uniform(*HUMAN_MOVE_DURATION)
tab.actions.move_to(ele, offset_x=ox, offset_y=oy, duration=duration)
tab.actions.wait(*HUMAN_DELAY_BETWEEN)
tab.actions.click()
return True
except Exception:
try:
ele.click()
return True
except Exception:
return False
def safe_click(tab, ele) -> bool:
"""多种方式尝试点击元素(拟人 → 普通 → JS确保触发成功。"""
if ele is None:
return False
# 方式1: 拟人点击
if human_click(tab, ele):
return True
# 方式2: 普通模拟点击
try:
ele.click(by_js=False)
return True
except Exception:
pass
# 方式3: JS 点击
try:
ele.click(by_js=True)
return True
except Exception:
pass
return False
def find_element(tab, selectors: list[str], timeout: int = 3):
"""按优先级尝试多个选择器,返回第一个找到的元素。"""
for sel in selectors:
try:
ele = tab.ele(sel, timeout=timeout)
if ele:
return ele
except Exception:
continue
return None
def find_elements(tab, selectors: list[str], timeout: int = 3) -> list:
"""按优先级尝试多个选择器,返回第一组找到的元素列表。"""
for sel in selectors:
try:
eles = tab.eles(sel, timeout=timeout)
if eles and len(eles) > 0:
return eles
except Exception:
continue
return []

20
worker/config.py Normal file
View File

@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
"""
Worker 配置。通过命令行参数或环境变量设置。
"""
import os
# ─── 中央服务器 ───
SERVER_WS_URL: str = os.getenv("SERVER_WS_URL", "ws://127.0.0.1:8000/ws")
# ─── Worker 标识 ───
WORKER_ID: str = os.getenv("WORKER_ID", "worker-1")
WORKER_NAME: str = os.getenv("WORKER_NAME", "本机")
# ─── 比特浏览器 ───
BIT_API_BASE: str = os.getenv("BIT_API_BASE", "http://127.0.0.1:54345")
# ─── WebSocket ───
HEARTBEAT_INTERVAL: int = 25 # 心跳发送间隔(秒)
RECONNECT_DELAY: int = 5 # 断线重连等待(秒)
RECONNECT_MAX_DELAY: int = 60 # 重连最大等待(秒,指数退避上限)

84
worker/main.py Normal file
View File

@@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
"""
Worker 启动入口。
启动方式: python -m worker.main [--server ws://IP:8000/ws] [--worker-id pc-a] [--worker-name 电脑A]
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import sys
from worker import config
from worker.tasks.registry import register_all_handlers
from worker.ws_client import WorkerWSClient
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(name)-28s %(levelname)-5s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("worker.main")
def parse_args():
parser = argparse.ArgumentParser(description="Browser Control Worker Agent")
parser.add_argument(
"--server",
default=config.SERVER_WS_URL,
help=f"中央服务器 WebSocket 地址 (默认: {config.SERVER_WS_URL})",
)
parser.add_argument(
"--worker-id",
default=config.WORKER_ID,
help=f"Worker ID (默认: {config.WORKER_ID})",
)
parser.add_argument(
"--worker-name",
default=config.WORKER_NAME,
help=f"Worker 名称 (默认: {config.WORKER_NAME})",
)
parser.add_argument(
"--bit-api",
default=config.BIT_API_BASE,
help=f"比特浏览器本地 API 地址 (默认: {config.BIT_API_BASE})",
)
return parser.parse_args()
async def run(args):
# 注册所有任务处理器
register_all_handlers()
logger.info("已注册任务处理器")
# 创建 WebSocket 客户端并运行
client = WorkerWSClient(
server_url=args.server,
worker_id=args.worker_id,
worker_name=args.worker_name,
bit_api_base=args.bit_api,
)
logger.info(
"Worker 启动: id=%s, name=%s, server=%s",
args.worker_id, args.worker_name, args.server,
)
try:
await client.run()
except KeyboardInterrupt:
logger.info("收到中断信号,正在退出...")
await client.stop()
def main():
args = parse_args()
try:
asyncio.run(run(args))
except KeyboardInterrupt:
logger.info("Worker 已退出")
if __name__ == "__main__":
main()

1
worker/tasks/__init__.py Normal file
View File

@@ -0,0 +1 @@
# -*- coding: utf-8 -*-

51
worker/tasks/base.py Normal file
View File

@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
"""
任务处理器基类。
所有具体任务(如 BOSS 招聘)都继承此类,实现 execute 方法。
"""
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Coroutine, Dict, Optional
class BaseTaskHandler(ABC):
"""
任务处理器基类。
子类需实现:
execute(task_id, params, progress_callback) -> result
progress_callback 是一个异步函数:
await progress_callback(task_id, "当前进度描述")
"""
# 子类覆盖,声明处理哪个 task_type
task_type: str = ""
def __init__(self) -> None:
self.logger = logging.getLogger(f"worker.tasks.{self.task_type or self.__class__.__name__}")
@abstractmethod
async def execute(
self,
task_id: str,
params: Dict[str, Any],
progress_cb: Callable[[str, str], Coroutine],
) -> Any:
"""
执行任务。
Args:
task_id: 任务 ID
params: 任务参数
progress_cb: 进度上报回调 async def(task_id, progress_text)
Returns:
任务结果(可序列化为 JSON
Raises:
Exception: 执行失败时抛出异常
"""
...

View File

@@ -0,0 +1,335 @@
# -*- coding: utf-8 -*-
"""
BOSS 直聘招聘任务处理器。
复用并重构原 boss_drission.py 的核心流程。
"""
from __future__ import annotations
import asyncio
import random
import re
import time
from typing import Any, Callable, Coroutine, Dict, List
from common.protocol import TaskType
from worker.tasks.base import BaseTaskHandler
from worker.bit_browser import BitBrowserAPI
from worker.browser_control import (
connect_browser,
find_element,
find_elements,
human_click,
human_delay,
safe_click,
)
# ─── 常量 ───
CHAT_INDEX_URL = "https://www.zhipin.com/web/chat/index"
ACTION_DELAY = 1.5
BETWEEN_CHAT_DELAY = 2.5
class BossRecruitHandler(BaseTaskHandler):
"""BOSS 直聘招聘自动化任务。"""
task_type = TaskType.BOSS_RECRUIT.value
async def execute(
self,
task_id: str,
params: Dict[str, Any],
progress_cb: Callable[[str, str], Coroutine],
) -> Any:
"""
执行 BOSS 招聘流程。
params:
- job_title: str 招聘岗位名称
- max_greet: int 最大打招呼人数(默认 5
- account_name: str 比特浏览器窗口名(用于打开浏览器)
- account_id: str 比特浏览器窗口 ID可选优先级高于 name
- bit_api_base: str 比特浏览器 API 地址(可选)
"""
job_title = params.get("job_title", "相关岗位")
max_greet = params.get("max_greet", 5)
account_name = params.get("account_name", "")
account_id = params.get("account_id", "")
bit_api_base = params.get("bit_api_base", "http://127.0.0.1:54345")
await progress_cb(task_id, "正在打开比特浏览器...")
# 在线程池中执行同步的浏览器操作DrissionPage 是同步库)
result = await asyncio.get_event_loop().run_in_executor(
None,
self._run_sync,
task_id, job_title, max_greet, account_name, account_id, bit_api_base, progress_cb,
)
return result
def _run_sync(
self,
task_id: str,
job_title: str,
max_greet: int,
account_name: str,
account_id: str,
bit_api_base: str,
progress_cb: Callable,
) -> dict:
"""同步执行浏览器自动化(在线程池中运行)。"""
# 1. 打开比特浏览器
bit_api = BitBrowserAPI(bit_api_base)
addr, port = bit_api.get_browser_for_drission(
browser_id=account_id or None,
name=account_name or None,
)
self.logger.info("已打开浏览器, CDP: %s (port=%d)", addr, port)
# 2. 连接浏览器
browser = connect_browser(port=port)
tab = browser.latest_tab
# 3. 打开 BOSS 直聘聊天页
tab.get(CHAT_INDEX_URL)
tab.wait.load_start()
human_delay(2.5, 4.0)
# 4. 执行招聘流程
collected = self._recruit_flow(tab, job_title, max_greet)
return {
"job_title": job_title,
"total_processed": len(collected),
"wechat_collected": sum(1 for c in collected if c.get("wechat")),
"details": collected,
}
def _recruit_flow(self, tab, job_title: str, max_greet: int) -> List[dict]:
"""核心招聘流程:遍历聊天列表,打招呼、询问微信号、收集结果。"""
greeting = f"您好,我们正在招【{job_title}】,看到您的经历比较匹配,方便简单聊聊吗?"
ask_wechat = "后续沟通会更及时,您方便留一下您的微信号吗?我这边加您。"
collected = []
# 获取左侧会话列表
items = self._get_conversation_items(tab)
if not items:
self.logger.warning("未找到会话列表元素")
return collected
total = min(len(items), max_greet)
self.logger.info("会话数约 %d,本次处理前 %d", len(items), total)
for i in range(total):
try:
human_delay(max(1.8, BETWEEN_CHAT_DELAY - 0.7), BETWEEN_CHAT_DELAY + 1.0)
items = self._get_conversation_items(tab)
if i >= len(items):
break
item = items[i]
# 获取候选人名称
name = self._get_candidate_name(item, i)
# 点击进入聊天
self._click_conversation(tab, item)
human_delay(1.2, 2.2)
# 等待输入框
inp = self._wait_for_input(tab)
if not inp:
self.logger.info("[%s] 未进入聊天,跳过", name)
continue
# 分析聊天上下文
messages = self._get_chat_messages(tab)
ctx = self._analyze_context(messages, job_title)
# 发招呼
if not ctx["already_greeting"]:
if self._send_message(tab, inp, greeting):
self.logger.info("[%s] 已发送招呼", name)
human_delay(1.5, 2.8)
else:
self.logger.info("[%s] 已有招呼记录,跳过", name)
# 询问微信号
if not ctx["already_asked_wechat"]:
if self._send_message(tab, inp, ask_wechat):
self.logger.info("[%s] 已询问微信号", name)
human_delay(1.5, 2.8)
# 收集微信号
human_delay(1.0, 2.0)
messages = self._get_chat_messages(tab)
ctx = self._analyze_context(messages, job_title)
wechats = ctx["wechats"][:2]
collected.append({
"name": name,
"job": job_title,
"wechat": wechats[0] if wechats else "",
})
except Exception as e:
self.logger.error("处理第 %d 个会话出错: %s", i + 1, e)
continue
return collected
# ─── 辅助方法 ───
def _get_conversation_items(self, tab) -> list:
selectors = [
"css:div.chat-container div.geek-item",
"css:div[role='listitem'] div.geek-item",
"css:.geek-item-wrap .geek-item",
"css:div.geek-item",
"css:div[role='listitem']",
]
return find_elements(tab, selectors, timeout=3)
def _get_candidate_name(self, item, index: int) -> str:
try:
name_el = item.ele("css:.geek-name", timeout=1)
if name_el and name_el.text:
return name_el.text.strip()
except Exception:
pass
try:
if item.text:
return item.text.strip()[:20]
except Exception:
pass
return f"候选人{index + 1}"
def _click_conversation(self, tab, item) -> bool:
if safe_click(tab, item):
return True
try:
parent = item.parent()
if parent and parent.attr("role") == "listitem":
return safe_click(tab, parent)
except Exception:
pass
return False
def _wait_for_input(self, tab, retries: int = 6):
for _ in range(retries):
inp = find_element(tab, [
"css:#boss-chat-editor-input",
"css:.boss-chat-editor-input",
], timeout=1)
if inp:
return inp
human_delay(0.5, 0.9)
return None
def _send_message(self, tab, inp, message: str) -> bool:
try:
human_click(tab, inp)
human_delay(0.25, 0.55)
try:
inp.clear()
except Exception:
pass
human_delay(0.15, 0.4)
inp.input(message)
human_delay(0.4, 0.9)
if self._click_send_button(tab):
return True
inp.input("\n")
human_delay(0.35, 0.7)
return True
except Exception:
return False
def _click_send_button(self, tab) -> bool:
human_delay(0.2, 0.45)
btn = find_element(tab, [
"css:.conversation-editor .submit-content .submit.active",
"css:.conversation-editor div.submit.active",
"css:div.submit-content div.submit.active",
"css:div.submit.active",
"css:.submit-content .submit",
"css:div.submit",
"text:发送",
], timeout=1)
if btn:
if safe_click(tab, btn):
return True
# JS 兜底
scripts = [
"var el = document.querySelector('.conversation-editor .submit.active') || document.querySelector('.conversation-editor .submit'); if(el){ el.click(); return true; } return false;",
"var els = document.querySelectorAll('div[class*=\"submit\"]'); for(var i=0;i<els.length;i++){ if(els[i].textContent.trim()==='发送'){ els[i].click(); return true; } } return false;",
]
for script in scripts:
try:
if tab.run_js(script) is True:
return True
except Exception:
continue
return False
def _get_chat_messages(self, tab) -> List[dict]:
result = []
try:
items = tab.eles("css:.message-item", timeout=2)
if not items:
return result
for e in items[-50:]:
t = (e.text or "").strip()
if not t or "沟通的职位" in t:
continue
role = "friend"
try:
if e.ele("css:.item-boss", timeout=0):
role = "boss"
elif e.ele("css:.item-friend", timeout=0):
role = "friend"
else:
cls = (e.attr("class") or "") + " "
if "item-boss" in cls or ("boss" in cls and "friend" not in cls):
role = "boss"
except Exception:
if any(k in t for k in ("", "岗位", "微信号", "方便留", "加您")):
role = "boss"
result.append({"role": role, "text": t})
except Exception:
pass
return result
def _analyze_context(self, messages: list, job_title: str) -> dict:
boss_texts = [m["text"] for m in messages if m.get("role") == "boss"]
friend_texts = [m["text"] for m in messages if m.get("role") == "friend"]
full_boss = " ".join(boss_texts)
wechats = []
for t in friend_texts:
wechats.extend(self._extract_wechat(t))
wechats = list(dict.fromkeys(wechats))[:3]
return {
"already_greeting": job_title in full_boss or "" in full_boss,
"already_asked_wechat": "微信" in full_boss or "微信号" in full_boss,
"wechats": wechats,
}
@staticmethod
def _extract_wechat(text: str) -> list:
if not text or not text.strip():
return []
found = []
patterns = [
r"微信号[:\s]*([a-zA-Z0-9_\-]{6,20})",
r"微信[:\s]*([a-zA-Z0-9_\-]{6,20})",
r"wx[:\s]*([a-zA-Z0-9_\-]{6,20})",
r"wechat[:\s]*([a-zA-Z0-9_\-]{6,20})",
r"([a-zA-Z][a-zA-Z0-9_\-]{5,19})",
]
for p in patterns:
for m in re.finditer(p, text, re.IGNORECASE):
s = m.group(1).strip() if m.lastindex else m.group(0).strip()
if s and s not in found and len(s) >= 6:
found.append(s)
return found[:3]

44
worker/tasks/registry.py Normal file
View File

@@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
"""
任务处理器注册表。
Worker 启动时注册所有可用的 handler收到任务时按 task_type 查找对应 handler。
"""
from __future__ import annotations
import logging
from typing import Dict, Optional, Type
from worker.tasks.base import BaseTaskHandler
logger = logging.getLogger("worker.tasks.registry")
# task_type → handler 实例
_registry: Dict[str, BaseTaskHandler] = {}
def register_handler(handler_cls: Type[BaseTaskHandler]) -> None:
"""注册一个任务处理器类(自动实例化)。"""
instance = handler_cls()
if not instance.task_type:
raise ValueError(f"{handler_cls.__name__} 未设置 task_type")
_registry[instance.task_type] = instance
logger.info("注册任务处理器: %s%s", instance.task_type, handler_cls.__name__)
def get_handler(task_type: str) -> Optional[BaseTaskHandler]:
"""根据 task_type 获取对应的处理器实例。"""
return _registry.get(task_type)
def list_handlers() -> list[str]:
"""列出所有已注册的 task_type。"""
return list(_registry.keys())
def register_all_handlers() -> None:
"""注册所有内置任务处理器。在此函数中 import 并注册。"""
from worker.tasks.boss_recruit import BossRecruitHandler
register_handler(BossRecruitHandler)
# 未来扩展:在此处添加新的 handler
# from worker.tasks.xxx import XxxHandler
# register_handler(XxxHandler)

204
worker/ws_client.py Normal file
View File

@@ -0,0 +1,204 @@
# -*- coding: utf-8 -*-
"""
Worker WebSocket 客户端。
负责:连接服务器、注册、心跳、接收任务、上报进度/结果、断线重连。
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from typing import List, Optional
import websockets
from websockets.exceptions import ConnectionClosed
from common.protocol import MsgType, TaskStatus, make_msg
from worker import config
from worker.bit_browser import BitBrowserAPI
from worker.tasks.registry import get_handler
logger = logging.getLogger("worker.ws_client")
class WorkerWSClient:
"""Worker WebSocket 客户端。"""
def __init__(
self,
server_url: str,
worker_id: str,
worker_name: str,
bit_api_base: str,
) -> None:
self.server_url = server_url
self.worker_id = worker_id
self.worker_name = worker_name
self.bit_api = BitBrowserAPI(bit_api_base)
self._ws: Optional[websockets.WebSocketClientProtocol] = None
self._running = False
self._reconnect_delay = config.RECONNECT_DELAY
# ────────────────────────── 主循环 ──────────────────────────
async def run(self) -> None:
"""主循环:连接 → 注册 → 收发消息;断线则重连。"""
self._running = True
while self._running:
try:
await self._connect_and_loop()
except Exception as e:
logger.error("连接异常: %s", e)
if self._running:
logger.info("将在 %d 秒后重连...", self._reconnect_delay)
await asyncio.sleep(self._reconnect_delay)
# 指数退避
self._reconnect_delay = min(
self._reconnect_delay * 2,
config.RECONNECT_MAX_DELAY,
)
async def stop(self) -> None:
self._running = False
if self._ws:
await self._ws.close()
# ────────────────────────── 连接流程 ──────────────────────────
async def _connect_and_loop(self) -> None:
logger.info("正在连接服务器: %s", self.server_url)
async with websockets.connect(self.server_url) as ws:
self._ws = ws
self._reconnect_delay = config.RECONNECT_DELAY # 连接成功重置退避
logger.info("WebSocket 已连接")
# 注册
await self._register(ws)
# 启动心跳协程
heartbeat_task = asyncio.create_task(self._heartbeat_loop(ws))
try:
# 消息接收循环
async for raw in ws:
data = json.loads(raw)
await self._handle_message(ws, data)
except ConnectionClosed as e:
logger.warning("WebSocket 连接关闭: %s", e)
finally:
heartbeat_task.cancel()
self._ws = None
async def _register(self, ws) -> None:
"""发送注册消息。"""
browsers = self._fetch_browser_list()
msg = make_msg(
MsgType.REGISTER,
worker_id=self.worker_id,
worker_name=self.worker_name,
browsers=browsers,
)
await ws.send(json.dumps(msg))
# 等待 ACK
ack_raw = await asyncio.wait_for(ws.recv(), timeout=10)
ack = json.loads(ack_raw)
if ack.get("type") == MsgType.REGISTER_ACK.value:
logger.info("注册成功: worker_id=%s", self.worker_id)
else:
logger.error("注册失败: %s", ack)
raise RuntimeError(f"注册失败: {ack}")
# ────────────────────────── 心跳 ──────────────────────────
async def _heartbeat_loop(self, ws) -> None:
"""定期发送心跳。"""
while True:
try:
await asyncio.sleep(config.HEARTBEAT_INTERVAL)
msg = make_msg(MsgType.HEARTBEAT, worker_id=self.worker_id)
await ws.send(json.dumps(msg))
logger.debug("心跳已发送")
except Exception:
break
# ────────────────────────── 消息处理 ──────────────────────────
async def _handle_message(self, ws, data: dict) -> None:
msg_type = data.get("type", "")
if msg_type == MsgType.HEARTBEAT_ACK.value:
logger.debug("收到心跳 ACK")
elif msg_type == MsgType.TASK_ASSIGN.value:
await self._handle_task(ws, data)
elif msg_type == MsgType.TASK_CANCEL.value:
task_id = data.get("task_id", "")
logger.info("收到任务取消: %s(暂不支持中途取消)", task_id)
elif msg_type == MsgType.ERROR.value:
logger.error("服务器错误: %s", data.get("detail", ""))
else:
logger.warning("未知消息: %s", msg_type)
async def _handle_task(self, ws, data: dict) -> None:
"""接收并执行任务。"""
task_id = data.get("task_id", "")
task_type = data.get("task_type", "")
account_name = data.get("account_name", "")
params = data.get("params", {})
# 将 account_name 注入 params供 handler 使用)
if account_name:
params.setdefault("account_name", account_name)
params.setdefault("bit_api_base", self.bit_api.base_url)
logger.info("收到任务: %s (type=%s)", task_id, task_type)
handler = get_handler(task_type)
if not handler:
error_msg = f"不支持的任务类型: {task_type}"
logger.error(error_msg)
await self._send_result(ws, task_id, error=error_msg)
return
# 上报进度的回调
async def progress_cb(tid: str, progress: str):
msg = make_msg(MsgType.TASK_PROGRESS, task_id=tid, progress=progress)
try:
await ws.send(json.dumps(msg))
except Exception:
pass
# 执行任务
try:
result = await handler.execute(task_id, params, progress_cb)
await self._send_result(ws, task_id, result=result)
except Exception as e:
logger.error("任务 %s 执行失败: %s", task_id, e, exc_info=True)
await self._send_result(ws, task_id, error=str(e))
async def _send_result(self, ws, task_id: str, result=None, error: str = None) -> None:
"""上报任务最终结果。"""
msg = make_msg(MsgType.TASK_RESULT, task_id=task_id, result=result, error=error)
try:
await ws.send(json.dumps(msg))
logger.info("任务 %s 结果已上报 (error=%s)", task_id, error)
except Exception as e:
logger.error("上报结果失败: %s", e)
# ────────────────────────── 比特浏览器列表 ──────────────────────────
def _fetch_browser_list(self) -> List[dict]:
"""获取本机比特浏览器窗口列表(用于注册时上报)。"""
try:
items = self.bit_api.list_browsers()
return [
{"id": b.get("id", ""), "name": b.get("name", ""), "remark": b.get("remark", "")}
for b in items
]
except Exception as e:
logger.warning("获取比特浏览器列表失败(比特浏览器可能未启动): %s", e)
return []