231 lines
7.9 KiB
Python
231 lines
7.9 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
任务路由与派发。
|
||
根据请求中的 worker_id / account_name 找到目标 Worker,通过 WebSocket 下发任务。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import time
|
||
from typing import List, Optional
|
||
|
||
from common.protocol import MsgType, TaskStatus, make_msg
|
||
from server.models import TaskCreate, Task
|
||
|
||
logger = logging.getLogger("server.task_dispatcher")
|
||
|
||
|
||
class TaskDispatcher:
|
||
"""
|
||
管理任务生命周期并派发给 Worker(以数据库为唯一真相)。
|
||
|
||
注意:
|
||
- 不再在内存中长期保存任务状态;
|
||
- 所有查询 / 状态变更均落在 ORM 模型 Task 上;
|
||
- 仅在单次请求/派发过程中临时构造 Python 对象。
|
||
"""
|
||
_ACTIVE_STATUSES = {TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING}
|
||
|
||
@staticmethod
|
||
def _normalize_account_name(account_name: Optional[str]) -> str:
|
||
return (account_name or "").strip()
|
||
|
||
@staticmethod
|
||
def _find_active_task(worker_id: str, account_name: str) -> Optional[Task]:
|
||
normalized = self._normalize_account_name(account_name)
|
||
if not worker_id or not normalized:
|
||
return None
|
||
from django.db.models import Q
|
||
|
||
active_values = [s.value for s in TaskDispatcher._ACTIVE_STATUSES]
|
||
return (
|
||
Task.objects.filter(
|
||
Q(worker_id=worker_id),
|
||
Q(account_name=normalized),
|
||
Q(status__in=active_values),
|
||
)
|
||
.order_by("-created_at")
|
||
.first()
|
||
)
|
||
|
||
@staticmethod
|
||
def _infer_failure_reason(result, error) -> Optional[str]:
|
||
"""
|
||
从 Worker 上报中推断失败原因。
|
||
规则:
|
||
1) 只要显式上报了 error(即便是空串)就视为失败;
|
||
2) 若 result 明确包含 success=false / ok=false / status=failed|error|fail,也视为失败。
|
||
"""
|
||
if error is not None:
|
||
msg = str(error).strip()
|
||
return msg or "任务执行失败"
|
||
|
||
if isinstance(result, dict):
|
||
if result.get("success") is False:
|
||
msg = str(result.get("error", "")).strip()
|
||
return msg or "任务执行失败"
|
||
if result.get("ok") is False:
|
||
msg = str(result.get("error", "")).strip()
|
||
return msg or "任务执行失败"
|
||
status_text = str(result.get("status", "")).strip().lower()
|
||
if status_text in {"failed", "error", "fail"}:
|
||
msg = str(result.get("error", "")).strip()
|
||
return msg or "任务执行失败"
|
||
|
||
return None
|
||
|
||
# ─── 创建任务 ───
|
||
|
||
def create_task(self, req: TaskCreate) -> Task:
|
||
"""
|
||
创建任务记录:
|
||
- 校验同一 worker + account 是否已有执行中的任务(查表);
|
||
- 在 Task 表中插入一条新记录;
|
||
- 返回 ORM 实例,供后续派发使用。
|
||
"""
|
||
from django.utils import timezone as tz
|
||
import uuid
|
||
|
||
existing = self._find_active_task(req.worker_id or "", req.account_name or "")
|
||
if existing:
|
||
raise ValueError(
|
||
"账号 '%s' 已有执行中的任务(task_id=%s, status=%s)"
|
||
% (
|
||
req.account_name or "",
|
||
existing.task_id,
|
||
existing.status,
|
||
)
|
||
)
|
||
|
||
task_id = uuid.uuid4().hex[:12]
|
||
now = tz.now()
|
||
task = Task.objects.create(
|
||
task_id=task_id,
|
||
task_type=req.task_type.value if hasattr(req.task_type, "value") else str(req.task_type),
|
||
worker_id=req.worker_id or "",
|
||
account_name=req.account_name or "",
|
||
status=TaskStatus.PENDING.value,
|
||
params=req.params,
|
||
created_at=now,
|
||
updated_at=now,
|
||
)
|
||
logger.info(
|
||
"任务创建: %s type=%s worker=%s account=%s",
|
||
task.task_id, task.task_type, task.worker_id, task.account_name,
|
||
)
|
||
return task
|
||
|
||
# ─── 派发 ───
|
||
|
||
async def dispatch(self, task: Task, ws_send) -> bool:
|
||
"""
|
||
将任务通过 WebSocket 发给目标 Worker。
|
||
ws_send: 异步可调用,接受一个 dict 参数。
|
||
返回是否发送成功。
|
||
"""
|
||
msg = make_msg(
|
||
MsgType.TASK_ASSIGN,
|
||
task_id=task.task_id,
|
||
task_type=task.task_type,
|
||
account_name=task.account_name or "",
|
||
params=task.params or {},
|
||
)
|
||
try:
|
||
await ws_send(msg)
|
||
from django.utils import timezone as tz
|
||
|
||
task.status = TaskStatus.DISPATCHED.value
|
||
task.updated_at = tz.now()
|
||
task.save(update_fields=["status", "updated_at"])
|
||
logger.info("任务 %s 已派发", task.task_id)
|
||
return True
|
||
except Exception as e:
|
||
from django.utils import timezone as tz
|
||
|
||
task.status = TaskStatus.FAILED.value
|
||
task.error = f"派发失败: {e}"
|
||
task.updated_at = tz.now()
|
||
task.save(update_fields=["status", "error", "updated_at"])
|
||
logger.error("任务 %s 派发失败: %s", task.task_id, e)
|
||
return False
|
||
|
||
# ─── 更新状态 ───
|
||
|
||
def update_progress(self, task_id: str, progress: str) -> None:
|
||
"""
|
||
更新任务进度和状态(RUNNING)。
|
||
"""
|
||
from django.utils import timezone as tz
|
||
|
||
updated = Task.objects.filter(task_id=task_id).update(
|
||
status=TaskStatus.RUNNING.value,
|
||
progress=progress,
|
||
updated_at=tz.now(),
|
||
)
|
||
if updated:
|
||
logger.info("任务 %s 进度更新: %s", task_id, progress)
|
||
|
||
def complete_task(self, task_id: str, result=None, error: str = None) -> None:
|
||
"""
|
||
完成任务:根据 result/error 推断最终状态并写入数据库。
|
||
"""
|
||
from django.utils import timezone as tz
|
||
|
||
task = Task.objects.filter(task_id=task_id).first()
|
||
if not task:
|
||
return
|
||
failure_reason = self._infer_failure_reason(result, error)
|
||
if failure_reason:
|
||
task.status = TaskStatus.FAILED.value
|
||
task.error = failure_reason
|
||
task.result = result
|
||
else:
|
||
task.status = TaskStatus.SUCCESS.value
|
||
task.result = result
|
||
task.error = None
|
||
task.updated_at = tz.now()
|
||
task.save(update_fields=["status", "result", "error", "updated_at"])
|
||
logger.info("任务 %s 完成: status=%s", task_id, task.status)
|
||
|
||
def cancel_task(self, task_id: str) -> None:
|
||
"""
|
||
取消任务:仅对活动状态的任务生效。
|
||
"""
|
||
from django.utils import timezone as tz
|
||
|
||
active_values = [s.value for s in self._ACTIVE_STATUSES]
|
||
task = Task.objects.filter(task_id=task_id, status__in=active_values).first()
|
||
if not task:
|
||
return
|
||
task.status = TaskStatus.CANCELLED.value
|
||
task.updated_at = tz.now()
|
||
task.save(update_fields=["status", "updated_at"])
|
||
|
||
# ─── 查询 ───
|
||
|
||
def get_task(self, task_id: str) -> Optional[Task]:
|
||
return Task.objects.filter(task_id=task_id).first()
|
||
|
||
def get_active_task_by_account(self, worker_id: str, account_name: str) -> Optional[Task]:
|
||
return self._find_active_task(worker_id, account_name)
|
||
|
||
def list_tasks(
|
||
self,
|
||
worker_id: Optional[str] = None,
|
||
status: Optional[TaskStatus] = None,
|
||
limit: int = 50,
|
||
) -> List[Task]:
|
||
"""
|
||
从数据库中查询任务列表。
|
||
"""
|
||
qs = Task.objects.all().order_by("-created_at")
|
||
if worker_id:
|
||
qs = qs.filter(worker_id=worker_id)
|
||
if status:
|
||
qs = qs.filter(status=status.value if hasattr(status, "value") else str(status))
|
||
return list(qs[:limit])
|
||
|
||
|
||
# 全局单例
|
||
task_dispatcher = TaskDispatcher()
|