# Copyright (c) TileDB, Inc. and The Chan Zuckerberg Initiative Foundation
#
# Licensed under the MIT License.
from __future__ import annotations
from typing import Any, TypeVar
from torch.utils.data import DataLoader
from tiledbsoma_ml._distributed import init_multiprocessing
from tiledbsoma_ml.dataset import ExperimentDataset
_T = TypeVar("_T")
[docs]
def experiment_dataloader(
    ds: ExperimentDataset,
    **dataloader_kwargs: Any,
) -> DataLoader:
    """|DataLoader| factory method for safely wrapping an |ExperimentDataset|.
    Several |DataLoader| constructor parameters are not applicable, or are non-performant when using loaders from this
    module, including ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``. Specifying any of these
    parameters will result in an error.
    Refer to `the DataLoader docs <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_ for more
    information on |DataLoader| parameters, and |ExperimentDataset| for info on corresponding parameters.
    Args:
        ds:
            A |IterableDataset|. May include chained data pipes.
        **dataloader_kwargs:
            Additional keyword arguments to pass to the |DataLoader| constructor,
            except for ``shuffle``, ``batch_size``, ``sampler``, and ``batch_sampler``, which are not
            supported when using data loaders in this module.
    Returns:
        |DataLoader|
    Raises:
        ValueError: if any of the ``shuffle``, ``batch_size``, ``sampler``, or ``batch_sampler`` params
            are passed as keyword arguments.
    Lifecycle:
        experimental
    """
    unsupported_dataloader_args = [
        "shuffle",
        "batch_size",
        "sampler",
        "batch_sampler",
    ]
    if set(unsupported_dataloader_args).intersection(dataloader_kwargs.keys()):
        raise ValueError(
            f"The {','.join(unsupported_dataloader_args)} DataLoader parameters are not supported"
        )
    if dataloader_kwargs.get("num_workers", 0) > 0:
        init_multiprocessing()
    if "collate_fn" not in dataloader_kwargs:
        dataloader_kwargs["collate_fn"] = _collate_noop
    return DataLoader(
        ds,
        batch_size=None,  # batching is handled by upstream iterator
        shuffle=False,  # shuffling is handled by upstream iterator
        **dataloader_kwargs,
    ) 
def _collate_noop(datum: _T) -> _T:
    """Noop collation used by |experiment_dataloader|.
    Private.
    """
    return datum