211 lines
7.1 KiB
Python
211 lines
7.1 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
数据库模块:SQLAlchemy 引擎、会话管理、CRUD 操作。
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from datetime import datetime
|
||
from typing import Optional
|
||
|
||
from sqlalchemy import create_engine
|
||
from sqlalchemy.orm import Session, sessionmaker
|
||
|
||
from server import config
|
||
from server.models import Base, BossAccount, TaskLog, AuthToken
|
||
|
||
logger = logging.getLogger("server.db")
|
||
|
||
# ────────────────────────── 引擎与会话 ──────────────────────────
|
||
|
||
_db_url = (
|
||
f"mysql+pymysql://{config.DB_USER}:{config.DB_PASSWORD}"
|
||
f"@{config.DB_HOST}:{config.DB_PORT}/{config.DB_NAME}"
|
||
f"?charset=utf8mb4"
|
||
)
|
||
|
||
engine = create_engine(_db_url, pool_pre_ping=True, pool_recycle=3600, echo=False)
|
||
SessionLocal = sessionmaker(bind=engine, autoflush=False, expire_on_commit=False)
|
||
|
||
|
||
def get_session() -> Session:
|
||
"""获取一个新的数据库会话(调用方负责关闭)。"""
|
||
return SessionLocal()
|
||
|
||
|
||
def init_db() -> None:
|
||
"""创建所有 ORM 定义的表(如果不存在)。"""
|
||
Base.metadata.create_all(bind=engine)
|
||
logger.info("数据库表初始化完成 (SQLAlchemy ORM)")
|
||
|
||
|
||
# ────────────────────────── BossAccount CRUD ──────────────────────────
|
||
|
||
def upsert_account_status(
|
||
worker_id: str,
|
||
browser_id: str,
|
||
browser_name: str,
|
||
boss_username: str,
|
||
is_logged_in: bool,
|
||
) -> BossAccount:
|
||
"""插入或更新 BOSS 账号登录状态。"""
|
||
with get_session() as session:
|
||
# 优先使用 worker_id + browser_name 匹配(前台绑定关系)
|
||
account = None
|
||
if browser_name:
|
||
account = (
|
||
session.query(BossAccount)
|
||
.filter_by(worker_id=worker_id, browser_name=browser_name)
|
||
.first()
|
||
)
|
||
# 兜底:使用 worker_id + browser_id 匹配
|
||
if account is None and browser_id:
|
||
account = (
|
||
session.query(BossAccount)
|
||
.filter_by(worker_id=worker_id, browser_id=browser_id)
|
||
.first()
|
||
)
|
||
if account:
|
||
account.browser_id = browser_id or account.browser_id
|
||
account.browser_name = browser_name or account.browser_name
|
||
account.boss_username = boss_username
|
||
account.is_logged_in = is_logged_in
|
||
account.checked_at = datetime.now()
|
||
else:
|
||
account = BossAccount(
|
||
worker_id=worker_id,
|
||
browser_id=browser_id or f"name:{browser_name}",
|
||
browser_name=browser_name,
|
||
boss_username=boss_username,
|
||
is_logged_in=is_logged_in,
|
||
checked_at=datetime.now(),
|
||
)
|
||
session.add(account)
|
||
session.commit()
|
||
session.refresh(account)
|
||
logger.info(
|
||
"账号状态更新: worker=%s, browser=%s(%s), username=%s, logged_in=%s",
|
||
worker_id, browser_name, browser_id, boss_username, is_logged_in,
|
||
)
|
||
return account
|
||
|
||
|
||
def bind_account_to_worker(worker_id: str, browser_name: str) -> BossAccount:
|
||
"""
|
||
前台添加账号时建立绑定关系:环境名称 -> 电脑(worker)。
|
||
初始状态设为未登录,等待后续 check_login 刷新。
|
||
"""
|
||
with get_session() as session:
|
||
account = (
|
||
session.query(BossAccount)
|
||
.filter_by(worker_id=worker_id, browser_name=browser_name)
|
||
.first()
|
||
)
|
||
if account:
|
||
return account
|
||
account = BossAccount(
|
||
worker_id=worker_id,
|
||
# 避免 browser_id 为空导致联合唯一冲突,先放占位值
|
||
browser_id=f"name:{browser_name}",
|
||
browser_name=browser_name,
|
||
boss_username="",
|
||
is_logged_in=False,
|
||
checked_at=None,
|
||
)
|
||
session.add(account)
|
||
session.commit()
|
||
session.refresh(account)
|
||
logger.info("账号绑定已保存: %s -> %s", browser_name, worker_id)
|
||
return account
|
||
|
||
|
||
def get_all_accounts() -> list[dict]:
|
||
"""获取所有账号状态。"""
|
||
with get_session() as session:
|
||
rows = session.query(BossAccount).order_by(BossAccount.updated_at.desc()).all()
|
||
return [r.to_dict() for r in rows]
|
||
|
||
|
||
def get_accounts_by_worker(worker_id: str) -> list[dict]:
|
||
"""获取指定 Worker 的所有账号状态。"""
|
||
with get_session() as session:
|
||
rows = (
|
||
session.query(BossAccount)
|
||
.filter_by(worker_id=worker_id)
|
||
.order_by(BossAccount.updated_at.desc())
|
||
.all()
|
||
)
|
||
return [r.to_dict() for r in rows]
|
||
|
||
|
||
def get_account_by_name(browser_name: str, worker_id: Optional[str] = None) -> Optional[dict]:
|
||
"""按浏览器环境名查找账号记录。"""
|
||
with get_session() as session:
|
||
q = session.query(BossAccount).filter_by(browser_name=browser_name)
|
||
if worker_id:
|
||
q = q.filter_by(worker_id=worker_id)
|
||
row = q.first()
|
||
return row.to_dict() if row else None
|
||
|
||
|
||
# ────────────────────────── TaskLog CRUD ──────────────────────────
|
||
|
||
def save_task_log(
|
||
task_id: str,
|
||
task_type: str,
|
||
worker_id: str,
|
||
status: str,
|
||
params: dict = None,
|
||
result=None,
|
||
error: str = None,
|
||
) -> TaskLog:
|
||
"""保存或更新任务执行记录。"""
|
||
with get_session() as session:
|
||
log = session.query(TaskLog).filter_by(task_id=task_id).first()
|
||
if log:
|
||
log.status = status
|
||
log.result = result
|
||
log.error = error
|
||
else:
|
||
log = TaskLog(
|
||
task_id=task_id,
|
||
task_type=task_type,
|
||
worker_id=worker_id,
|
||
status=status,
|
||
params=params,
|
||
result=result,
|
||
error=error,
|
||
)
|
||
session.add(log)
|
||
session.commit()
|
||
session.refresh(log)
|
||
return log
|
||
|
||
|
||
# ────────────────────────── AuthToken CRUD ──────────────────────────
|
||
|
||
def set_auth_token(username: str, token: str) -> AuthToken:
|
||
"""为指定用户名设置当前有效 token(会覆盖之前的 token)。"""
|
||
with get_session() as session:
|
||
row = session.query(AuthToken).filter_by(username=username).first()
|
||
if row:
|
||
row.token = token
|
||
row.created_at = datetime.now()
|
||
else:
|
||
row = AuthToken(username=username, token=token)
|
||
session.add(row)
|
||
session.commit()
|
||
session.refresh(row)
|
||
return row
|
||
|
||
|
||
def get_user_by_token(token: str) -> Optional[dict]:
|
||
"""根据 token 获取用户信息。"""
|
||
if not token:
|
||
return None
|
||
with get_session() as session:
|
||
row = session.query(AuthToken).filter_by(token=token).first()
|
||
if not row:
|
||
return None
|
||
return {"username": row.username, "created_at": row.created_at}
|