2026-02-14 16:50:02 +08:00
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
"""
|
|
|
|
|
|
Django Channels WebSocket Consumer:处理 Worker 连接。
|
|
|
|
|
|
"""
|
|
|
|
|
|
import asyncio
|
|
|
|
|
|
import json
|
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
|
|
from channels.generic.websocket import AsyncWebsocketConsumer
|
|
|
|
|
|
|
|
|
|
|
|
from common.protocol import MsgType, TaskStatus, TaskType, make_msg
|
|
|
|
|
|
from server.core.worker_manager import worker_manager
|
|
|
|
|
|
from server.core.task_dispatcher import task_dispatcher
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger("server.ws")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class WorkerConsumer(AsyncWebsocketConsumer):
|
|
|
|
|
|
"""
|
|
|
|
|
|
Worker 连接流程:
|
|
|
|
|
|
1. 建立连接
|
|
|
|
|
|
2. 等待第一条 register 消息
|
|
|
|
|
|
3. 持续收发消息(心跳、任务进度、任务结果等)
|
|
|
|
|
|
4. 断开时注销
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
self.worker_id = None
|
|
|
|
|
|
self._registered = False
|
|
|
|
|
|
|
|
|
|
|
|
async def connect(self):
|
|
|
|
|
|
await self.accept()
|
|
|
|
|
|
# 设置 30 秒注册超时
|
|
|
|
|
|
self._register_timeout = asyncio.get_event_loop().call_later(
|
|
|
|
|
|
30, lambda: asyncio.ensure_future(self._timeout_close())
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
async def _timeout_close(self):
|
|
|
|
|
|
if not self._registered:
|
|
|
|
|
|
logger.warning("WebSocket 连接超时(未在 30 秒内注册)")
|
|
|
|
|
|
await self.close(code=4003)
|
|
|
|
|
|
|
|
|
|
|
|
async def disconnect(self, close_code):
|
|
|
|
|
|
if hasattr(self, "_register_timeout"):
|
|
|
|
|
|
self._register_timeout.cancel()
|
|
|
|
|
|
if self.worker_id:
|
|
|
|
|
|
worker_manager.unregister(self.worker_id)
|
|
|
|
|
|
logger.info("Worker %s WebSocket 断开", self.worker_id)
|
|
|
|
|
|
|
|
|
|
|
|
async def _send_json(self, data: dict):
|
|
|
|
|
|
"""发送 JSON 消息(供 worker_manager 存储为 send_fn)。"""
|
|
|
|
|
|
await self.send(text_data=json.dumps(data, ensure_ascii=False))
|
|
|
|
|
|
|
|
|
|
|
|
async def receive(self, text_data=None, bytes_data=None):
|
|
|
|
|
|
if text_data is None:
|
|
|
|
|
|
return
|
|
|
|
|
|
try:
|
|
|
|
|
|
data = json.loads(text_data)
|
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
msg_type = data.get("type", "")
|
|
|
|
|
|
|
|
|
|
|
|
# ── 未注册:只接受 register ──
|
|
|
|
|
|
if not self._registered:
|
|
|
|
|
|
if msg_type != MsgType.REGISTER.value:
|
|
|
|
|
|
await self._send_json(make_msg(MsgType.ERROR, detail="首条消息必须是 register"))
|
|
|
|
|
|
await self.close(code=4001)
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
self.worker_id = data.get("worker_id", "")
|
|
|
|
|
|
worker_name = data.get("worker_name", self.worker_id)
|
|
|
|
|
|
browsers = data.get("browsers", [])
|
|
|
|
|
|
|
|
|
|
|
|
if not self.worker_id:
|
|
|
|
|
|
await self._send_json(make_msg(MsgType.ERROR, detail="worker_id 不能为空"))
|
|
|
|
|
|
await self.close(code=4002)
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
worker_manager.register(self._send_json, self.worker_id, worker_name, browsers)
|
|
|
|
|
|
self._registered = True
|
|
|
|
|
|
if hasattr(self, "_register_timeout"):
|
|
|
|
|
|
self._register_timeout.cancel()
|
|
|
|
|
|
await self._send_json(make_msg(MsgType.REGISTER_ACK, worker_id=self.worker_id))
|
|
|
|
|
|
logger.info("Worker %s 已连接", self.worker_id)
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
# ── 已注册:处理各类消息 ──
|
|
|
|
|
|
if msg_type == MsgType.HEARTBEAT.value:
|
|
|
|
|
|
worker_manager.heartbeat(self.worker_id)
|
|
|
|
|
|
await self._send_json(make_msg(MsgType.HEARTBEAT_ACK))
|
|
|
|
|
|
|
|
|
|
|
|
elif msg_type == MsgType.BROWSER_LIST_UPDATE.value:
|
|
|
|
|
|
worker_manager.update_browsers(self.worker_id, data.get("browsers", []))
|
2026-02-27 13:56:15 +08:00
|
|
|
|
try:
|
|
|
|
|
|
self._sync_boss_account_browsers(self.worker_id, data.get("browsers", []))
|
|
|
|
|
|
except Exception as sync_err:
|
|
|
|
|
|
logger.warning("同步 BossAccount 环境失败: %s", sync_err)
|
2026-02-14 16:50:02 +08:00
|
|
|
|
|
|
|
|
|
|
elif msg_type == MsgType.TASK_PROGRESS.value:
|
|
|
|
|
|
task_id = data.get("task_id", "")
|
|
|
|
|
|
progress = data.get("progress", "")
|
|
|
|
|
|
task_dispatcher.update_progress(task_id, progress)
|
|
|
|
|
|
logger.info("任务 %s 进度: %s", task_id, progress)
|
|
|
|
|
|
# 同步更新账号任务状态为 running
|
|
|
|
|
|
try:
|
|
|
|
|
|
self._update_account_task_status(task_id, TaskStatus.RUNNING.value)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
elif msg_type == MsgType.TASK_RESULT.value:
|
|
|
|
|
|
task_id = data.get("task_id", "")
|
|
|
|
|
|
result = data.get("result")
|
|
|
|
|
|
error = data.get("error")
|
|
|
|
|
|
task_dispatcher.complete_task(task_id, result=result, error=error)
|
|
|
|
|
|
worker_manager.set_current_task(self.worker_id, None)
|
|
|
|
|
|
logger.info("任务 %s 已完成", task_id)
|
|
|
|
|
|
|
|
|
|
|
|
# ── 将结果写入数据库 ──
|
|
|
|
|
|
try:
|
|
|
|
|
|
task_info = task_dispatcher.get_task(task_id)
|
|
|
|
|
|
if task_info:
|
|
|
|
|
|
final_status = task_info.status.value if hasattr(task_info.status, "value") else str(task_info.status)
|
|
|
|
|
|
self._save_task_log(task_id, task_info, result, error, final_status)
|
|
|
|
|
|
self._update_account_task_status(task_id, final_status)
|
|
|
|
|
|
# check_login 任务:更新账号登录状态
|
|
|
|
|
|
task_type_val = task_info.task_type.value if hasattr(task_info.task_type, "value") else str(task_info.task_type)
|
|
|
|
|
|
if task_type_val == TaskType.CHECK_LOGIN.value and result and not error:
|
|
|
|
|
|
self._upsert_account_status(result)
|
|
|
|
|
|
except Exception as db_err:
|
|
|
|
|
|
logger.error("任务 %s 写入数据库失败: %s", task_id, db_err)
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
logger.warning("未知消息类型: %s (from %s)", msg_type, self.worker_id)
|
|
|
|
|
|
|
|
|
|
|
|
# ────────────────────────── 数据库操作(同步) ──────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _save_task_log(task_id, task_info, result, error, final_status):
|
|
|
|
|
|
from server.models import TaskLog
|
|
|
|
|
|
TaskLog.objects.update_or_create(
|
|
|
|
|
|
task_id=task_id,
|
|
|
|
|
|
defaults={
|
|
|
|
|
|
"task_type": task_info.task_type.value if hasattr(task_info.task_type, "value") else str(task_info.task_type),
|
|
|
|
|
|
"worker_id": task_info.worker_id or "",
|
|
|
|
|
|
"status": final_status,
|
|
|
|
|
|
"params": task_info.params,
|
|
|
|
|
|
"result": result,
|
|
|
|
|
|
"error": error,
|
|
|
|
|
|
},
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _update_account_task_status(task_id, task_status):
|
|
|
|
|
|
from server.models import BossAccount
|
|
|
|
|
|
BossAccount.objects.filter(current_task_id=task_id).update(current_task_status=task_status)
|
|
|
|
|
|
|
|
|
|
|
|
def _upsert_account_status(self, result):
|
|
|
|
|
|
from server.models import BossAccount
|
|
|
|
|
|
from django.utils import timezone as tz
|
|
|
|
|
|
browser_id = result.get("browser_id", "")
|
|
|
|
|
|
browser_name = result.get("browser_name", "")
|
|
|
|
|
|
boss_username = result.get("boss_username", "")
|
2026-02-27 13:56:15 +08:00
|
|
|
|
boss_id = result.get("boss_id", "")
|
2026-02-14 16:50:02 +08:00
|
|
|
|
is_logged_in = result.get("is_logged_in", False)
|
|
|
|
|
|
|
|
|
|
|
|
# 优先按 worker_id + browser_name 匹配
|
|
|
|
|
|
account = None
|
|
|
|
|
|
if browser_name:
|
|
|
|
|
|
account = BossAccount.objects.filter(
|
|
|
|
|
|
worker_id=self.worker_id, browser_name=browser_name,
|
|
|
|
|
|
).first()
|
|
|
|
|
|
if not account and browser_id:
|
|
|
|
|
|
account = BossAccount.objects.filter(
|
|
|
|
|
|
worker_id=self.worker_id, browser_id=browser_id,
|
|
|
|
|
|
).first()
|
|
|
|
|
|
|
|
|
|
|
|
now = tz.now()
|
|
|
|
|
|
if account:
|
|
|
|
|
|
account.browser_id = browser_id or account.browser_id
|
|
|
|
|
|
account.browser_name = browser_name or account.browser_name
|
|
|
|
|
|
account.boss_username = boss_username
|
2026-02-27 13:56:15 +08:00
|
|
|
|
if boss_id:
|
|
|
|
|
|
account.boss_id = boss_id
|
2026-02-14 16:50:02 +08:00
|
|
|
|
account.is_logged_in = is_logged_in
|
|
|
|
|
|
account.checked_at = now
|
|
|
|
|
|
account.save()
|
|
|
|
|
|
else:
|
|
|
|
|
|
BossAccount.objects.create(
|
|
|
|
|
|
worker_id=self.worker_id,
|
|
|
|
|
|
browser_id=browser_id or f"name:{browser_name}",
|
|
|
|
|
|
browser_name=browser_name,
|
|
|
|
|
|
boss_username=boss_username,
|
2026-02-27 13:56:15 +08:00
|
|
|
|
boss_id=boss_id,
|
2026-02-14 16:50:02 +08:00
|
|
|
|
is_logged_in=is_logged_in,
|
|
|
|
|
|
checked_at=now,
|
|
|
|
|
|
)
|
|
|
|
|
|
logger.info(
|
2026-02-27 13:56:15 +08:00
|
|
|
|
"账号状态更新: worker=%s, browser=%s(%s), username=%s, boss_id=%s, logged_in=%s",
|
|
|
|
|
|
self.worker_id, browser_name, browser_id, boss_username, boss_id, is_logged_in,
|
2026-02-14 16:50:02 +08:00
|
|
|
|
)
|
2026-02-27 13:56:15 +08:00
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _sync_boss_account_browsers(worker_id: str, browsers: list) -> None:
|
|
|
|
|
|
"""根据 Worker 上报的浏览器列表,同步更新 BossAccount 的 browser_id、browser_name。"""
|
|
|
|
|
|
from server.models import BossAccount
|
|
|
|
|
|
|
|
|
|
|
|
if not browsers:
|
|
|
|
|
|
return
|
|
|
|
|
|
browser_map = {str(b.get("id", "")).strip(): b for b in browsers if b.get("id")}
|
|
|
|
|
|
if not browser_map:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
accounts = BossAccount.objects.filter(worker_id=worker_id)
|
|
|
|
|
|
updated = 0
|
|
|
|
|
|
for acc in accounts:
|
|
|
|
|
|
bid = (acc.browser_id or "").strip()
|
|
|
|
|
|
if not bid:
|
|
|
|
|
|
continue
|
|
|
|
|
|
new_info = browser_map.get(bid)
|
|
|
|
|
|
if new_info:
|
|
|
|
|
|
new_name = (new_info.get("name") or "").strip()
|
|
|
|
|
|
if new_name and new_name != acc.browser_name:
|
|
|
|
|
|
acc.browser_name = new_name
|
|
|
|
|
|
acc.save(update_fields=["browser_name"])
|
|
|
|
|
|
updated += 1
|
|
|
|
|
|
if updated:
|
|
|
|
|
|
logger.info("BossAccount 环境同步: worker=%s, 更新 %d 条", worker_id, updated)
|