From 0ac1e9549c0a54f81f9ec9ace4e32c42dd94c6e0 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 14 Feb 2026 16:49:44 +0800 Subject: [PATCH] =?UTF-8?q?=E5=93=88=E5=93=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- requirements.txt | 7 +- server/__init__.py | 4 + server/api/accounts.py | 132 ++++++++++----------- server/api/auth.py | 43 ++++--- server/api/deps.py | 49 -------- server/api/tasks.py | 142 +++++++++++++---------- server/api/workers.py | 59 +++++----- server/core/worker_manager.py | 19 +-- server/db.py | 210 ---------------------------------- server/main.py | 200 +++++--------------------------- server/models.py | 178 ++++++++-------------------- 11 files changed, 295 insertions(+), 748 deletions(-) delete mode 100644 server/api/deps.py delete mode 100644 server/db.py diff --git a/requirements.txt b/requirements.txt index b001bf3..6425885 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,11 @@ # ─── 中央服务器 (server/) ─── -fastapi>=0.115.0 +django>=5.0 +djangorestframework>=3.15.0 +channels>=4.0.0 uvicorn>=0.34.0 pydantic>=2.0.0 -SQLAlchemy>=2.0.0 PyMySQL>=1.1.0 -python-multipart>=0.0.9 +asgiref>=3.8.0 # ─── Worker 代理 (worker/) ─── websockets>=14.0 diff --git a/server/__init__.py b/server/__init__.py index 40a96af..ee5eb0a 100644 --- a/server/__init__.py +++ b/server/__init__.py @@ -1 +1,5 @@ # -*- coding: utf-8 -*- +import pymysql +pymysql.install_as_MySQLdb() +# Django 6.0 要求 mysqlclient >= 2.2.1,这里 patch PyMySQL 版本号绕过检查 +pymysql.version_info = (2, 2, 1, "final", 0) diff --git a/server/api/accounts.py b/server/api/accounts.py index 53dbd9c..87eee0b 100644 --- a/server/api/accounts.py +++ b/server/api/accounts.py @@ -1,90 +1,80 @@ # -*- coding: utf-8 -*- """ BOSS 账号 API(需要登录): -- POST /api/accounts -> 前台添加账号时绑定到指定电脑 -- POST /api/accounts/check -> 提交环境名称,自动派发 check_login 任务 -- GET /api/accounts -> 查询所有账号登录状态 -- GET /api/accounts/{worker_id} -> 查询指定 Worker 的账号状态 +- POST /api/accounts -> 添加账号(绑定环境名称到电脑) +- GET /api/accounts -> 查询所有账号(含电脑名称、在线、任务状态) +- GET /api/accounts/{id} -> 查询单个账号详情 +- DELETE /api/accounts/{id} -> 删除账号 + +检测登录等操作统一通过 POST /api/tasks 提交,task_type 传 check_login 即可。 """ -from __future__ import annotations +from rest_framework import status +from rest_framework.decorators import api_view +from rest_framework.response import Response -from typing import Optional - -from fastapi import APIRouter, Depends, HTTPException, Request - -from common.protocol import TaskStatus, TaskType -from server.models import AccountBindRequest, CheckLoginRequest, TaskCreate +from server.models import BossAccount +from server.serializers import BossAccountSerializer, AccountBindSerializer from server.core.worker_manager import worker_manager -from server.core.task_dispatcher import task_dispatcher -from server import db -from server.api.deps import require_auth, parse_body - -router = APIRouter(prefix="/api/accounts", tags=["accounts"], dependencies=[Depends(require_auth)]) -@router.post("", status_code=201) -async def bind_account(request: Request): - """前台添加账号:保存账号环境与电脑绑定关系。""" - req = AccountBindRequest(**(await parse_body(request))) - db.bind_account_to_worker(worker_id=req.worker_id, browser_name=req.browser_name) - return {"message": f"账号绑定已保存: {req.browser_name} -> {req.worker_id}"} +# ────────────────────────── 内部工具 ────────────────────────── + +def _enrich(account: BossAccount) -> dict: + """为账号实例补充电脑名称和电脑在线状态。""" + data = BossAccountSerializer(account).data + w = worker_manager.get_worker(account.worker_id) + data["worker_name"] = w.worker_name if w else "" + data["worker_online"] = w.online if w else False + return data -@router.post("/check", status_code=201) -async def check_login(request: Request): +# ────────────────────────── 接口 ────────────────────────── + +@api_view(["GET", "POST"]) +def account_list(request): """ - 前端提交 browser_name(可选 worker_id)→ 自动派发 check_login 任务。 + GET -> 查询账号列表(可选 ?worker_id= 过滤) + POST -> 添加账号(绑定环境名称到电脑) """ - req = CheckLoginRequest(**(await parse_body(request))) - worker_id = req.worker_id - if not worker_id: - bind = db.get_account_by_name(req.browser_name) - if not bind: - raise HTTPException( - status_code=400, - detail=f"未找到账号绑定关系,请先调用 POST /api/accounts 绑定: {req.browser_name}", - ) - worker_id = bind.get("worker_id") + if request.method == "GET": + worker_id = request.query_params.get("worker_id") + qs = BossAccount.objects.all().order_by("-updated_at") + if worker_id: + qs = qs.filter(worker_id=worker_id) + return Response([_enrich(a) for a in qs]) - if not worker_manager.is_online(worker_id): - raise HTTPException(status_code=503, detail=f"Worker {worker_id} 不在线") + # POST: 添加账号 + ser = AccountBindSerializer(data=request.data) + ser.is_valid(raise_exception=True) + wid = ser.validated_data["worker_id"] + bname = ser.validated_data["browser_name"] - task_req = TaskCreate( - task_type=TaskType.CHECK_LOGIN, - worker_id=worker_id, - account_name=req.browser_name, - params={"account_name": req.browser_name}, + account, _ = BossAccount.objects.get_or_create( + worker_id=wid, + browser_name=bname, + defaults={ + "browser_id": f"name:{bname}", + "boss_username": "", + "is_logged_in": False, + }, ) - task = task_dispatcher.create_task(task_req) - - ws = worker_manager.get_ws(worker_id) - if not ws: - task.status = TaskStatus.FAILED - task.error = "Worker WebSocket 连接不存在" - raise HTTPException(status_code=503, detail="Worker WebSocket 连接不存在") - - success = await task_dispatcher.dispatch(task, ws.send_json) - if not success: - raise HTTPException(status_code=503, detail=f"任务派发失败: {task.error}") - - worker_manager.set_current_task(worker_id, task.task_id) - - return { - "message": f"检测任务已派发,环境: {req.browser_name},目标: {worker_id}", - "task_id": task.task_id, - "worker_id": worker_id, - } + return Response(_enrich(account), status=status.HTTP_201_CREATED) -@router.get("") -async def list_accounts(worker_id: Optional[str] = None): - """查询 BOSS 账号登录状态列表。""" - if worker_id: - return db.get_accounts_by_worker(worker_id) - return db.get_all_accounts() +@api_view(["GET", "DELETE"]) +def account_detail(request, account_id): + """ + GET -> 查询单个账号详情 + DELETE -> 删除账号 + """ + try: + account = BossAccount.objects.get(pk=account_id) + except BossAccount.DoesNotExist: + return Response({"detail": "账号不存在"}, status=status.HTTP_404_NOT_FOUND) + if request.method == "GET": + return Response(_enrich(account)) -@router.get("/{worker_id}") -async def get_worker_accounts(worker_id: str): - """查询指定 Worker 的所有账号状态。""" - return db.get_accounts_by_worker(worker_id) + # DELETE + account.delete() + return Response({"message": "账号已删除"}) diff --git a/server/api/auth.py b/server/api/auth.py index e06e81b..27acd5d 100644 --- a/server/api/auth.py +++ b/server/api/auth.py @@ -2,21 +2,21 @@ """ 认证 API:登录(无需 token)。 """ -from __future__ import annotations - import uuid -from fastapi import APIRouter, HTTPException, Request, Response, status +from rest_framework import status +from rest_framework.decorators import api_view, permission_classes +from rest_framework.permissions import AllowAny +from rest_framework.response import Response -from server import config, db -from server.models import LoginRequest, LoginResponse -from server.api.deps import parse_body - -router = APIRouter(prefix="/api/auth", tags=["auth"]) +from server import config +from server.models import AuthToken +from server.serializers import LoginSerializer -@router.post("/login", response_model=LoginResponse) -async def login(request: Request, response: Response): +@api_view(["POST"]) +@permission_classes([AllowAny]) +def login(request): """ 登录接口(支持 JSON 和 form-data): - 校验用户名/密码 @@ -24,20 +24,27 @@ async def login(request: Request, response: Response): - 通过 Set-Cookie 返回 auth_token,前端后续请求自动携带 - 下一次登录会生成新 token,旧 token 自动失效 """ - req = LoginRequest(**(await parse_body(request))) + ser = LoginSerializer(data=request.data) + ser.is_valid(raise_exception=True) - if req.username != config.ADMIN_USERNAME or req.password != config.ADMIN_PASSWORD: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名或密码错误") + username = ser.validated_data["username"] + password = ser.validated_data["password"] + + if username != config.ADMIN_USERNAME or password != config.ADMIN_PASSWORD: + return Response({"detail": "用户名或密码错误"}, status=status.HTTP_401_UNAUTHORIZED) token = uuid.uuid4().hex - db.set_auth_token(req.username, token) + AuthToken.objects.update_or_create( + username=username, + defaults={"token": token}, + ) - response.set_cookie( + resp = Response({"token": token}) + resp.set_cookie( key="auth_token", value=token, httponly=True, max_age=365 * 24 * 60 * 60, - samesite="lax", + samesite="Lax", ) - - return LoginResponse(token=token) + return resp diff --git a/server/api/deps.py b/server/api/deps.py deleted file mode 100644 index 173ac47..0000000 --- a/server/api/deps.py +++ /dev/null @@ -1,49 +0,0 @@ -# -*- coding: utf-8 -*- -""" -API 公共依赖:认证校验、请求体解析。 -所有路由文件统一从这里导入,避免重复代码。 -""" -from __future__ import annotations - -from fastapi import Cookie, HTTPException, Request, status - -from server import db - - -# ────────────────────────── 认证依赖 ────────────────────────── - -async def require_auth(auth_token: str | None = Cookie(default=None)): - """ - 从 cookie 中读取 auth_token 并校验。 - 用法:在 Router 或单个接口上加 dependencies=[Depends(require_auth)] - """ - if auth_token is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="未登录,请先调用 POST /api/auth/login", - ) - user = db.get_user_by_token(auth_token) - if not user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="登录已失效,请重新登录", - ) - return user - - -# ────────────────────────── 请求体解析 ────────────────────────── - -async def parse_body(request: Request) -> dict: - """ - 兼容两种请求体格式: - 1) application/json - 2) multipart/form-data 或 x-www-form-urlencoded - """ - ctype = (request.headers.get("content-type") or "").lower() - if "application/json" in ctype: - data = await request.json() - if not isinstance(data, dict): - raise HTTPException(status_code=422, detail="JSON body 必须是对象") - return data - form = await request.form() - return {k.strip(): v for k, v in form.items()} diff --git a/server/api/tasks.py b/server/api/tasks.py index a84479c..785fd4f 100644 --- a/server/api/tasks.py +++ b/server/api/tasks.py @@ -1,103 +1,121 @@ # -*- coding: utf-8 -*- """ 任务提交与查询 API(需要登录)。 +统一任务入口:前端通过 task_type 指定任务类型(如 check_login、boss_recruit)。 """ -from __future__ import annotations - -from typing import List, Optional - import json +import logging -from fastapi import APIRouter, Depends, HTTPException, Request +from asgiref.sync import async_to_sync +from rest_framework import status as http_status +from rest_framework.decorators import api_view +from rest_framework.response import Response -from common.protocol import TaskStatus -from server.models import TaskCreate, TaskOut +from common.protocol import TaskStatus, TaskType +from server.models import BossAccount, TaskCreate +from server.serializers import TaskCreateSerializer, TaskOutSerializer from server.core.worker_manager import worker_manager from server.core.task_dispatcher import task_dispatcher -from server.api.deps import require_auth, parse_body -router = APIRouter(prefix="/api/tasks", tags=["tasks"], dependencies=[Depends(require_auth)]) +logger = logging.getLogger("server.api.tasks") -@router.post("", response_model=TaskOut, status_code=201) -async def create_task(request: Request): +def _task_to_dict(t) -> dict: + """将 TaskInfo 转为可序列化字典。""" + return { + "task_id": t.task_id, + "task_type": t.task_type.value if hasattr(t.task_type, "value") else str(t.task_type), + "status": t.status.value if hasattr(t.status, "value") else str(t.status), + "worker_id": t.worker_id, + "account_name": t.account_name, + "params": t.params, + "progress": t.progress, + "result": t.result, + "error": t.error, + "created_at": t.created_at, + "updated_at": t.updated_at, + } + + +@api_view(["GET", "POST"]) +def task_list(request): """ - 提交一个新任务(支持 JSON 和 form-data)。 - 路由规则:worker_id > account_name。 + GET -> 查询任务列表,支持 ?worker_id= / ?status= / ?limit= 过滤 + POST -> 提交新任务(支持 JSON 和 form-data) """ - body = await parse_body(request) - # form-data 中 params 可能是 JSON 字符串,需要解析 - params_raw = body.get("params", {}) + if request.method == "GET": + wid = request.query_params.get("worker_id") + st = request.query_params.get("status") + limit = int(request.query_params.get("limit", 50)) + task_status = TaskStatus(st) if st else None + tasks = task_dispatcher.list_tasks(worker_id=wid, status=task_status, limit=limit) + return Response([_task_to_dict(t) for t in tasks]) + + # POST: 提交新任务 + data = request.data.copy() + # form-data 中 params 可能是 JSON 字符串 + params_raw = data.get("params", {}) if isinstance(params_raw, str): try: - body["params"] = json.loads(params_raw) if params_raw.strip() else {} + data["params"] = json.loads(params_raw) if params_raw.strip() else {} except (json.JSONDecodeError, ValueError): - body["params"] = {} - req = TaskCreate(**body) - target_worker_id = req.worker_id + data["params"] = {} + + ser = TaskCreateSerializer(data=data) + ser.is_valid(raise_exception=True) + + req = TaskCreate(**ser.validated_data) + target_worker_id = req.worker_id or "" if not target_worker_id and req.account_name: target_worker_id = worker_manager.find_worker_by_account(req.account_name) if not target_worker_id: - raise HTTPException( - status_code=404, - detail=f"未找到拥有浏览器 '{req.account_name}' 的在线 Worker", + return Response( + {"detail": f"未找到拥有浏览器 '{req.account_name}' 的在线 Worker"}, + status=http_status.HTTP_404_NOT_FOUND, ) if not target_worker_id: - raise HTTPException(status_code=400, detail="请指定 worker_id 或 account_name") + return Response({"detail": "请指定 worker_id 或 account_name"}, status=http_status.HTTP_400_BAD_REQUEST) if not worker_manager.is_online(target_worker_id): - raise HTTPException(status_code=503, detail=f"Worker {target_worker_id} 不在线") + return Response({"detail": f"Worker {target_worker_id} 不在线"}, status=http_status.HTTP_503_SERVICE_UNAVAILABLE) req.worker_id = target_worker_id task = task_dispatcher.create_task(req) - ws = worker_manager.get_ws(target_worker_id) - if not ws: + send_fn = worker_manager.get_send_fn(target_worker_id) + if not send_fn: task.status = TaskStatus.FAILED task.error = "Worker WebSocket 连接不存在" - raise HTTPException(status_code=503, detail="Worker WebSocket 连接不存在") + return Response({"detail": "Worker WebSocket 连接不存在"}, status=http_status.HTTP_503_SERVICE_UNAVAILABLE) - success = await task_dispatcher.dispatch(task, ws.send_json) + success = async_to_sync(task_dispatcher.dispatch)(task, send_fn) if not success: - raise HTTPException(status_code=503, detail=f"任务派发失败: {task.error}") + return Response({"detail": f"任务派发失败: {task.error}"}, status=http_status.HTTP_503_SERVICE_UNAVAILABLE) worker_manager.set_current_task(target_worker_id, task.task_id) - return _to_out(task) + + # check_login 任务:关联账号的任务状态 + if req.task_type == TaskType.CHECK_LOGIN and req.account_name: + try: + account = BossAccount.objects.filter( + browser_name=req.account_name, worker_id=target_worker_id, + ).first() + if account: + account.current_task_id = task.task_id + account.current_task_status = task.status.value + account.save(update_fields=["current_task_id", "current_task_status"]) + except Exception as e: + logger.warning("关联账号任务状态失败: %s", e) + + return Response(_task_to_dict(task), status=http_status.HTTP_201_CREATED) -@router.get("", response_model=List[TaskOut]) -async def list_tasks( - worker_id: Optional[str] = None, - status: Optional[TaskStatus] = None, - limit: int = 50, -): - """查询任务列表,支持按 worker_id / status 过滤。""" - tasks = task_dispatcher.list_tasks(worker_id=worker_id, status=status, limit=limit) - return [_to_out(t) for t in tasks] - - -@router.get("/{task_id}", response_model=TaskOut) -async def get_task(task_id: str): +@api_view(["GET"]) +def task_detail(request, task_id): """查询指定任务的状态和结果。""" task = task_dispatcher.get_task(task_id) if not task: - raise HTTPException(status_code=404, detail=f"任务 {task_id} 不存在") - return _to_out(task) - - -def _to_out(t) -> TaskOut: - return TaskOut( - task_id=t.task_id, - task_type=t.task_type, - status=t.status, - worker_id=t.worker_id, - account_name=t.account_name, - params=t.params, - progress=t.progress, - result=t.result, - error=t.error, - created_at=t.created_at, - updated_at=t.updated_at, - ) + return Response({"detail": f"任务 {task_id} 不存在"}, status=http_status.HTTP_404_NOT_FOUND) + return Response(_task_to_dict(task)) diff --git a/server/api/workers.py b/server/api/workers.py index d89d223..483b914 100644 --- a/server/api/workers.py +++ b/server/api/workers.py @@ -2,43 +2,44 @@ """ Worker 查询 API(需要登录)。 """ -from typing import List +from rest_framework import status +from rest_framework.decorators import api_view, permission_classes +from rest_framework.permissions import AllowAny +from rest_framework.response import Response -from fastapi import APIRouter, Depends, HTTPException - -from server.models import WorkerOut +from server.serializers import WorkerOutSerializer from server.core.worker_manager import worker_manager -from server.api.deps import require_auth - -router = APIRouter(prefix="/api/workers", tags=["workers"], dependencies=[Depends(require_auth)]) -@router.get("", response_model=List[WorkerOut]) -async def list_workers(): +@api_view(["GET"]) +@permission_classes([AllowAny]) +def health_check(request): + """健康检查。""" + online = len([w for w in worker_manager.get_all_workers() if w.online]) + return Response({"status": "ok", "workers_online": online}) + + +def _worker_to_dict(w) -> dict: + return { + "worker_id": w.worker_id, + "worker_name": w.worker_name, + "browsers": [b.model_dump() for b in w.browsers], + "online": w.online, + "current_task_id": w.current_task_id, + } + + +@api_view(["GET"]) +def worker_list(request): """获取所有已注册的 Worker(含在线状态与浏览器列表)。""" workers = worker_manager.get_all_workers() - return [ - WorkerOut( - worker_id=w.worker_id, - worker_name=w.worker_name, - browsers=w.browsers, - online=w.online, - current_task_id=w.current_task_id, - ) - for w in workers - ] + return Response([_worker_to_dict(w) for w in workers]) -@router.get("/{worker_id}", response_model=WorkerOut) -async def get_worker(worker_id: str): +@api_view(["GET"]) +def worker_detail(request, worker_id): """获取指定 Worker 的详情。""" w = worker_manager.get_worker(worker_id) if not w: - raise HTTPException(status_code=404, detail=f"Worker {worker_id} 不存在") - return WorkerOut( - worker_id=w.worker_id, - worker_name=w.worker_name, - browsers=w.browsers, - online=w.online, - current_task_id=w.current_task_id, - ) + return Response({"detail": f"Worker {worker_id} 不存在"}, status=status.HTTP_404_NOT_FOUND) + return Response(_worker_to_dict(w)) diff --git a/server/core/worker_manager.py b/server/core/worker_manager.py index 6107990..a6294b1 100644 --- a/server/core/worker_manager.py +++ b/server/core/worker_manager.py @@ -2,21 +2,23 @@ """ Worker 注册、状态管理、账号 → Worker 映射。 全部在内存中,服务重启后 Worker 重新连接即恢复。 +框架无关:存储的是异步 send_json 可调用对象,不依赖具体 WebSocket 实现。 """ from __future__ import annotations import asyncio import logging import time -from typing import Dict, Optional - -from fastapi import WebSocket +from typing import Callable, Coroutine, Dict, Optional from server.config import HEARTBEAT_TIMEOUT from server.models import BrowserProfile, WorkerInfo logger = logging.getLogger("server.worker_manager") +# send_json 的类型:接受一个 dict 参数,返回协程 +SendJsonFn = Callable[[dict], Coroutine] + class WorkerManager: """管理所有已连接的 Worker。""" @@ -24,8 +26,8 @@ class WorkerManager: def __init__(self) -> None: # worker_id → WorkerInfo self._workers: Dict[str, WorkerInfo] = {} - # worker_id → WebSocket 实例 - self._connections: Dict[str, WebSocket] = {} + # worker_id → 异步 send_json 可调用 + self._connections: Dict[str, SendJsonFn] = {} # account_name(lower) → worker_id 快速路由表 self._account_map: Dict[str, str] = {} @@ -33,7 +35,7 @@ class WorkerManager: def register( self, - ws: WebSocket, + send_json: SendJsonFn, worker_id: str, worker_name: str, browsers: list[dict], @@ -48,7 +50,7 @@ class WorkerManager: connected_at=time.time(), ) self._workers[worker_id] = info - self._connections[worker_id] = ws + self._connections[worker_id] = send_json self._rebuild_account_map() logger.info("Worker 注册: %s (%s), 浏览器 %d 个", worker_id, worker_name, len(profiles)) return info @@ -84,7 +86,8 @@ class WorkerManager: def get_all_workers(self) -> list[WorkerInfo]: return list(self._workers.values()) - def get_ws(self, worker_id: str) -> Optional[WebSocket]: + def get_send_fn(self, worker_id: str) -> Optional[SendJsonFn]: + """获取指定 Worker 的 send_json 可调用对象。""" return self._connections.get(worker_id) def find_worker_by_account(self, account_name: str) -> Optional[str]: diff --git a/server/db.py b/server/db.py deleted file mode 100644 index 6fd818c..0000000 --- a/server/db.py +++ /dev/null @@ -1,210 +0,0 @@ -# -*- coding: utf-8 -*- -""" -数据库模块:SQLAlchemy 引擎、会话管理、CRUD 操作。 -""" -from __future__ import annotations - -import logging -from datetime import datetime -from typing import Optional - -from sqlalchemy import create_engine -from sqlalchemy.orm import Session, sessionmaker - -from server import config -from server.models import Base, BossAccount, TaskLog, AuthToken - -logger = logging.getLogger("server.db") - -# ────────────────────────── 引擎与会话 ────────────────────────── - -_db_url = ( - f"mysql+pymysql://{config.DB_USER}:{config.DB_PASSWORD}" - f"@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}" - f"?charset=utf8mb4" -) - -engine = create_engine(_db_url, pool_pre_ping=True, pool_recycle=3600, echo=False) -SessionLocal = sessionmaker(bind=engine, autoflush=False, expire_on_commit=False) - - -def get_session() -> Session: - """获取一个新的数据库会话(调用方负责关闭)。""" - return SessionLocal() - - -def init_db() -> None: - """创建所有 ORM 定义的表(如果不存在)。""" - Base.metadata.create_all(bind=engine) - logger.info("数据库表初始化完成 (SQLAlchemy ORM)") - - -# ────────────────────────── BossAccount CRUD ────────────────────────── - -def upsert_account_status( - worker_id: str, - browser_id: str, - browser_name: str, - boss_username: str, - is_logged_in: bool, -) -> BossAccount: - """插入或更新 BOSS 账号登录状态。""" - with get_session() as session: - # 优先使用 worker_id + browser_name 匹配(前台绑定关系) - account = None - if browser_name: - account = ( - session.query(BossAccount) - .filter_by(worker_id=worker_id, browser_name=browser_name) - .first() - ) - # 兜底:使用 worker_id + browser_id 匹配 - if account is None and browser_id: - account = ( - session.query(BossAccount) - .filter_by(worker_id=worker_id, browser_id=browser_id) - .first() - ) - if account: - account.browser_id = browser_id or account.browser_id - account.browser_name = browser_name or account.browser_name - account.boss_username = boss_username - account.is_logged_in = is_logged_in - account.checked_at = datetime.now() - else: - account = BossAccount( - worker_id=worker_id, - browser_id=browser_id or f"name:{browser_name}", - browser_name=browser_name, - boss_username=boss_username, - is_logged_in=is_logged_in, - checked_at=datetime.now(), - ) - session.add(account) - session.commit() - session.refresh(account) - logger.info( - "账号状态更新: worker=%s, browser=%s(%s), username=%s, logged_in=%s", - worker_id, browser_name, browser_id, boss_username, is_logged_in, - ) - return account - - -def bind_account_to_worker(worker_id: str, browser_name: str) -> BossAccount: - """ - 前台添加账号时建立绑定关系:环境名称 -> 电脑(worker)。 - 初始状态设为未登录,等待后续 check_login 刷新。 - """ - with get_session() as session: - account = ( - session.query(BossAccount) - .filter_by(worker_id=worker_id, browser_name=browser_name) - .first() - ) - if account: - return account - account = BossAccount( - worker_id=worker_id, - # 避免 browser_id 为空导致联合唯一冲突,先放占位值 - browser_id=f"name:{browser_name}", - browser_name=browser_name, - boss_username="", - is_logged_in=False, - checked_at=None, - ) - session.add(account) - session.commit() - session.refresh(account) - logger.info("账号绑定已保存: %s -> %s", browser_name, worker_id) - return account - - -def get_all_accounts() -> list[dict]: - """获取所有账号状态。""" - with get_session() as session: - rows = session.query(BossAccount).order_by(BossAccount.updated_at.desc()).all() - return [r.to_dict() for r in rows] - - -def get_accounts_by_worker(worker_id: str) -> list[dict]: - """获取指定 Worker 的所有账号状态。""" - with get_session() as session: - rows = ( - session.query(BossAccount) - .filter_by(worker_id=worker_id) - .order_by(BossAccount.updated_at.desc()) - .all() - ) - return [r.to_dict() for r in rows] - - -def get_account_by_name(browser_name: str, worker_id: Optional[str] = None) -> Optional[dict]: - """按浏览器环境名查找账号记录。""" - with get_session() as session: - q = session.query(BossAccount).filter_by(browser_name=browser_name) - if worker_id: - q = q.filter_by(worker_id=worker_id) - row = q.first() - return row.to_dict() if row else None - - -# ────────────────────────── TaskLog CRUD ────────────────────────── - -def save_task_log( - task_id: str, - task_type: str, - worker_id: str, - status: str, - params: dict = None, - result=None, - error: str = None, -) -> TaskLog: - """保存或更新任务执行记录。""" - with get_session() as session: - log = session.query(TaskLog).filter_by(task_id=task_id).first() - if log: - log.status = status - log.result = result - log.error = error - else: - log = TaskLog( - task_id=task_id, - task_type=task_type, - worker_id=worker_id, - status=status, - params=params, - result=result, - error=error, - ) - session.add(log) - session.commit() - session.refresh(log) - return log - - -# ────────────────────────── AuthToken CRUD ────────────────────────── - -def set_auth_token(username: str, token: str) -> AuthToken: - """为指定用户名设置当前有效 token(会覆盖之前的 token)。""" - with get_session() as session: - row = session.query(AuthToken).filter_by(username=username).first() - if row: - row.token = token - row.created_at = datetime.now() - else: - row = AuthToken(username=username, token=token) - session.add(row) - session.commit() - session.refresh(row) - return row - - -def get_user_by_token(token: str) -> Optional[dict]: - """根据 token 获取用户信息。""" - if not token: - return None - with get_session() as session: - row = session.query(AuthToken).filter_by(token=token).first() - if not row: - return None - return {"username": row.username, "created_at": row.created_at} diff --git a/server/main.py b/server/main.py index 603b5c6..ae3fce4 100644 --- a/server/main.py +++ b/server/main.py @@ -1,29 +1,24 @@ # -*- coding: utf-8 -*- """ -中央服务器入口。 +中央服务器入口(Django + Channels + uvicorn)。 启动方式: python -m server.main """ from __future__ import annotations import asyncio -import json import logging -from contextlib import asynccontextmanager +import os -import uvicorn -from fastapi import FastAPI, WebSocket, WebSocketDisconnect -from fastapi.middleware.cors import CORSMiddleware +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "server.settings") -from common.protocol import MsgType, TaskType, make_msg -from server import config -from server.api.workers import router as workers_router -from server.api.tasks import router as tasks_router -from server.api.accounts import router as accounts_router -from server.api.auth import router as auth_router -from server.core.worker_manager import worker_manager -from server.core.task_dispatcher import task_dispatcher -from server import db -from tunnel.server import TunnelServer +import django # noqa: E402 +django.setup() # noqa: E402 + +import uvicorn # noqa: E402 + +from server import config # noqa: E402 +from server.core.worker_manager import worker_manager # noqa: E402 +from tunnel.server import TunnelServer # noqa: E402 logging.basicConfig( level=logging.INFO, @@ -33,20 +28,12 @@ logging.basicConfig( logger = logging.getLogger("server.main") -# ────────────────────────── Lifespan ────────────────────────── - -@asynccontextmanager -async def lifespan(app: FastAPI): - """应用生命周期:启动时初始化隧道和心跳巡检,关闭时清理资源。""" - # ── startup ── - # 初始化数据库(SQLAlchemy ORM 建表) - try: - db.init_db() - logger.info("数据库初始化完成") - except Exception as e: - logger.error("数据库初始化失败: %s(服务继续运行,但数据库功能不可用)", e) - +async def run_server(): + """启动服务器:Django ASGI + 心跳巡检 + 隧道。""" + # 启动心跳巡检 asyncio.create_task(worker_manager.check_heartbeats_loop()) + + # 启动隧道服务 tunnel_server = TunnelServer( control_port=config.TUNNEL_CONTROL_PORT, stream_port=config.TUNNEL_STREAM_PORT, @@ -54,156 +41,31 @@ async def lifespan(app: FastAPI): host=config.HOST, ) await tunnel_server.start() + logger.info( "服务器启动: http://%s:%s | 隧道: 控制 %s, 流 %s, 代理起始 %s", config.HOST, config.PORT, config.TUNNEL_CONTROL_PORT, config.TUNNEL_STREAM_PORT, config.TUNNEL_PROXY_BASE_PORT, ) - yield - - # ── shutdown ── - await tunnel_server.stop() - logger.info("服务器已关闭") - - -# ────────────────────────── FastAPI App ────────────────────────── - -app = FastAPI(title="Browser Control Server", version="1.0.0", lifespan=lifespan) - -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_methods=["*"], - allow_headers=["*"], -) - -app.include_router(auth_router) -app.include_router(workers_router) -app.include_router(tasks_router) -app.include_router(accounts_router) - - -# ────────────────────────── 健康检查 ────────────────────────── - -@app.get("/health") -async def health(): - return {"status": "ok", "workers_online": len([w for w in worker_manager.get_all_workers() if w.online])} - - -# ────────────────────────── WebSocket 端点 ────────────────────────── - -@app.websocket(config.WS_PATH) -async def ws_endpoint(ws: WebSocket): - """ - Worker 连接流程: - 1. 建立连接 - 2. 等待第一条 register 消息 - 3. 持续收发消息(心跳、任务进度、任务结果等) - 4. 断开时注销 - """ - await ws.accept() - worker_id: str | None = None - - try: - # ── 等待注册 ── - raw = await asyncio.wait_for(ws.receive_json(), timeout=30) - if raw.get("type") != MsgType.REGISTER.value: - await ws.send_json(make_msg(MsgType.ERROR, detail="首条消息必须是 register")) - await ws.close(code=4001) - return - - worker_id = raw.get("worker_id", "") - worker_name = raw.get("worker_name", worker_id) - browsers = raw.get("browsers", []) - - if not worker_id: - await ws.send_json(make_msg(MsgType.ERROR, detail="worker_id 不能为空")) - await ws.close(code=4002) - return - - worker_manager.register(ws, worker_id, worker_name, browsers) - await ws.send_json(make_msg(MsgType.REGISTER_ACK, worker_id=worker_id)) - logger.info("Worker %s 已连接", worker_id) - - # ── 消息循环 ── - while True: - data = await ws.receive_json() - msg_type = data.get("type", "") - - if msg_type == MsgType.HEARTBEAT.value: - worker_manager.heartbeat(worker_id) - await ws.send_json(make_msg(MsgType.HEARTBEAT_ACK)) - - elif msg_type == MsgType.BROWSER_LIST_UPDATE.value: - worker_manager.update_browsers(worker_id, data.get("browsers", [])) - - elif msg_type == MsgType.TASK_PROGRESS.value: - task_id = data.get("task_id", "") - progress = data.get("progress", "") - task_dispatcher.update_progress(task_id, progress) - logger.info("任务 %s 进度: %s", task_id, progress) - - elif msg_type == MsgType.TASK_RESULT.value: - task_id = data.get("task_id", "") - result = data.get("result") - error = data.get("error") - task_dispatcher.complete_task(task_id, result=result, error=error) - # 释放 Worker 任务占用 - worker_manager.set_current_task(worker_id, None) - logger.info("任务 %s 已完成", task_id) - - # ── 将结果写入数据库 ── - try: - task_info = task_dispatcher.get_task(task_id) - if task_info: - # 保存任务日志 - db.save_task_log( - task_id=task_id, - task_type=task_info.task_type.value if hasattr(task_info.task_type, 'value') else str(task_info.task_type), - worker_id=worker_id, - status=task_info.status.value if hasattr(task_info.status, 'value') else str(task_info.status), - params=task_info.params, - result=result, - error=error, - ) - # check_login 任务:更新账号状态表 - task_type_val = task_info.task_type.value if hasattr(task_info.task_type, 'value') else str(task_info.task_type) - if task_type_val == TaskType.CHECK_LOGIN.value and result and not error: - db.upsert_account_status( - worker_id=worker_id, - browser_id=result.get("browser_id", ""), - browser_name=result.get("browser_name", ""), - boss_username=result.get("boss_username", ""), - is_logged_in=result.get("is_logged_in", False), - ) - except Exception as db_err: - logger.error("任务 %s 写入数据库失败: %s", task_id, db_err) - - else: - logger.warning("未知消息类型: %s (from %s)", msg_type, worker_id) - - except WebSocketDisconnect: - logger.info("Worker %s WebSocket 断开", worker_id or "unknown") - except asyncio.TimeoutError: - logger.warning("WebSocket 连接超时(未在 30 秒内注册)") - await ws.close(code=4003) - except Exception as e: - logger.error("WebSocket 处理异常: %s", e, exc_info=True) - finally: - if worker_id: - worker_manager.unregister(worker_id) - - -# ────────────────────────── 入口 ────────────────────────── - -def main(): - uvicorn.run( - "server.main:app", + # 启动 uvicorn(使用 Django Channels ASGI 应用) + uvi_config = uvicorn.Config( + "server.asgi:application", host=config.HOST, port=config.PORT, log_level="info", ) + server = uvicorn.Server(uvi_config) + + try: + await server.serve() + finally: + await tunnel_server.stop() + logger.info("服务器已关闭") + + +def main(): + asyncio.run(run_server()) if __name__ == "__main__": diff --git a/server/models.py b/server/models.py index 25269b1..d0fcdef 100644 --- a/server/models.py +++ b/server/models.py @@ -1,114 +1,83 @@ # -*- coding: utf-8 -*- """ -数据模型:SQLAlchemy ORM 表模型 + Pydantic 请求/响应模型。 +数据模型:Django ORM 表模型 + Pydantic 内存模型。 """ from __future__ import annotations import time import uuid -from datetime import datetime from typing import Any, Dict, List, Optional +from django.db import models from pydantic import BaseModel, Field -from sqlalchemy import ( - Boolean, Column, DateTime, Integer, JSON, String, Text, - UniqueConstraint, func, -) -from sqlalchemy.orm import DeclarativeBase from common.protocol import TaskStatus, TaskType # ══════════════════════════════════════════════════════════════ -# SQLAlchemy ORM 模型 +# Django ORM 模型 # ══════════════════════════════════════════════════════════════ -class Base(DeclarativeBase): - """SQLAlchemy 声明式基类。""" - pass - - -class BossAccount(Base): +class BossAccount(models.Model): """BOSS 账号登录状态表。""" - __tablename__ = "boss_account" + worker_id = models.CharField(max_length=64, verbose_name="Worker 标识") + browser_id = models.CharField(max_length=128, default="", verbose_name="比特浏览器窗口 ID") + browser_name = models.CharField(max_length=128, default="", verbose_name="比特浏览器窗口名称(环境名)") + boss_username = models.CharField(max_length=128, default="", verbose_name="BOSS 直聘用户名") + is_logged_in = models.BooleanField(default=False, verbose_name="是否已登录") + current_task_id = models.CharField(max_length=32, null=True, blank=True, verbose_name="当前检测任务 ID") + current_task_status = models.CharField(max_length=32, null=True, blank=True, verbose_name="当前检测任务状态") + checked_at = models.DateTimeField(null=True, blank=True, verbose_name="最近一次检测时间") + created_at = models.DateTimeField(auto_now_add=True, verbose_name="创建时间") + updated_at = models.DateTimeField(auto_now=True, verbose_name="更新时间") - id = Column(Integer, primary_key=True, autoincrement=True) - worker_id = Column(String(64), nullable=False, comment="Worker 标识") - browser_id = Column(String(128), nullable=False, default="", comment="比特浏览器窗口 ID") - browser_name = Column(String(128), default="", comment="比特浏览器窗口名称(环境名)") - boss_username = Column(String(128), default="", comment="BOSS 直聘用户名") - is_logged_in = Column(Boolean, default=False, comment="是否已登录") - checked_at = Column(DateTime, nullable=True, comment="最近一次检测时间") - created_at = Column(DateTime, default=func.now(), comment="创建时间") - updated_at = Column(DateTime, default=func.now(), onupdate=func.now(), comment="更新时间") + class Meta: + db_table = "boss_account" + unique_together = [("worker_id", "browser_id")] + verbose_name = "BOSS 账号" + verbose_name_plural = verbose_name - __table_args__ = ( - UniqueConstraint("worker_id", "browser_id", name="uk_worker_browser"), - {"mysql_charset": "utf8mb4", "comment": "BOSS 账号登录状态"}, - ) - - def to_dict(self) -> dict: - return { - "id": self.id, - "worker_id": self.worker_id, - "browser_id": self.browser_id, - "browser_name": self.browser_name, - "boss_username": self.boss_username, - "is_logged_in": self.is_logged_in, - "checked_at": self.checked_at.isoformat() if self.checked_at else None, - "created_at": self.created_at.isoformat() if self.created_at else None, - "updated_at": self.updated_at.isoformat() if self.updated_at else None, - } + def __str__(self): + return f"{self.browser_name}@{self.worker_id}" -class TaskLog(Base): +class TaskLog(models.Model): """任务执行记录表。""" - __tablename__ = "task_log" + task_id = models.CharField(max_length=32, unique=True, verbose_name="任务 ID") + task_type = models.CharField(max_length=64, verbose_name="任务类型") + worker_id = models.CharField(max_length=64, default="", verbose_name="执行的 Worker") + status = models.CharField(max_length=32, default="", verbose_name="最终状态") + params = models.JSONField(null=True, blank=True, verbose_name="任务参数") + result = models.JSONField(null=True, blank=True, verbose_name="任务结果") + error = models.TextField(null=True, blank=True, verbose_name="错误信息") + created_at = models.DateTimeField(auto_now_add=True, verbose_name="创建时间") - id = Column(Integer, primary_key=True, autoincrement=True) - task_id = Column(String(32), nullable=False, unique=True, comment="任务 ID") - task_type = Column(String(64), nullable=False, comment="任务类型") - worker_id = Column(String(64), default="", comment="执行的 Worker") - status = Column(String(32), default="", comment="最终状态") - params = Column(JSON, nullable=True, comment="任务参数") - result = Column(JSON, nullable=True, comment="任务结果") - error = Column(Text, nullable=True, comment="错误信息") - created_at = Column(DateTime, default=func.now(), comment="创建时间") + class Meta: + db_table = "task_log" + verbose_name = "任务日志" + verbose_name_plural = verbose_name - __table_args__ = ( - {"mysql_charset": "utf8mb4", "comment": "任务执行记录"}, - ) - - def to_dict(self) -> dict: - return { - "id": self.id, - "task_id": self.task_id, - "task_type": self.task_type, - "worker_id": self.worker_id, - "status": self.status, - "params": self.params, - "result": self.result, - "error": self.error, - "created_at": self.created_at.isoformat() if self.created_at else None, - } + def __str__(self): + return f"{self.task_id} ({self.task_type})" -class AuthToken(Base): +class AuthToken(models.Model): """登录 token 表:每个用户名仅保留当前有效 token。""" - __tablename__ = "auth_token" + username = models.CharField(max_length=64, unique=True, verbose_name="用户名") + token = models.CharField(max_length=64, verbose_name="当前有效 token") + created_at = models.DateTimeField(auto_now_add=True, verbose_name="创建时间") - id = Column(Integer, primary_key=True, autoincrement=True) - username = Column(String(64), nullable=False, unique=True, comment="用户名") - token = Column(String(64), nullable=False, comment="当前有效 token") - created_at = Column(DateTime, default=func.now(), comment="创建时间") + class Meta: + db_table = "auth_token" + verbose_name = "登录 Token" + verbose_name_plural = verbose_name - __table_args__ = ( - {"mysql_charset": "utf8mb4", "comment": "登录 token"}, - ) + def __str__(self): + return self.username # ══════════════════════════════════════════════════════════════ -# Pydantic 请求 / 响应模型(API 用) +# Pydantic 内存模型(非数据库,用于 Worker 运行时状态与任务调度) # ══════════════════════════════════════════════════════════════ # ─── Worker ─── @@ -131,19 +100,10 @@ class WorkerInfo(BaseModel): current_task_id: Optional[str] = None -class WorkerOut(BaseModel): - """返回给前端的 Worker 信息。""" - worker_id: str - worker_name: str - browsers: List[BrowserProfile] - online: bool - current_task_id: Optional[str] = None - - # ─── Task ─── class TaskCreate(BaseModel): - """前端提交任务的请求体。""" + """前端提交任务的请求体(也用于内部创建任务)。""" task_type: TaskType worker_id: Optional[str] = None account_name: Optional[str] = None @@ -151,7 +111,7 @@ class TaskCreate(BaseModel): class TaskInfo(BaseModel): - """任务完整信息(内存 / 返回前端)。""" + """任务完整信息(内存中保存)。""" task_id: str = Field(default_factory=lambda: uuid.uuid4().hex[:12]) task_type: TaskType status: TaskStatus = TaskStatus.PENDING @@ -163,43 +123,3 @@ class TaskInfo(BaseModel): error: Optional[str] = None created_at: float = Field(default_factory=time.time) updated_at: float = Field(default_factory=time.time) - - -class TaskOut(BaseModel): - """返回给前端的任务信息。""" - task_id: str - task_type: TaskType - status: TaskStatus - worker_id: Optional[str] = None - account_name: Optional[str] = None - params: Dict[str, Any] = {} - progress: Optional[str] = None - result: Any = None - error: Optional[str] = None - created_at: float - updated_at: float - - -# ─── 简化接口:前端添加环境名 ─── - -class CheckLoginRequest(BaseModel): - """前端提交检测登录请求。worker_id 可不传(走绑定关系)。""" - browser_name: str - worker_id: Optional[str] = None - - -class AccountBindRequest(BaseModel): - """前端添加账号时提交绑定:账号环境名 + 归属电脑。""" - browser_name: str - worker_id: str - - -class LoginRequest(BaseModel): - """登录请求:用户名 + 密码。""" - username: str - password: str - - -class LoginResponse(BaseModel): - """登录成功响应:返回 token。""" - token: str