Files
boss_dp/server/db.py
Your Name 620149716d 哈哈
2026-02-12 18:17:15 +08:00

211 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
数据库模块SQLAlchemy 引擎、会话管理、CRUD 操作。
"""
from __future__ import annotations
import logging
from datetime import datetime
from typing import Optional
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from server import config
from server.models import Base, BossAccount, TaskLog, AuthToken
logger = logging.getLogger("server.db")
# ────────────────────────── 引擎与会话 ──────────────────────────
_db_url = (
f"mysql+pymysql://{config.DB_USER}:{config.DB_PASSWORD}"
f"@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}"
f"?charset=utf8mb4"
)
engine = create_engine(_db_url, pool_pre_ping=True, pool_recycle=3600, echo=False)
SessionLocal = sessionmaker(bind=engine, autoflush=False, expire_on_commit=False)
def get_session() -> Session:
"""获取一个新的数据库会话(调用方负责关闭)。"""
return SessionLocal()
def init_db() -> None:
"""创建所有 ORM 定义的表(如果不存在)。"""
Base.metadata.create_all(bind=engine)
logger.info("数据库表初始化完成 (SQLAlchemy ORM)")
# ────────────────────────── BossAccount CRUD ──────────────────────────
def upsert_account_status(
worker_id: str,
browser_id: str,
browser_name: str,
boss_username: str,
is_logged_in: bool,
) -> BossAccount:
"""插入或更新 BOSS 账号登录状态。"""
with get_session() as session:
# 优先使用 worker_id + browser_name 匹配(前台绑定关系)
account = None
if browser_name:
account = (
session.query(BossAccount)
.filter_by(worker_id=worker_id, browser_name=browser_name)
.first()
)
# 兜底:使用 worker_id + browser_id 匹配
if account is None and browser_id:
account = (
session.query(BossAccount)
.filter_by(worker_id=worker_id, browser_id=browser_id)
.first()
)
if account:
account.browser_id = browser_id or account.browser_id
account.browser_name = browser_name or account.browser_name
account.boss_username = boss_username
account.is_logged_in = is_logged_in
account.checked_at = datetime.now()
else:
account = BossAccount(
worker_id=worker_id,
browser_id=browser_id or f"name:{browser_name}",
browser_name=browser_name,
boss_username=boss_username,
is_logged_in=is_logged_in,
checked_at=datetime.now(),
)
session.add(account)
session.commit()
session.refresh(account)
logger.info(
"账号状态更新: worker=%s, browser=%s(%s), username=%s, logged_in=%s",
worker_id, browser_name, browser_id, boss_username, is_logged_in,
)
return account
def bind_account_to_worker(worker_id: str, browser_name: str) -> BossAccount:
"""
前台添加账号时建立绑定关系:环境名称 -> 电脑(worker)。
初始状态设为未登录,等待后续 check_login 刷新。
"""
with get_session() as session:
account = (
session.query(BossAccount)
.filter_by(worker_id=worker_id, browser_name=browser_name)
.first()
)
if account:
return account
account = BossAccount(
worker_id=worker_id,
# 避免 browser_id 为空导致联合唯一冲突,先放占位值
browser_id=f"name:{browser_name}",
browser_name=browser_name,
boss_username="",
is_logged_in=False,
checked_at=None,
)
session.add(account)
session.commit()
session.refresh(account)
logger.info("账号绑定已保存: %s -> %s", browser_name, worker_id)
return account
def get_all_accounts() -> list[dict]:
"""获取所有账号状态。"""
with get_session() as session:
rows = session.query(BossAccount).order_by(BossAccount.updated_at.desc()).all()
return [r.to_dict() for r in rows]
def get_accounts_by_worker(worker_id: str) -> list[dict]:
"""获取指定 Worker 的所有账号状态。"""
with get_session() as session:
rows = (
session.query(BossAccount)
.filter_by(worker_id=worker_id)
.order_by(BossAccount.updated_at.desc())
.all()
)
return [r.to_dict() for r in rows]
def get_account_by_name(browser_name: str, worker_id: Optional[str] = None) -> Optional[dict]:
"""按浏览器环境名查找账号记录。"""
with get_session() as session:
q = session.query(BossAccount).filter_by(browser_name=browser_name)
if worker_id:
q = q.filter_by(worker_id=worker_id)
row = q.first()
return row.to_dict() if row else None
# ────────────────────────── TaskLog CRUD ──────────────────────────
def save_task_log(
task_id: str,
task_type: str,
worker_id: str,
status: str,
params: dict = None,
result=None,
error: str = None,
) -> TaskLog:
"""保存或更新任务执行记录。"""
with get_session() as session:
log = session.query(TaskLog).filter_by(task_id=task_id).first()
if log:
log.status = status
log.result = result
log.error = error
else:
log = TaskLog(
task_id=task_id,
task_type=task_type,
worker_id=worker_id,
status=status,
params=params,
result=result,
error=error,
)
session.add(log)
session.commit()
session.refresh(log)
return log
# ────────────────────────── AuthToken CRUD ──────────────────────────
def set_auth_token(username: str, token: str) -> AuthToken:
"""为指定用户名设置当前有效 token会覆盖之前的 token"""
with get_session() as session:
row = session.query(AuthToken).filter_by(username=username).first()
if row:
row.token = token
row.created_at = datetime.now()
else:
row = AuthToken(username=username, token=token)
session.add(row)
session.commit()
session.refresh(row)
return row
def get_user_by_token(token: str) -> Optional[dict]:
"""根据 token 获取用户信息。"""
if not token:
return None
with get_session() as session:
row = session.query(AuthToken).filter_by(token=token).first()
if not row:
return None
return {"username": row.username, "created_at": row.created_at}