104 lines
3.4 KiB
Python
104 lines
3.4 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
任务提交与查询 API(需要登录)。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
from typing import List, Optional
|
||
|
||
import json
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||
|
||
from common.protocol import TaskStatus
|
||
from server.models import TaskCreate, TaskOut
|
||
from server.core.worker_manager import worker_manager
|
||
from server.core.task_dispatcher import task_dispatcher
|
||
from server.api.deps import require_auth, parse_body
|
||
|
||
router = APIRouter(prefix="/api/tasks", tags=["tasks"], dependencies=[Depends(require_auth)])
|
||
|
||
|
||
@router.post("", response_model=TaskOut, status_code=201)
|
||
async def create_task(request: Request):
|
||
"""
|
||
提交一个新任务(支持 JSON 和 form-data)。
|
||
路由规则:worker_id > account_name。
|
||
"""
|
||
body = await parse_body(request)
|
||
# form-data 中 params 可能是 JSON 字符串,需要解析
|
||
params_raw = body.get("params", {})
|
||
if isinstance(params_raw, str):
|
||
try:
|
||
body["params"] = json.loads(params_raw) if params_raw.strip() else {}
|
||
except (json.JSONDecodeError, ValueError):
|
||
body["params"] = {}
|
||
req = TaskCreate(**body)
|
||
target_worker_id = req.worker_id
|
||
|
||
if not target_worker_id and req.account_name:
|
||
target_worker_id = worker_manager.find_worker_by_account(req.account_name)
|
||
if not target_worker_id:
|
||
raise HTTPException(
|
||
status_code=404,
|
||
detail=f"未找到拥有浏览器 '{req.account_name}' 的在线 Worker",
|
||
)
|
||
|
||
if not target_worker_id:
|
||
raise HTTPException(status_code=400, detail="请指定 worker_id 或 account_name")
|
||
|
||
if not worker_manager.is_online(target_worker_id):
|
||
raise HTTPException(status_code=503, detail=f"Worker {target_worker_id} 不在线")
|
||
|
||
req.worker_id = target_worker_id
|
||
task = task_dispatcher.create_task(req)
|
||
|
||
ws = worker_manager.get_ws(target_worker_id)
|
||
if not ws:
|
||
task.status = TaskStatus.FAILED
|
||
task.error = "Worker WebSocket 连接不存在"
|
||
raise HTTPException(status_code=503, detail="Worker WebSocket 连接不存在")
|
||
|
||
success = await task_dispatcher.dispatch(task, ws.send_json)
|
||
if not success:
|
||
raise HTTPException(status_code=503, detail=f"任务派发失败: {task.error}")
|
||
|
||
worker_manager.set_current_task(target_worker_id, task.task_id)
|
||
return _to_out(task)
|
||
|
||
|
||
@router.get("", response_model=List[TaskOut])
|
||
async def list_tasks(
|
||
worker_id: Optional[str] = None,
|
||
status: Optional[TaskStatus] = None,
|
||
limit: int = 50,
|
||
):
|
||
"""查询任务列表,支持按 worker_id / status 过滤。"""
|
||
tasks = task_dispatcher.list_tasks(worker_id=worker_id, status=status, limit=limit)
|
||
return [_to_out(t) for t in tasks]
|
||
|
||
|
||
@router.get("/{task_id}", response_model=TaskOut)
|
||
async def get_task(task_id: str):
|
||
"""查询指定任务的状态和结果。"""
|
||
task = task_dispatcher.get_task(task_id)
|
||
if not task:
|
||
raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在")
|
||
return _to_out(task)
|
||
|
||
|
||
def _to_out(t) -> TaskOut:
|
||
return TaskOut(
|
||
task_id=t.task_id,
|
||
task_type=t.task_type,
|
||
status=t.status,
|
||
worker_id=t.worker_id,
|
||
account_name=t.account_name,
|
||
params=t.params,
|
||
progress=t.progress,
|
||
result=t.result,
|
||
error=t.error,
|
||
created_at=t.created_at,
|
||
updated_at=t.updated_at,
|
||
)
|