151 lines
5.2 KiB
Python
151 lines
5.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
"""
|
|
中央服务器入口。
|
|
启动方式: python -m server.main
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
|
|
import uvicorn
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from common.protocol import MsgType, make_msg
|
|
from server import config
|
|
from server.api.workers import router as workers_router
|
|
from server.api.tasks import router as tasks_router
|
|
from server.core.worker_manager import worker_manager
|
|
from server.core.task_dispatcher import task_dispatcher
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(name)-28s %(levelname)-5s %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
logger = logging.getLogger("server.main")
|
|
|
|
# ────────────────────────── FastAPI App ──────────────────────────
|
|
|
|
app = FastAPI(title="Browser Control Server", version="1.0.0")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
app.include_router(workers_router)
|
|
app.include_router(tasks_router)
|
|
|
|
|
|
# ────────────────────────── 健康检查 ──────────────────────────
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok", "workers_online": len([w for w in worker_manager.get_all_workers() if w.online])}
|
|
|
|
|
|
# ────────────────────────── WebSocket 端点 ──────────────────────────
|
|
|
|
@app.websocket(config.WS_PATH)
|
|
async def ws_endpoint(ws: WebSocket):
|
|
"""
|
|
Worker 连接流程:
|
|
1. 建立连接
|
|
2. 等待第一条 register 消息
|
|
3. 持续收发消息(心跳、任务进度、任务结果等)
|
|
4. 断开时注销
|
|
"""
|
|
await ws.accept()
|
|
worker_id: str | None = None
|
|
|
|
try:
|
|
# ── 等待注册 ──
|
|
raw = await asyncio.wait_for(ws.receive_json(), timeout=30)
|
|
if raw.get("type") != MsgType.REGISTER.value:
|
|
await ws.send_json(make_msg(MsgType.ERROR, detail="首条消息必须是 register"))
|
|
await ws.close(code=4001)
|
|
return
|
|
|
|
worker_id = raw.get("worker_id", "")
|
|
worker_name = raw.get("worker_name", worker_id)
|
|
browsers = raw.get("browsers", [])
|
|
|
|
if not worker_id:
|
|
await ws.send_json(make_msg(MsgType.ERROR, detail="worker_id 不能为空"))
|
|
await ws.close(code=4002)
|
|
return
|
|
|
|
worker_manager.register(ws, worker_id, worker_name, browsers)
|
|
await ws.send_json(make_msg(MsgType.REGISTER_ACK, worker_id=worker_id))
|
|
logger.info("Worker %s 已连接", worker_id)
|
|
|
|
# ── 消息循环 ──
|
|
while True:
|
|
data = await ws.receive_json()
|
|
msg_type = data.get("type", "")
|
|
|
|
if msg_type == MsgType.HEARTBEAT.value:
|
|
worker_manager.heartbeat(worker_id)
|
|
await ws.send_json(make_msg(MsgType.HEARTBEAT_ACK))
|
|
|
|
elif msg_type == MsgType.BROWSER_LIST_UPDATE.value:
|
|
worker_manager.update_browsers(worker_id, data.get("browsers", []))
|
|
|
|
elif msg_type == MsgType.TASK_PROGRESS.value:
|
|
task_id = data.get("task_id", "")
|
|
progress = data.get("progress", "")
|
|
task_dispatcher.update_progress(task_id, progress)
|
|
logger.info("任务 %s 进度: %s", task_id, progress)
|
|
|
|
elif msg_type == MsgType.TASK_RESULT.value:
|
|
task_id = data.get("task_id", "")
|
|
result = data.get("result")
|
|
error = data.get("error")
|
|
task_dispatcher.complete_task(task_id, result=result, error=error)
|
|
# 释放 Worker 任务占用
|
|
worker_manager.set_current_task(worker_id, None)
|
|
logger.info("任务 %s 已完成", task_id)
|
|
|
|
else:
|
|
logger.warning("未知消息类型: %s (from %s)", msg_type, worker_id)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info("Worker %s WebSocket 断开", worker_id or "unknown")
|
|
except asyncio.TimeoutError:
|
|
logger.warning("WebSocket 连接超时(未在 30 秒内注册)")
|
|
await ws.close(code=4003)
|
|
except Exception as e:
|
|
logger.error("WebSocket 处理异常: %s", e, exc_info=True)
|
|
finally:
|
|
if worker_id:
|
|
worker_manager.unregister(worker_id)
|
|
|
|
|
|
# ────────────────────────── 生命周期 ──────────────────────────
|
|
|
|
@app.on_event("startup")
|
|
async def startup():
|
|
# 启动心跳巡检后台任务
|
|
asyncio.create_task(worker_manager.check_heartbeats_loop())
|
|
logger.info("服务器启动: http://%s:%s", config.HOST, config.PORT)
|
|
|
|
|
|
# ────────────────────────── 入口 ──────────────────────────
|
|
|
|
def main():
|
|
uvicorn.run(
|
|
"server.main:app",
|
|
host=config.HOST,
|
|
port=config.PORT,
|
|
log_level="info",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|