diff --git a/server/api/tasks.py b/server/api/tasks.py index ae0cf95..485080d 100644 --- a/server/api/tasks.py +++ b/server/api/tasks.py @@ -14,7 +14,7 @@ from rest_framework.decorators import api_view from common.protocol import TaskStatus, TaskType from server.core.response import api_success, api_error -from server.models import BossAccount, TaskCreate, TaskLog +from server.models import BossAccount, TaskCreate, TaskLog, Task from server.serializers import TaskCreateSerializer from server.core.worker_manager import worker_manager from server.core.task_dispatcher import task_dispatcher @@ -29,19 +29,27 @@ def _format_timestamp(ts: float) -> str: def _task_to_dict(t) -> dict: - """将 TaskInfo 转为可序列化字典。""" + """将任务实例(Task ORM)转为可序列化字典。""" return { "task_id": t.task_id, - "task_type": t.task_type.value if hasattr(t.task_type, "value") else str(t.task_type), - "status": t.status.value if hasattr(t.status, "value") else str(t.status), + "task_type": str(t.task_type), + "status": str(t.status), "worker_id": t.worker_id, - "account_name": t.account_name, - "params": t.params, - "progress": t.progress, + "account_name": getattr(t, "account_name", None), + "params": t.params or {}, + "progress": getattr(t, "progress", None), "result": t.result, "error": t.error, - "created_at": _format_timestamp(t.created_at), - "updated_at": _format_timestamp(t.updated_at), + "created_at": ( + t.created_at.strftime("%Y-%m-%dT%H:%M:%S") + if hasattr(t.created_at, "strftime") + else _format_timestamp(t.created_at) + ), + "updated_at": ( + t.updated_at.strftime("%Y-%m-%dT%H:%M:%S") + if hasattr(t.updated_at, "strftime") + else _format_timestamp(t.updated_at) + ), } @@ -135,38 +143,20 @@ def _is_task_log_for_account(task_log: TaskLog, account: BossAccount) -> bool: def _list_tasks_by_account(account: BossAccount, task_status: Optional[TaskStatus], limit: Optional[int] = 50) -> list: """ - 聚合某账号的任务列表: - 1) 内存任务(实时) - 2) TaskLog 历史任务(重启后可查) + 按账号维度查询任务列表,完全基于 Task 表。 + - 不再依赖内存中的 TaskInfo; + - 你可以直接修改数据库中 Task.status/Task.result 等字段来影响这里的返回。 """ - items_by_task_id = {} - - # limit=None 用于分页场景下先取全量,再切片;此处给内存任务查询一个足够大的上限 - memory_limit = 10000 if limit is None else max(limit * 3, 100) - memory_tasks = task_dispatcher.list_tasks(worker_id=account.worker_id, status=task_status, limit=memory_limit) - for t in memory_tasks: - if t.account_name != account.browser_name: - continue - items_by_task_id[t.task_id] = _task_to_dict(t) - - db_qs = TaskLog.objects.filter(worker_id=account.worker_id).order_by("-created_at") + qs = Task.objects.filter( + worker_id=account.worker_id, + account_name=account.browser_name, + ).order_by("-created_at") if task_status: - db_qs = db_qs.filter(status=task_status.value) + qs = qs.filter(status=task_status.value if hasattr(task_status, "value") else str(task_status)) - # 多取一些做过滤,避免因为条件匹配损耗导致结果太少 - db_logs = db_qs if limit is None else db_qs[: max(limit * 8, 200)] - for task_log in db_logs: - if not _is_task_log_for_account(task_log, account): - continue - if task_log.task_id in items_by_task_id: - continue - items_by_task_id[task_log.task_id] = _task_log_to_dict(task_log, account_name=account.browser_name) - - merged = list(items_by_task_id.values()) - merged.sort(key=lambda item: item.get("created_at") or "", reverse=True) - if limit is None: - return merged - return merged[:limit] + if limit is not None: + qs = qs[:limit] + return [_task_to_dict(t) for t in qs] @api_view(["GET", "POST"]) @@ -200,8 +190,13 @@ def task_list(request): ) return api_success(_list_tasks_by_account(account, task_status=task_status, limit=limit)) - tasks = task_dispatcher.list_tasks(worker_id=wid, status=task_status, limit=limit) - return api_success([_task_to_dict(t) for t in tasks]) + qs = Task.objects.all().order_by("-created_at") + if wid: + qs = qs.filter(worker_id=wid) + if task_status: + qs = qs.filter(status=task_status.value if hasattr(task_status, "value") else str(task_status)) + qs = qs[:limit] + return api_success([_task_to_dict(t) for t in qs]) # POST: 提交新任务 data = request.data.copy() diff --git a/server/core/task_dispatcher.py b/server/core/task_dispatcher.py index 8524900..95f29ba 100644 --- a/server/core/task_dispatcher.py +++ b/server/core/task_dispatcher.py @@ -7,40 +7,46 @@ from __future__ import annotations import logging import time -import threading -from typing import Dict, List, Optional +from typing import List, Optional from common.protocol import MsgType, TaskStatus, make_msg -from server.models import TaskCreate, TaskInfo +from server.models import TaskCreate, Task logger = logging.getLogger("server.task_dispatcher") class TaskDispatcher: - """管理任务生命周期并派发给 Worker。""" - _ACTIVE_STATUSES = {TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING} + """ + 管理任务生命周期并派发给 Worker(以数据库为唯一真相)。 - def __init__(self) -> None: - # task_id → TaskInfo - self._tasks: Dict[str, TaskInfo] = {} - self._lock = threading.RLock() + 注意: + - 不再在内存中长期保存任务状态; + - 所有查询 / 状态变更均落在 ORM 模型 Task 上; + - 仅在单次请求/派发过程中临时构造 Python 对象。 + """ + _ACTIVE_STATUSES = {TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING} @staticmethod def _normalize_account_name(account_name: Optional[str]) -> str: return (account_name or "").strip() - def _find_active_task_unlocked(self, worker_id: str, account_name: str) -> Optional[TaskInfo]: + @staticmethod + def _find_active_task(worker_id: str, account_name: str) -> Optional[Task]: normalized = self._normalize_account_name(account_name) if not worker_id or not normalized: return None - for task in self._tasks.values(): - if task.worker_id != worker_id: - continue - if self._normalize_account_name(task.account_name) != normalized: - continue - if task.status in self._ACTIVE_STATUSES: - return task - return None + from django.db.models import Q + + active_values = [s.value for s in TaskDispatcher._ACTIVE_STATUSES] + return ( + Task.objects.filter( + Q(worker_id=worker_id), + Q(account_name=normalized), + Q(status__in=active_values), + ) + .order_by("-created_at") + .first() + ) @staticmethod def _infer_failure_reason(result, error) -> Optional[str]: @@ -70,34 +76,48 @@ class TaskDispatcher: # ─── 创建任务 ─── - def create_task(self, req: TaskCreate) -> TaskInfo: - with self._lock: - existing = self._find_active_task_unlocked(req.worker_id or "", req.account_name or "") - if existing: - raise ValueError( - "账号 '%s' 已有执行中的任务(task_id=%s, status=%s)" - % ( - req.account_name or "", - existing.task_id, - existing.status.value if hasattr(existing.status, "value") else str(existing.status), - ) + def create_task(self, req: TaskCreate) -> Task: + """ + 创建任务记录: + - 校验同一 worker + account 是否已有执行中的任务(查表); + - 在 Task 表中插入一条新记录; + - 返回 ORM 实例,供后续派发使用。 + """ + from django.utils import timezone as tz + import uuid + + existing = self._find_active_task(req.worker_id or "", req.account_name or "") + if existing: + raise ValueError( + "账号 '%s' 已有执行中的任务(task_id=%s, status=%s)" + % ( + req.account_name or "", + existing.task_id, + existing.status, ) - task = TaskInfo( - task_type=req.task_type, - worker_id=req.worker_id, - account_name=req.account_name, - params=req.params, ) - self._tasks[task.task_id] = task - logger.info( - "任务创建: %s type=%s worker=%s account=%s", - task.task_id, task.task_type, task.worker_id, task.account_name, - ) - return task + + task_id = uuid.uuid4().hex[:12] + now = tz.now() + task = Task.objects.create( + task_id=task_id, + task_type=req.task_type.value if hasattr(req.task_type, "value") else str(req.task_type), + worker_id=req.worker_id or "", + account_name=req.account_name or "", + status=TaskStatus.PENDING.value, + params=req.params, + created_at=now, + updated_at=now, + ) + logger.info( + "任务创建: %s type=%s worker=%s account=%s", + task.task_id, task.task_type, task.worker_id, task.account_name, + ) + return task # ─── 派发 ─── - async def dispatch(self, task: TaskInfo, ws_send) -> bool: + async def dispatch(self, task: Task, ws_send) -> bool: """ 将任务通过 WebSocket 发给目标 Worker。 ws_send: 异步可调用,接受一个 dict 参数。 @@ -106,84 +126,104 @@ class TaskDispatcher: msg = make_msg( MsgType.TASK_ASSIGN, task_id=task.task_id, - task_type=task.task_type.value, + task_type=task.task_type, account_name=task.account_name or "", - params=task.params, + params=task.params or {}, ) try: await ws_send(msg) - with self._lock: - task.status = TaskStatus.DISPATCHED - task.updated_at = time.time() + from django.utils import timezone as tz + + task.status = TaskStatus.DISPATCHED.value + task.updated_at = tz.now() + task.save(update_fields=["status", "updated_at"]) logger.info("任务 %s 已派发", task.task_id) return True except Exception as e: - with self._lock: - task.status = TaskStatus.FAILED - task.error = f"派发失败: {e}" - task.updated_at = time.time() + from django.utils import timezone as tz + + task.status = TaskStatus.FAILED.value + task.error = f"派发失败: {e}" + task.updated_at = tz.now() + task.save(update_fields=["status", "error", "updated_at"]) logger.error("任务 %s 派发失败: %s", task.task_id, e) return False # ─── 更新状态 ─── def update_progress(self, task_id: str, progress: str) -> None: - with self._lock: - task = self._tasks.get(task_id) - if task: - task.status = TaskStatus.RUNNING - task.progress = progress - task.updated_at = time.time() + """ + 更新任务进度和状态(RUNNING)。 + """ + from django.utils import timezone as tz + + updated = Task.objects.filter(task_id=task_id).update( + status=TaskStatus.RUNNING.value, + progress=progress, + updated_at=tz.now(), + ) + if updated: + logger.info("任务 %s 进度更新: %s", task_id, progress) def complete_task(self, task_id: str, result=None, error: str = None) -> None: - with self._lock: - task = self._tasks.get(task_id) - if not task: - return - failure_reason = self._infer_failure_reason(result, error) - if failure_reason: - task.status = TaskStatus.FAILED - task.error = failure_reason - task.result = result - else: - task.status = TaskStatus.SUCCESS - task.result = result - task.error = None - task.updated_at = time.time() - logger.info("任务 %s 完成: status=%s", task_id, task.status) + """ + 完成任务:根据 result/error 推断最终状态并写入数据库。 + """ + from django.utils import timezone as tz + + task = Task.objects.filter(task_id=task_id).first() + if not task: + return + failure_reason = self._infer_failure_reason(result, error) + if failure_reason: + task.status = TaskStatus.FAILED.value + task.error = failure_reason + task.result = result + else: + task.status = TaskStatus.SUCCESS.value + task.result = result + task.error = None + task.updated_at = tz.now() + task.save(update_fields=["status", "result", "error", "updated_at"]) + logger.info("任务 %s 完成: status=%s", task_id, task.status) def cancel_task(self, task_id: str) -> None: - with self._lock: - task = self._tasks.get(task_id) - if task and task.status in (TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING): - task.status = TaskStatus.CANCELLED - task.updated_at = time.time() + """ + 取消任务:仅对活动状态的任务生效。 + """ + from django.utils import timezone as tz + + active_values = [s.value for s in self._ACTIVE_STATUSES] + task = Task.objects.filter(task_id=task_id, status__in=active_values).first() + if not task: + return + task.status = TaskStatus.CANCELLED.value + task.updated_at = tz.now() + task.save(update_fields=["status", "updated_at"]) # ─── 查询 ─── - def get_task(self, task_id: str) -> Optional[TaskInfo]: - with self._lock: - return self._tasks.get(task_id) + def get_task(self, task_id: str) -> Optional[Task]: + return Task.objects.filter(task_id=task_id).first() - def get_active_task_by_account(self, worker_id: str, account_name: str) -> Optional[TaskInfo]: - with self._lock: - return self._find_active_task_unlocked(worker_id, account_name) + def get_active_task_by_account(self, worker_id: str, account_name: str) -> Optional[Task]: + return self._find_active_task(worker_id, account_name) def list_tasks( self, worker_id: Optional[str] = None, status: Optional[TaskStatus] = None, limit: int = 50, - ) -> List[TaskInfo]: - with self._lock: - result = list(self._tasks.values()) - if worker_id: - result = [t for t in result if t.worker_id == worker_id] - if status: - result = [t for t in result if t.status == status] - # 按创建时间倒序 - result.sort(key=lambda t: t.created_at, reverse=True) - return result[:limit] + ) -> List[Task]: + """ + 从数据库中查询任务列表。 + """ + qs = Task.objects.all().order_by("-created_at") + if worker_id: + qs = qs.filter(worker_id=worker_id) + if status: + qs = qs.filter(status=status.value if hasattr(status, "value") else str(status)) + return list(qs[:limit]) # 全局单例 diff --git a/server/models.py b/server/models.py index 1e63695..077f236 100644 --- a/server/models.py +++ b/server/models.py @@ -62,6 +62,34 @@ class TaskLog(models.Model): return f"{self.task_id} ({self.task_type})" +class Task(models.Model): + """ + 任务表(数据库为唯一真相)。 + - 所有任务的生命周期状态均保存在此表中; + - 内存中不再长期保存任务状态,只作为必要的临时变量。 + """ + + task_id = models.CharField(max_length=32, unique=True, verbose_name="任务 ID") + task_type = models.CharField(max_length=64, verbose_name="任务类型") + worker_id = models.CharField(max_length=64, default="", verbose_name="执行的 Worker") + account_name = models.CharField(max_length=128, default="", blank=True, verbose_name="账号(环境名称)") + status = models.CharField(max_length=32, default="", verbose_name="当前状态") + params = models.JSONField(null=True, blank=True, verbose_name="任务参数") + progress = models.TextField(null=True, blank=True, verbose_name="进度信息") + result = models.JSONField(null=True, blank=True, verbose_name="任务结果") + error = models.TextField(null=True, blank=True, verbose_name="错误信息") + created_at = models.DateTimeField(auto_now_add=True, verbose_name="创建时间") + updated_at = models.DateTimeField(auto_now=True, verbose_name="更新时间") + + class Meta: + db_table = "task" + verbose_name = "任务" + verbose_name_plural = verbose_name + + def __str__(self): + return f"{self.task_id} ({self.task_type})" + + class AuthToken(models.Model): """登录 token 表:每个用户名仅保留当前有效 token。""" username = models.CharField(max_length=64, unique=True, verbose_name="用户名") diff --git a/server/ws/consumers.py b/server/ws/consumers.py index 706d979..078d247 100644 --- a/server/ws/consumers.py +++ b/server/ws/consumers.py @@ -11,7 +11,6 @@ from channels.generic.websocket import AsyncWebsocketConsumer from common.protocol import MsgType, TaskStatus, TaskType, make_msg from server.core.worker_manager import worker_manager -from server.core.task_dispatcher import task_dispatcher logger = logging.getLogger("server.ws") @@ -102,8 +101,11 @@ class WorkerConsumer(AsyncWebsocketConsumer): elif msg_type == MsgType.TASK_PROGRESS.value: task_id = data.get("task_id", "") progress = data.get("progress", "") - task_dispatcher.update_progress(task_id, progress) - logger.info("任务 %s 进度: %s", task_id, progress) + try: + await self._update_task_progress(task_id, progress) + logger.info("任务 %s 进度: %s", task_id, progress) + except Exception as e: + logger.error("更新任务进度失败 (task_id=%s): %s", task_id, e) # 同步更新账号任务状态为 running try: await self._update_account_task_status(task_id, TaskStatus.RUNNING.value) @@ -114,22 +116,16 @@ class WorkerConsumer(AsyncWebsocketConsumer): task_id = data.get("task_id", "") result = data.get("result") error = data.get("error") - task_dispatcher.complete_task(task_id, result=result, error=error) worker_manager.set_current_task(self.worker_id, None) logger.info("任务 %s 已完成", task_id) - # ── 将结果写入数据库 ── try: - task_info = task_dispatcher.get_task(task_id) - if task_info: - final_status = task_info.status.value if hasattr(task_info.status, "value") else str(task_info.status) - final_error = task_info.error if getattr(task_info, "error", None) else error - await self._save_task_log(task_id, task_info, result, final_error, final_status) + final_status = await self._finalize_task_and_log(task_id, result, error) + if final_status: await self._update_account_task_status(task_id, final_status) logger.info("任务 %s 最终状态已更新: %s", task_id, final_status) - # check_login 任务:更新账号登录状态 - task_type_val = task_info.task_type.value if hasattr(task_info.task_type, "value") else str(task_info.task_type) - if task_type_val == TaskType.CHECK_LOGIN.value and result and final_status == TaskStatus.SUCCESS.value: + # check_login 任务:更新账号登录状态(仅在成功时) + if result and final_status == TaskStatus.SUCCESS.value: await self._upsert_account_status(result) except Exception as db_err: logger.error("任务 %s 写入数据库失败: %s", task_id, db_err) @@ -141,19 +137,79 @@ class WorkerConsumer(AsyncWebsocketConsumer): @staticmethod @sync_to_async - def _save_task_log(task_id, task_info, result, error, final_status): - from server.models import TaskLog + def _update_task_progress(task_id, progress): + """将进度写入 Task 表。""" + from django.utils import timezone as tz + from server.models import Task + + Task.objects.filter(task_id=task_id).update( + status=TaskStatus.RUNNING.value, + progress=progress, + updated_at=tz.now(), + ) + + @staticmethod + def _infer_failure_reason(result, error): + """ + 失败原因推断逻辑,与 TaskDispatcher 内保持一致(此处复制一份,避免循环依赖)。 + """ + if error is not None: + msg = str(error).strip() + return msg or "任务执行失败" + + if isinstance(result, dict): + if result.get("success") is False: + msg = str(result.get("error", "")).strip() + return msg or "任务执行失败" + if result.get("ok") is False: + msg = str(result.get("error", "")).strip() + return msg or "任务执行失败" + status_text = str(result.get("status", "")).strip().lower() + if status_text in {"failed", "error", "fail"}: + msg = str(result.get("error", "")).strip() + return msg or "任务执行失败" + + return None + + @sync_to_async + def _finalize_task_and_log(self, task_id, result, error): + """ + 完成任务并写入 Task / TaskLog。 + 返回最终状态字符串(如 "success" / "failed"),供后续使用。 + """ + from django.utils import timezone as tz + from server.models import Task, TaskLog + + task = Task.objects.filter(task_id=task_id).first() + if not task: + return None + + failure_reason = self._infer_failure_reason(result, error) + if failure_reason: + final_status = TaskStatus.FAILED.value + final_error = failure_reason + else: + final_status = TaskStatus.SUCCESS.value + final_error = None + + task.status = final_status + task.result = result + task.error = final_error + task.updated_at = tz.now() + task.save(update_fields=["status", "result", "error", "updated_at"]) + TaskLog.objects.update_or_create( task_id=task_id, defaults={ - "task_type": task_info.task_type.value if hasattr(task_info.task_type, "value") else str(task_info.task_type), - "worker_id": task_info.worker_id or "", + "task_type": task.task_type, + "worker_id": task.worker_id or "", "status": final_status, - "params": task_info.params, + "params": task.params, "result": result, - "error": error, + "error": final_error, }, ) + return final_status @staticmethod @sync_to_async