Files
boss_dp/server/main.py
Your Name 51ae0756e0 哈哈
2026-02-12 17:10:05 +08:00

209 lines
8.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
中央服务器入口。
启动方式: python -m server.main
"""
from __future__ import annotations
import asyncio
import json
import logging
from contextlib import asynccontextmanager
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from common.protocol import MsgType, TaskType, make_msg
from server import config
from server.api.workers import router as workers_router
from server.api.tasks import router as tasks_router
from server.api.accounts import router as accounts_router
from server.core.worker_manager import worker_manager
from server.core.task_dispatcher import task_dispatcher
from server import db
from tunnel.server import TunnelServer
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(name)-28s %(levelname)-5s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("server.main")
# ────────────────────────── Lifespan ──────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期:启动时初始化隧道和心跳巡检,关闭时清理资源。"""
# ── startup ──
# 初始化数据库SQLAlchemy ORM 建表)
try:
db.init_db()
logger.info("数据库初始化完成")
except Exception as e:
logger.error("数据库初始化失败: %s(服务继续运行,但数据库功能不可用)", e)
asyncio.create_task(worker_manager.check_heartbeats_loop())
tunnel_server = TunnelServer(
control_port=config.TUNNEL_CONTROL_PORT,
stream_port=config.TUNNEL_STREAM_PORT,
proxy_base_port=config.TUNNEL_PROXY_BASE_PORT,
host=config.HOST,
)
await tunnel_server.start()
logger.info(
"服务器启动: http://%s:%s | 隧道: 控制 %s, 流 %s, 代理起始 %s",
config.HOST, config.PORT,
config.TUNNEL_CONTROL_PORT, config.TUNNEL_STREAM_PORT, config.TUNNEL_PROXY_BASE_PORT,
)
yield
# ── shutdown ──
await tunnel_server.stop()
logger.info("服务器已关闭")
# ────────────────────────── FastAPI App ──────────────────────────
app = FastAPI(title="Browser Control Server", version="1.0.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(workers_router)
app.include_router(tasks_router)
app.include_router(accounts_router)
# ────────────────────────── 健康检查 ──────────────────────────
@app.get("/health")
async def health():
return {"status": "ok", "workers_online": len([w for w in worker_manager.get_all_workers() if w.online])}
# ────────────────────────── WebSocket 端点 ──────────────────────────
@app.websocket(config.WS_PATH)
async def ws_endpoint(ws: WebSocket):
"""
Worker 连接流程:
1. 建立连接
2. 等待第一条 register 消息
3. 持续收发消息(心跳、任务进度、任务结果等)
4. 断开时注销
"""
await ws.accept()
worker_id: str | None = None
try:
# ── 等待注册 ──
raw = await asyncio.wait_for(ws.receive_json(), timeout=30)
if raw.get("type") != MsgType.REGISTER.value:
await ws.send_json(make_msg(MsgType.ERROR, detail="首条消息必须是 register"))
await ws.close(code=4001)
return
worker_id = raw.get("worker_id", "")
worker_name = raw.get("worker_name", worker_id)
browsers = raw.get("browsers", [])
if not worker_id:
await ws.send_json(make_msg(MsgType.ERROR, detail="worker_id 不能为空"))
await ws.close(code=4002)
return
worker_manager.register(ws, worker_id, worker_name, browsers)
await ws.send_json(make_msg(MsgType.REGISTER_ACK, worker_id=worker_id))
logger.info("Worker %s 已连接", worker_id)
# ── 消息循环 ──
while True:
data = await ws.receive_json()
msg_type = data.get("type", "")
if msg_type == MsgType.HEARTBEAT.value:
worker_manager.heartbeat(worker_id)
await ws.send_json(make_msg(MsgType.HEARTBEAT_ACK))
elif msg_type == MsgType.BROWSER_LIST_UPDATE.value:
worker_manager.update_browsers(worker_id, data.get("browsers", []))
elif msg_type == MsgType.TASK_PROGRESS.value:
task_id = data.get("task_id", "")
progress = data.get("progress", "")
task_dispatcher.update_progress(task_id, progress)
logger.info("任务 %s 进度: %s", task_id, progress)
elif msg_type == MsgType.TASK_RESULT.value:
task_id = data.get("task_id", "")
result = data.get("result")
error = data.get("error")
task_dispatcher.complete_task(task_id, result=result, error=error)
# 释放 Worker 任务占用
worker_manager.set_current_task(worker_id, None)
logger.info("任务 %s 已完成", task_id)
# ── 将结果写入数据库 ──
try:
task_info = task_dispatcher.get_task(task_id)
if task_info:
# 保存任务日志
db.save_task_log(
task_id=task_id,
task_type=task_info.task_type.value if hasattr(task_info.task_type, 'value') else str(task_info.task_type),
worker_id=worker_id,
status=task_info.status.value if hasattr(task_info.status, 'value') else str(task_info.status),
params=task_info.params,
result=result,
error=error,
)
# check_login 任务:更新账号状态表
task_type_val = task_info.task_type.value if hasattr(task_info.task_type, 'value') else str(task_info.task_type)
if task_type_val == TaskType.CHECK_LOGIN.value and result and not error:
db.upsert_account_status(
worker_id=worker_id,
browser_id=result.get("browser_id", ""),
browser_name=result.get("browser_name", ""),
boss_username=result.get("boss_username", ""),
is_logged_in=result.get("is_logged_in", False),
)
except Exception as db_err:
logger.error("任务 %s 写入数据库失败: %s", task_id, db_err)
else:
logger.warning("未知消息类型: %s (from %s)", msg_type, worker_id)
except WebSocketDisconnect:
logger.info("Worker %s WebSocket 断开", worker_id or "unknown")
except asyncio.TimeoutError:
logger.warning("WebSocket 连接超时(未在 30 秒内注册)")
await ws.close(code=4003)
except Exception as e:
logger.error("WebSocket 处理异常: %s", e, exc_info=True)
finally:
if worker_id:
worker_manager.unregister(worker_id)
# ────────────────────────── 入口 ──────────────────────────
def main():
uvicorn.run(
"server.main:app",
host=config.HOST,
port=config.PORT,
log_level="info",
)
if __name__ == "__main__":
main()