Files
boss_dp/server/main.py
ddrwode 4a520306b1 ha'ha
2026-02-12 16:46:01 +08:00

174 lines
5.9 KiB
Python

# -*- coding: utf-8 -*-
"""
中央服务器入口。
启动方式: python -m server.main
"""
from __future__ import annotations
import asyncio
import json
import logging
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from common.protocol import MsgType, make_msg
from server import config
from server.api.workers import router as workers_router
from server.api.tasks import router as tasks_router
from server.core.worker_manager import worker_manager
from server.core.task_dispatcher import task_dispatcher
from tunnel.server import TunnelServer
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(name)-28s %(levelname)-5s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("server.main")
# ────────────────────────── FastAPI App ──────────────────────────
app = FastAPI(title="Browser Control Server", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(workers_router)
app.include_router(tasks_router)
# ────────────────────────── 健康检查 ──────────────────────────
@app.get("/health")
async def health():
return {"status": "ok", "workers_online": len([w for w in worker_manager.get_all_workers() if w.online])}
# ────────────────────────── WebSocket 端点 ──────────────────────────
@app.websocket(config.WS_PATH)
async def ws_endpoint(ws: WebSocket):
"""
Worker 连接流程:
1. 建立连接
2. 等待第一条 register 消息
3. 持续收发消息(心跳、任务进度、任务结果等)
4. 断开时注销
"""
await ws.accept()
worker_id: str | None = None
try:
# ── 等待注册 ──
raw = await asyncio.wait_for(ws.receive_json(), timeout=30)
if raw.get("type") != MsgType.REGISTER.value:
await ws.send_json(make_msg(MsgType.ERROR, detail="首条消息必须是 register"))
await ws.close(code=4001)
return
worker_id = raw.get("worker_id", "")
worker_name = raw.get("worker_name", worker_id)
browsers = raw.get("browsers", [])
if not worker_id:
await ws.send_json(make_msg(MsgType.ERROR, detail="worker_id 不能为空"))
await ws.close(code=4002)
return
worker_manager.register(ws, worker_id, worker_name, browsers)
await ws.send_json(make_msg(MsgType.REGISTER_ACK, worker_id=worker_id))
logger.info("Worker %s 已连接", worker_id)
# ── 消息循环 ──
while True:
data = await ws.receive_json()
msg_type = data.get("type", "")
if msg_type == MsgType.HEARTBEAT.value:
worker_manager.heartbeat(worker_id)
await ws.send_json(make_msg(MsgType.HEARTBEAT_ACK))
elif msg_type == MsgType.BROWSER_LIST_UPDATE.value:
worker_manager.update_browsers(worker_id, data.get("browsers", []))
elif msg_type == MsgType.TASK_PROGRESS.value:
task_id = data.get("task_id", "")
progress = data.get("progress", "")
task_dispatcher.update_progress(task_id, progress)
logger.info("任务 %s 进度: %s", task_id, progress)
elif msg_type == MsgType.TASK_RESULT.value:
task_id = data.get("task_id", "")
result = data.get("result")
error = data.get("error")
task_dispatcher.complete_task(task_id, result=result, error=error)
# 释放 Worker 任务占用
worker_manager.set_current_task(worker_id, None)
logger.info("任务 %s 已完成", task_id)
else:
logger.warning("未知消息类型: %s (from %s)", msg_type, worker_id)
except WebSocketDisconnect:
logger.info("Worker %s WebSocket 断开", worker_id or "unknown")
except asyncio.TimeoutError:
logger.warning("WebSocket 连接超时(未在 30 秒内注册)")
await ws.close(code=4003)
except Exception as e:
logger.error("WebSocket 处理异常: %s", e, exc_info=True)
finally:
if worker_id:
worker_manager.unregister(worker_id)
# ────────────────────────── 生命周期 ──────────────────────────
_tunnel_server: TunnelServer | None = None
@app.on_event("startup")
async def startup():
global _tunnel_server
asyncio.create_task(worker_manager.check_heartbeats_loop())
_tunnel_server = TunnelServer(
control_port=config.TUNNEL_CONTROL_PORT,
stream_port=config.TUNNEL_STREAM_PORT,
proxy_base_port=config.TUNNEL_PROXY_BASE_PORT,
host=config.HOST,
)
await _tunnel_server.start()
logger.info(
"服务器启动: http://%s:%s | 隧道: 控制 %s, 流 %s, 代理起始 %s",
config.HOST, config.PORT,
config.TUNNEL_CONTROL_PORT, config.TUNNEL_STREAM_PORT, config.TUNNEL_PROXY_BASE_PORT,
)
@app.on_event("shutdown")
async def shutdown():
global _tunnel_server
if _tunnel_server:
await _tunnel_server.stop()
_tunnel_server = None
# ────────────────────────── 入口 ──────────────────────────
def main():
uvicorn.run(
"server.main:app",
host=config.HOST,
port=config.PORT,
log_level="info",
)
if __name__ == "__main__":
main()