This commit is contained in:
27942
2026-03-03 02:13:33 +08:00
parent 5526f2d048
commit 3788197358
4 changed files with 270 additions and 151 deletions

View File

@@ -14,7 +14,7 @@ from rest_framework.decorators import api_view
from common.protocol import TaskStatus, TaskType
from server.core.response import api_success, api_error
from server.models import BossAccount, TaskCreate, TaskLog
from server.models import BossAccount, TaskCreate, TaskLog, Task
from server.serializers import TaskCreateSerializer
from server.core.worker_manager import worker_manager
from server.core.task_dispatcher import task_dispatcher
@@ -29,19 +29,27 @@ def _format_timestamp(ts: float) -> str:
def _task_to_dict(t) -> dict:
""" TaskInfo 转为可序列化字典。"""
"""任务实例Task ORM转为可序列化字典。"""
return {
"task_id": t.task_id,
"task_type": t.task_type.value if hasattr(t.task_type, "value") else str(t.task_type),
"status": t.status.value if hasattr(t.status, "value") else str(t.status),
"task_type": str(t.task_type),
"status": str(t.status),
"worker_id": t.worker_id,
"account_name": t.account_name,
"params": t.params,
"progress": t.progress,
"account_name": getattr(t, "account_name", None),
"params": t.params or {},
"progress": getattr(t, "progress", None),
"result": t.result,
"error": t.error,
"created_at": _format_timestamp(t.created_at),
"updated_at": _format_timestamp(t.updated_at),
"created_at": (
t.created_at.strftime("%Y-%m-%dT%H:%M:%S")
if hasattr(t.created_at, "strftime")
else _format_timestamp(t.created_at)
),
"updated_at": (
t.updated_at.strftime("%Y-%m-%dT%H:%M:%S")
if hasattr(t.updated_at, "strftime")
else _format_timestamp(t.updated_at)
),
}
@@ -135,38 +143,20 @@ def _is_task_log_for_account(task_log: TaskLog, account: BossAccount) -> bool:
def _list_tasks_by_account(account: BossAccount, task_status: Optional[TaskStatus], limit: Optional[int] = 50) -> list:
"""
聚合某账号的任务列表:
1) 内存任务(实时)
2) TaskLog 历史任务(重启后可查)
按账号维度查询任务列表,完全基于 Task 表。
- 不再依赖内存中的 TaskInfo
- 你可以直接修改数据库中 Task.status/Task.result 等字段来影响这里的返回。
"""
items_by_task_id = {}
# limit=None 用于分页场景下先取全量,再切片;此处给内存任务查询一个足够大的上限
memory_limit = 10000 if limit is None else max(limit * 3, 100)
memory_tasks = task_dispatcher.list_tasks(worker_id=account.worker_id, status=task_status, limit=memory_limit)
for t in memory_tasks:
if t.account_name != account.browser_name:
continue
items_by_task_id[t.task_id] = _task_to_dict(t)
db_qs = TaskLog.objects.filter(worker_id=account.worker_id).order_by("-created_at")
qs = Task.objects.filter(
worker_id=account.worker_id,
account_name=account.browser_name,
).order_by("-created_at")
if task_status:
db_qs = db_qs.filter(status=task_status.value)
qs = qs.filter(status=task_status.value if hasattr(task_status, "value") else str(task_status))
# 多取一些做过滤,避免因为条件匹配损耗导致结果太少
db_logs = db_qs if limit is None else db_qs[: max(limit * 8, 200)]
for task_log in db_logs:
if not _is_task_log_for_account(task_log, account):
continue
if task_log.task_id in items_by_task_id:
continue
items_by_task_id[task_log.task_id] = _task_log_to_dict(task_log, account_name=account.browser_name)
merged = list(items_by_task_id.values())
merged.sort(key=lambda item: item.get("created_at") or "", reverse=True)
if limit is None:
return merged
return merged[:limit]
if limit is not None:
qs = qs[:limit]
return [_task_to_dict(t) for t in qs]
@api_view(["GET", "POST"])
@@ -200,8 +190,13 @@ def task_list(request):
)
return api_success(_list_tasks_by_account(account, task_status=task_status, limit=limit))
tasks = task_dispatcher.list_tasks(worker_id=wid, status=task_status, limit=limit)
return api_success([_task_to_dict(t) for t in tasks])
qs = Task.objects.all().order_by("-created_at")
if wid:
qs = qs.filter(worker_id=wid)
if task_status:
qs = qs.filter(status=task_status.value if hasattr(task_status, "value") else str(task_status))
qs = qs[:limit]
return api_success([_task_to_dict(t) for t in qs])
# POST: 提交新任务
data = request.data.copy()

