This commit is contained in:
ddrwode
2026-03-03 13:32:20 +08:00
parent 6d27568c61
commit a66cbc7424
7 changed files with 383 additions and 23 deletions

View File

@@ -13,6 +13,9 @@ PORT: int = int(os.getenv("SERVER_PORT", "9000")) # 云服务器主端口
WS_PATH: str = "/ws" # Worker 连接端点
HEARTBEAT_INTERVAL: int = 30 # 期望 Worker 心跳间隔(秒)
HEARTBEAT_TIMEOUT: int = 90 # 超时未收到心跳视为离线(秒)
TASK_EXEC_TIMEOUT: int = int(os.getenv("TASK_EXEC_TIMEOUT", "300")) # 任务执行超时时间(秒)
TASK_HEALTH_CHECK_INTERVAL: int = int(os.getenv("TASK_HEALTH_CHECK_INTERVAL", "30")) # 超时任务巡检间隔(秒)
TASK_STATUS_QUERY_TIMEOUT: int = int(os.getenv("TASK_STATUS_QUERY_TIMEOUT", "10")) # 等待 worker 状态回报超时(秒)
# ─── 安全(可选) ───
API_TOKEN: str = os.getenv("API_TOKEN", "") # 非空时校验 Header: Authorization: Bearer <token>

View File

