Source code for tiledbsoma_ml._mini_batch_iterable

# Copyright (c) TileDB, Inc. and The Chan Zuckerberg Initiative Foundation
#
# Licensed under the MIT License.
from __future__ import annotations

import logging
import os
from typing import Iterable, Iterator

import attrs
import numpy as np
import pandas as pd
import torch
from scipy import sparse

from tiledbsoma_ml._common import MiniBatch
from tiledbsoma_ml._eager_iter import EagerIterator
from tiledbsoma_ml._io_batch_iterable import IOBatchIterable

logger = logging.getLogger("tiledbsoma_ml._mini_batch_iterable")


[docs] @attrs.define(frozen=True) class MiniBatchIterable(Iterable[MiniBatch]): """Convert (possibly shuffled) |IOBatchIterable| into |MiniBatch|'s suitable for passing to PyTorch.""" io_batch_iter: IOBatchIterable batch_size: int use_eager_fetch: bool = True return_sparse_X: bool = False gpu_shuffle: bool = False gpu_shuffle_mode: str = "iobatch" device: torch.device | None = None seed: int | None = None epoch: int = 0 def _gpu_perm(self, n: int) -> torch.Tensor: """Deterministic permutation of range(n) seeded by (seed, epoch, pid).""" base = int(self.seed or 0) pid = os.getpid() mixed = (base * 1315423911 + self.epoch * 2654435761 + pid) & 0xFFFFFFFF gen_device = ( self.device if ( self.device is not None and getattr(self.device, "type", None) == "cuda" ) else "cpu" ) g = torch.Generator(device=gen_device) g.manual_seed(mixed) return torch.randperm(n, generator=g, device=gen_device) def _iter(self) -> Iterator[MiniBatch]: batch_size = self.batch_size result: MiniBatch | None = None for X_io_batch, obs_io_batch in self.io_batch_iter: assert X_io_batch.shape[0] == obs_io_batch.shape[0] iob_idx = 0 # current offset into io batch iob_len = X_io_batch.shape[0] # GPU within-IO-batch shuffle (dense only) if self.gpu_shuffle and self.gpu_shuffle_mode == "iobatch": if self.return_sparse_X: logger.warning( "GPU shuffle requested but return_sparse_X=True; leaving IO-batch order unchanged." ) else: perm = self._gpu_perm(iob_len) perm_cpu = perm.to("cpu", non_blocking=False).numpy() X_full = X_io_batch.slice_tonumpy(slice(0, iob_len)) X_t = torch.from_numpy(X_full) if ( self.device is not None and getattr(self.device, "type", None) == "cuda" ): if not X_t.is_pinned(): X_t = X_t.pin_memory() # faster H2D X_t = X_t.to(self.device, non_blocking=True) X_t = X_t.index_select(0, perm).contiguous() X_cpu = X_t.to("cpu", non_blocking=False).numpy() obs_perm = obs_io_batch.iloc[perm_cpu].reset_index(drop=True) # Emit mini-batches from the permuted IO-batch for start in range(0, iob_len, self.batch_size): stop = min(start + self.batch_size, iob_len) yield ( X_cpu[start:stop], obs_perm.iloc[start:stop].reset_index(drop=True), ) continue # done with this IO-batch while iob_idx < iob_len: if result is None: # perform zero copy slice where possible X_datum = ( X_io_batch.slice_toscipy(slice(iob_idx, iob_idx + batch_size)) if self.return_sparse_X else X_io_batch.slice_tonumpy( slice(iob_idx, iob_idx + batch_size) ) ) result = ( X_datum, obs_io_batch.iloc[iob_idx : iob_idx + batch_size].reset_index( drop=True ), ) iob_idx += len(result[1]) else: # Use any remnant from previous IO batch to_take = min(batch_size - len(result[1]), iob_len - iob_idx) X_datum = ( sparse.vstack( [result[0], X_io_batch.slice_toscipy(slice(0, to_take))] ) if self.return_sparse_X else np.concatenate( [result[0], X_io_batch.slice_tonumpy(slice(0, to_take))] ) ) result = ( X_datum, pd.concat( [result[1], obs_io_batch.iloc[0:to_take]], # Index `obs_batch` from 0 to N-1, instead of disjoint, concatenated pieces of IO batches' # indices ignore_index=True, ), ) iob_idx += to_take X, obs = result if ( self.gpu_shuffle and self.gpu_shuffle_mode == "minibatch" and not self.return_sparse_X ): mb_n = X.shape[0] perm = self._gpu_perm(mb_n) perm_cpu = perm.to("cpu", non_blocking=False).numpy() X_t = torch.from_numpy(X) if ( self.device is not None and getattr(self.device, "type", None) == "cuda" ): if not X_t.is_pinned(): X_t = X_t.pin_memory() X_t = X_t.to(self.device, non_blocking=True) X_t = X_t.index_select(0, perm).contiguous() X = X_t.to("cpu", non_blocking=False).numpy() obs = obs.iloc[perm_cpu].reset_index(drop=True) assert X.shape[0] == obs.shape[0] if X.shape[0] == batch_size: yield result result = None else: # yield the remnant, if any if result is not None: yield result def __iter__(self) -> Iterator[MiniBatch]: it = map(self.maybe_squeeze, self._iter()) return EagerIterator(it) if self.use_eager_fetch else it
[docs] def maybe_squeeze(self, mini_batch: MiniBatch) -> MiniBatch: X, obs = mini_batch if self.batch_size == 1: # This is a no-op for `csr_matrix`s return X[0], obs else: return mini_batch