View File

@@ -7,40 +7,46 @@ from __future__ import annotations
import logging
import time
import threading
from typing import Dict, List, Optional
from typing import List, Optional
from common.protocol import MsgType, TaskStatus, make_msg
from server.models import TaskCreate, TaskInfo
from server.models import TaskCreate, Task
logger = logging.getLogger("server.task_dispatcher")
class TaskDispatcher:
"""管理任务生命周期并派发给 Worker。"""
_ACTIVE_STATUSES = {TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING}
"""
管理任务生命周期并派发给 Worker以数据库为唯一真相
def __init__(self) -> None:
# task_id → TaskInfo
self._tasks: Dict[str, TaskInfo] = {}
self._lock = threading.RLock()
注意:
- 不再在内存中长期保存任务状态;
- 所有查询 / 状态变更均落在 ORM 模型 Task 上;
- 仅在单次请求/派发过程中临时构造 Python 对象。
"""
_ACTIVE_STATUSES = {TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING}
@staticmethod
def _normalize_account_name(account_name: Optional[str]) -> str:
return (account_name or "").strip()
def _find_active_task_unlocked(self, worker_id: str, account_name: str) -> Optional[TaskInfo]:
@staticmethod
def _find_active_task(worker_id: str, account_name: str) -> Optional[Task]:
normalized = self._normalize_account_name(account_name)
if not worker_id or not normalized:
return None
for task in self._tasks.values():
if task.worker_id != worker_id:
continue
if self._normalize_account_name(task.account_name) != normalized:
continue
if task.status in self._ACTIVE_STATUSES:
return task
return None
from django.db.models import Q
active_values = [s.value for s in TaskDispatcher._ACTIVE_STATUSES]
return (
Task.objects.filter(
Q(worker_id=worker_id),
Q(account_name=normalized),
Q(status__in=active_values),
)
.order_by("-created_at")
.first()
)
@staticmethod
def _infer_failure_reason(result, error) -> Optional[str]:
@@ -70,34 +76,48 @@ class TaskDispatcher:
# ─── 创建任务 ───
def create_task(self, req: TaskCreate) -> TaskInfo:
with self._lock:
existing = self._find_active_task_unlocked(req.worker_id or "", req.account_name or "")
if existing:
raise ValueError(
"账号 '%s' 已有执行中的任务task_id=%s, status=%s"
% (
req.account_name or "",
existing.task_id,
existing.status.value if hasattr(existing.status, "value") else str(existing.status),
)
def create_task(self, req: TaskCreate) -> Task:
"""
创建任务记录:
- 校验同一 worker + account 是否已有执行中的任务(查表);
- 在 Task 表中插入一条新记录;
- 返回 ORM 实例,供后续派发使用。
"""
from django.utils import timezone as tz
import uuid
existing = self._find_active_task(req.worker_id or "", req.account_name or "")
if existing:
raise ValueError(
"账号 '%s' 已有执行中的任务task_id=%s, status=%s"
% (
req.account_name or "",
existing.task_id,
existing.status,
)
task = TaskInfo(
task_type=req.task_type,
worker_id=req.worker_id,
account_name=req.account_name,
params=req.params,
)
self._tasks[task.task_id] = task
logger.info(
"任务创建: %s type=%s worker=%s account=%s",
task.task_id, task.task_type, task.worker_id, task.account_name,
)
return task
task_id = uuid.uuid4().hex[:12]
now = tz.now()
task = Task.objects.create(
task_id=task_id,
task_type=req.task_type.value if hasattr(req.task_type, "value") else str(req.task_type),
worker_id=req.worker_id or "",
account_name=req.account_name or "",
status=TaskStatus.PENDING.value,
params=req.params,
created_at=now,
updated_at=now,
)
logger.info(
"任务创建: %s type=%s worker=%s account=%s",
task.task_id, task.task_type, task.worker_id, task.account_name,
)
return task
# ─── 派发 ───
async def dispatch(self, task: TaskInfo, ws_send) -> bool:
async def dispatch(self, task: Task, ws_send) -> bool:
"""
将任务通过 WebSocket 发给目标 Worker。
ws_send: 异步可调用,接受一个 dict 参数。
@@ -106,84 +126,104 @@ class TaskDispatcher:
msg = make_msg(
MsgType.TASK_ASSIGN,
task_id=task.task_id,
task_type=task.task_type.value,
task_type=task.task_type,
account_name=task.account_name or "",
params=task.params,
params=task.params or {},
)
try:
await ws_send(msg)
with self._lock:
task.status = TaskStatus.DISPATCHED
task.updated_at = time.time()
from django.utils import timezone as tz
task.status = TaskStatus.DISPATCHED.value
task.updated_at = tz.now()
task.save(update_fields=["status", "updated_at"])
logger.info("任务 %s 已派发", task.task_id)
return True
except Exception as e:
with self._lock:
task.status = TaskStatus.FAILED
task.error = f"派发失败: {e}"
task.updated_at = time.time()
from django.utils import timezone as tz
task.status = TaskStatus.FAILED.value
task.error = f"派发失败: {e}"
task.updated_at = tz.now()
task.save(update_fields=["status", "error", "updated_at"])
logger.error("任务 %s 派发失败: %s", task.task_id, e)
return False
# ─── 更新状态 ───
def update_progress(self, task_id: str, progress: str) -> None:
with self._lock:
task = self._tasks.get(task_id)
if task:
task.status = TaskStatus.RUNNING
task.progress = progress
task.updated_at = time.time()
"""
更新任务进度和状态RUNNING
"""
from django.utils import timezone as tz
updated = Task.objects.filter(task_id=task_id).update(
status=TaskStatus.RUNNING.value,
progress=progress,
updated_at=tz.now(),
)
if updated:
logger.info("任务 %s 进度更新: %s", task_id, progress)
def complete_task(self, task_id: str, result=None, error: str = None) -> None:
with self._lock:
task = self._tasks.get(task_id)
if not task:
return
failure_reason = self._infer_failure_reason(result, error)
if failure_reason:
task.status = TaskStatus.FAILED
task.error = failure_reason
task.result = result
else:
task.status = TaskStatus.SUCCESS
task.result = result
task.error = None
task.updated_at = time.time()
logger.info("任务 %s 完成: status=%s", task_id, task.status)
"""
完成任务:根据 result/error 推断最终状态并写入数据库。
"""
from django.utils import timezone as tz
task = Task.objects.filter(task_id=task_id).first()
if not task:
return
failure_reason = self._infer_failure_reason(result, error)
if failure_reason:
task.status = TaskStatus.FAILED.value
task.error = failure_reason
task.result = result
else:
task.status = TaskStatus.SUCCESS.value
task.result = result
task.error = None
task.updated_at = tz.now()
task.save(update_fields=["status", "result", "error", "updated_at"])
logger.info("任务 %s 完成: status=%s", task_id, task.status)
def cancel_task(self, task_id: str) -> None:
with self._lock:
task = self._tasks.get(task_id)
if task and task.status in (TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING):
task.status = TaskStatus.CANCELLED
task.updated_at = time.time()
"""
取消任务:仅对活动状态的任务生效。
"""
from django.utils import timezone as tz
active_values = [s.value for s in self._ACTIVE_STATUSES]
task = Task.objects.filter(task_id=task_id, status__in=active_values).first()
if not task:
return
task.status = TaskStatus.CANCELLED.value
task.updated_at = tz.now()
task.save(update_fields=["status", "updated_at"])
# ─── 查询 ───
def get_task(self, task_id: str) -> Optional[TaskInfo]:
with self._lock:
return self._tasks.get(task_id)
def get_task(self, task_id: str) -> Optional[Task]:
return Task.objects.filter(task_id=task_id).first()
def get_active_task_by_account(self, worker_id: str, account_name: str) -> Optional[TaskInfo]:
with self._lock:
return self._find_active_task_unlocked(worker_id, account_name)
def get_active_task_by_account(self, worker_id: str, account_name: str) -> Optional[Task]:
return self._find_active_task(worker_id, account_name)
def list_tasks(
self,
worker_id: Optional[str] = None,
status: Optional[TaskStatus] = None,
limit: int = 50,
) -> List[TaskInfo]:
with self._lock:
result = list(self._tasks.values())
if worker_id:
result = [t for t in result if t.worker_id == worker_id]
if status:
result = [t for t in result if t.status == status]
# 按创建时间倒序
result.sort(key=lambda t: t.created_at, reverse=True)
return result[:limit]
) -> List[Task]:
"""
从数据库中查询任务列表。
"""
qs = Task.objects.all().order_by("-created_at")
if worker_id:
qs = qs.filter(worker_id=worker_id)
if status:
qs = qs.filter(status=status.value if hasattr(status, "value") else str(status))
return list(qs[:limit])
# 全局单例

View File

@@ -62,6 +62,34 @@ class TaskLog(models.Model):
return f"{self.task_id} ({self.task_type})"
class Task(models.Model):
"""
任务表(数据库为唯一真相)。
- 所有任务的生命周期状态均保存在此表中;
- 内存中不再长期保存任务状态,只作为必要的临时变量。
"""
task_id = models.CharField(max_length=32, unique=True, verbose_name="任务 ID")
task_type = models.CharField(max_length=64, verbose_name="任务类型")
worker_id = models.CharField(max_length=64, default="", verbose_name="执行的 Worker")
account_name = models.CharField(max_length=128, default="", blank=True, verbose_name="账号(环境名称)")
status = models.CharField(max_length=32, default="", verbose_name="当前状态")
params = models.JSONField(null=True, blank=True, verbose_name="任务参数")
progress = models.TextField(null=True, blank=True, verbose_name="进度信息")
result = models.JSONField(null=True, blank=True, verbose_name="任务结果")
error = models.TextField(null=True, blank=True, verbose_name="错误信息")
created_at = models.DateTimeField(auto_now_add=True, verbose_name="创建时间")
updated_at = models.DateTimeField(auto_now=True, verbose_name="更新时间")
class Meta:
db_table = "task"
verbose_name = "任务"
verbose_name_plural = verbose_name
def __str__(self):
return f"{self.task_id} ({self.task_type})"
class AuthToken(models.Model):
"""登录 token 表:每个用户名仅保留当前有效 token。"""
username = models.CharField(max_length=64, unique=True, verbose_name="用户名")

View File

@@ -11,7 +11,6 @@ from channels.generic.websocket import AsyncWebsocketConsumer
from common.protocol import MsgType, TaskStatus, TaskType, make_msg
from server.core.worker_manager import worker_manager
from server.core.task_dispatcher import task_dispatcher
logger = logging.getLogger("server.ws")
@@ -102,8 +101,11 @@ class WorkerConsumer(AsyncWebsocketConsumer):
elif msg_type == MsgType.TASK_PROGRESS.value:
task_id = data.get("task_id", "")
progress = data.get("progress", "")
task_dispatcher.update_progress(task_id, progress)
logger.info("任务 %s 进度: %s", task_id, progress)
try:
await self._update_task_progress(task_id, progress)
logger.info("任务 %s 进度: %s", task_id, progress)
except Exception as e:
logger.error("更新任务进度失败 (task_id=%s): %s", task_id, e)
# 同步更新账号任务状态为 running
try:
await self._update_account_task_status(task_id, TaskStatus.RUNNING.value)
@@ -114,22 +116,16 @@ class WorkerConsumer(AsyncWebsocketConsumer):
task_id = data.get("task_id", "")
result = data.get("result")
error = data.get("error")
task_dispatcher.complete_task(task_id, result=result, error=error)
worker_manager.set_current_task(self.worker_id, None)
logger.info("任务 %s 已完成", task_id)
# ── 将结果写入数据库 ──
try:
task_info = task_dispatcher.get_task(task_id)
if task_info:
final_status = task_info.status.value if hasattr(task_info.status, "value") else str(task_info.status)
final_error = task_info.error if getattr(task_info, "error", None) else error
await self._save_task_log(task_id, task_info, result, final_error, final_status)
final_status = await self._finalize_task_and_log(task_id, result, error)
if final_status:
await self._update_account_task_status(task_id, final_status)
logger.info("任务 %s 最终状态已更新: %s", task_id, final_status)
# check_login 任务:更新账号登录状态
task_type_val = task_info.task_type.value if hasattr(task_info.task_type, "value") else str(task_info.task_type)
if task_type_val == TaskType.CHECK_LOGIN.value and result and final_status == TaskStatus.SUCCESS.value:
# check_login 任务:更新账号登录状态(仅在成功时)
if result and final_status == TaskStatus.SUCCESS.value:
await self._upsert_account_status(result)
except Exception as db_err:
logger.error("任务 %s 写入数据库失败: %s", task_id, db_err)
@@ -141,19 +137,79 @@ class WorkerConsumer(AsyncWebsocketConsumer):
@staticmethod
@sync_to_async
def _save_task_log(task_id, task_info, result, error, final_status):
from server.models import TaskLog
def _update_task_progress(task_id, progress):
"""将进度写入 Task 表。"""
from django.utils import timezone as tz
from server.models import Task
Task.objects.filter(task_id=task_id).update(
status=TaskStatus.RUNNING.value,
progress=progress,
updated_at=tz.now(),
)
@staticmethod
def _infer_failure_reason(result, error):
"""
失败原因推断逻辑,与 TaskDispatcher 内保持一致(此处复制一份,避免循环依赖)。
"""
if error is not None:
msg = str(error).strip()
return msg or "任务执行失败"
if isinstance(result, dict):
if result.get("success") is False:
msg = str(result.get("error", "")).strip()
return msg or "任务执行失败"
if result.get("ok") is False:
msg = str(result.get("error", "")).strip()
return msg or "任务执行失败"
status_text = str(result.get("status", "")).strip().lower()
if status_text in {"failed", "error", "fail"}:
msg = str(result.get("error", "")).strip()
return msg or "任务执行失败"
return None
@sync_to_async
def _finalize_task_and_log(self, task_id, result, error):
"""
完成任务并写入 Task / TaskLog。
返回最终状态字符串(如 "success" / "failed"),供后续使用。
"""
from django.utils import timezone as tz
from server.models import Task, TaskLog
task = Task.objects.filter(task_id=task_id).first()
if not task:
return None
failure_reason = self._infer_failure_reason(result, error)
if failure_reason:
final_status = TaskStatus.FAILED.value
final_error = failure_reason
else:
final_status = TaskStatus.SUCCESS.value
final_error = None
task.status = final_status
task.result = result
task.error = final_error
task.updated_at = tz.now()
task.save(update_fields=["status", "result", "error", "updated_at"])
TaskLog.objects.update_or_create(
task_id=task_id,
defaults={
"task_type": task_info.task_type.value if hasattr(task_info.task_type, "value") else str(task_info.task_type),
"worker_id": task_info.worker_id or "",
"task_type": task.task_type,
"worker_id": task.worker_id or "",
"status": final_status,
"params": task_info.params,
"params": task.params,
"result": result,
"error": error,
"error": final_error,
},
)
return final_status
@staticmethod
@sync_to_async