diff --git a/common/protocol.py b/common/protocol.py index 4b78755..4b5ed93 100644 --- a/common/protocol.py +++ b/common/protocol.py @@ -17,12 +17,14 @@ class MsgType(str, Enum): BROWSER_LIST_UPDATE = "browser_list_update" # 浏览器列表变更 TASK_PROGRESS = "task_progress" # 任务进度上报 TASK_RESULT = "task_result" # 任务最终结果 + TASK_STATUS_REPORT = "task_status_report" # 任务执行状态回报(响应服务端查询) # Server → Worker REGISTER_ACK = "register_ack" # 注册确认 HEARTBEAT_ACK = "heartbeat_ack" # 心跳确认 TASK_ASSIGN = "task_assign" # 派发任务 TASK_CANCEL = "task_cancel" # 取消任务 + TASK_STATUS_QUERY = "task_status_query" # 查询任务执行状态 # 双向 ERROR = "error" # 错误消息 diff --git a/server/config.py b/server/config.py index d2c33d6..5d8bf86 100644 --- a/server/config.py +++ b/server/config.py @@ -13,6 +13,9 @@ PORT: int = int(os.getenv("SERVER_PORT", "9000")) # 云服务器主端口 WS_PATH: str = "/ws" # Worker 连接端点 HEARTBEAT_INTERVAL: int = 30 # 期望 Worker 心跳间隔(秒) HEARTBEAT_TIMEOUT: int = 90 # 超时未收到心跳视为离线(秒) +TASK_EXEC_TIMEOUT: int = int(os.getenv("TASK_EXEC_TIMEOUT", "300")) # 任务执行超时时间(秒) +TASK_HEALTH_CHECK_INTERVAL: int = int(os.getenv("TASK_HEALTH_CHECK_INTERVAL", "30")) # 超时任务巡检间隔(秒) +TASK_STATUS_QUERY_TIMEOUT: int = int(os.getenv("TASK_STATUS_QUERY_TIMEOUT", "10")) # 等待 worker 状态回报超时(秒) # ─── 安全(可选) ─── API_TOKEN: str = os.getenv("API_TOKEN", "") # 非空时校验 Header: Authorization: Bearer diff --git a/server/core/task_health_monitor.py b/server/core/task_health_monitor.py new file mode 100644 index 0000000..2453f11 --- /dev/null +++ b/server/core/task_health_monitor.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +""" +超时任务巡检: +1) 扫描超过阈值仍未结束的任务; +2) 向对应 Worker 查询该任务是否仍在执行; +3) 未执行 / 查询超时 / Worker 不可用则标记任务失败。 +""" +from __future__ import annotations + +import asyncio +import logging +import uuid +from datetime import timedelta +from typing import Dict, Optional + +from asgiref.sync import sync_to_async + +from common.protocol import MsgType, TaskStatus, make_msg +from server import config +from server.core.worker_manager import worker_manager + +logger = logging.getLogger("server.task_health") + + +class TaskHealthMonitor: + """任务执行健康巡检器。""" + + def __init__(self) -> None: + self._pending_reports: Dict[str, asyncio.Future] = {} + self._lock = asyncio.Lock() + + async def run_loop(self) -> None: + """后台循环:定期探测超时未结束任务。""" + interval = max(5, int(config.TASK_HEALTH_CHECK_INTERVAL)) + while True: + try: + await self.check_once() + except Exception as e: + logger.error("超时任务巡检异常: %s", e, exc_info=True) + await asyncio.sleep(interval) + + async def check_once(self) -> None: + timeout_seconds = max(60, int(config.TASK_EXEC_TIMEOUT)) + candidates = await self._list_timed_out_active_tasks(timeout_seconds) + if not candidates: + return + + results = await asyncio.gather( + *(self._probe_task(item) for item in candidates), + return_exceptions=True, + ) + for item, result in zip(candidates, results): + if isinstance(result, Exception): + logger.error( + "巡检任务异常: task_id=%s, worker_id=%s, err=%s", + item.get("task_id", ""), + item.get("worker_id", ""), + result, + ) + + async def report_task_status( + self, + request_id: str, + worker_id: str, + task_id: str, + running: bool, + detail: str = "", + ) -> None: + """接收 Worker 回报,并唤醒对应的等待方。""" + rid = (request_id or "").strip() + if not rid: + return + + future = await self._pop_pending(rid) + if not future: + return + if future.done(): + return + + future.set_result( + { + "request_id": rid, + "worker_id": (worker_id or "").strip(), + "task_id": (task_id or "").strip(), + "running": bool(running), + "detail": (detail or "").strip(), + } + ) + + async def _probe_task(self, item: dict) -> None: + task_id = str(item.get("task_id", "")).strip() + worker_id = str(item.get("worker_id", "")).strip() + if not task_id: + return + + if not worker_id: + await self._mark_task_failed_if_active(task_id, "任务执行异常:缺少 worker_id,无法确认执行状态") + return + + if not worker_manager.is_online(worker_id): + updated = await self._mark_task_failed_if_active(task_id, "任务执行异常:Worker 不在线") + if updated: + worker_manager.set_current_task(worker_id, None) + return + + send_fn = worker_manager.get_send_fn(worker_id) + if not send_fn: + updated = await self._mark_task_failed_if_active(task_id, "任务执行异常:Worker 通道不可用") + if updated: + worker_manager.set_current_task(worker_id, None) + return + + request_id = uuid.uuid4().hex[:16] + future = asyncio.get_running_loop().create_future() + await self._put_pending(request_id, future) + + try: + query = make_msg( + MsgType.TASK_STATUS_QUERY, + request_id=request_id, + task_id=task_id, + ) + await send_fn(query) + except Exception as e: + await self._remove_pending(request_id, future) + updated = await self._mark_task_failed_if_active(task_id, f"任务执行异常:状态查询下发失败 ({e})") + if updated: + worker_manager.set_current_task(worker_id, None) + return + + timeout = max(3, int(config.TASK_STATUS_QUERY_TIMEOUT)) + try: + report = await asyncio.wait_for(future, timeout=timeout) + except asyncio.TimeoutError: + await self._remove_pending(request_id, future) + updated = await self._mark_task_failed_if_active(task_id, "任务执行异常:状态查询超时,客户端未确认执行中") + if updated: + worker_manager.set_current_task(worker_id, None) + return + finally: + await self._remove_pending(request_id, future) + + is_running = bool(report.get("running", False)) if isinstance(report, dict) else False + if is_running: + return + + detail = "" + if isinstance(report, dict): + detail = str(report.get("detail", "")).strip() + error = "任务执行异常:客户端未在执行该任务" + if detail: + error = f"{error} ({detail})" + updated = await self._mark_task_failed_if_active(task_id, error) + if updated: + worker_manager.set_current_task(worker_id, None) + + async def _put_pending(self, request_id: str, future: asyncio.Future) -> None: + async with self._lock: + self._pending_reports[request_id] = future + + async def _pop_pending(self, request_id: str) -> Optional[asyncio.Future]: + async with self._lock: + return self._pending_reports.pop(request_id, None) + + async def _remove_pending(self, request_id: str, fallback_future: asyncio.Future) -> None: + removed = await self._pop_pending(request_id) + future = removed or fallback_future + if future and not future.done(): + future.cancel() + + @staticmethod + @sync_to_async + def _list_timed_out_active_tasks(timeout_seconds: int) -> list[dict]: + from django.utils import timezone as tz + from server.models import Task + + threshold = tz.now() - timedelta(seconds=max(1, timeout_seconds)) + active_statuses = [TaskStatus.DISPATCHED.value, TaskStatus.RUNNING.value] + rows = ( + Task.objects + .filter(status__in=active_statuses, created_at__lte=threshold) + .values("task_id", "worker_id") + ) + return list(rows) + + @staticmethod + @sync_to_async + def _mark_task_failed_if_active(task_id: str, error: str) -> bool: + from django.utils import timezone as tz + from server.models import BossAccount, Task, TaskLog + + active_statuses = [TaskStatus.DISPATCHED.value, TaskStatus.RUNNING.value] + task = Task.objects.filter(task_id=task_id, status__in=active_statuses).first() + if not task: + return False + + now = tz.now() + task.status = TaskStatus.FAILED.value + task.error = error + task.updated_at = now + task.save(update_fields=["status", "error", "updated_at"]) + + TaskLog.objects.update_or_create( + task_id=task_id, + defaults={ + "task_type": task.task_type, + "worker_id": task.worker_id or "", + "status": TaskStatus.FAILED.value, + "params": task.params, + "result": task.result, + "error": error, + }, + ) + BossAccount.objects.filter(current_task_id=task_id).update(current_task_status=TaskStatus.FAILED.value) + + logger.warning("任务超时巡检判定失败: task_id=%s, worker_id=%s, error=%s", task_id, task.worker_id, error) + return True + + +task_health_monitor = TaskHealthMonitor() diff --git a/server/main.py b/server/main.py index ae3fce4..23df051 100644 --- a/server/main.py +++ b/server/main.py @@ -17,6 +17,7 @@ django.setup() # noqa: E402 import uvicorn # noqa: E402 from server import config # noqa: E402 +from server.core.task_health_monitor import task_health_monitor # noqa: E402 from server.core.worker_manager import worker_manager # noqa: E402 from tunnel.server import TunnelServer # noqa: E402 @@ -32,6 +33,8 @@ async def run_server(): """启动服务器:Django ASGI + 心跳巡检 + 隧道。""" # 启动心跳巡检 asyncio.create_task(worker_manager.check_heartbeats_loop()) + # 启动超时任务巡检 + asyncio.create_task(task_health_monitor.run_loop()) # 启动隧道服务 tunnel_server = TunnelServer( diff --git a/server/ws/consumers.py b/server/ws/consumers.py index 4c32d5a..386305c 100644 --- a/server/ws/consumers.py +++ b/server/ws/consumers.py @@ -10,6 +10,7 @@ from asgiref.sync import sync_to_async from channels.generic.websocket import AsyncWebsocketConsumer from common.protocol import MsgType, TaskStatus, TaskType, make_msg +from server.core.task_health_monitor import task_health_monitor from server.core.worker_manager import worker_manager logger = logging.getLogger("server.ws") @@ -135,6 +136,22 @@ class WorkerConsumer(AsyncWebsocketConsumer): except Exception as db_err: logger.error("任务 %s 写入数据库失败: %s", task_id, db_err) + elif msg_type == MsgType.TASK_STATUS_REPORT.value: + request_id = str(data.get("request_id", "")).strip() + task_id = str(data.get("task_id", "")).strip() + running = bool(data.get("running", False)) + detail = str(data.get("detail", "")).strip() + try: + await task_health_monitor.report_task_status( + request_id=request_id, + worker_id=self.worker_id or "", + task_id=task_id, + running=running, + detail=detail, + ) + except Exception as e: + logger.error("处理任务状态回报失败 (request_id=%s, task_id=%s): %s", request_id, task_id, e) + else: logger.warning("未知消息类型: %s (from %s)", msg_type, self.worker_id) diff --git a/worker/tasks/boss_recruit.py b/worker/tasks/boss_recruit.py index e08f053..1a3adde 100644 --- a/worker/tasks/boss_recruit.py +++ b/worker/tasks/boss_recruit.py @@ -320,7 +320,7 @@ class BossRecruitHandler(BaseTaskHandler): return False def _send_with_confirm(self, tab, input_box, message: str, max_attempts: int = 2) -> bool: - """发送后检查末条消息,若未成功则自动补发。""" + """发送后检查末条消息,避免因确认误判导致重复补发。""" msg = (message or "").strip() if not msg: return False @@ -332,18 +332,57 @@ class BossRecruitHandler(BaseTaskHandler): if self._confirm_last_sent_message(tab, msg): return True - try: - input_box.click(by_js=True) - except Exception: - pass - try: - input_box.clear() - except Exception: - pass - input_box.input(msg) + # 已点击发送但未在列表中确认时,不直接补发,先判断输入框是否已清空。 + # 若已清空,通常表示消息已发出,只是 UI 刷新慢或识别未命中,避免重复发送。 + if not self._editor_has_text(input_box, msg): + return True time.sleep(0.35) - return self._confirm_last_sent_message(tab, msg) + sent = self._confirm_last_sent_message(tab, msg) + if sent: + return True + + self._clear_editor(input_box) + return False + + @staticmethod + def _normalize_text(text: str) -> str: + return re.sub(r"\s+", "", text or "") + + def _editor_has_text(self, input_box, expected: str = "") -> bool: + """判断输入框是否仍残留指定文本。""" + expected_norm = self._normalize_text(expected) + current = "" + try: + current = input_box.run_js("return (this.innerText || this.textContent || this.value || '').trim();") + except Exception: + try: + current = input_box.text + except Exception: + current = "" + current_norm = self._normalize_text(str(current)) + if not current_norm: + return False + if not expected_norm: + return True + return expected_norm in current_norm + + def _clear_editor(self, input_box) -> None: + """清空聊天输入框,避免残留预输入内容。""" + try: + input_box.click(by_js=True) + except Exception: + pass + try: + input_box.clear() + except Exception: + pass + try: + input_box.run_js( + "if (this.isContentEditable) { this.innerHTML=''; this.textContent=''; }" + ) + except Exception: + pass def _confirm_last_sent_message(self, tab, message: str) -> bool: """确认当前聊天窗口末条是否为刚发送内容。""" diff --git a/worker/ws_client.py b/worker/ws_client.py index f803441..440eb5a 100644 --- a/worker/ws_client.py +++ b/worker/ws_client.py @@ -15,7 +15,7 @@ from typing import Dict, List, Optional, Set import websockets from websockets.exceptions import ConnectionClosed -from common.protocol import MsgType, TaskStatus, make_msg +from common.protocol import MsgType, make_msg from worker import config from worker.bit_browser import BitBrowserAPI from worker.tasks.base import TaskCancelledError @@ -45,6 +45,8 @@ class WorkerWSClient: self._heartbeat_count = 0 self._cancel_events: Dict[str, threading.Event] = {} self._cancelled_tasks: Set[str] = set() + self._task_runners: Dict[str, asyncio.Task] = {} + self._send_lock: Optional[asyncio.Lock] = None # ────────────────────────── 主循环 ────────────────────────── @@ -76,6 +78,7 @@ class WorkerWSClient: logger.info("正在连接服务器: %s", self.server_url) async with websockets.connect(self.server_url) as ws: self._ws = ws + self._send_lock = asyncio.Lock() self._reconnect_delay = config.RECONNECT_DELAY # 连接成功重置退避 logger.info("WebSocket 已连接") @@ -94,6 +97,14 @@ class WorkerWSClient: logger.warning("WebSocket 连接关闭: %s", e) finally: heartbeat_task.cancel() + self._send_lock = None + for event in list(self._cancel_events.values()): + event.set() + for runner in list(self._task_runners.values()): + runner.cancel() + if self._task_runners: + await asyncio.gather(*self._task_runners.values(), return_exceptions=True) + self._task_runners.clear() self._ws = None async def _register(self, ws) -> None: @@ -106,7 +117,7 @@ class WorkerWSClient: worker_name=self.worker_name, browsers=browsers, ) - await ws.send(json.dumps(msg)) + await self._safe_send(ws, msg) # 等待 ACK ack_raw = await asyncio.wait_for(ws.recv(), timeout=10) ack = json.loads(ack_raw) @@ -124,7 +135,7 @@ class WorkerWSClient: try: await asyncio.sleep(config.HEARTBEAT_INTERVAL) msg = make_msg(MsgType.HEARTBEAT, worker_id=self.worker_id) - await ws.send(json.dumps(msg)) + await self._safe_send(ws, msg) logger.debug("心跳已发送") self._heartbeat_count += 1 @@ -142,7 +153,7 @@ class WorkerWSClient: if sorted(new_list, key=key) != sorted(self._last_browsers, key=key): self._last_browsers = new_list msg = make_msg(MsgType.BROWSER_LIST_UPDATE, browsers=new_list) - await ws.send(json.dumps(msg)) + await self._safe_send(ws, msg) logger.info("浏览器列表变更,已上报 %d 个环境", len(new_list)) except Exception as e: logger.debug("检查浏览器列表变更失败: %s", e) @@ -156,24 +167,43 @@ class WorkerWSClient: logger.debug("收到心跳 ACK") elif msg_type == MsgType.TASK_ASSIGN.value: - await self._handle_task(ws, data) + await self._start_task(ws, data) elif msg_type == MsgType.TASK_CANCEL.value: task_id = data.get("task_id", "") await self._handle_task_cancel(task_id) + elif msg_type == MsgType.TASK_STATUS_QUERY.value: + await self._handle_task_status_query(ws, data) + elif msg_type == MsgType.ERROR.value: logger.error("服务器错误: %s", data.get("detail", "")) else: logger.warning("未知消息: %s", msg_type) + async def _start_task(self, ws, data: dict) -> None: + """启动任务执行协程,不阻塞消息接收循环。""" + task_id = str(data.get("task_id", "")).strip() + if not task_id: + logger.warning("收到无 task_id 的任务消息,已忽略") + return + + runner = self._task_runners.get(task_id) + if runner and not runner.done(): + logger.warning("任务 %s 已在执行中,忽略重复派发", task_id) + return + + self._task_runners[task_id] = asyncio.create_task(self._handle_task(ws, data)) + async def _handle_task(self, ws, data: dict) -> None: """接收并执行任务。""" - task_id = data.get("task_id", "") - task_type = data.get("task_type", "") - account_name = data.get("account_name", "") + task_id = str(data.get("task_id", "")).strip() + task_type = str(data.get("task_type", "")).strip() + account_name = str(data.get("account_name", "")).strip() params = data.get("params", {}) + if not isinstance(params, dict): + params = {} pre_cancelled = task_id in self._cancelled_tasks cancel_event = threading.Event() @@ -188,25 +218,29 @@ class WorkerWSClient: params.setdefault("bit_api_base", self.bit_api.base_url) params["_cancel_event"] = cancel_event - logger.info("收到任务: %s (type=%s)", task_id, task_type) + if task_id: + logger.info("收到任务: %s (type=%s)", task_id, task_type) handler = get_handler(task_type) if not handler: error_msg = f"不支持的任务类型: {task_type}" logger.error(error_msg) await self._send_result(ws, task_id, error=error_msg) + params.pop("_cancel_event", None) + self._cancel_events.pop(task_id, None) + self._task_runners.pop(task_id, None) return # 上报进度的回调 async def progress_cb(tid: str, progress: str): msg = make_msg(MsgType.TASK_PROGRESS, task_id=tid, progress=progress) try: - await ws.send(json.dumps(msg)) + await self._safe_send(ws, msg) except Exception: pass - # 执行任务 try: + # 执行任务 if cancel_event.is_set(): raise TaskCancelledError("任务已取消") result = await handler.execute(task_id, params, progress_cb) @@ -216,12 +250,16 @@ class WorkerWSClient: except TaskCancelledError: logger.info("任务 %s 已取消", task_id) await self._send_result(ws, task_id, error="任务已取消") + except asyncio.CancelledError: + logger.info("任务 %s 执行协程已取消", task_id) + raise except Exception as e: logger.error("任务 %s 执行失败: %s", task_id, e, exc_info=True) await self._send_result(ws, task_id, error=str(e)) finally: params.pop("_cancel_event", None) self._cancel_events.pop(task_id, None) + self._task_runners.pop(task_id, None) async def _handle_task_cancel(self, task_id: str) -> None: """处理服务端下发的任务取消。""" @@ -236,15 +274,53 @@ class WorkerWSClient: self._cancelled_tasks.add(task_id) logger.info("收到任务取消: %s,任务尚未执行,已记录预取消", task_id) + async def _handle_task_status_query(self, ws, data: dict) -> None: + """响应服务端任务状态探测。""" + task_id = str(data.get("task_id", "")).strip() + request_id = str(data.get("request_id", "")).strip() + + running = False + detail = "" + if task_id: + runner = self._task_runners.get(task_id) + running = bool(runner and not runner.done()) + if not running: + detail = "task_not_running" + else: + running = any(r and not r.done() for r in self._task_runners.values()) + if not running: + detail = "no_running_task" + + msg = make_msg( + MsgType.TASK_STATUS_REPORT, + request_id=request_id, + task_id=task_id, + running=running, + detail=detail, + worker_id=self.worker_id, + ) + try: + await self._safe_send(ws, msg) + except Exception as e: + logger.debug("发送任务状态回报失败: %s", e) + async def _send_result(self, ws, task_id: str, result=None, error: str = None) -> None: """上报任务最终结果。""" msg = make_msg(MsgType.TASK_RESULT, task_id=task_id, result=result, error=error) try: - await ws.send(json.dumps(msg)) + await self._safe_send(ws, msg) logger.info("任务 %s 结果已上报 (error=%s)", task_id, error) except Exception as e: logger.error("上报结果失败: %s", e) + async def _safe_send(self, ws, msg: dict) -> None: + payload = json.dumps(msg) + if self._send_lock is None: + await ws.send(payload) + return + async with self._send_lock: + await ws.send(payload) + # ────────────────────────── 比特浏览器列表 ────────────────────────── def _fetch_browser_list(self) -> List[dict]: