Files
boss_dp/server/core/task_dispatcher.py
2026-03-03 02:28:54 +08:00

230 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
任务路由与派发。
根据请求中的 worker_id / account_name 找到目标 Worker通过 WebSocket 下发任务。
"""
from __future__ import annotations
import logging
import time
from typing import List, Optional
from common.protocol import MsgType, TaskStatus, make_msg
from server.models import TaskCreate, Task
logger = logging.getLogger("server.task_dispatcher")
class TaskDispatcher:
"""
管理任务生命周期并派发给 Worker以数据库为唯一真相
注意:
- 不再在内存中长期保存任务状态;
- 所有查询 / 状态变更均落在 ORM 模型 Task 上;
- 仅在单次请求/派发过程中临时构造 Python 对象。
"""
_ACTIVE_STATUSES = {TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING}
@staticmethod
def _normalize_account_name(account_name: Optional[str]) -> str:
return (account_name or "").strip()
def _find_active_task(self, worker_id: str, account_name: str) -> Optional[Task]:
normalized = self._normalize_account_name(account_name)
if not worker_id or not normalized:
return None
from django.db.models import Q
active_values = [s.value for s in TaskDispatcher._ACTIVE_STATUSES]
return (
Task.objects.filter(
Q(worker_id=worker_id),
Q(account_name=normalized),
Q(status__in=active_values),
)
.order_by("-created_at")
.first()
)
@staticmethod
def _infer_failure_reason(result, error) -> Optional[str]:
"""
从 Worker 上报中推断失败原因。
规则:
1) 只要显式上报了 error即便是空串就视为失败
2) 若 result 明确包含 success=false / ok=false / status=failed|error|fail也视为失败。
"""
if error is not None:
msg = str(error).strip()
return msg or "任务执行失败"
if isinstance(result, dict):
if result.get("success") is False:
msg = str(result.get("error", "")).strip()
return msg or "任务执行失败"
if result.get("ok") is False:
msg = str(result.get("error", "")).strip()
return msg or "任务执行失败"
status_text = str(result.get("status", "")).strip().lower()
if status_text in {"failed", "error", "fail"}:
msg = str(result.get("error", "")).strip()
return msg or "任务执行失败"
return None
# ─── 创建任务 ───
def create_task(self, req: TaskCreate) -> Task:
"""
创建任务记录:
- 校验同一 worker + account 是否已有执行中的任务(查表);
- 在 Task 表中插入一条新记录;
- 返回 ORM 实例,供后续派发使用。
"""
from django.utils import timezone as tz
import uuid
existing = self._find_active_task(req.worker_id or "", req.account_name or "")
if existing:
raise ValueError(
"账号 '%s' 已有执行中的任务task_id=%s, status=%s"
% (
req.account_name or "",
existing.task_id,
existing.status,
)
)
task_id = uuid.uuid4().hex[:12]
now = tz.now()
task = Task.objects.create(
task_id=task_id,
task_type=req.task_type.value if hasattr(req.task_type, "value") else str(req.task_type),
worker_id=req.worker_id or "",
account_name=req.account_name or "",
status=TaskStatus.PENDING.value,
params=req.params,
created_at=now,
updated_at=now,
)
logger.info(
"任务创建: %s type=%s worker=%s account=%s",
task.task_id, task.task_type, task.worker_id, task.account_name,
)
return task
# ─── 派发 ───
async def dispatch(self, task: Task, ws_send) -> bool:
"""
将任务通过 WebSocket 发给目标 Worker。
ws_send: 异步可调用,接受一个 dict 参数。
返回是否发送成功。
"""
msg = make_msg(
MsgType.TASK_ASSIGN,
task_id=task.task_id,
task_type=task.task_type,
account_name=task.account_name or "",
params=task.params or {},
)
try:
await ws_send(msg)
from django.utils import timezone as tz
task.status = TaskStatus.DISPATCHED.value
task.updated_at = tz.now()
task.save(update_fields=["status", "updated_at"])
logger.info("任务 %s 已派发", task.task_id)
return True
except Exception as e:
from django.utils import timezone as tz
task.status = TaskStatus.FAILED.value
task.error = f"派发失败: {e}"
task.updated_at = tz.now()
task.save(update_fields=["status", "error", "updated_at"])
logger.error("任务 %s 派发失败: %s", task.task_id, e)
return False
# ─── 更新状态 ───
def update_progress(self, task_id: str, progress: str) -> None:
"""
更新任务进度和状态RUNNING
"""
from django.utils import timezone as tz
updated = Task.objects.filter(task_id=task_id).update(
status=TaskStatus.RUNNING.value,
progress=progress,
updated_at=tz.now(),
)
if updated:
logger.info("任务 %s 进度更新: %s", task_id, progress)
def complete_task(self, task_id: str, result=None, error: str = None) -> None:
"""
完成任务:根据 result/error 推断最终状态并写入数据库。
"""
from django.utils import timezone as tz
task = Task.objects.filter(task_id=task_id).first()
if not task:
return
failure_reason = self._infer_failure_reason(result, error)
if failure_reason:
task.status = TaskStatus.FAILED.value
task.error = failure_reason
task.result = result
else:
task.status = TaskStatus.SUCCESS.value
task.result = result
task.error = None
task.updated_at = tz.now()
task.save(update_fields=["status", "result", "error", "updated_at"])
logger.info("任务 %s 完成: status=%s", task_id, task.status)
def cancel_task(self, task_id: str) -> None:
"""
取消任务:仅对活动状态的任务生效。
"""
from django.utils import timezone as tz
active_values = [s.value for s in self._ACTIVE_STATUSES]
task = Task.objects.filter(task_id=task_id, status__in=active_values).first()
if not task:
return
task.status = TaskStatus.CANCELLED.value
task.updated_at = tz.now()
task.save(update_fields=["status", "updated_at"])
# ─── 查询 ───
def get_task(self, task_id: str) -> Optional[Task]:
return Task.objects.filter(task_id=task_id).first()
def get_active_task_by_account(self, worker_id: str, account_name: str) -> Optional[Task]:
return self._find_active_task(worker_id, account_name)
def list_tasks(
self,
worker_id: Optional[str] = None,
status: Optional[TaskStatus] = None,
limit: int = 50,
) -> List[Task]:
"""
从数据库中查询任务列表。
"""
qs = Task.objects.all().order_by("-created_at")
if worker_id:
qs = qs.filter(worker_id=worker_id)
if status:
qs = qs.filter(status=status.value if hasattr(status, "value") else str(status))
return list(qs[:limit])
# 全局单例
task_dispatcher = TaskDispatcher()