This commit is contained in:
ddrwode
2026-03-02 00:40:21 +08:00
parent 0a7468799f
commit 57ee9a6cda
4 changed files with 106 additions and 48 deletions

View File

@@ -360,6 +360,7 @@ Set-Cookie: auth_token=a1b2c3d4e5f6...; HttpOnly; Max-Age=31536000; SameSite=Lax
| 201 | 成功,任务已创建并派发 |
| 400 | 未指定 id、boss_id、worker_id 或 account_name |
| 401 | 未登录或 token 失效 |
| 409 | 同一账号已有执行中的任务 |
| 404 | 未找到拥有该浏览器环境的在线 Worker |
| 503 | Worker 不在线 / WebSocket 连接不存在 / 派发失败 |
@@ -419,6 +420,7 @@ Set-Cookie: auth_token=a1b2c3d4e5f6...; HttpOnly; Max-Age=31536000; SameSite=Lax
|--------|------|
| 400 | 未指定 id、boss_id、worker_id 或 account_name |
| 401 | 未登录或 token 失效 |
| 409 | 同一账号已有执行中的任务 |
| 404 | 未找到拥有该浏览器环境的在线 Worker |
| 503 | Worker 不在线 / WebSocket 连接不存在 / 派发失败 |

View File

@@ -16,7 +16,7 @@ from asgiref.sync import async_to_sync
from rest_framework import status
from rest_framework.decorators import api_view
from common.protocol import TaskType
from common.protocol import TaskType, TaskStatus
from server.core.response import api_success, api_error
from server.models import BossAccount, TaskCreate
from server.serializers import BossAccountSerializer, AccountBindSerializer
@@ -72,7 +72,11 @@ def account_list(request):
send_fn = worker_manager.get_send_fn(wid)
if send_fn:
req = TaskCreate(task_type=TaskType.CHECK_LOGIN, worker_id=wid, account_name=bname, params={})
task = task_dispatcher.create_task(req)
try:
task = task_dispatcher.create_task(req)
except ValueError as e:
logger.info("绑定账号后自动触发 check_login 跳过: %s@%s, reason=%s", bname, wid, e)
return api_success(_enrich(account), http_status=status.HTTP_201_CREATED)
if async_to_sync(task_dispatcher.dispatch)(task, send_fn):
worker_manager.set_current_task(wid, task.task_id)
account.current_task_id = task.task_id
@@ -112,7 +116,12 @@ def fill_boss_ids(request):
account_name=bname,
params={},
)
task = task_dispatcher.create_task(req)
try:
task = task_dispatcher.create_task(req)
except ValueError as e:
skipped += 1
errors.append(f"{bname}@{wid}: {e}")
continue
success = async_to_sync(task_dispatcher.dispatch)(task, send_fn)
if success:
worker_manager.set_current_task(wid, task.task_id)

View File

@@ -265,7 +265,10 @@ def task_list(request):
return api_error(http_status.HTTP_503_SERVICE_UNAVAILABLE, f"Worker {target_worker_id} 不在线")
req.worker_id = target_worker_id
task = task_dispatcher.create_task(req)
try:
task = task_dispatcher.create_task(req)
except ValueError as e:
return api_error(http_status.HTTP_409_CONFLICT, str(e))
send_fn = worker_manager.get_send_fn(target_worker_id)
if not send_fn:

View File

