209 lines
8.0 KiB
Python
209 lines
8.0 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
中央服务器入口。
|
||
启动方式: python -m server.main
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
from contextlib import asynccontextmanager
|
||
|
||
import uvicorn
|
||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
|
||
from common.protocol import MsgType, TaskType, make_msg
|
||
from server import config
|
||
from server.api.workers import router as workers_router
|
||
from server.api.tasks import router as tasks_router
|
||
from server.api.accounts import router as accounts_router
|
||
from server.core.worker_manager import worker_manager
|
||
from server.core.task_dispatcher import task_dispatcher
|
||
from server import db
|
||
from tunnel.server import TunnelServer
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s %(name)-28s %(levelname)-5s %(message)s",
|
||
datefmt="%Y-%m-%d %H:%M:%S",
|
||
)
|
||
logger = logging.getLogger("server.main")
|
||
|
||
|
||
# ────────────────────────── Lifespan ──────────────────────────
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""应用生命周期:启动时初始化隧道和心跳巡检,关闭时清理资源。"""
|
||
# ── startup ──
|
||
# 初始化数据库(SQLAlchemy ORM 建表)
|
||
try:
|
||
db.init_db()
|
||
logger.info("数据库初始化完成")
|
||
except Exception as e:
|
||
logger.error("数据库初始化失败: %s(服务继续运行,但数据库功能不可用)", e)
|
||
|
||
asyncio.create_task(worker_manager.check_heartbeats_loop())
|
||
tunnel_server = TunnelServer(
|
||
control_port=config.TUNNEL_CONTROL_PORT,
|
||
stream_port=config.TUNNEL_STREAM_PORT,
|
||
proxy_base_port=config.TUNNEL_PROXY_BASE_PORT,
|
||
host=config.HOST,
|
||
)
|
||
await tunnel_server.start()
|
||
logger.info(
|
||
"服务器启动: http://%s:%s | 隧道: 控制 %s, 流 %s, 代理起始 %s",
|
||
config.HOST, config.PORT,
|
||
config.TUNNEL_CONTROL_PORT, config.TUNNEL_STREAM_PORT, config.TUNNEL_PROXY_BASE_PORT,
|
||
)
|
||
|
||
yield
|
||
|
||
# ── shutdown ──
|
||
await tunnel_server.stop()
|
||
logger.info("服务器已关闭")
|
||
|
||
|
||
# ────────────────────────── FastAPI App ──────────────────────────
|
||
|
||
app = FastAPI(title="Browser Control Server", version="1.0.0", lifespan=lifespan)
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
app.include_router(workers_router)
|
||
app.include_router(tasks_router)
|
||
app.include_router(accounts_router)
|
||
|
||
|
||
# ────────────────────────── 健康检查 ──────────────────────────
|
||
|
||
@app.get("/health")
|
||
async def health():
|
||
return {"status": "ok", "workers_online": len([w for w in worker_manager.get_all_workers() if w.online])}
|
||
|
||
|
||
# ────────────────────────── WebSocket 端点 ──────────────────────────
|
||
|
||
@app.websocket(config.WS_PATH)
|
||
async def ws_endpoint(ws: WebSocket):
|
||
"""
|
||
Worker 连接流程:
|
||
1. 建立连接
|
||
2. 等待第一条 register 消息
|
||
3. 持续收发消息(心跳、任务进度、任务结果等)
|
||
4. 断开时注销
|
||
"""
|
||
await ws.accept()
|
||
worker_id: str | None = None
|
||
|
||
try:
|
||
# ── 等待注册 ──
|
||
raw = await asyncio.wait_for(ws.receive_json(), timeout=30)
|
||
if raw.get("type") != MsgType.REGISTER.value:
|
||
await ws.send_json(make_msg(MsgType.ERROR, detail="首条消息必须是 register"))
|
||
await ws.close(code=4001)
|
||
return
|
||
|
||
worker_id = raw.get("worker_id", "")
|
||
worker_name = raw.get("worker_name", worker_id)
|
||
browsers = raw.get("browsers", [])
|
||
|
||
if not worker_id:
|
||
await ws.send_json(make_msg(MsgType.ERROR, detail="worker_id 不能为空"))
|
||
await ws.close(code=4002)
|
||
return
|
||
|
||
worker_manager.register(ws, worker_id, worker_name, browsers)
|
||
await ws.send_json(make_msg(MsgType.REGISTER_ACK, worker_id=worker_id))
|
||
logger.info("Worker %s 已连接", worker_id)
|
||
|
||
# ── 消息循环 ──
|
||
while True:
|
||
data = await ws.receive_json()
|
||
msg_type = data.get("type", "")
|
||
|
||
if msg_type == MsgType.HEARTBEAT.value:
|
||
worker_manager.heartbeat(worker_id)
|
||
await ws.send_json(make_msg(MsgType.HEARTBEAT_ACK))
|
||
|
||
elif msg_type == MsgType.BROWSER_LIST_UPDATE.value:
|
||
worker_manager.update_browsers(worker_id, data.get("browsers", []))
|
||
|
||
elif msg_type == MsgType.TASK_PROGRESS.value:
|
||
task_id = data.get("task_id", "")
|
||
progress = data.get("progress", "")
|
||
task_dispatcher.update_progress(task_id, progress)
|
||
logger.info("任务 %s 进度: %s", task_id, progress)
|
||
|
||
elif msg_type == MsgType.TASK_RESULT.value:
|
||
task_id = data.get("task_id", "")
|
||
result = data.get("result")
|
||
error = data.get("error")
|
||
task_dispatcher.complete_task(task_id, result=result, error=error)
|
||
# 释放 Worker 任务占用
|
||
worker_manager.set_current_task(worker_id, None)
|
||
logger.info("任务 %s 已完成", task_id)
|
||
|
||
# ── 将结果写入数据库 ──
|
||
try:
|
||
task_info = task_dispatcher.get_task(task_id)
|
||
if task_info:
|
||
# 保存任务日志
|
||
db.save_task_log(
|
||
task_id=task_id,
|
||
task_type=task_info.task_type.value if hasattr(task_info.task_type, 'value') else str(task_info.task_type),
|
||
worker_id=worker_id,
|
||
status=task_info.status.value if hasattr(task_info.status, 'value') else str(task_info.status),
|
||
params=task_info.params,
|
||
result=result,
|
||
error=error,
|
||
)
|
||
# check_login 任务:更新账号状态表
|
||
task_type_val = task_info.task_type.value if hasattr(task_info.task_type, 'value') else str(task_info.task_type)
|
||
if task_type_val == TaskType.CHECK_LOGIN.value and result and not error:
|
||
db.upsert_account_status(
|
||
worker_id=worker_id,
|
||
browser_id=result.get("browser_id", ""),
|
||
browser_name=result.get("browser_name", ""),
|
||
boss_username=result.get("boss_username", ""),
|
||
is_logged_in=result.get("is_logged_in", False),
|
||
)
|
||
except Exception as db_err:
|
||
logger.error("任务 %s 写入数据库失败: %s", task_id, db_err)
|
||
|
||
else:
|
||
logger.warning("未知消息类型: %s (from %s)", msg_type, worker_id)
|
||
|
||
except WebSocketDisconnect:
|
||
logger.info("Worker %s WebSocket 断开", worker_id or "unknown")
|
||
except asyncio.TimeoutError:
|
||
logger.warning("WebSocket 连接超时(未在 30 秒内注册)")
|
||
await ws.close(code=4003)
|
||
except Exception as e:
|
||
logger.error("WebSocket 处理异常: %s", e, exc_info=True)
|
||
finally:
|
||
if worker_id:
|
||
worker_manager.unregister(worker_id)
|
||
|
||
|
||
# ────────────────────────── 入口 ──────────────────────────
|
||
|
||
def main():
|
||
uvicorn.run(
|
||
"server.main:app",
|
||
host=config.HOST,
|
||
port=config.PORT,
|
||
log_level="info",
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|