# -*- coding: utf-8 -*- """ 任务路由与派发。 根据请求中的 worker_id / account_name 找到目标 Worker,通过 WebSocket 下发任务。 """ from __future__ import annotations import logging import time from typing import Dict, List, Optional from common.protocol import MsgType, TaskStatus, make_msg from server.models import TaskCreate, TaskInfo logger = logging.getLogger("server.task_dispatcher") class TaskDispatcher: """管理任务生命周期并派发给 Worker。""" def __init__(self) -> None: # task_id → TaskInfo self._tasks: Dict[str, TaskInfo] = {} # ─── 创建任务 ─── def create_task(self, req: TaskCreate) -> TaskInfo: task = TaskInfo( task_type=req.task_type, worker_id=req.worker_id, account_name=req.account_name, params=req.params, ) self._tasks[task.task_id] = task logger.info("任务创建: %s type=%s worker=%s account=%s", task.task_id, task.task_type, task.worker_id, task.account_name) return task # ─── 派发 ─── async def dispatch(self, task: TaskInfo, ws_send) -> bool: """ 将任务通过 WebSocket 发给目标 Worker。 ws_send: 异步可调用,接受一个 dict 参数。 返回是否发送成功。 """ msg = make_msg( MsgType.TASK_ASSIGN, task_id=task.task_id, task_type=task.task_type.value, account_name=task.account_name or "", params=task.params, ) try: await ws_send(msg) task.status = TaskStatus.DISPATCHED task.updated_at = time.time() logger.info("任务 %s 已派发", task.task_id) return True except Exception as e: task.status = TaskStatus.FAILED task.error = f"派发失败: {e}" task.updated_at = time.time() logger.error("任务 %s 派发失败: %s", task.task_id, e) return False # ─── 更新状态 ─── def update_progress(self, task_id: str, progress: str) -> None: task = self._tasks.get(task_id) if task: task.status = TaskStatus.RUNNING task.progress = progress task.updated_at = time.time() def complete_task(self, task_id: str, result=None, error: str = None) -> None: task = self._tasks.get(task_id) if not task: return if error: task.status = TaskStatus.FAILED task.error = error else: task.status = TaskStatus.SUCCESS task.result = result task.updated_at = time.time() logger.info("任务 %s 完成: status=%s", task_id, task.status) def cancel_task(self, task_id: str) -> None: task = self._tasks.get(task_id) if task and task.status in (TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING): task.status = TaskStatus.CANCELLED task.updated_at = time.time() # ─── 查询 ─── def get_task(self, task_id: str) -> Optional[TaskInfo]: return self._tasks.get(task_id) def list_tasks( self, worker_id: Optional[str] = None, status: Optional[TaskStatus] = None, limit: int = 50, ) -> List[TaskInfo]: result = list(self._tasks.values()) if worker_id: result = [t for t in result if t.worker_id == worker_id] if status: result = [t for t in result if t.status == status] # 按创建时间倒序 result.sort(key=lambda t: t.created_at, reverse=True) return result[:limit] # 全局单例 task_dispatcher = TaskDispatcher()