Files
boss_dp/server/ws/consumers.py
ddrwode 530d7fe135 haha
2026-02-27 13:56:15 +08:00

230 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
Django Channels WebSocket Consumer处理 Worker 连接。
"""
import asyncio
import json
import logging
from channels.generic.websocket import AsyncWebsocketConsumer
from common.protocol import MsgType, TaskStatus, TaskType, make_msg
from server.core.worker_manager import worker_manager
from server.core.task_dispatcher import task_dispatcher
logger = logging.getLogger("server.ws")
class WorkerConsumer(AsyncWebsocketConsumer):
"""
Worker 连接流程:
1. 建立连接
2. 等待第一条 register 消息
3. 持续收发消息(心跳、任务进度、任务结果等)
4. 断开时注销
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.worker_id = None
self._registered = False
async def connect(self):
await self.accept()
# 设置 30 秒注册超时
self._register_timeout = asyncio.get_event_loop().call_later(
30, lambda: asyncio.ensure_future(self._timeout_close())
)
async def _timeout_close(self):
if not self._registered:
logger.warning("WebSocket 连接超时(未在 30 秒内注册)")
await self.close(code=4003)
async def disconnect(self, close_code):
if hasattr(self, "_register_timeout"):
self._register_timeout.cancel()
if self.worker_id:
worker_manager.unregister(self.worker_id)
logger.info("Worker %s WebSocket 断开", self.worker_id)
async def _send_json(self, data: dict):
"""发送 JSON 消息(供 worker_manager 存储为 send_fn"""
await self.send(text_data=json.dumps(data, ensure_ascii=False))
async def receive(self, text_data=None, bytes_data=None):
if text_data is None:
return
try:
data = json.loads(text_data)
except json.JSONDecodeError:
return
msg_type = data.get("type", "")
# ── 未注册:只接受 register ──
if not self._registered:
if msg_type != MsgType.REGISTER.value:
await self._send_json(make_msg(MsgType.ERROR, detail="首条消息必须是 register"))
await self.close(code=4001)
return
self.worker_id = data.get("worker_id", "")
worker_name = data.get("worker_name", self.worker_id)
browsers = data.get("browsers", [])
if not self.worker_id:
await self._send_json(make_msg(MsgType.ERROR, detail="worker_id 不能为空"))
await self.close(code=4002)
return
worker_manager.register(self._send_json, self.worker_id, worker_name, browsers)
self._registered = True
if hasattr(self, "_register_timeout"):
self._register_timeout.cancel()
await self._send_json(make_msg(MsgType.REGISTER_ACK, worker_id=self.worker_id))
logger.info("Worker %s 已连接", self.worker_id)
return
# ── 已注册:处理各类消息 ──
if msg_type == MsgType.HEARTBEAT.value:
worker_manager.heartbeat(self.worker_id)
await self._send_json(make_msg(MsgType.HEARTBEAT_ACK))
elif msg_type == MsgType.BROWSER_LIST_UPDATE.value:
worker_manager.update_browsers(self.worker_id, data.get("browsers", []))
try:
self._sync_boss_account_browsers(self.worker_id, data.get("browsers", []))
except Exception as sync_err:
logger.warning("同步 BossAccount 环境失败: %s", sync_err)
elif msg_type == MsgType.TASK_PROGRESS.value:
task_id = data.get("task_id", "")
progress = data.get("progress", "")
task_dispatcher.update_progress(task_id, progress)
logger.info("任务 %s 进度: %s", task_id, progress)
# 同步更新账号任务状态为 running
try:
self._update_account_task_status(task_id, TaskStatus.RUNNING.value)
except Exception:
pass
elif msg_type == MsgType.TASK_RESULT.value:
task_id = data.get("task_id", "")
result = data.get("result")
error = data.get("error")
task_dispatcher.complete_task(task_id, result=result, error=error)
worker_manager.set_current_task(self.worker_id, None)
logger.info("任务 %s 已完成", task_id)
# ── 将结果写入数据库 ──
try:
task_info = task_dispatcher.get_task(task_id)
if task_info:
final_status = task_info.status.value if hasattr(task_info.status, "value") else str(task_info.status)
self._save_task_log(task_id, task_info, result, error, final_status)
self._update_account_task_status(task_id, final_status)
# check_login 任务:更新账号登录状态
task_type_val = task_info.task_type.value if hasattr(task_info.task_type, "value") else str(task_info.task_type)
if task_type_val == TaskType.CHECK_LOGIN.value and result and not error:
self._upsert_account_status(result)
except Exception as db_err:
logger.error("任务 %s 写入数据库失败: %s", task_id, db_err)
else:
logger.warning("未知消息类型: %s (from %s)", msg_type, self.worker_id)
# ────────────────────────── 数据库操作(同步) ──────────────────────────
@staticmethod
def _save_task_log(task_id, task_info, result, error, final_status):
from server.models import TaskLog
TaskLog.objects.update_or_create(
task_id=task_id,
defaults={
"task_type": task_info.task_type.value if hasattr(task_info.task_type, "value") else str(task_info.task_type),
"worker_id": task_info.worker_id or "",
"status": final_status,
"params": task_info.params,
"result": result,
"error": error,
},
)
@staticmethod
def _update_account_task_status(task_id, task_status):
from server.models import BossAccount
BossAccount.objects.filter(current_task_id=task_id).update(current_task_status=task_status)
def _upsert_account_status(self, result):
from server.models import BossAccount
from django.utils import timezone as tz
browser_id = result.get("browser_id", "")
browser_name = result.get("browser_name", "")
boss_username = result.get("boss_username", "")
boss_id = result.get("boss_id", "")
is_logged_in = result.get("is_logged_in", False)
# 优先按 worker_id + browser_name 匹配
account = None
if browser_name:
account = BossAccount.objects.filter(
worker_id=self.worker_id, browser_name=browser_name,
).first()
if not account and browser_id:
account = BossAccount.objects.filter(
worker_id=self.worker_id, browser_id=browser_id,
).first()
now = tz.now()
if account:
account.browser_id = browser_id or account.browser_id
account.browser_name = browser_name or account.browser_name
account.boss_username = boss_username
if boss_id:
account.boss_id = boss_id
account.is_logged_in = is_logged_in
account.checked_at = now
account.save()
else:
BossAccount.objects.create(
worker_id=self.worker_id,
browser_id=browser_id or f"name:{browser_name}",
browser_name=browser_name,
boss_username=boss_username,
boss_id=boss_id,
is_logged_in=is_logged_in,
checked_at=now,
)
logger.info(
"账号状态更新: worker=%s, browser=%s(%s), username=%s, boss_id=%s, logged_in=%s",
self.worker_id, browser_name, browser_id, boss_username, boss_id, is_logged_in,
)
@staticmethod
def _sync_boss_account_browsers(worker_id: str, browsers: list) -> None:
"""根据 Worker 上报的浏览器列表,同步更新 BossAccount 的 browser_id、browser_name。"""
from server.models import BossAccount
if not browsers:
return
browser_map = {str(b.get("id", "")).strip(): b for b in browsers if b.get("id")}
if not browser_map:
return
accounts = BossAccount.objects.filter(worker_id=worker_id)
updated = 0
for acc in accounts:
bid = (acc.browser_id or "").strip()
if not bid:
continue
new_info = browser_map.get(bid)
if new_info:
new_name = (new_info.get("name") or "").strip()
if new_name and new_name != acc.browser_name:
acc.browser_name = new_name
acc.save(update_fields=["browser_name"])
updated += 1
if updated:
logger.info("BossAccount 环境同步: worker=%s, 更新 %d", worker_id, updated)