@@ -7,6 +7,7 @@ from __future__ import annotations
import logging
import time
import threading
from typing import Dict, List, Optional
from common.protocol import MsgType, TaskStatus, make_msg
@@ -17,24 +18,56 @@ logger = logging.getLogger("server.task_dispatcher")
class TaskDispatcher:
"""管理任务生命周期并派发给 Worker。"""
_ACTIVE_STATUSES = {TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING}
def __init__(self) -> None:
# task_id → TaskInfo
self._tasks: Dict[str, TaskInfo] = {}
self._lock = threading.RLock()
@staticmethod
def _normalize_account_name(account_name: Optional[str]) -> str:
return (account_name or "").strip()
def _find_active_task_unlocked(self, worker_id: str, account_name: str) -> Optional[TaskInfo]:
normalized = self._normalize_account_name(account_name)
if not worker_id or not normalized:
return None
for task in self._tasks.values():
if task.worker_id != worker_id:
continue
if self._normalize_account_name(task.account_name) != normalized:
continue
if task.status in self._ACTIVE_STATUSES:
return task
return None
# ─── 创建任务 ───
def create_task(self, req: TaskCreate) -> TaskInfo:
task = TaskInfo(
task_type=req.task_type,
worker_id=req.worker_id,
account_name=req.account_name,
params=req.params,
)
self._tasks[task.task_id] = task
logger.info("任务创建: %s type=%s worker=%s account=%s",
task.task_id, task.task_type, task.worker_id, task.account_name)
return task
with self._lock:
existing = self._find_active_task_unlocked(req.worker_id or "", req.account_name or "")
if existing:
raise ValueError(
"账号 '%s' 已有执行中的任务task_id=%s, status=%s"
% (
req.account_name or "",
existing.task_id,
existing.status.value if hasattr(existing.status, "value") else str(existing.status),
)
)
task = TaskInfo(
task_type=req.task_type,
worker_id=req.worker_id,
account_name=req.account_name,
params=req.params,
)
self._tasks[task.task_id] = task
logger.info(
"任务创建: %s type=%s worker=%s account=%s",
task.task_id, task.task_type, task.worker_id, task.account_name,
)
return task
# ─── 派发 ───
@@ -53,49 +86,59 @@ class TaskDispatcher:
)
try:
await ws_send(msg)
task.status = TaskStatus.DISPATCHED
task.updated_at = time.time()
with self._lock:
task.status = TaskStatus.DISPATCHED
task.updated_at = time.time()
logger.info("任务 %s 已派发", task.task_id)
return True
except Exception as e:
task.status = TaskStatus.FAILED
task.error = f"派发失败: {e}"
task.updated_at = time.time()
with self._lock:
task.status = TaskStatus.FAILED
task.error = f"派发失败: {e}"
task.updated_at = time.time()
logger.error("任务 %s 派发失败: %s", task.task_id, e)
return False
# ─── 更新状态 ───
def update_progress(self, task_id: str, progress: str) -> None:
task = self._tasks.get(task_id)
if task:
task.status = TaskStatus.RUNNING
task.progress = progress
task.updated_at = time.time()
with self._lock:
task = self._tasks.get(task_id)
if task:
task.status = TaskStatus.RUNNING
task.progress = progress
task.updated_at = time.time()
def complete_task(self, task_id: str, result=None, error: str = None) -> None:
task = self._tasks.get(task_id)
if not task:
return
if error:
task.status = TaskStatus.FAILED
task.error = error
else:
task.status = TaskStatus.SUCCESS
task.result = result
task.updated_at = time.time()
logger.info("任务 %s 完成: status=%s", task_id, task.status)
with self._lock:
task = self._tasks.get(task_id)
if not task:
return
if error:
task.status = TaskStatus.FAILED
task.error = error
else:
task.status = TaskStatus.SUCCESS
task.result = result
task.updated_at = time.time()
logger.info("任务 %s 完成: status=%s", task_id, task.status)
def cancel_task(self, task_id: str) -> None:
task = self._tasks.get(task_id)
if task and task.status in (TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING):
task.status = TaskStatus.CANCELLED
task.updated_at = time.time()
with self._lock:
task = self._tasks.get(task_id)
if task and task.status in (TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING):
task.status = TaskStatus.CANCELLED
task.updated_at = time.time()
# ─── 查询 ───
def get_task(self, task_id: str) -> Optional[TaskInfo]:
return self._tasks.get(task_id)
with self._lock:
return self._tasks.get(task_id)
def get_active_task_by_account(self, worker_id: str, account_name: str) -> Optional[TaskInfo]:
with self._lock:
return self._find_active_task_unlocked(worker_id, account_name)
def list_tasks(
self,
@@ -103,14 +146,15 @@ class TaskDispatcher:
status: Optional[TaskStatus] = None,
limit: int = 50,
) -> List[TaskInfo]:
result = list(self._tasks.values())
if worker_id:
result = [t for t in result if t.worker_id == worker_id]
if status:
result = [t for t in result if t.status == status]
# 按创建时间倒序
result.sort(key=lambda t: t.created_at, reverse=True)
return result[:limit]
with self._lock:
result = list(self._tasks.values())
if worker_id:
result = [t for t in result if t.worker_id == worker_id]
if status:
result = [t for t in result if t.status == status]
# 按创建时间倒序
result.sort(key=lambda t: t.created_at, reverse=True)
return result[:limit]
# 全局单例