diff --git a/API文档.md b/API文档.md index 4bafbe4..7628007 100644 --- a/API文档.md +++ b/API文档.md @@ -360,6 +360,7 @@ Set-Cookie: auth_token=a1b2c3d4e5f6...; HttpOnly; Max-Age=31536000; SameSite=Lax | 201 | 成功,任务已创建并派发 | | 400 | 未指定 id、boss_id、worker_id 或 account_name | | 401 | 未登录或 token 失效 | +| 409 | 同一账号已有执行中的任务 | | 404 | 未找到拥有该浏览器环境的在线 Worker | | 503 | Worker 不在线 / WebSocket 连接不存在 / 派发失败 | @@ -419,6 +420,7 @@ Set-Cookie: auth_token=a1b2c3d4e5f6...; HttpOnly; Max-Age=31536000; SameSite=Lax |--------|------| | 400 | 未指定 id、boss_id、worker_id 或 account_name | | 401 | 未登录或 token 失效 | +| 409 | 同一账号已有执行中的任务 | | 404 | 未找到拥有该浏览器环境的在线 Worker | | 503 | Worker 不在线 / WebSocket 连接不存在 / 派发失败 | diff --git a/server/api/accounts.py b/server/api/accounts.py index 8d864d6..78f5089 100644 --- a/server/api/accounts.py +++ b/server/api/accounts.py @@ -16,7 +16,7 @@ from asgiref.sync import async_to_sync from rest_framework import status from rest_framework.decorators import api_view -from common.protocol import TaskType +from common.protocol import TaskType, TaskStatus from server.core.response import api_success, api_error from server.models import BossAccount, TaskCreate from server.serializers import BossAccountSerializer, AccountBindSerializer @@ -72,7 +72,11 @@ def account_list(request): send_fn = worker_manager.get_send_fn(wid) if send_fn: req = TaskCreate(task_type=TaskType.CHECK_LOGIN, worker_id=wid, account_name=bname, params={}) - task = task_dispatcher.create_task(req) + try: + task = task_dispatcher.create_task(req) + except ValueError as e: + logger.info("绑定账号后自动触发 check_login 跳过: %s@%s, reason=%s", bname, wid, e) + return api_success(_enrich(account), http_status=status.HTTP_201_CREATED) if async_to_sync(task_dispatcher.dispatch)(task, send_fn): worker_manager.set_current_task(wid, task.task_id) account.current_task_id = task.task_id @@ -112,7 +116,12 @@ def fill_boss_ids(request): account_name=bname, params={}, ) - task = task_dispatcher.create_task(req) + try: + task = task_dispatcher.create_task(req) + except ValueError as e: + skipped += 1 + errors.append(f"{bname}@{wid}: {e}") + continue success = async_to_sync(task_dispatcher.dispatch)(task, send_fn) if success: worker_manager.set_current_task(wid, task.task_id) diff --git a/server/api/tasks.py b/server/api/tasks.py index 10a6ee0..ae0cf95 100644 --- a/server/api/tasks.py +++ b/server/api/tasks.py @@ -265,7 +265,10 @@ def task_list(request): return api_error(http_status.HTTP_503_SERVICE_UNAVAILABLE, f"Worker {target_worker_id} 不在线") req.worker_id = target_worker_id - task = task_dispatcher.create_task(req) + try: + task = task_dispatcher.create_task(req) + except ValueError as e: + return api_error(http_status.HTTP_409_CONFLICT, str(e)) send_fn = worker_manager.get_send_fn(target_worker_id) if not send_fn: diff --git a/server/core/task_dispatcher.py b/server/core/task_dispatcher.py index 950649d..537919d 100644 --- a/server/core/task_dispatcher.py +++ b/server/core/task_dispatcher.py @@ -7,6 +7,7 @@ from __future__ import annotations import logging import time +import threading from typing import Dict, List, Optional from common.protocol import MsgType, TaskStatus, make_msg @@ -17,24 +18,56 @@ logger = logging.getLogger("server.task_dispatcher") class TaskDispatcher: """管理任务生命周期并派发给 Worker。""" + _ACTIVE_STATUSES = {TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING} def __init__(self) -> None: # task_id → TaskInfo self._tasks: Dict[str, TaskInfo] = {} + self._lock = threading.RLock() + + @staticmethod + def _normalize_account_name(account_name: Optional[str]) -> str: + return (account_name or "").strip() + + def _find_active_task_unlocked(self, worker_id: str, account_name: str) -> Optional[TaskInfo]: + normalized = self._normalize_account_name(account_name) + if not worker_id or not normalized: + return None + for task in self._tasks.values(): + if task.worker_id != worker_id: + continue + if self._normalize_account_name(task.account_name) != normalized: + continue + if task.status in self._ACTIVE_STATUSES: + return task + return None # ─── 创建任务 ─── def create_task(self, req: TaskCreate) -> TaskInfo: - task = TaskInfo( - task_type=req.task_type, - worker_id=req.worker_id, - account_name=req.account_name, - params=req.params, - ) - self._tasks[task.task_id] = task - logger.info("任务创建: %s type=%s worker=%s account=%s", - task.task_id, task.task_type, task.worker_id, task.account_name) - return task + with self._lock: + existing = self._find_active_task_unlocked(req.worker_id or "", req.account_name or "") + if existing: + raise ValueError( + "账号 '%s' 已有执行中的任务(task_id=%s, status=%s)" + % ( + req.account_name or "", + existing.task_id, + existing.status.value if hasattr(existing.status, "value") else str(existing.status), + ) + ) + task = TaskInfo( + task_type=req.task_type, + worker_id=req.worker_id, + account_name=req.account_name, + params=req.params, + ) + self._tasks[task.task_id] = task + logger.info( + "任务创建: %s type=%s worker=%s account=%s", + task.task_id, task.task_type, task.worker_id, task.account_name, + ) + return task # ─── 派发 ─── @@ -53,49 +86,59 @@ class TaskDispatcher: ) try: await ws_send(msg) - task.status = TaskStatus.DISPATCHED - task.updated_at = time.time() + with self._lock: + task.status = TaskStatus.DISPATCHED + task.updated_at = time.time() logger.info("任务 %s 已派发", task.task_id) return True except Exception as e: - task.status = TaskStatus.FAILED - task.error = f"派发失败: {e}" - task.updated_at = time.time() + with self._lock: + task.status = TaskStatus.FAILED + task.error = f"派发失败: {e}" + task.updated_at = time.time() logger.error("任务 %s 派发失败: %s", task.task_id, e) return False # ─── 更新状态 ─── def update_progress(self, task_id: str, progress: str) -> None: - task = self._tasks.get(task_id) - if task: - task.status = TaskStatus.RUNNING - task.progress = progress - task.updated_at = time.time() + with self._lock: + task = self._tasks.get(task_id) + if task: + task.status = TaskStatus.RUNNING + task.progress = progress + task.updated_at = time.time() def complete_task(self, task_id: str, result=None, error: str = None) -> None: - task = self._tasks.get(task_id) - if not task: - return - if error: - task.status = TaskStatus.FAILED - task.error = error - else: - task.status = TaskStatus.SUCCESS - task.result = result - task.updated_at = time.time() - logger.info("任务 %s 完成: status=%s", task_id, task.status) + with self._lock: + task = self._tasks.get(task_id) + if not task: + return + if error: + task.status = TaskStatus.FAILED + task.error = error + else: + task.status = TaskStatus.SUCCESS + task.result = result + task.updated_at = time.time() + logger.info("任务 %s 完成: status=%s", task_id, task.status) def cancel_task(self, task_id: str) -> None: - task = self._tasks.get(task_id) - if task and task.status in (TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING): - task.status = TaskStatus.CANCELLED - task.updated_at = time.time() + with self._lock: + task = self._tasks.get(task_id) + if task and task.status in (TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING): + task.status = TaskStatus.CANCELLED + task.updated_at = time.time() # ─── 查询 ─── def get_task(self, task_id: str) -> Optional[TaskInfo]: - return self._tasks.get(task_id) + with self._lock: + return self._tasks.get(task_id) + + def get_active_task_by_account(self, worker_id: str, account_name: str) -> Optional[TaskInfo]: + with self._lock: + return self._find_active_task_unlocked(worker_id, account_name) def list_tasks( self, @@ -103,14 +146,15 @@ class TaskDispatcher: status: Optional[TaskStatus] = None, limit: int = 50, ) -> List[TaskInfo]: - result = list(self._tasks.values()) - if worker_id: - result = [t for t in result if t.worker_id == worker_id] - if status: - result = [t for t in result if t.status == status] - # 按创建时间倒序 - result.sort(key=lambda t: t.created_at, reverse=True) - return result[:limit] + with self._lock: + result = list(self._tasks.values()) + if worker_id: + result = [t for t in result if t.worker_id == worker_id] + if status: + result = [t for t in result if t.status == status] + # 按创建时间倒序 + result.sort(key=lambda t: t.created_at, reverse=True) + return result[:limit] # 全局单例