Files
boss_dp/server/api/tasks.py
ddrwode b43e2b51ad haha
2026-03-06 13:29:10 +08:00

403 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
任务提交与查询 API需要登录
统一任务入口:前端通过 task_type 指定任务类型(如 check_login、boss_recruit
"""
import json
import logging
from datetime import datetime
from typing import Optional
from asgiref.sync import async_to_sync
from rest_framework import status as http_status
from rest_framework.decorators import api_view
from common.protocol import MsgType, TaskStatus, TaskType, make_msg
from server.core.response import api_success, api_error
from server.models import BossAccount, TaskCreate, TaskLog, Task
from server.serializers import TaskCreateSerializer
from server.core.worker_manager import worker_manager
from server.core.task_dispatcher import task_dispatcher
logger = logging.getLogger("server.api.tasks")
def _format_timestamp(ts: float) -> str:
"""将时间戳转换为指定格式的字符串2026-03-01T20:19:51"""
dt = datetime.fromtimestamp(ts)
return dt.strftime("%Y-%m-%dT%H:%M:%S")
def _task_to_dict(t) -> dict:
"""将任务实例Task ORM转为可序列化字典。"""
return {
"task_id": t.task_id,
"task_type": str(t.task_type),
"status": str(t.status),
"worker_id": t.worker_id,
"account_name": getattr(t, "account_name", None),
"params": t.params or {},
"progress": getattr(t, "progress", None),
"result": t.result,
"error": t.error,
"created_at": (
t.created_at.strftime("%Y-%m-%dT%H:%M:%S")
if hasattr(t.created_at, "strftime")
else _format_timestamp(t.created_at)
),
"updated_at": (
t.updated_at.strftime("%Y-%m-%dT%H:%M:%S")
if hasattr(t.updated_at, "strftime")
else _format_timestamp(t.updated_at)
),
}
def _parse_limit(raw_value, default: int = 50, max_value: int = 200) -> int:
"""解析分页上限,避免非法参数导致 500。"""
try:
value = int(raw_value)
except (TypeError, ValueError):
return default
if value <= 0:
return default
return min(value, max_value)
def _parse_positive_int(raw_value, default: int = 1) -> int:
"""解析正整数参数。"""
try:
value = int(raw_value)
except (TypeError, ValueError):
return default
if value <= 0:
return default
return value
def _parse_task_status(raw_status: Optional[str]) -> Optional[TaskStatus]:
"""解析任务状态。为空时返回 None。"""
if not raw_status:
return None
try:
return TaskStatus(raw_status)
except ValueError:
return None
def _task_log_account_name(task_log: TaskLog) -> Optional[str]:
"""从任务日志里提取可识别的环境名。"""
result = task_log.result if isinstance(task_log.result, dict) else {}
params = task_log.params if isinstance(task_log.params, dict) else {}
account_name = (
result.get("browser_name")
or result.get("account_name")
or params.get("account_name")
or params.get("browser_name")
)
if not account_name:
return None
return str(account_name).strip() or None
def _task_log_to_dict(task_log: TaskLog, account_name: Optional[str] = None) -> dict:
"""将 TaskLog 转为统一响应结构。"""
return {
"task_id": task_log.task_id,
"task_type": task_log.task_type,
"status": task_log.status,
"worker_id": task_log.worker_id,
"account_name": account_name or _task_log_account_name(task_log),
"params": task_log.params if isinstance(task_log.params, dict) else {},
"progress": None,
"result": task_log.result,
"error": task_log.error,
"created_at": task_log.created_at.strftime("%Y-%m-%dT%H:%M:%S"),
"updated_at": task_log.created_at.strftime("%Y-%m-%dT%H:%M:%S"),
}
def _is_task_log_for_account(task_log: TaskLog, account: BossAccount) -> bool:
"""判断任务日志是否属于某个账号(用于账号任务列表兼容)。"""
if task_log.worker_id and task_log.worker_id != account.worker_id:
return False
if account.current_task_id and task_log.task_id == account.current_task_id:
return True
matched_name = _task_log_account_name(task_log)
if matched_name and matched_name == account.browser_name:
return True
params = task_log.params if isinstance(task_log.params, dict) else {}
account_pk = str(account.pk)
for key in ("id", "account_id", "boss_id"):
value = params.get(key)
if value is None:
continue
if str(value).strip() == account_pk:
return True
return False
def _list_tasks_by_account(account: BossAccount, task_status: Optional[TaskStatus], limit: Optional[int] = 50) -> list:
"""
按账号维度查询任务列表,完全基于 Task 表。
- 不再依赖内存中的 TaskInfo
- 你可以直接修改数据库中 Task.status/Task.result 等字段来影响这里的返回。
"""
qs = Task.objects.filter(
worker_id=account.worker_id,
account_name=account.browser_name,
).order_by("-created_at")
if task_status:
qs = qs.filter(status=task_status.value if hasattr(task_status, "value") else str(task_status))
if limit is not None:
qs = qs[:limit]
return [_task_to_dict(t) for t in qs]
@api_view(["GET", "POST"])
def task_list(request):
"""
GET -> 查询任务列表,支持以下过滤参数:
- ?worker_id= : 按 Worker ID 过滤
- ?account_id= : 按账号 ID 过滤(自动查询对应的 account_name
- ?status= : 按任务状态过滤
- ?limit= : 返回条数上限,默认 50
POST -> 提交新任务(支持 JSON 和 form-data
"""
if request.method == "GET":
wid = request.query_params.get("worker_id")
account_id = request.query_params.get("account_id")
st = request.query_params.get("status")
limit = _parse_limit(request.query_params.get("limit", 50))
task_status = _parse_task_status(st)
if st and task_status is None:
return api_error(http_status.HTTP_400_BAD_REQUEST, f"不支持的任务状态: {st}")
# 如果指定了 account_id直接返回该账号的任务列表兼容前端“查看任务”
if account_id:
try:
account_id = int(account_id)
account = BossAccount.objects.get(pk=account_id)
except (ValueError, BossAccount.DoesNotExist):
return api_error(
http_status.HTTP_404_NOT_FOUND,
f"未找到 id={account_id} 的账号",
)
return api_success(_list_tasks_by_account(account, task_status=task_status, limit=limit))
qs = Task.objects.all().order_by("-created_at")
if wid:
qs = qs.filter(worker_id=wid)
if task_status:
qs = qs.filter(status=task_status.value if hasattr(task_status, "value") else str(task_status))
qs = qs[:limit]
return api_success([_task_to_dict(t) for t in qs])
# POST: 提交新任务
data = request.data.copy()
ser = TaskCreateSerializer(data=data)
if not ser.is_valid():
logger.warning("任务提交参数校验失败: data=%s, errors=%s", data, ser.errors)
from rest_framework.exceptions import ValidationError
raise ValidationError(ser.errors)
validated = ser.validated_data.copy()
# form-data 中 params 是 JSON 字符串,需要手动解析
params_raw = validated.get("params", "")
if isinstance(params_raw, str):
try:
validated["params"] = json.loads(params_raw) if params_raw.strip() else {}
except (json.JSONDecodeError, ValueError):
validated["params"] = {}
elif not isinstance(params_raw, dict):
validated["params"] = {}
# boss_id 即 /api/accounts 返回的 id账号主键与 id/account_id 等价,均按主键查表
account_id = (
validated.pop("id", None)
or validated.pop("account_id", None)
or validated.pop("boss_id", None)
)
if account_id is not None:
try:
account_id = int(account_id)
except (TypeError, ValueError):
account_id = None
req = TaskCreate(**validated)
target_worker_id = req.worker_id or ""
account_name = req.account_name or ""
# 指定 id / account_id / boss_id均为账号主键按主键查表
if account_id:
try:
account = BossAccount.objects.get(pk=account_id)
except BossAccount.DoesNotExist:
return api_error(
http_status.HTTP_404_NOT_FOUND,
f"未找到 id={account_id} 的账号",
)
target_worker_id = account.worker_id
account_name = account.browser_name
req.worker_id = target_worker_id
req.account_name = account_name
if not target_worker_id and account_name:
target_worker_id = worker_manager.find_worker_by_account(account_name)
if not target_worker_id:
return api_error(
http_status.HTTP_404_NOT_FOUND,
f"未找到拥有浏览器 '{account_name}' 的在线 Worker",
)
req.worker_id = target_worker_id
if not target_worker_id:
return api_error(http_status.HTTP_400_BAD_REQUEST, "请指定 id、boss_id、worker_id 或 account_name")
if not worker_manager.is_online(target_worker_id):
return api_error(http_status.HTTP_503_SERVICE_UNAVAILABLE, f"Worker {target_worker_id} 不在线")
req.worker_id = target_worker_id
try:
task = task_dispatcher.create_task(req)
except ValueError as e:
return api_error(http_status.HTTP_409_CONFLICT, str(e))
send_fn = worker_manager.get_send_fn(target_worker_id)
if not send_fn:
task.status = TaskStatus.FAILED
task.error = "Worker WebSocket 连接不存在"
return api_error(http_status.HTTP_503_SERVICE_UNAVAILABLE, "Worker WebSocket 连接不存在")
success = async_to_sync(task_dispatcher.dispatch)(task, send_fn)
if not success:
return api_error(http_status.HTTP_503_SERVICE_UNAVAILABLE, f"任务派发失败: {task.error}")
worker_manager.set_current_task(target_worker_id, task.task_id)
# check_login 任务:关联账号的任务状态
if req.task_type == TaskType.CHECK_LOGIN and req.account_name:
try:
account = BossAccount.objects.filter(
browser_name=req.account_name, worker_id=target_worker_id,
).first()
if account:
account.current_task_id = task.task_id
# 派发成功后,将状态设为 DISPATCHED不是 task.status.value因为那可能还是 PENDING
account.current_task_status = TaskStatus.DISPATCHED.value
account.save(update_fields=["current_task_id", "current_task_status"])
logger.info("账号 %s 任务关联: task_id=%s, status=%s", req.account_name, task.task_id, account.current_task_status)
except Exception as e:
logger.warning("关联账号任务状态失败: %s", e)
return api_success(_task_to_dict(task), http_status=http_status.HTTP_201_CREATED)
@api_view(["GET"])
def task_detail(request, task_id: str):
"""
按 task_id 查询单个任务详情。
优先查 Task 表(活跃任务),再查 TaskLog 表(历史记录)。
"""
task = Task.objects.filter(task_id=task_id).first()
if task:
return api_success(_task_to_dict(task))
task_log = TaskLog.objects.filter(task_id=task_id).first()
if task_log:
return api_success(_task_log_to_dict(task_log))
return api_error(http_status.HTTP_404_NOT_FOUND, f"任务 {task_id} 不存在")
@api_view(["POST"])
def task_cancel(request, task_id: str):
"""
取消任务。
- 若任务已结束success/failed/cancelled返回 409
- 若任务可取消:先写入 cancelled再向 Worker 下发 TASK_CANCEL。
"""
task = Task.objects.filter(task_id=task_id).first()
if not task:
return api_error(http_status.HTTP_404_NOT_FOUND, f"任务 {task_id} 不存在")
active_status_values = {
TaskStatus.PENDING.value,
TaskStatus.DISPATCHED.value,
TaskStatus.RUNNING.value,
}
if str(task.status) not in active_status_values:
return api_error(http_status.HTTP_409_CONFLICT, f"任务当前状态为 {task.status},不可取消")
cancelled_task = task_dispatcher.cancel_task(task_id, error="任务已取消")
if not cancelled_task:
return api_error(http_status.HTTP_409_CONFLICT, "任务已结束或已被取消")
if cancelled_task.worker_id:
worker_manager.set_current_task(cancelled_task.worker_id, None)
if cancelled_task.account_name:
BossAccount.objects.filter(
worker_id=cancelled_task.worker_id,
browser_name=cancelled_task.account_name,
current_task_id=cancelled_task.task_id,
).update(current_task_status=TaskStatus.CANCELLED.value)
send_fn = worker_manager.get_send_fn(cancelled_task.worker_id)
if send_fn:
cancel_msg = make_msg(MsgType.TASK_CANCEL, task_id=cancelled_task.task_id)
try:
async_to_sync(send_fn)(cancel_msg)
except Exception as e:
logger.warning("向 Worker 下发任务取消失败 task_id=%s: %s", cancelled_task.task_id, e)
return api_success(_task_to_dict(cancelled_task), msg="任务已取消")
@api_view(["GET"])
def task_list_by_account(request, account_id: int):
"""
按账号 ID 查询任务列表(不支持按 task_id 查询)。
- 兼容模式:不传 page/page_size 时,返回数组
- 分页模式:传 page 或 page_size 时,返回 {total, page, page_size, results}
"""
account = BossAccount.objects.filter(pk=account_id).first()
if not account:
return api_error(http_status.HTTP_404_NOT_FOUND, f"账号 {account_id} 不存在")
st = request.query_params.get("status")
task_status = _parse_task_status(st)
if st and task_status is None:
return api_error(http_status.HTTP_400_BAD_REQUEST, f"不支持的任务状态: {st}")
# 旧接口兼容默认返回数组limit 控制条数
enable_paging = ("page" in request.query_params) or ("page_size" in request.query_params)
if not enable_paging:
limit = _parse_limit(request.query_params.get("limit", 50))
return api_success(_list_tasks_by_account(account, task_status=task_status, limit=limit))
page = _parse_positive_int(request.query_params.get("page"), default=1)
page_size = _parse_limit(
request.query_params.get("page_size", request.query_params.get("limit", 20)),
default=20,
max_value=200,
)
all_items = _list_tasks_by_account(account, task_status=task_status, limit=None)
total = len(all_items)
start = (page - 1) * page_size
end = start + page_size
return api_success({
"total": total,
"page": page,
"page_size": page_size,
"results": all_items[start:end],
})