Files
boss_dp/server/core/task_dispatcher.py
2026-02-12 16:27:43 +08:00

118 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
任务路由与派发。
根据请求中的 worker_id / account_name 找到目标 Worker通过 WebSocket 下发任务。
"""
from __future__ import annotations
import logging
import time
from typing import Dict, List, Optional
from common.protocol import MsgType, TaskStatus, make_msg
from server.models import TaskCreate, TaskInfo
logger = logging.getLogger("server.task_dispatcher")
class TaskDispatcher:
"""管理任务生命周期并派发给 Worker。"""
def __init__(self) -> None:
# task_id → TaskInfo
self._tasks: Dict[str, TaskInfo] = {}
# ─── 创建任务 ───
def create_task(self, req: TaskCreate) -> TaskInfo:
task = TaskInfo(
task_type=req.task_type,
worker_id=req.worker_id,
account_name=req.account_name,
params=req.params,
)
self._tasks[task.task_id] = task
logger.info("任务创建: %s type=%s worker=%s account=%s",
task.task_id, task.task_type, task.worker_id, task.account_name)
return task
# ─── 派发 ───
async def dispatch(self, task: TaskInfo, ws_send) -> bool:
"""
将任务通过 WebSocket 发给目标 Worker。
ws_send: 异步可调用,接受一个 dict 参数。
返回是否发送成功。
"""
msg = make_msg(
MsgType.TASK_ASSIGN,
task_id=task.task_id,
task_type=task.task_type.value,
account_name=task.account_name or "",
params=task.params,
)
try:
await ws_send(msg)
task.status = TaskStatus.DISPATCHED
task.updated_at = time.time()
logger.info("任务 %s 已派发", task.task_id)
return True
except Exception as e:
task.status = TaskStatus.FAILED
task.error = f"派发失败: {e}"
task.updated_at = time.time()
logger.error("任务 %s 派发失败: %s", task.task_id, e)
return False
# ─── 更新状态 ───
def update_progress(self, task_id: str, progress: str) -> None:
task = self._tasks.get(task_id)
if task:
task.status = TaskStatus.RUNNING
task.progress = progress
task.updated_at = time.time()
def complete_task(self, task_id: str, result=None, error: str = None) -> None:
task = self._tasks.get(task_id)
if not task:
return
if error:
task.status = TaskStatus.FAILED
task.error = error
else:
task.status = TaskStatus.SUCCESS
task.result = result
task.updated_at = time.time()
logger.info("任务 %s 完成: status=%s", task_id, task.status)
def cancel_task(self, task_id: str) -> None:
task = self._tasks.get(task_id)
if task and task.status in (TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING):
task.status = TaskStatus.CANCELLED
task.updated_at = time.time()
# ─── 查询 ───
def get_task(self, task_id: str) -> Optional[TaskInfo]:
return self._tasks.get(task_id)
def list_tasks(
self,
worker_id: Optional[str] = None,
status: Optional[TaskStatus] = None,
limit: int = 50,
) -> List[TaskInfo]:
result = list(self._tasks.values())
if worker_id:
result = [t for t in result if t.worker_id == worker_id]
if status:
result = [t for t in result if t.status == status]
# 按创建时间倒序
result.sort(key=lambda t: t.created_at, reverse=True)
return result[:limit]
# 全局单例
task_dispatcher = TaskDispatcher()