haha
This commit is contained in:
@@ -14,7 +14,7 @@ from rest_framework.decorators import api_view
|
||||
|
||||
from common.protocol import TaskStatus, TaskType
|
||||
from server.core.response import api_success, api_error
|
||||
from server.models import BossAccount, TaskCreate, TaskLog
|
||||
from server.models import BossAccount, TaskCreate, TaskLog, Task
|
||||
from server.serializers import TaskCreateSerializer
|
||||
from server.core.worker_manager import worker_manager
|
||||
from server.core.task_dispatcher import task_dispatcher
|
||||
@@ -29,19 +29,27 @@ def _format_timestamp(ts: float) -> str:
|
||||
|
||||
|
||||
def _task_to_dict(t) -> dict:
|
||||
"""将 TaskInfo 转为可序列化字典。"""
|
||||
"""将任务实例(Task ORM)转为可序列化字典。"""
|
||||
return {
|
||||
"task_id": t.task_id,
|
||||
"task_type": t.task_type.value if hasattr(t.task_type, "value") else str(t.task_type),
|
||||
"status": t.status.value if hasattr(t.status, "value") else str(t.status),
|
||||
"task_type": str(t.task_type),
|
||||
"status": str(t.status),
|
||||
"worker_id": t.worker_id,
|
||||
"account_name": t.account_name,
|
||||
"params": t.params,
|
||||
"progress": t.progress,
|
||||
"account_name": getattr(t, "account_name", None),
|
||||
"params": t.params or {},
|
||||
"progress": getattr(t, "progress", None),
|
||||
"result": t.result,
|
||||
"error": t.error,
|
||||
"created_at": _format_timestamp(t.created_at),
|
||||
"updated_at": _format_timestamp(t.updated_at),
|
||||
"created_at": (
|
||||
t.created_at.strftime("%Y-%m-%dT%H:%M:%S")
|
||||
if hasattr(t.created_at, "strftime")
|
||||
else _format_timestamp(t.created_at)
|
||||
),
|
||||
"updated_at": (
|
||||
t.updated_at.strftime("%Y-%m-%dT%H:%M:%S")
|
||||
if hasattr(t.updated_at, "strftime")
|
||||
else _format_timestamp(t.updated_at)
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -135,38 +143,20 @@ def _is_task_log_for_account(task_log: TaskLog, account: BossAccount) -> bool:
|
||||
|
||||
def _list_tasks_by_account(account: BossAccount, task_status: Optional[TaskStatus], limit: Optional[int] = 50) -> list:
|
||||
"""
|
||||
聚合某账号的任务列表:
|
||||
1) 内存任务(实时)
|
||||
2) TaskLog 历史任务(重启后可查)
|
||||
按账号维度查询任务列表,完全基于 Task 表。
|
||||
- 不再依赖内存中的 TaskInfo;
|
||||
- 你可以直接修改数据库中 Task.status/Task.result 等字段来影响这里的返回。
|
||||
"""
|
||||
items_by_task_id = {}
|
||||
|
||||
# limit=None 用于分页场景下先取全量,再切片;此处给内存任务查询一个足够大的上限
|
||||
memory_limit = 10000 if limit is None else max(limit * 3, 100)
|
||||
memory_tasks = task_dispatcher.list_tasks(worker_id=account.worker_id, status=task_status, limit=memory_limit)
|
||||
for t in memory_tasks:
|
||||
if t.account_name != account.browser_name:
|
||||
continue
|
||||
items_by_task_id[t.task_id] = _task_to_dict(t)
|
||||
|
||||
db_qs = TaskLog.objects.filter(worker_id=account.worker_id).order_by("-created_at")
|
||||
qs = Task.objects.filter(
|
||||
worker_id=account.worker_id,
|
||||
account_name=account.browser_name,
|
||||
).order_by("-created_at")
|
||||
if task_status:
|
||||
db_qs = db_qs.filter(status=task_status.value)
|
||||
qs = qs.filter(status=task_status.value if hasattr(task_status, "value") else str(task_status))
|
||||
|
||||
# 多取一些做过滤,避免因为条件匹配损耗导致结果太少
|
||||
db_logs = db_qs if limit is None else db_qs[: max(limit * 8, 200)]
|
||||
for task_log in db_logs:
|
||||
if not _is_task_log_for_account(task_log, account):
|
||||
continue
|
||||
if task_log.task_id in items_by_task_id:
|
||||
continue
|
||||
items_by_task_id[task_log.task_id] = _task_log_to_dict(task_log, account_name=account.browser_name)
|
||||
|
||||
merged = list(items_by_task_id.values())
|
||||
merged.sort(key=lambda item: item.get("created_at") or "", reverse=True)
|
||||
if limit is None:
|
||||
return merged
|
||||
return merged[:limit]
|
||||
if limit is not None:
|
||||
qs = qs[:limit]
|
||||
return [_task_to_dict(t) for t in qs]
|
||||
|
||||
|
||||
@api_view(["GET", "POST"])
|
||||
@@ -200,8 +190,13 @@ def task_list(request):
|
||||
)
|
||||
return api_success(_list_tasks_by_account(account, task_status=task_status, limit=limit))
|
||||
|
||||
tasks = task_dispatcher.list_tasks(worker_id=wid, status=task_status, limit=limit)
|
||||
return api_success([_task_to_dict(t) for t in tasks])
|
||||
qs = Task.objects.all().order_by("-created_at")
|
||||
if wid:
|
||||
qs = qs.filter(worker_id=wid)
|
||||
if task_status:
|
||||
qs = qs.filter(status=task_status.value if hasattr(task_status, "value") else str(task_status))
|
||||
qs = qs[:limit]
|
||||
return api_success([_task_to_dict(t) for t in qs])
|
||||
|
||||
# POST: 提交新任务
|
||||
data = request.data.copy()
|
||||
|
||||
@@ -7,40 +7,46 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from common.protocol import MsgType, TaskStatus, make_msg
|
||||
from server.models import TaskCreate, TaskInfo
|
||||
from server.models import TaskCreate, Task
|
||||
|
||||
logger = logging.getLogger("server.task_dispatcher")
|
||||
|
||||
|
||||
class TaskDispatcher:
|
||||
"""管理任务生命周期并派发给 Worker。"""
|
||||
_ACTIVE_STATUSES = {TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING}
|
||||
"""
|
||||
管理任务生命周期并派发给 Worker(以数据库为唯一真相)。
|
||||
|
||||
def __init__(self) -> None:
|
||||
# task_id → TaskInfo
|
||||
self._tasks: Dict[str, TaskInfo] = {}
|
||||
self._lock = threading.RLock()
|
||||
注意:
|
||||
- 不再在内存中长期保存任务状态;
|
||||
- 所有查询 / 状态变更均落在 ORM 模型 Task 上;
|
||||
- 仅在单次请求/派发过程中临时构造 Python 对象。
|
||||
"""
|
||||
_ACTIVE_STATUSES = {TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_account_name(account_name: Optional[str]) -> str:
|
||||
return (account_name or "").strip()
|
||||
|
||||
def _find_active_task_unlocked(self, worker_id: str, account_name: str) -> Optional[TaskInfo]:
|
||||
@staticmethod
|
||||
def _find_active_task(worker_id: str, account_name: str) -> Optional[Task]:
|
||||
normalized = self._normalize_account_name(account_name)
|
||||
if not worker_id or not normalized:
|
||||
return None
|
||||
for task in self._tasks.values():
|
||||
if task.worker_id != worker_id:
|
||||
continue
|
||||
if self._normalize_account_name(task.account_name) != normalized:
|
||||
continue
|
||||
if task.status in self._ACTIVE_STATUSES:
|
||||
return task
|
||||
return None
|
||||
from django.db.models import Q
|
||||
|
||||
active_values = [s.value for s in TaskDispatcher._ACTIVE_STATUSES]
|
||||
return (
|
||||
Task.objects.filter(
|
||||
Q(worker_id=worker_id),
|
||||
Q(account_name=normalized),
|
||||
Q(status__in=active_values),
|
||||
)
|
||||
.order_by("-created_at")
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _infer_failure_reason(result, error) -> Optional[str]:
|
||||
@@ -70,34 +76,48 @@ class TaskDispatcher:
|
||||
|
||||
# ─── 创建任务 ───
|
||||
|
||||
def create_task(self, req: TaskCreate) -> TaskInfo:
|
||||
with self._lock:
|
||||
existing = self._find_active_task_unlocked(req.worker_id or "", req.account_name or "")
|
||||
if existing:
|
||||
raise ValueError(
|
||||
"账号 '%s' 已有执行中的任务(task_id=%s, status=%s)"
|
||||
% (
|
||||
req.account_name or "",
|
||||
existing.task_id,
|
||||
existing.status.value if hasattr(existing.status, "value") else str(existing.status),
|
||||
)
|
||||
def create_task(self, req: TaskCreate) -> Task:
|
||||
"""
|
||||
创建任务记录:
|
||||
- 校验同一 worker + account 是否已有执行中的任务(查表);
|
||||
- 在 Task 表中插入一条新记录;
|
||||
- 返回 ORM 实例,供后续派发使用。
|
||||
"""
|
||||
from django.utils import timezone as tz
|
||||
import uuid
|
||||
|
||||
existing = self._find_active_task(req.worker_id or "", req.account_name or "")
|
||||
if existing:
|
||||
raise ValueError(
|
||||
"账号 '%s' 已有执行中的任务(task_id=%s, status=%s)"
|
||||
% (
|
||||
req.account_name or "",
|
||||
existing.task_id,
|
||||
existing.status,
|
||||
)
|
||||
task = TaskInfo(
|
||||
task_type=req.task_type,
|
||||
worker_id=req.worker_id,
|
||||
account_name=req.account_name,
|
||||
params=req.params,
|
||||
)
|
||||
self._tasks[task.task_id] = task
|
||||
logger.info(
|
||||
"任务创建: %s type=%s worker=%s account=%s",
|
||||
task.task_id, task.task_type, task.worker_id, task.account_name,
|
||||
)
|
||||
return task
|
||||
|
||||
task_id = uuid.uuid4().hex[:12]
|
||||
now = tz.now()
|
||||
task = Task.objects.create(
|
||||
task_id=task_id,
|
||||
task_type=req.task_type.value if hasattr(req.task_type, "value") else str(req.task_type),
|
||||
worker_id=req.worker_id or "",
|
||||
account_name=req.account_name or "",
|
||||
status=TaskStatus.PENDING.value,
|
||||
params=req.params,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
logger.info(
|
||||
"任务创建: %s type=%s worker=%s account=%s",
|
||||
task.task_id, task.task_type, task.worker_id, task.account_name,
|
||||
)
|
||||
return task
|
||||
|
||||
# ─── 派发 ───
|
||||
|
||||
async def dispatch(self, task: TaskInfo, ws_send) -> bool:
|
||||
async def dispatch(self, task: Task, ws_send) -> bool:
|
||||
"""
|
||||
将任务通过 WebSocket 发给目标 Worker。
|
||||
ws_send: 异步可调用,接受一个 dict 参数。
|
||||
@@ -106,84 +126,104 @@ class TaskDispatcher:
|
||||
msg = make_msg(
|
||||
MsgType.TASK_ASSIGN,
|
||||
task_id=task.task_id,
|
||||
task_type=task.task_type.value,
|
||||
task_type=task.task_type,
|
||||
account_name=task.account_name or "",
|
||||
params=task.params,
|
||||
params=task.params or {},
|
||||
)
|
||||
try:
|
||||
await ws_send(msg)
|
||||
with self._lock:
|
||||
task.status = TaskStatus.DISPATCHED
|
||||
task.updated_at = time.time()
|
||||
from django.utils import timezone as tz
|
||||
|
||||
task.status = TaskStatus.DISPATCHED.value
|
||||
task.updated_at = tz.now()
|
||||
task.save(update_fields=["status", "updated_at"])
|
||||
logger.info("任务 %s 已派发", task.task_id)
|
||||
return True
|
||||
except Exception as e:
|
||||
with self._lock:
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error = f"派发失败: {e}"
|
||||
task.updated_at = time.time()
|
||||
from django.utils import timezone as tz
|
||||
|
||||
task.status = TaskStatus.FAILED.value
|
||||
task.error = f"派发失败: {e}"
|
||||
task.updated_at = tz.now()
|
||||
task.save(update_fields=["status", "error", "updated_at"])
|
||||
logger.error("任务 %s 派发失败: %s", task.task_id, e)
|
||||
return False
|
||||
|
||||
# ─── 更新状态 ───
|
||||
|
||||
def update_progress(self, task_id: str, progress: str) -> None:
|
||||
with self._lock:
|
||||
task = self._tasks.get(task_id)
|
||||
if task:
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.progress = progress
|
||||
task.updated_at = time.time()
|
||||
"""
|
||||
更新任务进度和状态(RUNNING)。
|
||||
"""
|
||||
from django.utils import timezone as tz
|
||||
|
||||
updated = Task.objects.filter(task_id=task_id).update(
|
||||
status=TaskStatus.RUNNING.value,
|
||||
progress=progress,
|
||||
updated_at=tz.now(),
|
||||
)
|
||||
if updated:
|
||||
logger.info("任务 %s 进度更新: %s", task_id, progress)
|
||||
|
||||
def complete_task(self, task_id: str, result=None, error: str = None) -> None:
|
||||
with self._lock:
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
return
|
||||
failure_reason = self._infer_failure_reason(result, error)
|
||||
if failure_reason:
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error = failure_reason
|
||||
task.result = result
|
||||
else:
|
||||
task.status = TaskStatus.SUCCESS
|
||||
task.result = result
|
||||
task.error = None
|
||||
task.updated_at = time.time()
|
||||
logger.info("任务 %s 完成: status=%s", task_id, task.status)
|
||||
"""
|
||||
完成任务:根据 result/error 推断最终状态并写入数据库。
|
||||
"""
|
||||
from django.utils import timezone as tz
|
||||
|
||||
task = Task.objects.filter(task_id=task_id).first()
|
||||
if not task:
|
||||
return
|
||||
failure_reason = self._infer_failure_reason(result, error)
|
||||
if failure_reason:
|
||||
task.status = TaskStatus.FAILED.value
|
||||
task.error = failure_reason
|
||||
task.result = result
|
||||
else:
|
||||
task.status = TaskStatus.SUCCESS.value
|
||||
task.result = result
|
||||
task.error = None
|
||||
task.updated_at = tz.now()
|
||||
task.save(update_fields=["status", "result", "error", "updated_at"])
|
||||
logger.info("任务 %s 完成: status=%s", task_id, task.status)
|
||||
|
||||
def cancel_task(self, task_id: str) -> None:
|
||||
with self._lock:
|
||||
task = self._tasks.get(task_id)
|
||||
if task and task.status in (TaskStatus.PENDING, TaskStatus.DISPATCHED, TaskStatus.RUNNING):
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.updated_at = time.time()
|
||||
"""
|
||||
取消任务:仅对活动状态的任务生效。
|
||||
"""
|
||||
from django.utils import timezone as tz
|
||||
|
||||
active_values = [s.value for s in self._ACTIVE_STATUSES]
|
||||
task = Task.objects.filter(task_id=task_id, status__in=active_values).first()
|
||||
if not task:
|
||||
return
|
||||
task.status = TaskStatus.CANCELLED.value
|
||||
task.updated_at = tz.now()
|
||||
task.save(update_fields=["status", "updated_at"])
|
||||
|
||||
# ─── 查询 ───
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[TaskInfo]:
|
||||
with self._lock:
|
||||
return self._tasks.get(task_id)
|
||||
def get_task(self, task_id: str) -> Optional[Task]:
|
||||
return Task.objects.filter(task_id=task_id).first()
|
||||
|
||||
def get_active_task_by_account(self, worker_id: str, account_name: str) -> Optional[TaskInfo]:
|
||||
with self._lock:
|
||||
return self._find_active_task_unlocked(worker_id, account_name)
|
||||
def get_active_task_by_account(self, worker_id: str, account_name: str) -> Optional[Task]:
|
||||
return self._find_active_task(worker_id, account_name)
|
||||
|
||||
def list_tasks(
|
||||
self,
|
||||
worker_id: Optional[str] = None,
|
||||
status: Optional[TaskStatus] = None,
|
||||
limit: int = 50,
|
||||
) -> List[TaskInfo]:
|
||||
with self._lock:
|
||||
result = list(self._tasks.values())
|
||||
if worker_id:
|
||||
result = [t for t in result if t.worker_id == worker_id]
|
||||
if status:
|
||||
result = [t for t in result if t.status == status]
|
||||
# 按创建时间倒序
|
||||
result.sort(key=lambda t: t.created_at, reverse=True)
|
||||
return result[:limit]
|
||||
) -> List[Task]:
|
||||
"""
|
||||
从数据库中查询任务列表。
|
||||
"""
|
||||
qs = Task.objects.all().order_by("-created_at")
|
||||
if worker_id:
|
||||
qs = qs.filter(worker_id=worker_id)
|
||||
if status:
|
||||
qs = qs.filter(status=status.value if hasattr(status, "value") else str(status))
|
||||
return list(qs[:limit])
|
||||
|
||||
|
||||
# 全局单例
|
||||
|
||||
@@ -62,6 +62,34 @@ class TaskLog(models.Model):
|
||||
return f"{self.task_id} ({self.task_type})"
|
||||
|
||||
|
||||
class Task(models.Model):
|
||||
"""
|
||||
任务表(数据库为唯一真相)。
|
||||
- 所有任务的生命周期状态均保存在此表中;
|
||||
- 内存中不再长期保存任务状态,只作为必要的临时变量。
|
||||
"""
|
||||
|
||||
task_id = models.CharField(max_length=32, unique=True, verbose_name="任务 ID")
|
||||
task_type = models.CharField(max_length=64, verbose_name="任务类型")
|
||||
worker_id = models.CharField(max_length=64, default="", verbose_name="执行的 Worker")
|
||||
account_name = models.CharField(max_length=128, default="", blank=True, verbose_name="账号(环境名称)")
|
||||
status = models.CharField(max_length=32, default="", verbose_name="当前状态")
|
||||
params = models.JSONField(null=True, blank=True, verbose_name="任务参数")
|
||||
progress = models.TextField(null=True, blank=True, verbose_name="进度信息")
|
||||
result = models.JSONField(null=True, blank=True, verbose_name="任务结果")
|
||||
error = models.TextField(null=True, blank=True, verbose_name="错误信息")
|
||||
created_at = models.DateTimeField(auto_now_add=True, verbose_name="创建时间")
|
||||
updated_at = models.DateTimeField(auto_now=True, verbose_name="更新时间")
|
||||
|
||||
class Meta:
|
||||
db_table = "task"
|
||||
verbose_name = "任务"
|
||||
verbose_name_plural = verbose_name
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.task_id} ({self.task_type})"
|
||||
|
||||
|
||||
class AuthToken(models.Model):
|
||||
"""登录 token 表:每个用户名仅保留当前有效 token。"""
|
||||
username = models.CharField(max_length=64, unique=True, verbose_name="用户名")
|
||||
|
||||
@@ -11,7 +11,6 @@ from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
|
||||
from common.protocol import MsgType, TaskStatus, TaskType, make_msg
|
||||
from server.core.worker_manager import worker_manager
|
||||
from server.core.task_dispatcher import task_dispatcher
|
||||
|
||||
logger = logging.getLogger("server.ws")
|
||||
|
||||
@@ -102,8 +101,11 @@ class WorkerConsumer(AsyncWebsocketConsumer):
|
||||
elif msg_type == MsgType.TASK_PROGRESS.value:
|
||||
task_id = data.get("task_id", "")
|
||||
progress = data.get("progress", "")
|
||||
task_dispatcher.update_progress(task_id, progress)
|
||||
logger.info("任务 %s 进度: %s", task_id, progress)
|
||||
try:
|
||||
await self._update_task_progress(task_id, progress)
|
||||
logger.info("任务 %s 进度: %s", task_id, progress)
|
||||
except Exception as e:
|
||||
logger.error("更新任务进度失败 (task_id=%s): %s", task_id, e)
|
||||
# 同步更新账号任务状态为 running
|
||||
try:
|
||||
await self._update_account_task_status(task_id, TaskStatus.RUNNING.value)
|
||||
@@ -114,22 +116,16 @@ class WorkerConsumer(AsyncWebsocketConsumer):
|
||||
task_id = data.get("task_id", "")
|
||||
result = data.get("result")
|
||||
error = data.get("error")
|
||||
task_dispatcher.complete_task(task_id, result=result, error=error)
|
||||
worker_manager.set_current_task(self.worker_id, None)
|
||||
logger.info("任务 %s 已完成", task_id)
|
||||
|
||||
# ── 将结果写入数据库 ──
|
||||
try:
|
||||
task_info = task_dispatcher.get_task(task_id)
|
||||
if task_info:
|
||||
final_status = task_info.status.value if hasattr(task_info.status, "value") else str(task_info.status)
|
||||
final_error = task_info.error if getattr(task_info, "error", None) else error
|
||||
await self._save_task_log(task_id, task_info, result, final_error, final_status)
|
||||
final_status = await self._finalize_task_and_log(task_id, result, error)
|
||||
if final_status:
|
||||
await self._update_account_task_status(task_id, final_status)
|
||||
logger.info("任务 %s 最终状态已更新: %s", task_id, final_status)
|
||||
# check_login 任务:更新账号登录状态
|
||||
task_type_val = task_info.task_type.value if hasattr(task_info.task_type, "value") else str(task_info.task_type)
|
||||
if task_type_val == TaskType.CHECK_LOGIN.value and result and final_status == TaskStatus.SUCCESS.value:
|
||||
# check_login 任务:更新账号登录状态(仅在成功时)
|
||||
if result and final_status == TaskStatus.SUCCESS.value:
|
||||
await self._upsert_account_status(result)
|
||||
except Exception as db_err:
|
||||
logger.error("任务 %s 写入数据库失败: %s", task_id, db_err)
|
||||
@@ -141,19 +137,79 @@ class WorkerConsumer(AsyncWebsocketConsumer):
|
||||
|
||||
@staticmethod
|
||||
@sync_to_async
|
||||
def _save_task_log(task_id, task_info, result, error, final_status):
|
||||
from server.models import TaskLog
|
||||
def _update_task_progress(task_id, progress):
|
||||
"""将进度写入 Task 表。"""
|
||||
from django.utils import timezone as tz
|
||||
from server.models import Task
|
||||
|
||||
Task.objects.filter(task_id=task_id).update(
|
||||
status=TaskStatus.RUNNING.value,
|
||||
progress=progress,
|
||||
updated_at=tz.now(),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _infer_failure_reason(result, error):
|
||||
"""
|
||||
失败原因推断逻辑,与 TaskDispatcher 内保持一致(此处复制一份,避免循环依赖)。
|
||||
"""
|
||||
if error is not None:
|
||||
msg = str(error).strip()
|
||||
return msg or "任务执行失败"
|
||||
|
||||
if isinstance(result, dict):
|
||||
if result.get("success") is False:
|
||||
msg = str(result.get("error", "")).strip()
|
||||
return msg or "任务执行失败"
|
||||
if result.get("ok") is False:
|
||||
msg = str(result.get("error", "")).strip()
|
||||
return msg or "任务执行失败"
|
||||
status_text = str(result.get("status", "")).strip().lower()
|
||||
if status_text in {"failed", "error", "fail"}:
|
||||
msg = str(result.get("error", "")).strip()
|
||||
return msg or "任务执行失败"
|
||||
|
||||
return None
|
||||
|
||||
@sync_to_async
|
||||
def _finalize_task_and_log(self, task_id, result, error):
|
||||
"""
|
||||
完成任务并写入 Task / TaskLog。
|
||||
返回最终状态字符串(如 "success" / "failed"),供后续使用。
|
||||
"""
|
||||
from django.utils import timezone as tz
|
||||
from server.models import Task, TaskLog
|
||||
|
||||
task = Task.objects.filter(task_id=task_id).first()
|
||||
if not task:
|
||||
return None
|
||||
|
||||
failure_reason = self._infer_failure_reason(result, error)
|
||||
if failure_reason:
|
||||
final_status = TaskStatus.FAILED.value
|
||||
final_error = failure_reason
|
||||
else:
|
||||
final_status = TaskStatus.SUCCESS.value
|
||||
final_error = None
|
||||
|
||||
task.status = final_status
|
||||
task.result = result
|
||||
task.error = final_error
|
||||
task.updated_at = tz.now()
|
||||
task.save(update_fields=["status", "result", "error", "updated_at"])
|
||||
|
||||
TaskLog.objects.update_or_create(
|
||||
task_id=task_id,
|
||||
defaults={
|
||||
"task_type": task_info.task_type.value if hasattr(task_info.task_type, "value") else str(task_info.task_type),
|
||||
"worker_id": task_info.worker_id or "",
|
||||
"task_type": task.task_type,
|
||||
"worker_id": task.worker_id or "",
|
||||
"status": final_status,
|
||||
"params": task_info.params,
|
||||
"params": task.params,
|
||||
"result": result,
|
||||
"error": error,
|
||||
"error": final_error,
|
||||
},
|
||||
)
|
||||
return final_status
|
||||
|
||||
@staticmethod
|
||||
@sync_to_async
|
||||
|
||||
Reference in New Issue
Block a user