@@ -0,0 +1,220 @@
# -*- coding: utf-8 -*-
"""
超时任务巡检:
1) 扫描超过阈值仍未结束的任务;
2) 向对应 Worker 查询该任务是否仍在执行;
3) 未执行 / 查询超时 / Worker 不可用则标记任务失败。
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from datetime import timedelta
from typing import Dict, Optional
from asgiref.sync import sync_to_async
from common.protocol import MsgType, TaskStatus, make_msg
from server import config
from server.core.worker_manager import worker_manager
logger = logging.getLogger("server.task_health")
class TaskHealthMonitor:
"""任务执行健康巡检器。"""
def __init__(self) -> None:
self._pending_reports: Dict[str, asyncio.Future] = {}
self._lock = asyncio.Lock()
async def run_loop(self) -> None:
"""后台循环:定期探测超时未结束任务。"""
interval = max(5, int(config.TASK_HEALTH_CHECK_INTERVAL))
while True:
try:
await self.check_once()
except Exception as e:
logger.error("超时任务巡检异常: %s", e, exc_info=True)
await asyncio.sleep(interval)
async def check_once(self) -> None:
timeout_seconds = max(60, int(config.TASK_EXEC_TIMEOUT))
candidates = await self._list_timed_out_active_tasks(timeout_seconds)
if not candidates:
return
results = await asyncio.gather(
*(self._probe_task(item) for item in candidates),
return_exceptions=True,
)
for item, result in zip(candidates, results):
if isinstance(result, Exception):
logger.error(
"巡检任务异常: task_id=%s, worker_id=%s, err=%s",
item.get("task_id", ""),
item.get("worker_id", ""),
result,
)
async def report_task_status(
self,
request_id: str,
worker_id: str,
task_id: str,
running: bool,
detail: str = "",
) -> None:
"""接收 Worker 回报,并唤醒对应的等待方。"""
rid = (request_id or "").strip()
if not rid:
return
future = await self._pop_pending(rid)
if not future:
return
if future.done():
return
future.set_result(
{
"request_id": rid,
"worker_id": (worker_id or "").strip(),
"task_id": (task_id or "").strip(),
"running": bool(running),
"detail": (detail or "").strip(),
}
)
async def _probe_task(self, item: dict) -> None:
task_id = str(item.get("task_id", "")).strip()
worker_id = str(item.get("worker_id", "")).strip()
if not task_id:
return
if not worker_id:
await self._mark_task_failed_if_active(task_id, "任务执行异常:缺少 worker_id无法确认执行状态")
return
if not worker_manager.is_online(worker_id):
updated = await self._mark_task_failed_if_active(task_id, "任务执行异常Worker 不在线")
if updated:
worker_manager.set_current_task(worker_id, None)
return
send_fn = worker_manager.get_send_fn(worker_id)
if not send_fn:
updated = await self._mark_task_failed_if_active(task_id, "任务执行异常Worker 通道不可用")
if updated:
worker_manager.set_current_task(worker_id, None)
return
request_id = uuid.uuid4().hex[:16]
future = asyncio.get_running_loop().create_future()
await self._put_pending(request_id, future)
try:
query = make_msg(
MsgType.TASK_STATUS_QUERY,
request_id=request_id,
task_id=task_id,
)
await send_fn(query)
except Exception as e:
await self._remove_pending(request_id, future)
updated = await self._mark_task_failed_if_active(task_id, f"任务执行异常:状态查询下发失败 ({e})")
if updated:
worker_manager.set_current_task(worker_id, None)
return
timeout = max(3, int(config.TASK_STATUS_QUERY_TIMEOUT))
try:
report = await asyncio.wait_for(future, timeout=timeout)
except asyncio.TimeoutError:
await self._remove_pending(request_id, future)
updated = await self._mark_task_failed_if_active(task_id, "任务执行异常:状态查询超时,客户端未确认执行中")
if updated:
worker_manager.set_current_task(worker_id, None)
return
finally:
await self._remove_pending(request_id, future)
is_running = bool(report.get("running", False)) if isinstance(report, dict) else False
if is_running:
return
detail = ""
if isinstance(report, dict):
detail = str(report.get("detail", "")).strip()
error = "任务执行异常:客户端未在执行该任务"
if detail:
error = f"{error} ({detail})"
updated = await self._mark_task_failed_if_active(task_id, error)
if updated:
worker_manager.set_current_task(worker_id, None)
async def _put_pending(self, request_id: str, future: asyncio.Future) -> None:
async with self._lock:
self._pending_reports[request_id] = future
async def _pop_pending(self, request_id: str) -> Optional[asyncio.Future]:
async with self._lock:
return self._pending_reports.pop(request_id, None)
async def _remove_pending(self, request_id: str, fallback_future: asyncio.Future) -> None:
removed = await self._pop_pending(request_id)
future = removed or fallback_future
if future and not future.done():
future.cancel()
@staticmethod
@sync_to_async
def _list_timed_out_active_tasks(timeout_seconds: int) -> list[dict]:
from django.utils import timezone as tz
from server.models import Task
threshold = tz.now() - timedelta(seconds=max(1, timeout_seconds))
active_statuses = [TaskStatus.DISPATCHED.value, TaskStatus.RUNNING.value]
rows = (
Task.objects
.filter(status__in=active_statuses, created_at__lte=threshold)
.values("task_id", "worker_id")
)
return list(rows)
@staticmethod
@sync_to_async
def _mark_task_failed_if_active(task_id: str, error: str) -> bool:
from django.utils import timezone as tz
from server.models import BossAccount, Task, TaskLog
active_statuses = [TaskStatus.DISPATCHED.value, TaskStatus.RUNNING.value]
task = Task.objects.filter(task_id=task_id, status__in=active_statuses).first()
if not task:
return False
now = tz.now()
task.status = TaskStatus.FAILED.value
task.error = error
task.updated_at = now
task.save(update_fields=["status", "error", "updated_at"])
TaskLog.objects.update_or_create(
task_id=task_id,
defaults={
"task_type": task.task_type,
"worker_id": task.worker_id or "",
"status": TaskStatus.FAILED.value,
"params": task.params,
"result": task.result,
"error": error,
},
)
BossAccount.objects.filter(current_task_id=task_id).update(current_task_status=TaskStatus.FAILED.value)
logger.warning("任务超时巡检判定失败: task_id=%s, worker_id=%s, error=%s", task_id, task.worker_id, error)
return True
task_health_monitor = TaskHealthMonitor()

View File

@@ -17,6 +17,7 @@ django.setup() # noqa: E402
import uvicorn # noqa: E402
from server import config # noqa: E402
from server.core.task_health_monitor import task_health_monitor # noqa: E402
from server.core.worker_manager import worker_manager # noqa: E402
from tunnel.server import TunnelServer # noqa: E402
@@ -32,6 +33,8 @@ async def run_server():
"""启动服务器Django ASGI + 心跳巡检 + 隧道。"""
# 启动心跳巡检
asyncio.create_task(worker_manager.check_heartbeats_loop())
# 启动超时任务巡检
asyncio.create_task(task_health_monitor.run_loop())
# 启动隧道服务
tunnel_server = TunnelServer(

View File

@@ -10,6 +10,7 @@ from asgiref.sync import sync_to_async
from channels.generic.websocket import AsyncWebsocketConsumer
from common.protocol import MsgType, TaskStatus, TaskType, make_msg
from server.core.task_health_monitor import task_health_monitor
from server.core.worker_manager import worker_manager
logger = logging.getLogger("server.ws")
@@ -135,6 +136,22 @@ class WorkerConsumer(AsyncWebsocketConsumer):
except Exception as db_err:
logger.error("任务 %s 写入数据库失败: %s", task_id, db_err)
elif msg_type == MsgType.TASK_STATUS_REPORT.value:
request_id = str(data.get("request_id", "")).strip()
task_id = str(data.get("task_id", "")).strip()
running = bool(data.get("running", False))
detail = str(data.get("detail", "")).strip()
try:
await task_health_monitor.report_task_status(
request_id=request_id,
worker_id=self.worker_id or "",
task_id=task_id,
running=running,
detail=detail,
)
except Exception as e:
logger.error("处理任务状态回报失败 (request_id=%s, task_id=%s): %s", request_id, task_id, e)
else:
logger.warning("未知消息类型: %s (from %s)", msg_type, self.worker_id)