Files
boss_dp/server/core/task_dispatcher.py
ddrwode e9b359b4fb ha'ha
2026-03-02 00:51:20 +08:00

190 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
任务路由与派发。
根据请求中的 worker_id / account_name 找到目标 Worker通过 WebSocket 下发任务。
"""
from __future__ import annotations
import logging
import time
import threading
from typing import Dict, List, Optional
from common.protocol import MsgType, TaskStatus, make_msg
from server.models import TaskCreate, TaskInfo
logger = logging.getLogger("server.task_dispatcher")
class TaskDispatcher:
"""管理任务生命周期并派发给 Worker。"""
_ACTIVE_STATUSES = {TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING}
def __init__(self) -> None:
# task_id → TaskInfo
self._tasks: Dict[str, TaskInfo] = {}
self._lock = threading.RLock()
@staticmethod
def _normalize_account_name(account_name: Optional[str]) -> str:
return (account_name or "").strip()
def _find_active_task_unlocked(self, worker_id: str, account_name: str) -> Optional[TaskInfo]:
normalized = self._normalize_account_name(account_name)
if not worker_id or not normalized:
return None
for task in self._tasks.values():
if task.worker_id != worker_id:
continue
if self._normalize_account_name(task.account_name) != normalized:
continue
if task.status in self._ACTIVE_STATUSES:
return task
return None
@staticmethod
def _infer_failure_reason(result, error) -> Optional[str]:
"""
从 Worker 上报中推断失败原因。
规则:
1) 只要显式上报了 error即便是空串就视为失败
2) 若 result 明确包含 success=false / ok=false / status=failed|error|fail也视为失败。
"""
if error is not None:
msg = str(error).strip()
return msg or "任务执行失败"
if isinstance(result, dict):
if result.get("success") is False:
msg = str(result.get("error", "")).strip()
return msg or "任务执行失败"
if result.get("ok") is False:
msg = str(result.get("error", "")).strip()
return msg or "任务执行失败"
status_text = str(result.get("status", "")).strip().lower()
if status_text in {"failed", "error", "fail"}:
msg = str(result.get("error", "")).strip()
return msg or "任务执行失败"
return None
# ─── 创建任务 ───
def create_task(self, req: TaskCreate) -> TaskInfo:
with self._lock:
existing = self._find_active_task_unlocked(req.worker_id or "", req.account_name or "")
if existing:
raise ValueError(
"账号 '%s' 已有执行中的任务task_id=%s, status=%s"
% (
req.account_name or "",
existing.task_id,
existing.status.value if hasattr(existing.status, "value") else str(existing.status),
)
)
task = TaskInfo(
task_type=req.task_type,
worker_id=req.worker_id,
account_name=req.account_name,
params=req.params,
)
self._tasks[task.task_id] = task
logger.info(
"任务创建: %s type=%s worker=%s account=%s",
task.task_id, task.task_type, task.worker_id, task.account_name,
)
return task
# ─── 派发 ───
async def dispatch(self, task: TaskInfo, ws_send) -> bool:
"""
将任务通过 WebSocket 发给目标 Worker。
ws_send: 异步可调用,接受一个 dict 参数。
返回是否发送成功。
"""
msg = make_msg(
MsgType.TASK_ASSIGN,
task_id=task.task_id,
task_type=task.task_type.value,
account_name=task.account_name or "",
params=task.params,
)
try:
await ws_send(msg)
with self._lock:
task.status = TaskStatus.DISPATCHED
task.updated_at = time.time()
logger.info("任务 %s 已派发", task.task_id)
return True
except Exception as e:
with self._lock:
task.status = TaskStatus.FAILED
task.error = f"派发失败: {e}"
task.updated_at = time.time()
logger.error("任务 %s 派发失败: %s", task.task_id, e)
return False
# ─── 更新状态 ───
def update_progress(self, task_id: str, progress: str) -> None:
with self._lock:
task = self._tasks.get(task_id)
if task:
task.status = TaskStatus.RUNNING
task.progress = progress
task.updated_at = time.time()
def complete_task(self, task_id: str, result=None, error: str = None) -> None:
with self._lock:
task = self._tasks.get(task_id)
if not task:
return
failure_reason = self._infer_failure_reason(result, error)
if failure_reason:
task.status = TaskStatus.FAILED
task.error = failure_reason
else:
task.status = TaskStatus.SUCCESS
task.result = result
task.error = None
task.updated_at = time.time()
logger.info("任务 %s 完成: status=%s", task_id, task.status)
def cancel_task(self, task_id: str) -> None:
with self._lock:
task = self._tasks.get(task_id)
if task and task.status in (TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING):
task.status = TaskStatus.CANCELLED
task.updated_at = time.time()
# ─── 查询 ───
def get_task(self, task_id: str) -> Optional[TaskInfo]:
with self._lock:
return self._tasks.get(task_id)
def get_active_task_by_account(self, worker_id: str, account_name: str) -> Optional[TaskInfo]:
with self._lock:
return self._find_active_task_unlocked(worker_id, account_name)
def list_tasks(
self,
worker_id: Optional[str] = None,
status: Optional[TaskStatus] = None,
limit: int = 50,
) -> List[TaskInfo]:
with self._lock:
result = list(self._tasks.values())
if worker_id:
result = [t for t in result if t.worker_id == worker_id]
if status:
result = [t for t in result if t.status == status]
# 按创建时间倒序
result.sort(key=lambda t: t.created_at, reverse=True)
return result[:limit]
# 全局单例
task_dispatcher = TaskDispatcher()