Files
boss_dp/server/core/task_dispatcher.py
ddrwode 5c9cfada28 haha
2026-03-03 10:50:32 +08:00

254 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
任务路由与派发。
根据请求中的 worker_id / account_name 找到目标 Worker通过 WebSocket 下发任务。
"""
from __future__ import annotations
import logging
import time
from typing import List, Optional
from asgiref.sync import sync_to_async
from common.protocol import MsgType, TaskStatus, make_msg
from server.models import TaskCreate, Task
logger = logging.getLogger("server.task_dispatcher")
class TaskDispatcher:
"""
管理任务生命周期并派发给 Worker以数据库为唯一真相
注意:
- 不再在内存中长期保存任务状态;
- 所有查询 / 状态变更均落在 ORM 模型 Task 上;
- 仅在单次请求/派发过程中临时构造 Python 对象。
"""
_ACTIVE_STATUSES = {TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING}
@staticmethod
def _normalize_account_name(account_name: Optional[str]) -> str:
return (account_name or "").strip()
def _find_active_task(self, worker_id: str, account_name: str) -> Optional[Task]:
normalized = self._normalize_account_name(account_name)
if not worker_id or not normalized:
return None
from django.db.models import Q
active_values = [s.value for s in TaskDispatcher._ACTIVE_STATUSES]
return (
Task.objects.filter(
Q(worker_id=worker_id),
Q(account_name=normalized),
Q(status__in=active_values),
)
.order_by("-created_at")
.first()
)
@staticmethod
def _infer_failure_reason(result, error) -> Optional[str]:
"""
从 Worker 上报中推断失败原因。
规则:
1) 只要显式上报了 error即便是空串就视为失败
2) 若 result 明确包含 success=false / ok=false / status=failed|error|fail也视为失败。
"""
if error is not None:
msg = str(error).strip()
return msg or "任务执行失败"
if isinstance(result, dict):
if result.get("success") is False:
msg = str(result.get("error", "")).strip()
return msg or "任务执行失败"
if result.get("ok") is False:
msg = str(result.get("error", "")).strip()
return msg or "任务执行失败"
status_text = str(result.get("status", "")).strip().lower()
if status_text in {"failed", "error", "fail"}:
msg = str(result.get("error", "")).strip()
return msg or "任务执行失败"
return None
# ─── 创建任务 ───
@staticmethod
@sync_to_async
def _mark_dispatched(task_id: str):
from django.utils import timezone as tz
now = tz.now()
Task.objects.filter(task_id=task_id).update(
status=TaskStatus.DISPATCHED.value,
updated_at=now,
)
return now
@staticmethod
@sync_to_async
def _mark_dispatch_failed(task_id: str, error: str):
from django.utils import timezone as tz
now = tz.now()
Task.objects.filter(task_id=task_id).update(
status=TaskStatus.FAILED.value,
error=error,
updated_at=now,
)
return now
def create_task(self, req: TaskCreate) -> Task:
"""
创建任务记录:
- 校验同一 worker + account 是否已有执行中的任务(查表);
- 在 Task 表中插入一条新记录;
- 返回 ORM 实例,供后续派发使用。
"""
from django.utils import timezone as tz
import uuid
existing = self._find_active_task(req.worker_id or "", req.account_name or "")
if existing:
raise ValueError(
"账号 '%s' 已有执行中的任务task_id=%s, status=%s"
% (
req.account_name or "",
existing.task_id,
existing.status,
)
)
task_id = uuid.uuid4().hex[:12]
now = tz.now()
task = Task.objects.create(
task_id=task_id,
task_type=req.task_type.value if hasattr(req.task_type, "value") else str(req.task_type),
worker_id=req.worker_id or "",
account_name=req.account_name or "",
status=TaskStatus.PENDING.value,
params=req.params,
created_at=now,
updated_at=now,
)
logger.info(
"任务创建: %s type=%s worker=%s account=%s",
task.task_id, task.task_type, task.worker_id, task.account_name,
)
return task
# ─── 派发 ───
async def dispatch(self, task: Task, ws_send) -> bool:
"""
将任务通过 WebSocket 发给目标 Worker。
ws_send: 异步可调用,接受一个 dict 参数。
返回是否发送成功。
"""
msg = make_msg(
MsgType.TASK_ASSIGN,
task_id=task.task_id,
task_type=task.task_type,
account_name=task.account_name or "",
params=task.params or {},
)
try:
await ws_send(msg)
task.status = TaskStatus.DISPATCHED.value
task.error = None
task.updated_at = await self._mark_dispatched(task.task_id)
logger.info("任务 %s 已派发", task.task_id)
return True
except Exception as e:
task.status = TaskStatus.FAILED.value
task.error = f"派发失败: {e}"
task.updated_at = await self._mark_dispatch_failed(task.task_id, task.error)
logger.error("任务 %s 派发失败: %s", task.task_id, e)
return False
# ─── 更新状态 ───
def update_progress(self, task_id: str, progress: str) -> None:
"""
更新任务进度和状态RUNNING
"""
from django.utils import timezone as tz
updated = Task.objects.filter(task_id=task_id).update(
status=TaskStatus.RUNNING.value,
progress=progress,
updated_at=tz.now(),
)
if updated:
logger.info("任务 %s 进度更新: %s", task_id, progress)
def complete_task(self, task_id: str, result=None, error: str = None) -> None:
"""
完成任务:根据 result/error 推断最终状态并写入数据库。
"""
from django.utils import timezone as tz
task = Task.objects.filter(task_id=task_id).first()
if not task:
return
failure_reason = self._infer_failure_reason(result, error)
if failure_reason:
task.status = TaskStatus.FAILED.value
task.error = failure_reason
task.result = result
else:
task.status = TaskStatus.SUCCESS.value
task.result = result
task.error = None
task.updated_at = tz.now()
task.save(update_fields=["status", "result", "error", "updated_at"])
logger.info("任务 %s 完成: status=%s", task_id, task.status)
def cancel_task(self, task_id: str, error: str = "任务已取消") -> Optional[Task]:
"""
取消任务:仅对活动状态的任务生效。
返回取消后的任务对象,若任务不存在或不可取消则返回 None。
"""
from django.utils import timezone as tz
active_values = [s.value for s in self._ACTIVE_STATUSES]
task = Task.objects.filter(task_id=task_id, status__in=active_values).first()
if not task:
return None
task.status = TaskStatus.CANCELLED.value
task.error = error
task.updated_at = tz.now()
task.save(update_fields=["status", "error", "updated_at"])
return task
# ─── 查询 ───
def get_task(self, task_id: str) -> Optional[Task]:
return Task.objects.filter(task_id=task_id).first()
def get_active_task_by_account(self, worker_id: str, account_name: str) -> Optional[Task]:
return self._find_active_task(worker_id, account_name)
def list_tasks(
self,
worker_id: Optional[str] = None,
status: Optional[TaskStatus] = None,
limit: int = 50,
) -> List[Task]:
"""
从数据库中查询任务列表。
"""
qs = Task.objects.all().order_by("-created_at")
if worker_id:
qs = qs.filter(worker_id=worker_id)
if status:
qs = qs.filter(status=status.value if hasattr(status, "value") else str(status))
return list(qs[:limit])
# 全局单例
task_dispatcher = TaskDispatcher()