初始化:分布式浏览器控制后台
This commit is contained in:
37
.gitignore
vendored
Normal file
37
.gitignore
vendored
Normal file
@@ -0,0 +1,37 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
*.egg
|
||||
|
||||
# 虚拟环境
|
||||
venv/
|
||||
.venv/
|
||||
env/
|
||||
ENV/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# 操作系统
|
||||
Thumbs.db
|
||||
Desktop.ini
|
||||
.DS_Store
|
||||
|
||||
# 环境变量 / 密钥
|
||||
.env
|
||||
.env.*
|
||||
|
||||
# 日志
|
||||
*.log
|
||||
|
||||
# Cursor
|
||||
.cursor/
|
||||
203
README.md
Normal file
203
README.md
Normal file
@@ -0,0 +1,203 @@
|
||||
# 分布式浏览器控制后台
|
||||
|
||||
通过 **中央服务器 + Worker 代理** 架构,远程控制多台电脑上的比特浏览器,并使用 DrissionPage 执行自动化任务。
|
||||
|
||||
---
|
||||
|
||||
## 架构概览
|
||||
|
||||
```
|
||||
前端 / Postman 中央服务器(第三台机器) 电脑 A / B
|
||||
─────────────── ───────────────────── ──────────
|
||||
REST API ──── HTTP ────→ FastAPI Server Worker Agent
|
||||
├── Worker Manager ←── WebSocket ──┤
|
||||
└── Task Dispatcher ──── WebSocket ──→ 比特浏览器
|
||||
+ DrissionPage
|
||||
```
|
||||
|
||||
- **中央服务器**:接收前端 API 请求,管理 Worker 状态,路由并派发任务
|
||||
- **Worker 代理**:运行在每台电脑上,通过 WebSocket 连接服务器,接收任务后本地控制比特浏览器 + DrissionPage
|
||||
|
||||
---
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 安装依赖
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. 启动中央服务器
|
||||
|
||||
在第三台机器(或任一台能被其他机器访问的机器)上运行:
|
||||
|
||||
```bash
|
||||
python -m server.main
|
||||
```
|
||||
|
||||
默认监听 `0.0.0.0:8000`。可通过环境变量修改:
|
||||
|
||||
```bash
|
||||
set SERVER_HOST=0.0.0.0
|
||||
set SERVER_PORT=8000
|
||||
python -m server.main
|
||||
```
|
||||
|
||||
### 3. 启动 Worker(每台电脑上各运行一个)
|
||||
|
||||
**电脑 A:**
|
||||
```bash
|
||||
python -m worker.main --server ws://服务器IP:8000/ws --worker-id pc-a --worker-name "电脑A"
|
||||
```
|
||||
|
||||
**电脑 B:**
|
||||
```bash
|
||||
python -m worker.main --server ws://服务器IP:8000/ws --worker-id pc-b --worker-name "电脑B"
|
||||
```
|
||||
|
||||
Worker 启动后会自动:
|
||||
1. 连接中央服务器
|
||||
2. 上报本机比特浏览器窗口列表
|
||||
3. 等待接收任务
|
||||
|
||||
> 注意:每台电脑需先启动比特浏览器客户端(默认 API 端口 54345)。
|
||||
|
||||
---
|
||||
|
||||
## API 接口
|
||||
|
||||
### 查看在线 Worker
|
||||
|
||||
```
|
||||
GET /api/workers
|
||||
```
|
||||
|
||||
响应示例:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"worker_id": "pc-a",
|
||||
"worker_name": "电脑A",
|
||||
"browsers": [
|
||||
{"id": "abc123", "name": "BOSS主号", "remark": "张三"}
|
||||
],
|
||||
"online": true,
|
||||
"current_task_id": null
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### 提交任务
|
||||
|
||||
```
|
||||
POST /api/tasks
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"task_type": "boss_recruit",
|
||||
"account_name": "BOSS主号",
|
||||
"params": {
|
||||
"job_title": "前端工程师",
|
||||
"max_greet": 5
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
路由规则:
|
||||
- 指定 `worker_id` → 直接发到该 Worker
|
||||
- 指定 `account_name` → 自动找到拥有该浏览器的 Worker
|
||||
- 两者都传 → `worker_id` 优先
|
||||
|
||||
### 查询任务状态
|
||||
|
||||
```
|
||||
GET /api/tasks/{task_id}
|
||||
```
|
||||
|
||||
### 查询任务列表
|
||||
|
||||
```
|
||||
GET /api/tasks?worker_id=pc-a&status=success&limit=20
|
||||
```
|
||||
|
||||
### 健康检查
|
||||
|
||||
```
|
||||
GET /health
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
boss_dp/
|
||||
├── server/ # 中央服务器
|
||||
│ ├── main.py # FastAPI 入口(含 WebSocket 端点)
|
||||
│ ├── config.py # 服务器配置
|
||||
│ ├── models.py # Pydantic 数据模型
|
||||
│ ├── api/
|
||||
│ │ ├── workers.py # Worker 查询 API
|
||||
│ │ └── tasks.py # 任务提交与查询 API
|
||||
│ └── core/
|
||||
│ ├── worker_manager.py # Worker 注册、状态管理、账号映射
|
||||
│ └── task_dispatcher.py # 任务路由与派发
|
||||
├── worker/ # Worker 代理(每台电脑部署)
|
||||
│ ├── main.py # 启动入口
|
||||
│ ├── config.py # Worker 配置
|
||||
│ ├── ws_client.py # WebSocket 客户端(心跳、重连)
|
||||
│ ├── bit_browser.py # 比特浏览器 API 封装
|
||||
│ ├── browser_control.py # DrissionPage 通用控制封装
|
||||
│ └── tasks/
|
||||
│ ├── base.py # 任务处理器基类
|
||||
│ ├── registry.py # 处理器注册表
|
||||
│ └── boss_recruit.py # BOSS 直聘招聘任务
|
||||
├── common/
|
||||
│ └── protocol.py # 共享消息协议
|
||||
├── requirements.txt
|
||||
└── README.md
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 扩展新任务
|
||||
|
||||
1. 在 `worker/tasks/` 下新建文件,继承 `BaseTaskHandler`
|
||||
2. 实现 `execute` 方法
|
||||
3. 在 `common/protocol.py` 的 `TaskType` 中添加新类型
|
||||
4. 在 `worker/tasks/registry.py` 的 `register_all_handlers()` 中注册
|
||||
|
||||
```python
|
||||
# worker/tasks/my_task.py
|
||||
from worker.tasks.base import BaseTaskHandler
|
||||
|
||||
class MyTaskHandler(BaseTaskHandler):
|
||||
task_type = "my_task"
|
||||
|
||||
async def execute(self, task_id, params, progress_cb):
|
||||
await progress_cb(task_id, "开始执行...")
|
||||
# ... 你的自动化逻辑 ...
|
||||
return {"status": "done"}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 配置项
|
||||
|
||||
### 服务器 (server/config.py)
|
||||
|
||||
| 环境变量 | 默认值 | 说明 |
|
||||
|---------|--------|------|
|
||||
| `SERVER_HOST` | `0.0.0.0` | 监听地址 |
|
||||
| `SERVER_PORT` | `8000` | 监听端口 |
|
||||
| `API_TOKEN` | (空) | 非空时校验 Authorization Header |
|
||||
|
||||
### Worker (worker/config.py)
|
||||
|
||||
| 环境变量 | 默认值 | 说明 |
|
||||
|---------|--------|------|
|
||||
| `SERVER_WS_URL` | `ws://127.0.0.1:8000/ws` | 服务器 WebSocket 地址 |
|
||||
| `WORKER_ID` | `worker-1` | Worker 唯一标识 |
|
||||
| `WORKER_NAME` | `本机` | Worker 显示名称 |
|
||||
| `BIT_API_BASE` | `http://127.0.0.1:54345` | 比特浏览器本地 API |
|
||||
1
common/__init__.py
Normal file
1
common/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
54
common/protocol.py
Normal file
54
common/protocol.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
共享消息协议定义。
|
||||
服务器与 Worker 之间通过 WebSocket 传递 JSON 消息,每条消息包含 type 字段标识消息类型。
|
||||
"""
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# ────────────────────────── 消息类型 ──────────────────────────
|
||||
|
||||
class MsgType(str, Enum):
|
||||
"""WebSocket 消息类型枚举(str 混入方便 JSON 序列化)。"""
|
||||
|
||||
# Worker → Server
|
||||
REGISTER = "register" # 注册:上报 worker 信息与浏览器列表
|
||||
HEARTBEAT = "heartbeat" # 心跳
|
||||
BROWSER_LIST_UPDATE = "browser_list_update" # 浏览器列表变更
|
||||
TASK_PROGRESS = "task_progress" # 任务进度上报
|
||||
TASK_RESULT = "task_result" # 任务最终结果
|
||||
|
||||
# Server → Worker
|
||||
REGISTER_ACK = "register_ack" # 注册确认
|
||||
HEARTBEAT_ACK = "heartbeat_ack" # 心跳确认
|
||||
TASK_ASSIGN = "task_assign" # 派发任务
|
||||
TASK_CANCEL = "task_cancel" # 取消任务
|
||||
|
||||
# 双向
|
||||
ERROR = "error" # 错误消息
|
||||
|
||||
|
||||
# ────────────────────────── 任务状态 ──────────────────────────
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""任务生命周期状态。"""
|
||||
PENDING = "pending" # 已创建,等待派发
|
||||
DISPATCHED = "dispatched" # 已派发给 Worker
|
||||
RUNNING = "running" # Worker 正在执行
|
||||
SUCCESS = "success" # 执行成功
|
||||
FAILED = "failed" # 执行失败
|
||||
CANCELLED = "cancelled" # 已取消
|
||||
|
||||
|
||||
# ────────────────────────── 任务类型 ──────────────────────────
|
||||
|
||||
class TaskType(str, Enum):
|
||||
"""可扩展的任务类型。新增任务在此追加即可。"""
|
||||
BOSS_RECRUIT = "boss_recruit" # BOSS 直聘招聘流程
|
||||
|
||||
|
||||
# ────────────────────────── 辅助函数 ──────────────────────────
|
||||
|
||||
def make_msg(msg_type: MsgType, **payload) -> dict:
|
||||
"""构造一条标准 WebSocket JSON 消息。"""
|
||||
return {"type": msg_type.value, **payload}
|
||||
9
requirements.txt
Normal file
9
requirements.txt
Normal file
@@ -0,0 +1,9 @@
|
||||
# ─── 中央服务器 (server/) ───
|
||||
fastapi>=0.115.0
|
||||
uvicorn>=0.34.0
|
||||
pydantic>=2.0.0
|
||||
|
||||
# ─── Worker 代理 (worker/) ───
|
||||
websockets>=14.0
|
||||
requests>=2.31.0
|
||||
DrissionPage>=4.1.0
|
||||
1
server/__init__.py
Normal file
1
server/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
1
server/api/__init__.py
Normal file
1
server/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
107
server/api/tasks.py
Normal file
107
server/api/tasks.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
任务提交与查询 API。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from common.protocol import TaskStatus
|
||||
from server.models import TaskCreate, TaskOut
|
||||
from server.core.worker_manager import worker_manager
|
||||
from server.core.task_dispatcher import task_dispatcher
|
||||
|
||||
router = APIRouter(prefix="/api/tasks", tags=["tasks"])
|
||||
|
||||
|
||||
@router.post("", response_model=TaskOut, status_code=201)
|
||||
async def create_task(req: TaskCreate):
|
||||
"""
|
||||
提交一个新任务。
|
||||
|
||||
路由规则(优先级从高到低):
|
||||
1. 如果指定了 worker_id → 直接发到该 Worker
|
||||
2. 如果指定了 account_name → 查找拥有该浏览器的 Worker
|
||||
3. 两者都没有 → 400 错误
|
||||
"""
|
||||
# 确定目标 worker_id
|
||||
target_worker_id = req.worker_id
|
||||
|
||||
if not target_worker_id and req.account_name:
|
||||
target_worker_id = worker_manager.find_worker_by_account(req.account_name)
|
||||
if not target_worker_id:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"未找到拥有浏览器 '{req.account_name}' 的在线 Worker",
|
||||
)
|
||||
|
||||
if not target_worker_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="请指定 worker_id 或 account_name",
|
||||
)
|
||||
|
||||
# 检查 Worker 是否在线
|
||||
if not worker_manager.is_online(target_worker_id):
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Worker {target_worker_id} 不在线",
|
||||
)
|
||||
|
||||
# 创建任务
|
||||
req.worker_id = target_worker_id
|
||||
task = task_dispatcher.create_task(req)
|
||||
|
||||
# 通过 WebSocket 派发
|
||||
ws = worker_manager.get_ws(target_worker_id)
|
||||
if not ws:
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error = "Worker WebSocket 连接不存在"
|
||||
raise HTTPException(status_code=503, detail="Worker WebSocket 连接不存在")
|
||||
|
||||
success = await task_dispatcher.dispatch(task, ws.send_json)
|
||||
if not success:
|
||||
raise HTTPException(status_code=503, detail=f"任务派发失败: {task.error}")
|
||||
|
||||
# 更新 Worker 当前任务
|
||||
worker_manager.set_current_task(target_worker_id, task.task_id)
|
||||
|
||||
return _to_out(task)
|
||||
|
||||
|
||||
@router.get("", response_model=List[TaskOut])
|
||||
async def list_tasks(
|
||||
worker_id: Optional[str] = None,
|
||||
status: Optional[TaskStatus] = None,
|
||||
limit: int = 50,
|
||||
):
|
||||
"""查询任务列表,支持按 worker_id / status 过滤。"""
|
||||
tasks = task_dispatcher.list_tasks(worker_id=worker_id, status=status, limit=limit)
|
||||
return [_to_out(t) for t in tasks]
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=TaskOut)
|
||||
async def get_task(task_id: str):
|
||||
"""查询指定任务的状态和结果。"""
|
||||
task = task_dispatcher.get_task(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
|
||||
return _to_out(task)
|
||||
|
||||
|
||||
def _to_out(t) -> TaskOut:
|
||||
return TaskOut(
|
||||
task_id=t.task_id,
|
||||
task_type=t.task_type,
|
||||
status=t.status,
|
||||
worker_id=t.worker_id,
|
||||
account_name=t.account_name,
|
||||
params=t.params,
|
||||
progress=t.progress,
|
||||
result=t.result,
|
||||
error=t.error,
|
||||
created_at=t.created_at,
|
||||
updated_at=t.updated_at,
|
||||
)
|
||||
42
server/api/workers.py
Normal file
42
server/api/workers.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Worker 查询 API。
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import List
|
||||
|
||||
from server.models import WorkerOut
|
||||
from server.core.worker_manager import worker_manager
|
||||
|
||||
router = APIRouter(prefix="/api/workers", tags=["workers"])
|
||||
|
||||
|
||||
@router.get("", response_model=List[WorkerOut])
|
||||
async def list_workers():
|
||||
"""获取所有已注册的 Worker(含在线状态与浏览器列表)。"""
|
||||
workers = worker_manager.get_all_workers()
|
||||
return [
|
||||
WorkerOut(
|
||||
worker_id=w.worker_id,
|
||||
worker_name=w.worker_name,
|
||||
browsers=w.browsers,
|
||||
online=w.online,
|
||||
current_task_id=w.current_task_id,
|
||||
)
|
||||
for w in workers
|
||||
]
|
||||
|
||||
|
||||
@router.get("/{worker_id}", response_model=WorkerOut)
|
||||
async def get_worker(worker_id: str):
|
||||
"""获取指定 Worker 的详情。"""
|
||||
w = worker_manager.get_worker(worker_id)
|
||||
if not w:
|
||||
raise HTTPException(status_code=404, detail=f"Worker {worker_id} 不存在")
|
||||
return WorkerOut(
|
||||
worker_id=w.worker_id,
|
||||
worker_name=w.worker_name,
|
||||
browsers=w.browsers,
|
||||
online=w.online,
|
||||
current_task_id=w.current_task_id,
|
||||
)
|
||||
18
server/config.py
Normal file
18
server/config.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
服务器配置。
|
||||
可通过环境变量或直接修改此文件调整。
|
||||
"""
|
||||
import os
|
||||
|
||||
# ─── 服务 ───
|
||||
HOST: str = os.getenv("SERVER_HOST", "0.0.0.0")
|
||||
PORT: int = int(os.getenv("SERVER_PORT", "8000"))
|
||||
|
||||
# ─── WebSocket ───
|
||||
WS_PATH: str = "/ws" # Worker 连接端点
|
||||
HEARTBEAT_INTERVAL: int = 30 # 期望 Worker 心跳间隔(秒)
|
||||
HEARTBEAT_TIMEOUT: int = 90 # 超时未收到心跳视为离线(秒)
|
||||
|
||||
# ─── 安全(可选) ───
|
||||
API_TOKEN: str = os.getenv("API_TOKEN", "") # 非空时校验 Header: Authorization: Bearer <token>
|
||||
1
server/core/__init__.py
Normal file
1
server/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
117
server/core/task_dispatcher.py
Normal file
117
server/core/task_dispatcher.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
任务路由与派发。
|
||||
根据请求中的 worker_id / account_name 找到目标 Worker,通过 WebSocket 下发任务。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from common.protocol import MsgType, TaskStatus, make_msg
|
||||
from server.models import TaskCreate, TaskInfo
|
||||
|
||||
logger = logging.getLogger("server.task_dispatcher")
|
||||
|
||||
|
||||
class TaskDispatcher:
|
||||
"""管理任务生命周期并派发给 Worker。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# task_id → TaskInfo
|
||||
self._tasks: Dict[str, TaskInfo] = {}
|
||||
|
||||
# ─── 创建任务 ───
|
||||
|
||||
def create_task(self, req: TaskCreate) -> TaskInfo:
|
||||
task = TaskInfo(
|
||||
task_type=req.task_type,
|
||||
worker_id=req.worker_id,
|
||||
account_name=req.account_name,
|
||||
params=req.params,
|
||||
)
|
||||
self._tasks[task.task_id] = task
|
||||
logger.info("任务创建: %s type=%s worker=%s account=%s",
|
||||
task.task_id, task.task_type, task.worker_id, task.account_name)
|
||||
return task
|
||||
|
||||
# ─── 派发 ───
|
||||
|
||||
async def dispatch(self, task: TaskInfo, ws_send) -> bool:
|
||||
"""
|
||||
将任务通过 WebSocket 发给目标 Worker。
|
||||
ws_send: 异步可调用,接受一个 dict 参数。
|
||||
返回是否发送成功。
|
||||
"""
|
||||
msg = make_msg(
|
||||
MsgType.TASK_ASSIGN,
|
||||
task_id=task.task_id,
|
||||
task_type=task.task_type.value,
|
||||
account_name=task.account_name or "",
|
||||
params=task.params,
|
||||
)
|
||||
try:
|
||||
await ws_send(msg)
|
||||
task.status = TaskStatus.DISPATCHED
|
||||
task.updated_at = time.time()
|
||||
logger.info("任务 %s 已派发", task.task_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error = f"派发失败: {e}"
|
||||
task.updated_at = time.time()
|
||||
logger.error("任务 %s 派发失败: %s", task.task_id, e)
|
||||
return False
|
||||
|
||||
# ─── 更新状态 ───
|
||||
|
||||
def update_progress(self, task_id: str, progress: str) -> None:
|
||||
task = self._tasks.get(task_id)
|
||||
if task:
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.progress = progress
|
||||
task.updated_at = time.time()
|
||||
|
||||
def complete_task(self, task_id: str, result=None, error: str = None) -> None:
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
return
|
||||
if error:
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error = error
|
||||
else:
|
||||
task.status = TaskStatus.SUCCESS
|
||||
task.result = result
|
||||
task.updated_at = time.time()
|
||||
logger.info("任务 %s 完成: status=%s", task_id, task.status)
|
||||
|
||||
def cancel_task(self, task_id: str) -> None:
|
||||
task = self._tasks.get(task_id)
|
||||
if task and task.status in (TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING):
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.updated_at = time.time()
|
||||
|
||||
# ─── 查询 ───
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[TaskInfo]:
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
def list_tasks(
|
||||
self,
|
||||
worker_id: Optional[str] = None,
|
||||
status: Optional[TaskStatus] = None,
|
||||
limit: int = 50,
|
||||
) -> List[TaskInfo]:
|
||||
result = list(self._tasks.values())
|
||||
if worker_id:
|
||||
result = [t for t in result if t.worker_id == worker_id]
|
||||
if status:
|
||||
result = [t for t in result if t.status == status]
|
||||
# 按创建时间倒序
|
||||
result.sort(key=lambda t: t.created_at, reverse=True)
|
||||
return result[:limit]
|
||||
|
||||
|
||||
# 全局单例
|
||||
task_dispatcher = TaskDispatcher()
|
||||
130
server/core/worker_manager.py
Normal file
130
server/core/worker_manager.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Worker 注册、状态管理、账号 → Worker 映射。
|
||||
全部在内存中,服务重启后 Worker 重新连接即恢复。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Optional
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from server.config import HEARTBEAT_TIMEOUT
|
||||
from server.models import BrowserProfile, WorkerInfo
|
||||
|
||||
logger = logging.getLogger("server.worker_manager")
|
||||
|
||||
|
||||
class WorkerManager:
|
||||
"""管理所有已连接的 Worker。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# worker_id → WorkerInfo
|
||||
self._workers: Dict[str, WorkerInfo] = {}
|
||||
# worker_id → WebSocket 实例
|
||||
self._connections: Dict[str, WebSocket] = {}
|
||||
# account_name(lower) → worker_id 快速路由表
|
||||
self._account_map: Dict[str, str] = {}
|
||||
|
||||
# ─── 注册 / 注销 ───
|
||||
|
||||
def register(
|
||||
self,
|
||||
ws: WebSocket,
|
||||
worker_id: str,
|
||||
worker_name: str,
|
||||
browsers: list[dict],
|
||||
) -> WorkerInfo:
|
||||
profiles = [BrowserProfile(**b) for b in browsers]
|
||||
info = WorkerInfo(
|
||||
worker_id=worker_id,
|
||||
worker_name=worker_name,
|
||||
browsers=profiles,
|
||||
online=True,
|
||||
last_heartbeat=time.time(),
|
||||
connected_at=time.time(),
|
||||
)
|
||||
self._workers[worker_id] = info
|
||||
self._connections[worker_id] = ws
|
||||
self._rebuild_account_map()
|
||||
logger.info("Worker 注册: %s (%s), 浏览器 %d 个", worker_id, worker_name, len(profiles))
|
||||
return info
|
||||
|
||||
def unregister(self, worker_id: str) -> None:
|
||||
self._workers.pop(worker_id, None)
|
||||
self._connections.pop(worker_id, None)
|
||||
self._rebuild_account_map()
|
||||
logger.info("Worker 注销: %s", worker_id)
|
||||
|
||||
# ─── 心跳 ───
|
||||
|
||||
def heartbeat(self, worker_id: str) -> None:
|
||||
info = self._workers.get(worker_id)
|
||||
if info:
|
||||
info.last_heartbeat = time.time()
|
||||
info.online = True
|
||||
|
||||
# ─── 浏览器列表更新 ───
|
||||
|
||||
def update_browsers(self, worker_id: str, browsers: list[dict]) -> None:
|
||||
info = self._workers.get(worker_id)
|
||||
if info:
|
||||
info.browsers = [BrowserProfile(**b) for b in browsers]
|
||||
self._rebuild_account_map()
|
||||
logger.info("Worker %s 浏览器列表更新: %d 个", worker_id, len(info.browsers))
|
||||
|
||||
# ─── 查询 ───
|
||||
|
||||
def get_worker(self, worker_id: str) -> Optional[WorkerInfo]:
|
||||
return self._workers.get(worker_id)
|
||||
|
||||
def get_all_workers(self) -> list[WorkerInfo]:
|
||||
return list(self._workers.values())
|
||||
|
||||
def get_ws(self, worker_id: str) -> Optional[WebSocket]:
|
||||
return self._connections.get(worker_id)
|
||||
|
||||
def find_worker_by_account(self, account_name: str) -> Optional[str]:
|
||||
"""按浏览器窗口 name 查找对应的 worker_id。"""
|
||||
return self._account_map.get(account_name.lower())
|
||||
|
||||
def is_online(self, worker_id: str) -> bool:
|
||||
info = self._workers.get(worker_id)
|
||||
return info is not None and info.online
|
||||
|
||||
# ─── 任务占用 ───
|
||||
|
||||
def set_current_task(self, worker_id: str, task_id: Optional[str]) -> None:
|
||||
info = self._workers.get(worker_id)
|
||||
if info:
|
||||
info.current_task_id = task_id
|
||||
|
||||
# ─── 定时巡检(检测超时离线) ───
|
||||
|
||||
async def check_heartbeats_loop(self, interval: int = 15) -> None:
|
||||
"""后台协程,定期检查心跳超时的 Worker。"""
|
||||
while True:
|
||||
await asyncio.sleep(interval)
|
||||
now = time.time()
|
||||
for wid, info in list(self._workers.items()):
|
||||
if info.online and (now - info.last_heartbeat) > HEARTBEAT_TIMEOUT:
|
||||
info.online = False
|
||||
logger.warning("Worker %s 心跳超时,标记为离线", wid)
|
||||
|
||||
# ─── 内部 ───
|
||||
|
||||
def _rebuild_account_map(self) -> None:
|
||||
self._account_map.clear()
|
||||
for wid, info in self._workers.items():
|
||||
if not info.online:
|
||||
continue
|
||||
for b in info.browsers:
|
||||
if b.name:
|
||||
self._account_map[b.name.lower()] = wid
|
||||
|
||||
|
||||
# 全局单例
|
||||
worker_manager = WorkerManager()
|
||||
150
server/main.py
Normal file
150
server/main.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
中央服务器入口。
|
||||
启动方式: python -m server.main
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from common.protocol import MsgType, make_msg
|
||||
from server import config
|
||||
from server.api.workers import router as workers_router
|
||||
from server.api.tasks import router as tasks_router
|
||||
from server.core.worker_manager import worker_manager
|
||||
from server.core.task_dispatcher import task_dispatcher
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(name)-28s %(levelname)-5s %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger("server.main")
|
||||
|
||||
# ────────────────────────── FastAPI App ──────────────────────────
|
||||
|
||||
app = FastAPI(title="Browser Control Server", version="1.0.0")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(workers_router)
|
||||
app.include_router(tasks_router)
|
||||
|
||||
|
||||
# ────────────────────────── 健康检查 ──────────────────────────
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "workers_online": len([w for w in worker_manager.get_all_workers() if w.online])}
|
||||
|
||||
|
||||
# ────────────────────────── WebSocket 端点 ──────────────────────────
|
||||
|
||||
@app.websocket(config.WS_PATH)
|
||||
async def ws_endpoint(ws: WebSocket):
|
||||
"""
|
||||
Worker 连接流程:
|
||||
1. 建立连接
|
||||
2. 等待第一条 register 消息
|
||||
3. 持续收发消息(心跳、任务进度、任务结果等)
|
||||
4. 断开时注销
|
||||
"""
|
||||
await ws.accept()
|
||||
worker_id: str | None = None
|
||||
|
||||
try:
|
||||
# ── 等待注册 ──
|
||||
raw = await asyncio.wait_for(ws.receive_json(), timeout=30)
|
||||
if raw.get("type") != MsgType.REGISTER.value:
|
||||
await ws.send_json(make_msg(MsgType.ERROR, detail="首条消息必须是 register"))
|
||||
await ws.close(code=4001)
|
||||
return
|
||||
|
||||
worker_id = raw.get("worker_id", "")
|
||||
worker_name = raw.get("worker_name", worker_id)
|
||||
browsers = raw.get("browsers", [])
|
||||
|
||||
if not worker_id:
|
||||
await ws.send_json(make_msg(MsgType.ERROR, detail="worker_id 不能为空"))
|
||||
await ws.close(code=4002)
|
||||
return
|
||||
|
||||
worker_manager.register(ws, worker_id, worker_name, browsers)
|
||||
await ws.send_json(make_msg(MsgType.REGISTER_ACK, worker_id=worker_id))
|
||||
logger.info("Worker %s 已连接", worker_id)
|
||||
|
||||
# ── 消息循环 ──
|
||||
while True:
|
||||
data = await ws.receive_json()
|
||||
msg_type = data.get("type", "")
|
||||
|
||||
if msg_type == MsgType.HEARTBEAT.value:
|
||||
worker_manager.heartbeat(worker_id)
|
||||
await ws.send_json(make_msg(MsgType.HEARTBEAT_ACK))
|
||||
|
||||
elif msg_type == MsgType.BROWSER_LIST_UPDATE.value:
|
||||
worker_manager.update_browsers(worker_id, data.get("browsers", []))
|
||||
|
||||
elif msg_type == MsgType.TASK_PROGRESS.value:
|
||||
task_id = data.get("task_id", "")
|
||||
progress = data.get("progress", "")
|
||||
task_dispatcher.update_progress(task_id, progress)
|
||||
logger.info("任务 %s 进度: %s", task_id, progress)
|
||||
|
||||
elif msg_type == MsgType.TASK_RESULT.value:
|
||||
task_id = data.get("task_id", "")
|
||||
result = data.get("result")
|
||||
error = data.get("error")
|
||||
task_dispatcher.complete_task(task_id, result=result, error=error)
|
||||
# 释放 Worker 任务占用
|
||||
worker_manager.set_current_task(worker_id, None)
|
||||
logger.info("任务 %s 已完成", task_id)
|
||||
|
||||
else:
|
||||
logger.warning("未知消息类型: %s (from %s)", msg_type, worker_id)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Worker %s WebSocket 断开", worker_id or "unknown")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("WebSocket 连接超时(未在 30 秒内注册)")
|
||||
await ws.close(code=4003)
|
||||
except Exception as e:
|
||||
logger.error("WebSocket 处理异常: %s", e, exc_info=True)
|
||||
finally:
|
||||
if worker_id:
|
||||
worker_manager.unregister(worker_id)
|
||||
|
||||
|
||||
# ────────────────────────── 生命周期 ──────────────────────────
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
# 启动心跳巡检后台任务
|
||||
asyncio.create_task(worker_manager.check_heartbeats_loop())
|
||||
logger.info("服务器启动: http://%s:%s", config.HOST, config.PORT)
|
||||
|
||||
|
||||
# ────────────────────────── 入口 ──────────────────────────
|
||||
|
||||
def main():
|
||||
uvicorn.run(
|
||||
"server.main:app",
|
||||
host=config.HOST,
|
||||
port=config.PORT,
|
||||
log_level="info",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
82
server/models.py
Normal file
82
server/models.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Pydantic 数据模型 —— 用于 REST API 请求 / 响应以及内部状态。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from common.protocol import TaskStatus, TaskType
|
||||
|
||||
|
||||
# ────────────────────────── Worker ──────────────────────────
|
||||
|
||||
class BrowserProfile(BaseModel):
|
||||
"""比特浏览器窗口信息(Worker 上报)。"""
|
||||
id: str
|
||||
name: str = ""
|
||||
remark: str = ""
|
||||
|
||||
|
||||
class WorkerInfo(BaseModel):
|
||||
"""一台 Worker 的运行时信息(内存中保存)。"""
|
||||
worker_id: str
|
||||
worker_name: str = ""
|
||||
browsers: List[BrowserProfile] = []
|
||||
online: bool = True
|
||||
last_heartbeat: float = Field(default_factory=time.time)
|
||||
connected_at: float = Field(default_factory=time.time)
|
||||
current_task_id: Optional[str] = None
|
||||
|
||||
|
||||
class WorkerOut(BaseModel):
|
||||
"""返回给前端的 Worker 信息。"""
|
||||
worker_id: str
|
||||
worker_name: str
|
||||
browsers: List[BrowserProfile]
|
||||
online: bool
|
||||
current_task_id: Optional[str] = None
|
||||
|
||||
|
||||
# ────────────────────────── Task ──────────────────────────
|
||||
|
||||
class TaskCreate(BaseModel):
|
||||
"""前端提交任务的请求体。"""
|
||||
task_type: TaskType
|
||||
worker_id: Optional[str] = None # 手动指定机器
|
||||
account_name: Optional[str] = None # 按比特浏览器窗口名自动路由
|
||||
params: Dict[str, Any] = {} # 任务参数(如 job_title, max_greet 等)
|
||||
|
||||
|
||||
class TaskInfo(BaseModel):
|
||||
"""任务完整信息(内存 / 返回前端)。"""
|
||||
task_id: str = Field(default_factory=lambda: uuid.uuid4().hex[:12])
|
||||
task_type: TaskType
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
worker_id: Optional[str] = None
|
||||
account_name: Optional[str] = None
|
||||
params: Dict[str, Any] = {}
|
||||
progress: Optional[str] = None # 最新进度描述
|
||||
result: Any = None # 最终结果
|
||||
error: Optional[str] = None
|
||||
created_at: float = Field(default_factory=time.time)
|
||||
updated_at: float = Field(default_factory=time.time)
|
||||
|
||||
|
||||
class TaskOut(BaseModel):
|
||||
"""返回给前端的任务信息。"""
|
||||
task_id: str
|
||||
task_type: TaskType
|
||||
status: TaskStatus
|
||||
worker_id: Optional[str] = None
|
||||
account_name: Optional[str] = None
|
||||
params: Dict[str, Any] = {}
|
||||
progress: Optional[str] = None
|
||||
result: Any = None
|
||||
error: Optional[str] = None
|
||||
created_at: float
|
||||
updated_at: float
|
||||
1
worker/__init__.py
Normal file
1
worker/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
118
worker/bit_browser.py
Normal file
118
worker/bit_browser.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
比特浏览器 API 封装。
|
||||
提供:列出窗口、打开窗口、关闭窗口,返回 CDP 地址供 DrissionPage 连接。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger("worker.bit_browser")
|
||||
|
||||
|
||||
class BitBrowserAPI:
|
||||
"""比特浏览器本地 API 客户端。"""
|
||||
|
||||
def __init__(self, base_url: str = "http://127.0.0.1:54345") -> None:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
|
||||
def _post(self, path: str, data: dict, timeout: int = 15) -> dict:
|
||||
url = f"{self.base_url}{path}"
|
||||
r = requests.post(url, json=data, timeout=timeout)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
# ─── 列出浏览器窗口 ───
|
||||
|
||||
def list_browsers(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
remark: Optional[str] = None,
|
||||
page: int = 0,
|
||||
page_size: int = 100,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取比特浏览器窗口列表。可按 name / remark 模糊筛选。"""
|
||||
data: Dict[str, Any] = {"page": page, "pageSize": page_size}
|
||||
if name is not None:
|
||||
data["name"] = name
|
||||
if remark is not None:
|
||||
data["remark"] = remark
|
||||
|
||||
res = self._post("/browser/list", data)
|
||||
if not res.get("success"):
|
||||
raise RuntimeError(f"list 失败: {res.get('msg', res)}")
|
||||
|
||||
list_data = res.get("data")
|
||||
if not list_data:
|
||||
return []
|
||||
if isinstance(list_data, list):
|
||||
return list_data
|
||||
if isinstance(list_data, dict) and "list" in list_data:
|
||||
return list_data["list"]
|
||||
return []
|
||||
|
||||
# ─── 打开浏览器窗口 ───
|
||||
|
||||
def open_browser(
|
||||
self,
|
||||
browser_id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
remark: Optional[str] = None,
|
||||
) -> Tuple[str, int, str]:
|
||||
"""
|
||||
打开指定浏览器窗口。
|
||||
返回 (cdp_addr, port, browser_id)。
|
||||
"""
|
||||
if browser_id is None:
|
||||
items = self.list_browsers(name=name, remark=remark, page_size=10)
|
||||
if not items:
|
||||
raise RuntimeError("没有匹配的浏览器窗口")
|
||||
browser_id = items[0].get("id")
|
||||
if not browser_id:
|
||||
raise RuntimeError("列表项中没有 id")
|
||||
|
||||
res = self._post("/browser/open", {"id": browser_id})
|
||||
if not res.get("success"):
|
||||
raise RuntimeError(f"open 失败: {res.get('msg', res)}")
|
||||
|
||||
data = res.get("data") or {}
|
||||
http_addr = data.get("http")
|
||||
if not http_addr:
|
||||
raise RuntimeError(f"返回中无 http 地址: {data}")
|
||||
|
||||
if ":" in str(http_addr):
|
||||
host, port_str = str(http_addr).strip().rsplit(":", 1)
|
||||
port = int(port_str)
|
||||
cdp_addr = f"{host}:{port}"
|
||||
else:
|
||||
port = int(http_addr)
|
||||
cdp_addr = f"127.0.0.1:{port}"
|
||||
|
||||
logger.info("已打开浏览器 %s, CDP: %s", browser_id, cdp_addr)
|
||||
return cdp_addr, port, browser_id
|
||||
|
||||
# ─── 关闭浏览器窗口 ───
|
||||
|
||||
def close_browser(self, browser_id: str) -> bool:
|
||||
"""关闭指定浏览器窗口。"""
|
||||
try:
|
||||
res = self._post("/browser/close", {"id": browser_id})
|
||||
return res.get("success", False)
|
||||
except Exception as e:
|
||||
logger.warning("关闭浏览器 %s 失败: %s", browser_id, e)
|
||||
return False
|
||||
|
||||
# ─── 便捷方法:获取供 DrissionPage 连接的信息 ───
|
||||
|
||||
def get_browser_for_drission(
|
||||
self,
|
||||
browser_id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
remark: Optional[str] = None,
|
||||
) -> Tuple[str, int]:
|
||||
"""按需打开浏览器,返回 (addr, port) 供 DrissionPage 使用。"""
|
||||
cdp_addr, port, _ = self.open_browser(browser_id=browser_id, name=name, remark=remark)
|
||||
return cdp_addr, port
|
||||
119
worker/browser_control.py
Normal file
119
worker/browser_control.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
DrissionPage 浏览器控制基础封装。
|
||||
提供连接浏览器、拟人化操作等通用能力,供各任务处理器复用。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger("worker.browser_control")
|
||||
|
||||
try:
|
||||
from DrissionPage import Chromium
|
||||
except ImportError:
|
||||
Chromium = None # type: ignore
|
||||
|
||||
|
||||
# ─── 拟人化参数 ───
|
||||
|
||||
HUMAN_DELAY_CLICK = (0.2, 0.6)
|
||||
HUMAN_DELAY_BETWEEN = (0.1, 0.4)
|
||||
HUMAN_MOVE_DURATION = (0.25, 0.7)
|
||||
HUMAN_CLICK_OFFSET = 12
|
||||
|
||||
|
||||
def ensure_drission():
|
||||
"""确保 DrissionPage 已安装。"""
|
||||
if Chromium is None:
|
||||
raise RuntimeError("请先安装: pip install DrissionPage")
|
||||
|
||||
|
||||
def connect_browser(port: int = None, addr: str = None):
|
||||
"""使用 DrissionPage 连接浏览器。"""
|
||||
ensure_drission()
|
||||
if port is not None:
|
||||
return Chromium(port)
|
||||
if addr:
|
||||
return Chromium(addr)
|
||||
raise ValueError("需要 port 或 addr 以连接浏览器")
|
||||
|
||||
|
||||
def human_delay(low: float = 0.2, high: float = 0.6):
|
||||
"""随机延迟,模拟人类操作间隔。"""
|
||||
time.sleep(random.uniform(low, high))
|
||||
|
||||
|
||||
def human_click(tab, ele, scroll_first: bool = True) -> bool:
|
||||
"""
|
||||
拟人化点击:先滚动到可见(可选),再动作链「移动 → 短暂停顿 → 点击」,
|
||||
移动带随机耗时与元素内随机偏移。
|
||||
"""
|
||||
if ele is None:
|
||||
return False
|
||||
try:
|
||||
if scroll_first and hasattr(ele, "scroll") and getattr(ele.scroll, "to_see", None):
|
||||
ele.scroll.to_see()
|
||||
human_delay(0.15, 0.4)
|
||||
ox = random.randint(-HUMAN_CLICK_OFFSET, HUMAN_CLICK_OFFSET)
|
||||
oy = random.randint(-HUMAN_CLICK_OFFSET, HUMAN_CLICK_OFFSET)
|
||||
duration = random.uniform(*HUMAN_MOVE_DURATION)
|
||||
tab.actions.move_to(ele, offset_x=ox, offset_y=oy, duration=duration)
|
||||
tab.actions.wait(*HUMAN_DELAY_BETWEEN)
|
||||
tab.actions.click()
|
||||
return True
|
||||
except Exception:
|
||||
try:
|
||||
ele.click()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def safe_click(tab, ele) -> bool:
|
||||
"""多种方式尝试点击元素(拟人 → 普通 → JS),确保触发成功。"""
|
||||
if ele is None:
|
||||
return False
|
||||
# 方式1: 拟人点击
|
||||
if human_click(tab, ele):
|
||||
return True
|
||||
# 方式2: 普通模拟点击
|
||||
try:
|
||||
ele.click(by_js=False)
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
# 方式3: JS 点击
|
||||
try:
|
||||
ele.click(by_js=True)
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def find_element(tab, selectors: list[str], timeout: int = 3):
|
||||
"""按优先级尝试多个选择器,返回第一个找到的元素。"""
|
||||
for sel in selectors:
|
||||
try:
|
||||
ele = tab.ele(sel, timeout=timeout)
|
||||
if ele:
|
||||
return ele
|
||||
except Exception:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def find_elements(tab, selectors: list[str], timeout: int = 3) -> list:
|
||||
"""按优先级尝试多个选择器,返回第一组找到的元素列表。"""
|
||||
for sel in selectors:
|
||||
try:
|
||||
eles = tab.eles(sel, timeout=timeout)
|
||||
if eles and len(eles) > 0:
|
||||
return eles
|
||||
except Exception:
|
||||
continue
|
||||
return []
|
||||
20
worker/config.py
Normal file
20
worker/config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Worker 配置。通过命令行参数或环境变量设置。
|
||||
"""
|
||||
import os
|
||||
|
||||
# ─── 中央服务器 ───
|
||||
SERVER_WS_URL: str = os.getenv("SERVER_WS_URL", "ws://127.0.0.1:8000/ws")
|
||||
|
||||
# ─── Worker 标识 ───
|
||||
WORKER_ID: str = os.getenv("WORKER_ID", "worker-1")
|
||||
WORKER_NAME: str = os.getenv("WORKER_NAME", "本机")
|
||||
|
||||
# ─── 比特浏览器 ───
|
||||
BIT_API_BASE: str = os.getenv("BIT_API_BASE", "http://127.0.0.1:54345")
|
||||
|
||||
# ─── WebSocket ───
|
||||
HEARTBEAT_INTERVAL: int = 25 # 心跳发送间隔(秒)
|
||||
RECONNECT_DELAY: int = 5 # 断线重连等待(秒)
|
||||
RECONNECT_MAX_DELAY: int = 60 # 重连最大等待(秒,指数退避上限)
|
||||
84
worker/main.py
Normal file
84
worker/main.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Worker 启动入口。
|
||||
启动方式: python -m worker.main [--server ws://IP:8000/ws] [--worker-id pc-a] [--worker-name 电脑A]
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from worker import config
|
||||
from worker.tasks.registry import register_all_handlers
|
||||
from worker.ws_client import WorkerWSClient
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(name)-28s %(levelname)-5s %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger("worker.main")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Browser Control Worker Agent")
|
||||
parser.add_argument(
|
||||
"--server",
|
||||
default=config.SERVER_WS_URL,
|
||||
help=f"中央服务器 WebSocket 地址 (默认: {config.SERVER_WS_URL})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--worker-id",
|
||||
default=config.WORKER_ID,
|
||||
help=f"Worker ID (默认: {config.WORKER_ID})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--worker-name",
|
||||
default=config.WORKER_NAME,
|
||||
help=f"Worker 名称 (默认: {config.WORKER_NAME})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bit-api",
|
||||
default=config.BIT_API_BASE,
|
||||
help=f"比特浏览器本地 API 地址 (默认: {config.BIT_API_BASE})",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def run(args):
|
||||
# 注册所有任务处理器
|
||||
register_all_handlers()
|
||||
logger.info("已注册任务处理器")
|
||||
|
||||
# 创建 WebSocket 客户端并运行
|
||||
client = WorkerWSClient(
|
||||
server_url=args.server,
|
||||
worker_id=args.worker_id,
|
||||
worker_name=args.worker_name,
|
||||
bit_api_base=args.bit_api,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Worker 启动: id=%s, name=%s, server=%s",
|
||||
args.worker_id, args.worker_name, args.server,
|
||||
)
|
||||
|
||||
try:
|
||||
await client.run()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("收到中断信号,正在退出...")
|
||||
await client.stop()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
try:
|
||||
asyncio.run(run(args))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Worker 已退出")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
1
worker/tasks/__init__.py
Normal file
1
worker/tasks/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
51
worker/tasks/base.py
Normal file
51
worker/tasks/base.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
任务处理器基类。
|
||||
所有具体任务(如 BOSS 招聘)都继承此类,实现 execute 方法。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Coroutine, Dict, Optional
|
||||
|
||||
|
||||
class BaseTaskHandler(ABC):
|
||||
"""
|
||||
任务处理器基类。
|
||||
|
||||
子类需实现:
|
||||
execute(task_id, params, progress_callback) -> result
|
||||
|
||||
progress_callback 是一个异步函数:
|
||||
await progress_callback(task_id, "当前进度描述")
|
||||
"""
|
||||
|
||||
# 子类覆盖,声明处理哪个 task_type
|
||||
task_type: str = ""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.logger = logging.getLogger(f"worker.tasks.{self.task_type or self.__class__.__name__}")
|
||||
|
||||
@abstractmethod
|
||||
async def execute(
|
||||
self,
|
||||
task_id: str,
|
||||
params: Dict[str, Any],
|
||||
progress_cb: Callable[[str, str], Coroutine],
|
||||
) -> Any:
|
||||
"""
|
||||
执行任务。
|
||||
|
||||
Args:
|
||||
task_id: 任务 ID
|
||||
params: 任务参数
|
||||
progress_cb: 进度上报回调 async def(task_id, progress_text)
|
||||
|
||||
Returns:
|
||||
任务结果(可序列化为 JSON)
|
||||
|
||||
Raises:
|
||||
Exception: 执行失败时抛出异常
|
||||
"""
|
||||
...
|
||||
335
worker/tasks/boss_recruit.py
Normal file
335
worker/tasks/boss_recruit.py
Normal file
@@ -0,0 +1,335 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
BOSS 直聘招聘任务处理器。
|
||||
复用并重构原 boss_drission.py 的核心流程。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Callable, Coroutine, Dict, List
|
||||
|
||||
from common.protocol import TaskType
|
||||
from worker.tasks.base import BaseTaskHandler
|
||||
from worker.bit_browser import BitBrowserAPI
|
||||
from worker.browser_control import (
|
||||
connect_browser,
|
||||
find_element,
|
||||
find_elements,
|
||||
human_click,
|
||||
human_delay,
|
||||
safe_click,
|
||||
)
|
||||
|
||||
# ─── 常量 ───
|
||||
CHAT_INDEX_URL = "https://www.zhipin.com/web/chat/index"
|
||||
ACTION_DELAY = 1.5
|
||||
BETWEEN_CHAT_DELAY = 2.5
|
||||
|
||||
|
||||
class BossRecruitHandler(BaseTaskHandler):
|
||||
"""BOSS 直聘招聘自动化任务。"""
|
||||
|
||||
task_type = TaskType.BOSS_RECRUIT.value
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
task_id: str,
|
||||
params: Dict[str, Any],
|
||||
progress_cb: Callable[[str, str], Coroutine],
|
||||
) -> Any:
|
||||
"""
|
||||
执行 BOSS 招聘流程。
|
||||
|
||||
params:
|
||||
- job_title: str 招聘岗位名称
|
||||
- max_greet: int 最大打招呼人数(默认 5)
|
||||
- account_name: str 比特浏览器窗口名(用于打开浏览器)
|
||||
- account_id: str 比特浏览器窗口 ID(可选,优先级高于 name)
|
||||
- bit_api_base: str 比特浏览器 API 地址(可选)
|
||||
"""
|
||||
job_title = params.get("job_title", "相关岗位")
|
||||
max_greet = params.get("max_greet", 5)
|
||||
account_name = params.get("account_name", "")
|
||||
account_id = params.get("account_id", "")
|
||||
bit_api_base = params.get("bit_api_base", "http://127.0.0.1:54345")
|
||||
|
||||
await progress_cb(task_id, "正在打开比特浏览器...")
|
||||
|
||||
# 在线程池中执行同步的浏览器操作(DrissionPage 是同步库)
|
||||
result = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
self._run_sync,
|
||||
task_id, job_title, max_greet, account_name, account_id, bit_api_base, progress_cb,
|
||||
)
|
||||
return result
|
||||
|
||||
def _run_sync(
|
||||
self,
|
||||
task_id: str,
|
||||
job_title: str,
|
||||
max_greet: int,
|
||||
account_name: str,
|
||||
account_id: str,
|
||||
bit_api_base: str,
|
||||
progress_cb: Callable,
|
||||
) -> dict:
|
||||
"""同步执行浏览器自动化(在线程池中运行)。"""
|
||||
# 1. 打开比特浏览器
|
||||
bit_api = BitBrowserAPI(bit_api_base)
|
||||
addr, port = bit_api.get_browser_for_drission(
|
||||
browser_id=account_id or None,
|
||||
name=account_name or None,
|
||||
)
|
||||
self.logger.info("已打开浏览器, CDP: %s (port=%d)", addr, port)
|
||||
|
||||
# 2. 连接浏览器
|
||||
browser = connect_browser(port=port)
|
||||
tab = browser.latest_tab
|
||||
|
||||
# 3. 打开 BOSS 直聘聊天页
|
||||
tab.get(CHAT_INDEX_URL)
|
||||
tab.wait.load_start()
|
||||
human_delay(2.5, 4.0)
|
||||
|
||||
# 4. 执行招聘流程
|
||||
collected = self._recruit_flow(tab, job_title, max_greet)
|
||||
|
||||
return {
|
||||
"job_title": job_title,
|
||||
"total_processed": len(collected),
|
||||
"wechat_collected": sum(1 for c in collected if c.get("wechat")),
|
||||
"details": collected,
|
||||
}
|
||||
|
||||
def _recruit_flow(self, tab, job_title: str, max_greet: int) -> List[dict]:
|
||||
"""核心招聘流程:遍历聊天列表,打招呼、询问微信号、收集结果。"""
|
||||
greeting = f"您好,我们正在招【{job_title}】,看到您的经历比较匹配,方便简单聊聊吗?"
|
||||
ask_wechat = "后续沟通会更及时,您方便留一下您的微信号吗?我这边加您。"
|
||||
collected = []
|
||||
|
||||
# 获取左侧会话列表
|
||||
items = self._get_conversation_items(tab)
|
||||
if not items:
|
||||
self.logger.warning("未找到会话列表元素")
|
||||
return collected
|
||||
|
||||
total = min(len(items), max_greet)
|
||||
self.logger.info("会话数约 %d,本次处理前 %d 个", len(items), total)
|
||||
|
||||
for i in range(total):
|
||||
try:
|
||||
human_delay(max(1.8, BETWEEN_CHAT_DELAY - 0.7), BETWEEN_CHAT_DELAY + 1.0)
|
||||
|
||||
items = self._get_conversation_items(tab)
|
||||
if i >= len(items):
|
||||
break
|
||||
item = items[i]
|
||||
|
||||
# 获取候选人名称
|
||||
name = self._get_candidate_name(item, i)
|
||||
|
||||
# 点击进入聊天
|
||||
self._click_conversation(tab, item)
|
||||
human_delay(1.2, 2.2)
|
||||
|
||||
# 等待输入框
|
||||
inp = self._wait_for_input(tab)
|
||||
if not inp:
|
||||
self.logger.info("[%s] 未进入聊天,跳过", name)
|
||||
continue
|
||||
|
||||
# 分析聊天上下文
|
||||
messages = self._get_chat_messages(tab)
|
||||
ctx = self._analyze_context(messages, job_title)
|
||||
|
||||
# 发招呼
|
||||
if not ctx["already_greeting"]:
|
||||
if self._send_message(tab, inp, greeting):
|
||||
self.logger.info("[%s] 已发送招呼", name)
|
||||
human_delay(1.5, 2.8)
|
||||
else:
|
||||
self.logger.info("[%s] 已有招呼记录,跳过", name)
|
||||
|
||||
# 询问微信号
|
||||
if not ctx["already_asked_wechat"]:
|
||||
if self._send_message(tab, inp, ask_wechat):
|
||||
self.logger.info("[%s] 已询问微信号", name)
|
||||
human_delay(1.5, 2.8)
|
||||
|
||||
# 收集微信号
|
||||
human_delay(1.0, 2.0)
|
||||
messages = self._get_chat_messages(tab)
|
||||
ctx = self._analyze_context(messages, job_title)
|
||||
wechats = ctx["wechats"][:2]
|
||||
|
||||
collected.append({
|
||||
"name": name,
|
||||
"job": job_title,
|
||||
"wechat": wechats[0] if wechats else "",
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("处理第 %d 个会话出错: %s", i + 1, e)
|
||||
continue
|
||||
|
||||
return collected
|
||||
|
||||
# ─── 辅助方法 ───
|
||||
|
||||
def _get_conversation_items(self, tab) -> list:
|
||||
selectors = [
|
||||
"css:div.chat-container div.geek-item",
|
||||
"css:div[role='listitem'] div.geek-item",
|
||||
"css:.geek-item-wrap .geek-item",
|
||||
"css:div.geek-item",
|
||||
"css:div[role='listitem']",
|
||||
]
|
||||
return find_elements(tab, selectors, timeout=3)
|
||||
|
||||
def _get_candidate_name(self, item, index: int) -> str:
|
||||
try:
|
||||
name_el = item.ele("css:.geek-name", timeout=1)
|
||||
if name_el and name_el.text:
|
||||
return name_el.text.strip()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if item.text:
|
||||
return item.text.strip()[:20]
|
||||
except Exception:
|
||||
pass
|
||||
return f"候选人{index + 1}"
|
||||
|
||||
def _click_conversation(self, tab, item) -> bool:
|
||||
if safe_click(tab, item):
|
||||
return True
|
||||
try:
|
||||
parent = item.parent()
|
||||
if parent and parent.attr("role") == "listitem":
|
||||
return safe_click(tab, parent)
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _wait_for_input(self, tab, retries: int = 6):
|
||||
for _ in range(retries):
|
||||
inp = find_element(tab, [
|
||||
"css:#boss-chat-editor-input",
|
||||
"css:.boss-chat-editor-input",
|
||||
], timeout=1)
|
||||
if inp:
|
||||
return inp
|
||||
human_delay(0.5, 0.9)
|
||||
return None
|
||||
|
||||
def _send_message(self, tab, inp, message: str) -> bool:
|
||||
try:
|
||||
human_click(tab, inp)
|
||||
human_delay(0.25, 0.55)
|
||||
try:
|
||||
inp.clear()
|
||||
except Exception:
|
||||
pass
|
||||
human_delay(0.15, 0.4)
|
||||
inp.input(message)
|
||||
human_delay(0.4, 0.9)
|
||||
if self._click_send_button(tab):
|
||||
return True
|
||||
inp.input("\n")
|
||||
human_delay(0.35, 0.7)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _click_send_button(self, tab) -> bool:
|
||||
human_delay(0.2, 0.45)
|
||||
btn = find_element(tab, [
|
||||
"css:.conversation-editor .submit-content .submit.active",
|
||||
"css:.conversation-editor div.submit.active",
|
||||
"css:div.submit-content div.submit.active",
|
||||
"css:div.submit.active",
|
||||
"css:.submit-content .submit",
|
||||
"css:div.submit",
|
||||
"text:发送",
|
||||
], timeout=1)
|
||||
if btn:
|
||||
if safe_click(tab, btn):
|
||||
return True
|
||||
# JS 兜底
|
||||
scripts = [
|
||||
"var el = document.querySelector('.conversation-editor .submit.active') || document.querySelector('.conversation-editor .submit'); if(el){ el.click(); return true; } return false;",
|
||||
"var els = document.querySelectorAll('div[class*=\"submit\"]'); for(var i=0;i<els.length;i++){ if(els[i].textContent.trim()==='发送'){ els[i].click(); return true; } } return false;",
|
||||
]
|
||||
for script in scripts:
|
||||
try:
|
||||
if tab.run_js(script) is True:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
|
||||
def _get_chat_messages(self, tab) -> List[dict]:
|
||||
result = []
|
||||
try:
|
||||
items = tab.eles("css:.message-item", timeout=2)
|
||||
if not items:
|
||||
return result
|
||||
for e in items[-50:]:
|
||||
t = (e.text or "").strip()
|
||||
if not t or "沟通的职位" in t:
|
||||
continue
|
||||
role = "friend"
|
||||
try:
|
||||
if e.ele("css:.item-boss", timeout=0):
|
||||
role = "boss"
|
||||
elif e.ele("css:.item-friend", timeout=0):
|
||||
role = "friend"
|
||||
else:
|
||||
cls = (e.attr("class") or "") + " "
|
||||
if "item-boss" in cls or ("boss" in cls and "friend" not in cls):
|
||||
role = "boss"
|
||||
except Exception:
|
||||
if any(k in t for k in ("招", "岗位", "微信号", "方便留", "加您")):
|
||||
role = "boss"
|
||||
result.append({"role": role, "text": t})
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
def _analyze_context(self, messages: list, job_title: str) -> dict:
|
||||
boss_texts = [m["text"] for m in messages if m.get("role") == "boss"]
|
||||
friend_texts = [m["text"] for m in messages if m.get("role") == "friend"]
|
||||
full_boss = " ".join(boss_texts)
|
||||
|
||||
wechats = []
|
||||
for t in friend_texts:
|
||||
wechats.extend(self._extract_wechat(t))
|
||||
wechats = list(dict.fromkeys(wechats))[:3]
|
||||
|
||||
return {
|
||||
"already_greeting": job_title in full_boss or "招" in full_boss,
|
||||
"already_asked_wechat": "微信" in full_boss or "微信号" in full_boss,
|
||||
"wechats": wechats,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_wechat(text: str) -> list:
|
||||
if not text or not text.strip():
|
||||
return []
|
||||
found = []
|
||||
patterns = [
|
||||
r"微信号[::\s]*([a-zA-Z0-9_\-]{6,20})",
|
||||
r"微信[::\s]*([a-zA-Z0-9_\-]{6,20})",
|
||||
r"wx[::\s]*([a-zA-Z0-9_\-]{6,20})",
|
||||
r"wechat[::\s]*([a-zA-Z0-9_\-]{6,20})",
|
||||
r"([a-zA-Z][a-zA-Z0-9_\-]{5,19})",
|
||||
]
|
||||
for p in patterns:
|
||||
for m in re.finditer(p, text, re.IGNORECASE):
|
||||
s = m.group(1).strip() if m.lastindex else m.group(0).strip()
|
||||
if s and s not in found and len(s) >= 6:
|
||||
found.append(s)
|
||||
return found[:3]
|
||||
44
worker/tasks/registry.py
Normal file
44
worker/tasks/registry.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
任务处理器注册表。
|
||||
Worker 启动时注册所有可用的 handler,收到任务时按 task_type 查找对应 handler。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
from worker.tasks.base import BaseTaskHandler
|
||||
|
||||
logger = logging.getLogger("worker.tasks.registry")
|
||||
|
||||
# task_type → handler 实例
|
||||
_registry: Dict[str, BaseTaskHandler] = {}
|
||||
|
||||
|
||||
def register_handler(handler_cls: Type[BaseTaskHandler]) -> None:
|
||||
"""注册一个任务处理器类(自动实例化)。"""
|
||||
instance = handler_cls()
|
||||
if not instance.task_type:
|
||||
raise ValueError(f"{handler_cls.__name__} 未设置 task_type")
|
||||
_registry[instance.task_type] = instance
|
||||
logger.info("注册任务处理器: %s → %s", instance.task_type, handler_cls.__name__)
|
||||
|
||||
|
||||
def get_handler(task_type: str) -> Optional[BaseTaskHandler]:
|
||||
"""根据 task_type 获取对应的处理器实例。"""
|
||||
return _registry.get(task_type)
|
||||
|
||||
|
||||
def list_handlers() -> list[str]:
|
||||
"""列出所有已注册的 task_type。"""
|
||||
return list(_registry.keys())
|
||||
|
||||
|
||||
def register_all_handlers() -> None:
|
||||
"""注册所有内置任务处理器。在此函数中 import 并注册。"""
|
||||
from worker.tasks.boss_recruit import BossRecruitHandler
|
||||
register_handler(BossRecruitHandler)
|
||||
# 未来扩展:在此处添加新的 handler
|
||||
# from worker.tasks.xxx import XxxHandler
|
||||
# register_handler(XxxHandler)
|
||||
204
worker/ws_client.py
Normal file
204
worker/ws_client.py
Normal file
@@ -0,0 +1,204 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Worker WebSocket 客户端。
|
||||
负责:连接服务器、注册、心跳、接收任务、上报进度/结果、断线重连。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
|
||||
from common.protocol import MsgType, TaskStatus, make_msg
|
||||
from worker import config
|
||||
from worker.bit_browser import BitBrowserAPI
|
||||
from worker.tasks.registry import get_handler
|
||||
|
||||
logger = logging.getLogger("worker.ws_client")
|
||||
|
||||
|
||||
class WorkerWSClient:
|
||||
"""Worker WebSocket 客户端。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str,
|
||||
worker_id: str,
|
||||
worker_name: str,
|
||||
bit_api_base: str,
|
||||
) -> None:
|
||||
self.server_url = server_url
|
||||
self.worker_id = worker_id
|
||||
self.worker_name = worker_name
|
||||
self.bit_api = BitBrowserAPI(bit_api_base)
|
||||
self._ws: Optional[websockets.WebSocketClientProtocol] = None
|
||||
self._running = False
|
||||
self._reconnect_delay = config.RECONNECT_DELAY
|
||||
|
||||
# ────────────────────────── 主循环 ──────────────────────────
|
||||
|
||||
async def run(self) -> None:
|
||||
"""主循环:连接 → 注册 → 收发消息;断线则重连。"""
|
||||
self._running = True
|
||||
while self._running:
|
||||
try:
|
||||
await self._connect_and_loop()
|
||||
except Exception as e:
|
||||
logger.error("连接异常: %s", e)
|
||||
if self._running:
|
||||
logger.info("将在 %d 秒后重连...", self._reconnect_delay)
|
||||
await asyncio.sleep(self._reconnect_delay)
|
||||
# 指数退避
|
||||
self._reconnect_delay = min(
|
||||
self._reconnect_delay * 2,
|
||||
config.RECONNECT_MAX_DELAY,
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
|
||||
# ────────────────────────── 连接流程 ──────────────────────────
|
||||
|
||||
async def _connect_and_loop(self) -> None:
|
||||
logger.info("正在连接服务器: %s", self.server_url)
|
||||
async with websockets.connect(self.server_url) as ws:
|
||||
self._ws = ws
|
||||
self._reconnect_delay = config.RECONNECT_DELAY # 连接成功重置退避
|
||||
logger.info("WebSocket 已连接")
|
||||
|
||||
# 注册
|
||||
await self._register(ws)
|
||||
|
||||
# 启动心跳协程
|
||||
heartbeat_task = asyncio.create_task(self._heartbeat_loop(ws))
|
||||
|
||||
try:
|
||||
# 消息接收循环
|
||||
async for raw in ws:
|
||||
data = json.loads(raw)
|
||||
await self._handle_message(ws, data)
|
||||
except ConnectionClosed as e:
|
||||
logger.warning("WebSocket 连接关闭: %s", e)
|
||||
finally:
|
||||
heartbeat_task.cancel()
|
||||
self._ws = None
|
||||
|
||||
async def _register(self, ws) -> None:
|
||||
"""发送注册消息。"""
|
||||
browsers = self._fetch_browser_list()
|
||||
msg = make_msg(
|
||||
MsgType.REGISTER,
|
||||
worker_id=self.worker_id,
|
||||
worker_name=self.worker_name,
|
||||
browsers=browsers,
|
||||
)
|
||||
await ws.send(json.dumps(msg))
|
||||
# 等待 ACK
|
||||
ack_raw = await asyncio.wait_for(ws.recv(), timeout=10)
|
||||
ack = json.loads(ack_raw)
|
||||
if ack.get("type") == MsgType.REGISTER_ACK.value:
|
||||
logger.info("注册成功: worker_id=%s", self.worker_id)
|
||||
else:
|
||||
logger.error("注册失败: %s", ack)
|
||||
raise RuntimeError(f"注册失败: {ack}")
|
||||
|
||||
# ────────────────────────── 心跳 ──────────────────────────
|
||||
|
||||
async def _heartbeat_loop(self, ws) -> None:
|
||||
"""定期发送心跳。"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(config.HEARTBEAT_INTERVAL)
|
||||
msg = make_msg(MsgType.HEARTBEAT, worker_id=self.worker_id)
|
||||
await ws.send(json.dumps(msg))
|
||||
logger.debug("心跳已发送")
|
||||
except Exception:
|
||||
break
|
||||
|
||||
# ────────────────────────── 消息处理 ──────────────────────────
|
||||
|
||||
async def _handle_message(self, ws, data: dict) -> None:
|
||||
msg_type = data.get("type", "")
|
||||
|
||||
if msg_type == MsgType.HEARTBEAT_ACK.value:
|
||||
logger.debug("收到心跳 ACK")
|
||||
|
||||
elif msg_type == MsgType.TASK_ASSIGN.value:
|
||||
await self._handle_task(ws, data)
|
||||
|
||||
elif msg_type == MsgType.TASK_CANCEL.value:
|
||||
task_id = data.get("task_id", "")
|
||||
logger.info("收到任务取消: %s(暂不支持中途取消)", task_id)
|
||||
|
||||
elif msg_type == MsgType.ERROR.value:
|
||||
logger.error("服务器错误: %s", data.get("detail", ""))
|
||||
|
||||
else:
|
||||
logger.warning("未知消息: %s", msg_type)
|
||||
|
||||
async def _handle_task(self, ws, data: dict) -> None:
|
||||
"""接收并执行任务。"""
|
||||
task_id = data.get("task_id", "")
|
||||
task_type = data.get("task_type", "")
|
||||
account_name = data.get("account_name", "")
|
||||
params = data.get("params", {})
|
||||
|
||||
# 将 account_name 注入 params(供 handler 使用)
|
||||
if account_name:
|
||||
params.setdefault("account_name", account_name)
|
||||
params.setdefault("bit_api_base", self.bit_api.base_url)
|
||||
|
||||
logger.info("收到任务: %s (type=%s)", task_id, task_type)
|
||||
|
||||
handler = get_handler(task_type)
|
||||
if not handler:
|
||||
error_msg = f"不支持的任务类型: {task_type}"
|
||||
logger.error(error_msg)
|
||||
await self._send_result(ws, task_id, error=error_msg)
|
||||
return
|
||||
|
||||
# 上报进度的回调
|
||||
async def progress_cb(tid: str, progress: str):
|
||||
msg = make_msg(MsgType.TASK_PROGRESS, task_id=tid, progress=progress)
|
||||
try:
|
||||
await ws.send(json.dumps(msg))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 执行任务
|
||||
try:
|
||||
result = await handler.execute(task_id, params, progress_cb)
|
||||
await self._send_result(ws, task_id, result=result)
|
||||
except Exception as e:
|
||||
logger.error("任务 %s 执行失败: %s", task_id, e, exc_info=True)
|
||||
await self._send_result(ws, task_id, error=str(e))
|
||||
|
||||
async def _send_result(self, ws, task_id: str, result=None, error: str = None) -> None:
|
||||
"""上报任务最终结果。"""
|
||||
msg = make_msg(MsgType.TASK_RESULT, task_id=task_id, result=result, error=error)
|
||||
try:
|
||||
await ws.send(json.dumps(msg))
|
||||
logger.info("任务 %s 结果已上报 (error=%s)", task_id, error)
|
||||
except Exception as e:
|
||||
logger.error("上报结果失败: %s", e)
|
||||
|
||||
# ────────────────────────── 比特浏览器列表 ──────────────────────────
|
||||
|
||||
def _fetch_browser_list(self) -> List[dict]:
|
||||
"""获取本机比特浏览器窗口列表(用于注册时上报)。"""
|
||||
try:
|
||||
items = self.bit_api.list_browsers()
|
||||
return [
|
||||
{"id": b.get("id", ""), "name": b.get("name", ""), "remark": b.get("remark", "")}
|
||||
for b in items
|
||||
]
|
||||
except Exception as e:
|
||||
logger.warning("获取比特浏览器列表失败(比特浏览器可能未启动): %s", e)
|
||||
return []
|
||||
Reference in New Issue
Block a user