118 lines
3.7 KiB
Python
118 lines
3.7 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
任务路由与派发。
|
||
根据请求中的 worker_id / account_name 找到目标 Worker,通过 WebSocket 下发任务。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import time
|
||
from typing import Dict, List, Optional
|
||
|
||
from common.protocol import MsgType, TaskStatus, make_msg
|
||
from server.models import TaskCreate, TaskInfo
|
||
|
||
logger = logging.getLogger("server.task_dispatcher")
|
||
|
||
|
||
class TaskDispatcher:
|
||
"""管理任务生命周期并派发给 Worker。"""
|
||
|
||
def __init__(self) -> None:
|
||
# task_id → TaskInfo
|
||
self._tasks: Dict[str, TaskInfo] = {}
|
||
|
||
# ─── 创建任务 ───
|
||
|
||
def create_task(self, req: TaskCreate) -> TaskInfo:
|
||
task = TaskInfo(
|
||
task_type=req.task_type,
|
||
worker_id=req.worker_id,
|
||
account_name=req.account_name,
|
||
params=req.params,
|
||
)
|
||
self._tasks[task.task_id] = task
|
||
logger.info("任务创建: %s type=%s worker=%s account=%s",
|
||
task.task_id, task.task_type, task.worker_id, task.account_name)
|
||
return task
|
||
|
||
# ─── 派发 ───
|
||
|
||
async def dispatch(self, task: TaskInfo, ws_send) -> bool:
|
||
"""
|
||
将任务通过 WebSocket 发给目标 Worker。
|
||
ws_send: 异步可调用,接受一个 dict 参数。
|
||
返回是否发送成功。
|
||
"""
|
||
msg = make_msg(
|
||
MsgType.TASK_ASSIGN,
|
||
task_id=task.task_id,
|
||
task_type=task.task_type.value,
|
||
account_name=task.account_name or "",
|
||
params=task.params,
|
||
)
|
||
try:
|
||
await ws_send(msg)
|
||
task.status = TaskStatus.DISPATCHED
|
||
task.updated_at = time.time()
|
||
logger.info("任务 %s 已派发", task.task_id)
|
||
return True
|
||
except Exception as e:
|
||
task.status = TaskStatus.FAILED
|
||
task.error = f"派发失败: {e}"
|
||
task.updated_at = time.time()
|
||
logger.error("任务 %s 派发失败: %s", task.task_id, e)
|
||
return False
|
||
|
||
# ─── 更新状态 ───
|
||
|
||
def update_progress(self, task_id: str, progress: str) -> None:
|
||
task = self._tasks.get(task_id)
|
||
if task:
|
||
task.status = TaskStatus.RUNNING
|
||
task.progress = progress
|
||
task.updated_at = time.time()
|
||
|
||
def complete_task(self, task_id: str, result=None, error: str = None) -> None:
|
||
task = self._tasks.get(task_id)
|
||
if not task:
|
||
return
|
||
if error:
|
||
task.status = TaskStatus.FAILED
|
||
task.error = error
|
||
else:
|
||
task.status = TaskStatus.SUCCESS
|
||
task.result = result
|
||
task.updated_at = time.time()
|
||
logger.info("任务 %s 完成: status=%s", task_id, task.status)
|
||
|
||
def cancel_task(self, task_id: str) -> None:
|
||
task = self._tasks.get(task_id)
|
||
if task and task.status in (TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING):
|
||
task.status = TaskStatus.CANCELLED
|
||
task.updated_at = time.time()
|
||
|
||
# ─── 查询 ───
|
||
|
||
def get_task(self, task_id: str) -> Optional[TaskInfo]:
|
||
return self._tasks.get(task_id)
|
||
|
||
def list_tasks(
|
||
self,
|
||
worker_id: Optional[str] = None,
|
||
status: Optional[TaskStatus] = None,
|
||
limit: int = 50,
|
||
) -> List[TaskInfo]:
|
||
result = list(self._tasks.values())
|
||
if worker_id:
|
||
result = [t for t in result if t.worker_id == worker_id]
|
||
if status:
|
||
result = [t for t in result if t.status == status]
|
||
# 按创建时间倒序
|
||
result.sort(key=lambda t: t.created_at, reverse=True)
|
||
return result[:limit]
|
||
|
||
|
||
# 全局单例
|
||
task_dispatcher = TaskDispatcher()
|