# -*- coding: utf-8 -*- """ 任务路由与派发。 根据请求中的 worker_id / account_name 找到目标 Worker,通过 WebSocket 下发任务。 """ from __future__ import annotations import logging import time from typing import List, Optional from asgiref.sync import sync_to_async from common.protocol import MsgType, TaskStatus, make_msg from server.models import TaskCreate, Task logger = logging.getLogger("server.task_dispatcher") class TaskDispatcher: """ 管理任务生命周期并派发给 Worker(以数据库为唯一真相)。 注意: - 不再在内存中长期保存任务状态; - 所有查询 / 状态变更均落在 ORM 模型 Task 上; - 仅在单次请求/派发过程中临时构造 Python 对象。 """ _ACTIVE_STATUSES = {TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING} @staticmethod def _normalize_account_name(account_name: Optional[str]) -> str: return (account_name or "").strip() def _find_active_task(self, worker_id: str, account_name: str) -> Optional[Task]: normalized = self._normalize_account_name(account_name) if not worker_id or not normalized: return None from django.db.models import Q active_values = [s.value for s in TaskDispatcher._ACTIVE_STATUSES] return ( Task.objects.filter( Q(worker_id=worker_id), Q(account_name=normalized), Q(status__in=active_values), ) .order_by("-created_at") .first() ) @staticmethod def _infer_failure_reason(result, error) -> Optional[str]: """ 从 Worker 上报中推断失败原因。 规则: 1) 只要显式上报了 error(即便是空串)就视为失败; 2) 若 result 明确包含 success=false / ok=false / status=failed|error|fail,也视为失败。 """ if error is not None: msg = str(error).strip() return msg or "任务执行失败" if isinstance(result, dict): if result.get("success") is False: msg = str(result.get("error", "")).strip() return msg or "任务执行失败" if result.get("ok") is False: msg = str(result.get("error", "")).strip() return msg or "任务执行失败" status_text = str(result.get("status", "")).strip().lower() if status_text in {"failed", "error", "fail"}: msg = str(result.get("error", "")).strip() return msg or "任务执行失败" return None # ─── 创建任务 ─── @staticmethod @sync_to_async def _mark_dispatched(task_id: str): from django.utils import timezone as tz now = tz.now() Task.objects.filter(task_id=task_id).update( status=TaskStatus.DISPATCHED.value, updated_at=now, ) return now @staticmethod @sync_to_async def _mark_dispatch_failed(task_id: str, error: str): from django.utils import timezone as tz now = tz.now() Task.objects.filter(task_id=task_id).update( status=TaskStatus.FAILED.value, error=error, updated_at=now, ) return now def create_task(self, req: TaskCreate) -> Task: """ 创建任务记录: - 校验同一 worker + account 是否已有执行中的任务(查表); - 在 Task 表中插入一条新记录; - 返回 ORM 实例,供后续派发使用。 """ from django.utils import timezone as tz import uuid existing = self._find_active_task(req.worker_id or "", req.account_name or "") if existing: raise ValueError( "账号 '%s' 已有执行中的任务(task_id=%s, status=%s)" % ( req.account_name or "", existing.task_id, existing.status, ) ) task_id = uuid.uuid4().hex[:12] now = tz.now() task = Task.objects.create( task_id=task_id, task_type=req.task_type.value if hasattr(req.task_type, "value") else str(req.task_type), worker_id=req.worker_id or "", account_name=req.account_name or "", status=TaskStatus.PENDING.value, params=req.params, created_at=now, updated_at=now, ) logger.info( "任务创建: %s type=%s worker=%s account=%s", task.task_id, task.task_type, task.worker_id, task.account_name, ) return task # ─── 派发 ─── async def dispatch(self, task: Task, ws_send) -> bool: """ 将任务通过 WebSocket 发给目标 Worker。 ws_send: 异步可调用,接受一个 dict 参数。 返回是否发送成功。 """ msg = make_msg( MsgType.TASK_ASSIGN, task_id=task.task_id, task_type=task.task_type, account_name=task.account_name or "", params=task.params or {}, ) try: await ws_send(msg) task.status = TaskStatus.DISPATCHED.value task.error = None task.updated_at = await self._mark_dispatched(task.task_id) logger.info("任务 %s 已派发", task.task_id) return True except Exception as e: task.status = TaskStatus.FAILED.value task.error = f"派发失败: {e}" task.updated_at = await self._mark_dispatch_failed(task.task_id, task.error) logger.error("任务 %s 派发失败: %s", task.task_id, e) return False # ─── 更新状态 ─── def update_progress(self, task_id: str, progress: str) -> None: """ 更新任务进度和状态(RUNNING)。 """ from django.utils import timezone as tz updated = Task.objects.filter(task_id=task_id).update( status=TaskStatus.RUNNING.value, progress=progress, updated_at=tz.now(), ) if updated: logger.info("任务 %s 进度更新: %s", task_id, progress) def complete_task(self, task_id: str, result=None, error: str = None) -> None: """ 完成任务:根据 result/error 推断最终状态并写入数据库。 """ from django.utils import timezone as tz task = Task.objects.filter(task_id=task_id).first() if not task: return failure_reason = self._infer_failure_reason(result, error) if failure_reason: task.status = TaskStatus.FAILED.value task.error = failure_reason task.result = result else: task.status = TaskStatus.SUCCESS.value task.result = result task.error = None task.updated_at = tz.now() task.save(update_fields=["status", "result", "error", "updated_at"]) logger.info("任务 %s 完成: status=%s", task_id, task.status) def cancel_task(self, task_id: str, error: str = "任务已取消") -> Optional[Task]: """ 取消任务:仅对活动状态的任务生效。 返回取消后的任务对象,若任务不存在或不可取消则返回 None。 """ from django.utils import timezone as tz active_values = [s.value for s in self._ACTIVE_STATUSES] task = Task.objects.filter(task_id=task_id, status__in=active_values).first() if not task: return None task.status = TaskStatus.CANCELLED.value task.error = error task.updated_at = tz.now() task.save(update_fields=["status", "error", "updated_at"]) return task # ─── 查询 ─── def get_task(self, task_id: str) -> Optional[Task]: return Task.objects.filter(task_id=task_id).first() def get_active_task_by_account(self, worker_id: str, account_name: str) -> Optional[Task]: return self._find_active_task(worker_id, account_name) def list_tasks( self, worker_id: Optional[str] = None, status: Optional[TaskStatus] = None, limit: int = 50, ) -> List[Task]: """ 从数据库中查询任务列表。 """ qs = Task.objects.all().order_by("-created_at") if worker_id: qs = qs.filter(worker_id=worker_id) if status: qs = qs.filter(status=status.value if hasattr(status, "value") else str(status)) return list(qs[:limit]) # 全局单例 task_dispatcher = TaskDispatcher()