Source code for tiledbsoma_ml._distributed

# Copyright (c) TileDB, Inc. and The Chan Zuckerberg Initiative Foundation
#
# Licensed under the MIT License.
"""Utilities for multiprocess training: determine GPU "rank" / "world_size" and DataLoader worker ID / count."""

import logging
import os
from typing import Tuple

import torch

logger = logging.getLogger("tiledbsoma_ml.pytorch")


[docs] def get_distributed_rank_and_world_size() -> Tuple[int, int]: """Return tuple containing equivalent of |torch.distributed| rank and world size.""" rank, world_size = 0, 1 if "RANK" in os.environ and "WORLD_SIZE" in os.environ: rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) elif "LOCAL_RANK" in os.environ and "WORLD_SIZE" in os.environ: # Lightning doesn't use RANK! LOCAL_RANK is only for the local node. There # is a NODE_RANK for the node's rank, but no way to tell the local node's # world. So computing a global rank is impossible(?). Using LOCAL_RANK as a # proxy, which works fine on a single-CPU box. TODO: could throw/error # if NODE_RANK != 0. rank = int(os.environ["LOCAL_RANK"]) world_size = int(os.environ["WORLD_SIZE"]) elif torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() rank = torch.distributed.get_rank() return rank, world_size
[docs] def get_worker_id_and_num() -> Tuple[int, int]: """Return |DataLoader| ID, and the total number of |DataLoader| workers.""" worker, num_workers = 0, 1 if "WORKER" in os.environ and "NUM_WORKERS" in os.environ: worker = int(os.environ["WORKER"]) num_workers = int(os.environ["NUM_WORKERS"]) else: worker_info = torch.utils.data.get_worker_info() if worker_info is not None: worker = worker_info.id num_workers = worker_info.num_workers return worker, num_workers
[docs] def init_multiprocessing() -> None: """Ensures use of "spawn" for starting child processes with multiprocessing. Note: - Forked processes are known to be problematic: `Avoiding and fighting deadlocks <https://pytorch.org/docs/stable/notes/multiprocessing.html#avoiding-and-fighting-deadlocks>`_. - CUDA does not support forked child processes: `CUDA in multiprocessing <https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing>`_. Private. """ orig_start_method = torch.multiprocessing.get_start_method() if orig_start_method != "spawn": if orig_start_method: logger.warning( "switching torch multiprocessing start method from " f'"{torch.multiprocessing.get_start_method()}" to "spawn"' ) torch.multiprocessing.set_start_method("spawn", force=True)