Source code for tiledbsoma_ml.dataloader

# Copyright (c) TileDB, Inc. and The Chan Zuckerberg Initiative Foundation
#
# Licensed under the MIT License.

from __future__ import annotations

from typing import Any, TypeVar

from torch.utils.data import DataLoader

from tiledbsoma_ml._distributed import init_multiprocessing
from tiledbsoma_ml.dataset import ExperimentDataset

_T = TypeVar("_T")


[docs] def experiment_dataloader( ds: ExperimentDataset, **dataloader_kwargs: Any, ) -> DataLoader: """|DataLoader| factory method for safely wrapping an |ExperimentDataset|. Several |DataLoader| constructor parameters are not applicable, or are non-performant when using loaders from this module, including ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``. Specifying any of these parameters will result in an error. Refer to `the DataLoader docs <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_ for more information on |DataLoader| parameters, and |ExperimentDataset| for info on corresponding parameters. Args: ds: A |IterableDataset|. May include chained data pipes. **dataloader_kwargs: Additional keyword arguments to pass to the |DataLoader| constructor, except for ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``, which are not supported when using data loaders in this module. Returns: |DataLoader| Raises: ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, or ``batch_sampler`` params are passed as keyword arguments. Lifecycle: experimental """ unsupported_dataloader_args = [ "shuffle", "batch_size", "sampler", "batch_sampler", ] if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()): raise ValueError( f"The {','.join(unsupported_dataloader_args)} DataLoader parameters are not supported" ) if dataloader_kwargs.get("num_workers", 0) > 0: init_multiprocessing() if "collate_fn" not in dataloader_kwargs: dataloader_kwargs["collate_fn"] = _collate_noop return DataLoader( ds, batch_size=None, # batching is handled by upstream iterator shuffle=False, # shuffling is handled by upstream iterator **dataloader_kwargs, )
def _collate_noop(datum: _T) -> _T: """Noop collation used by |experiment_dataloader|. Private. """ return datum