230 lines
9.3 KiB
Python
230 lines
9.3 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
Django Channels WebSocket Consumer:处理 Worker 连接。
|
||
"""
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
|
||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||
|
||
from common.protocol import MsgType, TaskStatus, TaskType, make_msg
|
||
from server.core.worker_manager import worker_manager
|
||
from server.core.task_dispatcher import task_dispatcher
|
||
|
||
logger = logging.getLogger("server.ws")
|
||
|
||
|
||
class WorkerConsumer(AsyncWebsocketConsumer):
|
||
"""
|
||
Worker 连接流程:
|
||
1. 建立连接
|
||
2. 等待第一条 register 消息
|
||
3. 持续收发消息(心跳、任务进度、任务结果等)
|
||
4. 断开时注销
|
||
"""
|
||
|
||
def __init__(self, *args, **kwargs):
|
||
super().__init__(*args, **kwargs)
|
||
self.worker_id = None
|
||
self._registered = False
|
||
|
||
async def connect(self):
|
||
await self.accept()
|
||
# 设置 30 秒注册超时
|
||
self._register_timeout = asyncio.get_event_loop().call_later(
|
||
30, lambda: asyncio.ensure_future(self._timeout_close())
|
||
)
|
||
|
||
async def _timeout_close(self):
|
||
if not self._registered:
|
||
logger.warning("WebSocket 连接超时(未在 30 秒内注册)")
|
||
await self.close(code=4003)
|
||
|
||
async def disconnect(self, close_code):
|
||
if hasattr(self, "_register_timeout"):
|
||
self._register_timeout.cancel()
|
||
if self.worker_id:
|
||
worker_manager.unregister(self.worker_id)
|
||
logger.info("Worker %s WebSocket 断开", self.worker_id)
|
||
|
||
async def _send_json(self, data: dict):
|
||
"""发送 JSON 消息(供 worker_manager 存储为 send_fn)。"""
|
||
await self.send(text_data=json.dumps(data, ensure_ascii=False))
|
||
|
||
async def receive(self, text_data=None, bytes_data=None):
|
||
if text_data is None:
|
||
return
|
||
try:
|
||
data = json.loads(text_data)
|
||
except json.JSONDecodeError:
|
||
return
|
||
|
||
msg_type = data.get("type", "")
|
||
|
||
# ── 未注册:只接受 register ──
|
||
if not self._registered:
|
||
if msg_type != MsgType.REGISTER.value:
|
||
await self._send_json(make_msg(MsgType.ERROR, detail="首条消息必须是 register"))
|
||
await self.close(code=4001)
|
||
return
|
||
|
||
self.worker_id = data.get("worker_id", "")
|
||
worker_name = data.get("worker_name", self.worker_id)
|
||
browsers = data.get("browsers", [])
|
||
|
||
if not self.worker_id:
|
||
await self._send_json(make_msg(MsgType.ERROR, detail="worker_id 不能为空"))
|
||
await self.close(code=4002)
|
||
return
|
||
|
||
worker_manager.register(self._send_json, self.worker_id, worker_name, browsers)
|
||
self._registered = True
|
||
if hasattr(self, "_register_timeout"):
|
||
self._register_timeout.cancel()
|
||
await self._send_json(make_msg(MsgType.REGISTER_ACK, worker_id=self.worker_id))
|
||
logger.info("Worker %s 已连接", self.worker_id)
|
||
return
|
||
|
||
# ── 已注册:处理各类消息 ──
|
||
if msg_type == MsgType.HEARTBEAT.value:
|
||
worker_manager.heartbeat(self.worker_id)
|
||
await self._send_json(make_msg(MsgType.HEARTBEAT_ACK))
|
||
|
||
elif msg_type == MsgType.BROWSER_LIST_UPDATE.value:
|
||
worker_manager.update_browsers(self.worker_id, data.get("browsers", []))
|
||
try:
|
||
self._sync_boss_account_browsers(self.worker_id, data.get("browsers", []))
|
||
except Exception as sync_err:
|
||
logger.warning("同步 BossAccount 环境失败: %s", sync_err)
|
||
|
||
elif msg_type == MsgType.TASK_PROGRESS.value:
|
||
task_id = data.get("task_id", "")
|
||
progress = data.get("progress", "")
|
||
task_dispatcher.update_progress(task_id, progress)
|
||
logger.info("任务 %s 进度: %s", task_id, progress)
|
||
# 同步更新账号任务状态为 running
|
||
try:
|
||
self._update_account_task_status(task_id, TaskStatus.RUNNING.value)
|
||
except Exception:
|
||
pass
|
||
|
||
elif msg_type == MsgType.TASK_RESULT.value:
|
||
task_id = data.get("task_id", "")
|
||
result = data.get("result")
|
||
error = data.get("error")
|
||
task_dispatcher.complete_task(task_id, result=result, error=error)
|
||
worker_manager.set_current_task(self.worker_id, None)
|
||
logger.info("任务 %s 已完成", task_id)
|
||
|
||
# ── 将结果写入数据库 ──
|
||
try:
|
||
task_info = task_dispatcher.get_task(task_id)
|
||
if task_info:
|
||
final_status = task_info.status.value if hasattr(task_info.status, "value") else str(task_info.status)
|
||
self._save_task_log(task_id, task_info, result, error, final_status)
|
||
self._update_account_task_status(task_id, final_status)
|
||
# check_login 任务:更新账号登录状态
|
||
task_type_val = task_info.task_type.value if hasattr(task_info.task_type, "value") else str(task_info.task_type)
|
||
if task_type_val == TaskType.CHECK_LOGIN.value and result and not error:
|
||
self._upsert_account_status(result)
|
||
except Exception as db_err:
|
||
logger.error("任务 %s 写入数据库失败: %s", task_id, db_err)
|
||
|
||
else:
|
||
logger.warning("未知消息类型: %s (from %s)", msg_type, self.worker_id)
|
||
|
||
# ────────────────────────── 数据库操作(同步) ──────────────────────────
|
||
|
||
@staticmethod
|
||
def _save_task_log(task_id, task_info, result, error, final_status):
|
||
from server.models import TaskLog
|
||
TaskLog.objects.update_or_create(
|
||
task_id=task_id,
|
||
defaults={
|
||
"task_type": task_info.task_type.value if hasattr(task_info.task_type, "value") else str(task_info.task_type),
|
||
"worker_id": task_info.worker_id or "",
|
||
"status": final_status,
|
||
"params": task_info.params,
|
||
"result": result,
|
||
"error": error,
|
||
},
|
||
)
|
||
|
||
@staticmethod
|
||
def _update_account_task_status(task_id, task_status):
|
||
from server.models import BossAccount
|
||
BossAccount.objects.filter(current_task_id=task_id).update(current_task_status=task_status)
|
||
|
||
def _upsert_account_status(self, result):
|
||
from server.models import BossAccount
|
||
from django.utils import timezone as tz
|
||
browser_id = result.get("browser_id", "")
|
||
browser_name = result.get("browser_name", "")
|
||
boss_username = result.get("boss_username", "")
|
||
boss_id = result.get("boss_id", "")
|
||
is_logged_in = result.get("is_logged_in", False)
|
||
|
||
# 优先按 worker_id + browser_name 匹配
|
||
account = None
|
||
if browser_name:
|
||
account = BossAccount.objects.filter(
|
||
worker_id=self.worker_id, browser_name=browser_name,
|
||
).first()
|
||
if not account and browser_id:
|
||
account = BossAccount.objects.filter(
|
||
worker_id=self.worker_id, browser_id=browser_id,
|
||
).first()
|
||
|
||
now = tz.now()
|
||
if account:
|
||
account.browser_id = browser_id or account.browser_id
|
||
account.browser_name = browser_name or account.browser_name
|
||
account.boss_username = boss_username
|
||
if boss_id:
|
||
account.boss_id = boss_id
|
||
account.is_logged_in = is_logged_in
|
||
account.checked_at = now
|
||
account.save()
|
||
else:
|
||
BossAccount.objects.create(
|
||
worker_id=self.worker_id,
|
||
browser_id=browser_id or f"name:{browser_name}",
|
||
browser_name=browser_name,
|
||
boss_username=boss_username,
|
||
boss_id=boss_id,
|
||
is_logged_in=is_logged_in,
|
||
checked_at=now,
|
||
)
|
||
logger.info(
|
||
"账号状态更新: worker=%s, browser=%s(%s), username=%s, boss_id=%s, logged_in=%s",
|
||
self.worker_id, browser_name, browser_id, boss_username, boss_id, is_logged_in,
|
||
)
|
||
|
||
@staticmethod
|
||
def _sync_boss_account_browsers(worker_id: str, browsers: list) -> None:
|
||
"""根据 Worker 上报的浏览器列表,同步更新 BossAccount 的 browser_id、browser_name。"""
|
||
from server.models import BossAccount
|
||
|
||
if not browsers:
|
||
return
|
||
browser_map = {str(b.get("id", "")).strip(): b for b in browsers if b.get("id")}
|
||
if not browser_map:
|
||
return
|
||
|
||
accounts = BossAccount.objects.filter(worker_id=worker_id)
|
||
updated = 0
|
||
for acc in accounts:
|
||
bid = (acc.browser_id or "").strip()
|
||
if not bid:
|
||
continue
|
||
new_info = browser_map.get(bid)
|
||
if new_info:
|
||
new_name = (new_info.get("name") or "").strip()
|
||
if new_name and new_name != acc.browser_name:
|
||
acc.browser_name = new_name
|
||
acc.save(update_fields=["browser_name"])
|
||
updated += 1
|
||
if updated:
|
||
logger.info("BossAccount 环境同步: worker=%s, 更新 %d 条", worker_id, updated)
|