ha'ha
This commit is contained in:
2
API文档.md
2
API文档.md
@@ -360,6 +360,7 @@ Set-Cookie: auth_token=a1b2c3d4e5f6...; HttpOnly; Max-Age=31536000; SameSite=Lax
|
||||
| 201 | 成功,任务已创建并派发 |
|
||||
| 400 | 未指定 id、boss_id、worker_id 或 account_name |
|
||||
| 401 | 未登录或 token 失效 |
|
||||
| 409 | 同一账号已有执行中的任务 |
|
||||
| 404 | 未找到拥有该浏览器环境的在线 Worker |
|
||||
| 503 | Worker 不在线 / WebSocket 连接不存在 / 派发失败 |
|
||||
|
||||
@@ -419,6 +420,7 @@ Set-Cookie: auth_token=a1b2c3d4e5f6...; HttpOnly; Max-Age=31536000; SameSite=Lax
|
||||
|--------|------|
|
||||
| 400 | 未指定 id、boss_id、worker_id 或 account_name |
|
||||
| 401 | 未登录或 token 失效 |
|
||||
| 409 | 同一账号已有执行中的任务 |
|
||||
| 404 | 未找到拥有该浏览器环境的在线 Worker |
|
||||
| 503 | Worker 不在线 / WebSocket 连接不存在 / 派发失败 |
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from asgiref.sync import async_to_sync
|
||||
from rest_framework import status
|
||||
from rest_framework.decorators import api_view
|
||||
|
||||
from common.protocol import TaskType
|
||||
from common.protocol import TaskType, TaskStatus
|
||||
from server.core.response import api_success, api_error
|
||||
from server.models import BossAccount, TaskCreate
|
||||
from server.serializers import BossAccountSerializer, AccountBindSerializer
|
||||
@@ -72,7 +72,11 @@ def account_list(request):
|
||||
send_fn = worker_manager.get_send_fn(wid)
|
||||
if send_fn:
|
||||
req = TaskCreate(task_type=TaskType.CHECK_LOGIN, worker_id=wid, account_name=bname, params={})
|
||||
task = task_dispatcher.create_task(req)
|
||||
try:
|
||||
task = task_dispatcher.create_task(req)
|
||||
except ValueError as e:
|
||||
logger.info("绑定账号后自动触发 check_login 跳过: %s@%s, reason=%s", bname, wid, e)
|
||||
return api_success(_enrich(account), http_status=status.HTTP_201_CREATED)
|
||||
if async_to_sync(task_dispatcher.dispatch)(task, send_fn):
|
||||
worker_manager.set_current_task(wid, task.task_id)
|
||||
account.current_task_id = task.task_id
|
||||
@@ -112,7 +116,12 @@ def fill_boss_ids(request):
|
||||
account_name=bname,
|
||||
params={},
|
||||
)
|
||||
task = task_dispatcher.create_task(req)
|
||||
try:
|
||||
task = task_dispatcher.create_task(req)
|
||||
except ValueError as e:
|
||||
skipped += 1
|
||||
errors.append(f"{bname}@{wid}: {e}")
|
||||
continue
|
||||
success = async_to_sync(task_dispatcher.dispatch)(task, send_fn)
|
||||
if success:
|
||||
worker_manager.set_current_task(wid, task.task_id)
|
||||
|
||||
@@ -265,7 +265,10 @@ def task_list(request):
|
||||
return api_error(http_status.HTTP_503_SERVICE_UNAVAILABLE, f"Worker {target_worker_id} 不在线")
|
||||
|
||||
req.worker_id = target_worker_id
|
||||
task = task_dispatcher.create_task(req)
|
||||
try:
|
||||
task = task_dispatcher.create_task(req)
|
||||
except ValueError as e:
|
||||
return api_error(http_status.HTTP_409_CONFLICT, str(e))
|
||||
|
||||
send_fn = worker_manager.get_send_fn(target_worker_id)
|
||||
if not send_fn:
|
||||
|
||||
@@ -7,6 +7,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from common.protocol import MsgType, TaskStatus, make_msg
|
||||
@@ -17,24 +18,56 @@ logger = logging.getLogger("server.task_dispatcher")
|
||||
|
||||
class TaskDispatcher:
|
||||
"""管理任务生命周期并派发给 Worker。"""
|
||||
_ACTIVE_STATUSES = {TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING}
|
||||
|
||||
def __init__(self) -> None:
|
||||
# task_id → TaskInfo
|
||||
self._tasks: Dict[str, TaskInfo] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
@staticmethod
|
||||
def _normalize_account_name(account_name: Optional[str]) -> str:
|
||||
return (account_name or "").strip()
|
||||
|
||||
def _find_active_task_unlocked(self, worker_id: str, account_name: str) -> Optional[TaskInfo]:
|
||||
normalized = self._normalize_account_name(account_name)
|
||||
if not worker_id or not normalized:
|
||||
return None
|
||||
for task in self._tasks.values():
|
||||
if task.worker_id != worker_id:
|
||||
continue
|
||||
if self._normalize_account_name(task.account_name) != normalized:
|
||||
continue
|
||||
if task.status in self._ACTIVE_STATUSES:
|
||||
return task
|
||||
return None
|
||||
|
||||
# ─── 创建任务 ───
|
||||
|
||||
def create_task(self, req: TaskCreate) -> TaskInfo:
|
||||
task = TaskInfo(
|
||||
task_type=req.task_type,
|
||||
worker_id=req.worker_id,
|
||||
account_name=req.account_name,
|
||||
params=req.params,
|
||||
)
|
||||
self._tasks[task.task_id] = task
|
||||
logger.info("任务创建: %s type=%s worker=%s account=%s",
|
||||
task.task_id, task.task_type, task.worker_id, task.account_name)
|
||||
return task
|
||||
with self._lock:
|
||||
existing = self._find_active_task_unlocked(req.worker_id or "", req.account_name or "")
|
||||
if existing:
|
||||
raise ValueError(
|
||||
"账号 '%s' 已有执行中的任务(task_id=%s, status=%s)"
|
||||
% (
|
||||
req.account_name or "",
|
||||
existing.task_id,
|
||||
existing.status.value if hasattr(existing.status, "value") else str(existing.status),
|
||||
)
|
||||
)
|
||||
task = TaskInfo(
|
||||
task_type=req.task_type,
|
||||
worker_id=req.worker_id,
|
||||
account_name=req.account_name,
|
||||
params=req.params,
|
||||
)
|
||||
self._tasks[task.task_id] = task
|
||||
logger.info(
|
||||
"任务创建: %s type=%s worker=%s account=%s",
|
||||
task.task_id, task.task_type, task.worker_id, task.account_name,
|
||||
)
|
||||
return task
|
||||
|
||||
# ─── 派发 ───
|
||||
|
||||
@@ -53,49 +86,59 @@ class TaskDispatcher:
|
||||
)
|
||||
try:
|
||||
await ws_send(msg)
|
||||
task.status = TaskStatus.DISPATCHED
|
||||
task.updated_at = time.time()
|
||||
with self._lock:
|
||||
task.status = TaskStatus.DISPATCHED
|
||||
task.updated_at = time.time()
|
||||
logger.info("任务 %s 已派发", task.task_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error = f"派发失败: {e}"
|
||||
task.updated_at = time.time()
|
||||
with self._lock:
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error = f"派发失败: {e}"
|
||||
task.updated_at = time.time()
|
||||
logger.error("任务 %s 派发失败: %s", task.task_id, e)
|
||||
return False
|
||||
|
||||
# ─── 更新状态 ───
|
||||
|
||||
def update_progress(self, task_id: str, progress: str) -> None:
|
||||
task = self._tasks.get(task_id)
|
||||
if task:
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.progress = progress
|
||||
task.updated_at = time.time()
|
||||
with self._lock:
|
||||
task = self._tasks.get(task_id)
|
||||
if task:
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.progress = progress
|
||||
task.updated_at = time.time()
|
||||
|
||||
def complete_task(self, task_id: str, result=None, error: str = None) -> None:
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
return
|
||||
if error:
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error = error
|
||||
else:
|
||||
task.status = TaskStatus.SUCCESS
|
||||
task.result = result
|
||||
task.updated_at = time.time()
|
||||
logger.info("任务 %s 完成: status=%s", task_id, task.status)
|
||||
with self._lock:
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
return
|
||||
if error:
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error = error
|
||||
else:
|
||||
task.status = TaskStatus.SUCCESS
|
||||
task.result = result
|
||||
task.updated_at = time.time()
|
||||
logger.info("任务 %s 完成: status=%s", task_id, task.status)
|
||||
|
||||
def cancel_task(self, task_id: str) -> None:
|
||||
task = self._tasks.get(task_id)
|
||||
if task and task.status in (TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING):
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.updated_at = time.time()
|
||||
with self._lock:
|
||||
task = self._tasks.get(task_id)
|
||||
if task and task.status in (TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING):
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.updated_at = time.time()
|
||||
|
||||
# ─── 查询 ───
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[TaskInfo]:
|
||||
return self._tasks.get(task_id)
|
||||
with self._lock:
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
def get_active_task_by_account(self, worker_id: str, account_name: str) -> Optional[TaskInfo]:
|
||||
with self._lock:
|
||||
return self._find_active_task_unlocked(worker_id, account_name)
|
||||
|
||||
def list_tasks(
|
||||
self,
|
||||
@@ -103,14 +146,15 @@ class TaskDispatcher:
|
||||
status: Optional[TaskStatus] = None,
|
||||
limit: int = 50,
|
||||
) -> List[TaskInfo]:
|
||||
result = list(self._tasks.values())
|
||||
if worker_id:
|
||||
result = [t for t in result if t.worker_id == worker_id]
|
||||
if status:
|
||||
result = [t for t in result if t.status == status]
|
||||
# 按创建时间倒序
|
||||
result.sort(key=lambda t: t.created_at, reverse=True)
|
||||
return result[:limit]
|
||||
with self._lock:
|
||||
result = list(self._tasks.values())
|
||||
if worker_id:
|
||||
result = [t for t in result if t.worker_id == worker_id]
|
||||
if status:
|
||||
result = [t for t in result if t.status == status]
|
||||
# 按创建时间倒序
|
||||
result.sort(key=lambda t: t.created_at, reverse=True)
|
||||
return result[:limit]
|
||||
|
||||
|
||||
# 全局单例
|
||||
|
||||
Reference in New Issue
Block a user