哈哈
This commit is contained in:
@@ -1,103 +1,121 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
任务提交与查询 API(需要登录)。
|
||||
统一任务入口:前端通过 task_type 指定任务类型(如 check_login、boss_recruit)。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from asgiref.sync import async_to_sync
|
||||
from rest_framework import status as http_status
|
||||
from rest_framework.decorators import api_view
|
||||
from rest_framework.response import Response
|
||||
|
||||
from common.protocol import TaskStatus
|
||||
from server.models import TaskCreate, TaskOut
|
||||
from common.protocol import TaskStatus, TaskType
|
||||
from server.models import BossAccount, TaskCreate
|
||||
from server.serializers import TaskCreateSerializer, TaskOutSerializer
|
||||
from server.core.worker_manager import worker_manager
|
||||
from server.core.task_dispatcher import task_dispatcher
|
||||
from server.api.deps import require_auth, parse_body
|
||||
|
||||
router = APIRouter(prefix="/api/tasks", tags=["tasks"], dependencies=[Depends(require_auth)])
|
||||
logger = logging.getLogger("server.api.tasks")
|
||||
|
||||
|
||||
@router.post("", response_model=TaskOut, status_code=201)
|
||||
async def create_task(request: Request):
|
||||
def _task_to_dict(t) -> dict:
|
||||
"""将 TaskInfo 转为可序列化字典。"""
|
||||
return {
|
||||
"task_id": t.task_id,
|
||||
"task_type": t.task_type.value if hasattr(t.task_type, "value") else str(t.task_type),
|
||||
"status": t.status.value if hasattr(t.status, "value") else str(t.status),
|
||||
"worker_id": t.worker_id,
|
||||
"account_name": t.account_name,
|
||||
"params": t.params,
|
||||
"progress": t.progress,
|
||||
"result": t.result,
|
||||
"error": t.error,
|
||||
"created_at": t.created_at,
|
||||
"updated_at": t.updated_at,
|
||||
}
|
||||
|
||||
|
||||
@api_view(["GET", "POST"])
|
||||
def task_list(request):
|
||||
"""
|
||||
提交一个新任务(支持 JSON 和 form-data)。
|
||||
路由规则:worker_id > account_name。
|
||||
GET -> 查询任务列表,支持 ?worker_id= / ?status= / ?limit= 过滤
|
||||
POST -> 提交新任务(支持 JSON 和 form-data)
|
||||
"""
|
||||
body = await parse_body(request)
|
||||
# form-data 中 params 可能是 JSON 字符串,需要解析
|
||||
params_raw = body.get("params", {})
|
||||
if request.method == "GET":
|
||||
wid = request.query_params.get("worker_id")
|
||||
st = request.query_params.get("status")
|
||||
limit = int(request.query_params.get("limit", 50))
|
||||
task_status = TaskStatus(st) if st else None
|
||||
tasks = task_dispatcher.list_tasks(worker_id=wid, status=task_status, limit=limit)
|
||||
return Response([_task_to_dict(t) for t in tasks])
|
||||
|
||||
# POST: 提交新任务
|
||||
data = request.data.copy()
|
||||
# form-data 中 params 可能是 JSON 字符串
|
||||
params_raw = data.get("params", {})
|
||||
if isinstance(params_raw, str):
|
||||
try:
|
||||
body["params"] = json.loads(params_raw) if params_raw.strip() else {}
|
||||
data["params"] = json.loads(params_raw) if params_raw.strip() else {}
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
body["params"] = {}
|
||||
req = TaskCreate(**body)
|
||||
target_worker_id = req.worker_id
|
||||
data["params"] = {}
|
||||
|
||||
ser = TaskCreateSerializer(data=data)
|
||||
ser.is_valid(raise_exception=True)
|
||||
|
||||
req = TaskCreate(**ser.validated_data)
|
||||
target_worker_id = req.worker_id or ""
|
||||
|
||||
if not target_worker_id and req.account_name:
|
||||
target_worker_id = worker_manager.find_worker_by_account(req.account_name)
|
||||
if not target_worker_id:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"未找到拥有浏览器 '{req.account_name}' 的在线 Worker",
|
||||
return Response(
|
||||
{"detail": f"未找到拥有浏览器 '{req.account_name}' 的在线 Worker"},
|
||||
status=http_status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
if not target_worker_id:
|
||||
raise HTTPException(status_code=400, detail="请指定 worker_id 或 account_name")
|
||||
return Response({"detail": "请指定 worker_id 或 account_name"}, status=http_status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if not worker_manager.is_online(target_worker_id):
|
||||
raise HTTPException(status_code=503, detail=f"Worker {target_worker_id} 不在线")
|
||||
return Response({"detail": f"Worker {target_worker_id} 不在线"}, status=http_status.HTTP_503_SERVICE_UNAVAILABLE)
|
||||
|
||||
req.worker_id = target_worker_id
|
||||
task = task_dispatcher.create_task(req)
|
||||
|
||||
ws = worker_manager.get_ws(target_worker_id)
|
||||
if not ws:
|
||||
send_fn = worker_manager.get_send_fn(target_worker_id)
|
||||
if not send_fn:
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error = "Worker WebSocket 连接不存在"
|
||||
raise HTTPException(status_code=503, detail="Worker WebSocket 连接不存在")
|
||||
return Response({"detail": "Worker WebSocket 连接不存在"}, status=http_status.HTTP_503_SERVICE_UNAVAILABLE)
|
||||
|
||||
success = await task_dispatcher.dispatch(task, ws.send_json)
|
||||
success = async_to_sync(task_dispatcher.dispatch)(task, send_fn)
|
||||
if not success:
|
||||
raise HTTPException(status_code=503, detail=f"任务派发失败: {task.error}")
|
||||
return Response({"detail": f"任务派发失败: {task.error}"}, status=http_status.HTTP_503_SERVICE_UNAVAILABLE)
|
||||
|
||||
worker_manager.set_current_task(target_worker_id, task.task_id)
|
||||
return _to_out(task)
|
||||
|
||||
# check_login 任务:关联账号的任务状态
|
||||
if req.task_type == TaskType.CHECK_LOGIN and req.account_name:
|
||||
try:
|
||||
account = BossAccount.objects.filter(
|
||||
browser_name=req.account_name, worker_id=target_worker_id,
|
||||
).first()
|
||||
if account:
|
||||
account.current_task_id = task.task_id
|
||||
account.current_task_status = task.status.value
|
||||
account.save(update_fields=["current_task_id", "current_task_status"])
|
||||
except Exception as e:
|
||||
logger.warning("关联账号任务状态失败: %s", e)
|
||||
|
||||
return Response(_task_to_dict(task), status=http_status.HTTP_201_CREATED)
|
||||
|
||||
|
||||
@router.get("", response_model=List[TaskOut])
|
||||
async def list_tasks(
|
||||
worker_id: Optional[str] = None,
|
||||
status: Optional[TaskStatus] = None,
|
||||
limit: int = 50,
|
||||
):
|
||||
"""查询任务列表,支持按 worker_id / status 过滤。"""
|
||||
tasks = task_dispatcher.list_tasks(worker_id=worker_id, status=status, limit=limit)
|
||||
return [_to_out(t) for t in tasks]
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=TaskOut)
|
||||
async def get_task(task_id: str):
|
||||
@api_view(["GET"])
|
||||
def task_detail(request, task_id):
|
||||
"""查询指定任务的状态和结果。"""
|
||||
task = task_dispatcher.get_task(task_id)
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
|
||||
return _to_out(task)
|
||||
|
||||
|
||||
def _to_out(t) -> TaskOut:
|
||||
return TaskOut(
|
||||
task_id=t.task_id,
|
||||
task_type=t.task_type,
|
||||
status=t.status,
|
||||
worker_id=t.worker_id,
|
||||
account_name=t.account_name,
|
||||
params=t.params,
|
||||
progress=t.progress,
|
||||
result=t.result,
|
||||
error=t.error,
|
||||
created_at=t.created_at,
|
||||
updated_at=t.updated_at,
|
||||
)
|
||||
return Response({"detail": f"任务 {task_id} 不存在"}, status=http_status.HTTP_404_NOT_FOUND)
|
||||
return Response(_task_to_dict(task))
|
||||
|
||||
Reference in New Issue
Block a user