# Copyright (c) TileDB, Inc. and The Chan Zuckerberg Initiative Foundation
#
# Licensed under the MIT License.
from __future__ import annotations
import logging
import os
from typing import Iterable, Iterator
import attrs
import numpy as np
import pandas as pd
import torch
from scipy import sparse
from tiledbsoma_ml._common import MiniBatch
from tiledbsoma_ml._eager_iter import EagerIterator
from tiledbsoma_ml._io_batch_iterable import IOBatchIterable
logger = logging.getLogger("tiledbsoma_ml._mini_batch_iterable")
[docs]
@attrs.define(frozen=True)
class MiniBatchIterable(Iterable[MiniBatch]):
    """Convert (possibly shuffled) |IOBatchIterable| into |MiniBatch|'s suitable for passing to PyTorch."""
    io_batch_iter: IOBatchIterable
    batch_size: int
    use_eager_fetch: bool = True
    return_sparse_X: bool = False
    gpu_shuffle: bool = False
    gpu_shuffle_mode: str = "iobatch"
    device: torch.device | None = None
    seed: int | None = None
    epoch: int = 0
    def _gpu_perm(self, n: int) -> torch.Tensor:
        """Deterministic permutation of range(n) seeded by (seed, epoch, pid)."""
        base = int(self.seed or 0)
        pid = os.getpid()
        mixed = (base * 1315423911 + self.epoch * 2654435761 + pid) & 0xFFFFFFFF
        gen_device = (
            self.device
            if (
                self.device is not None and getattr(self.device, "type", None) == "cuda"
            )
            else "cpu"
        )
        g = torch.Generator(device=gen_device)
        g.manual_seed(mixed)
        return torch.randperm(n, generator=g, device=gen_device)
    def _iter(self) -> Iterator[MiniBatch]:
        batch_size = self.batch_size
        result: MiniBatch | None = None
        for X_io_batch, obs_io_batch in self.io_batch_iter:
            assert X_io_batch.shape[0] == obs_io_batch.shape[0]
            iob_idx = 0  # current offset into io batch
            iob_len = X_io_batch.shape[0]
            # GPU within-IO-batch shuffle (dense only)
            if self.gpu_shuffle and self.gpu_shuffle_mode == "iobatch":
                if self.return_sparse_X:
                    logger.warning(
                        "GPU shuffle requested but return_sparse_X=True; leaving IO-batch order unchanged."
                    )
                else:
                    perm = self._gpu_perm(iob_len)
                    perm_cpu = perm.to("cpu", non_blocking=False).numpy()
                    X_full = X_io_batch.slice_tonumpy(slice(0, iob_len))
                    X_t = torch.from_numpy(X_full)
                    if (
                        self.device is not None
                        and getattr(self.device, "type", None) == "cuda"
                    ):
                        if not X_t.is_pinned():
                            X_t = X_t.pin_memory()  # faster H2D
                        X_t = X_t.to(self.device, non_blocking=True)
                    X_t = X_t.index_select(0, perm).contiguous()
                    X_cpu = X_t.to("cpu", non_blocking=False).numpy()
                    obs_perm = obs_io_batch.iloc[perm_cpu].reset_index(drop=True)
                    # Emit mini-batches from the permuted IO-batch
                    for start in range(0, iob_len, self.batch_size):
                        stop = min(start + self.batch_size, iob_len)
                        yield (
                            X_cpu[start:stop],
                            obs_perm.iloc[start:stop].reset_index(drop=True),
                        )
                    continue  # done with this IO-batch
            while iob_idx < iob_len:
                if result is None:
                    # perform zero copy slice where possible
                    X_datum = (
                        X_io_batch.slice_toscipy(slice(iob_idx, iob_idx + batch_size))
                        if self.return_sparse_X
                        else X_io_batch.slice_tonumpy(
                            slice(iob_idx, iob_idx + batch_size)
                        )
                    )
                    result = (
                        X_datum,
                        obs_io_batch.iloc[iob_idx : iob_idx + batch_size].reset_index(
                            drop=True
                        ),
                    )
                    iob_idx += len(result[1])
                else:
                    # Use any remnant from previous IO batch
                    to_take = min(batch_size - len(result[1]), iob_len - iob_idx)
                    X_datum = (
                        sparse.vstack(
                            [result[0], X_io_batch.slice_toscipy(slice(0, to_take))]
                        )
                        if self.return_sparse_X
                        else np.concatenate(
                            [result[0], X_io_batch.slice_tonumpy(slice(0, to_take))]
                        )
                    )
                    result = (
                        X_datum,
                        pd.concat(
                            [result[1], obs_io_batch.iloc[0:to_take]],
                            # Index `obs_batch` from 0 to N-1, instead of disjoint, concatenated pieces of IO batches'
                            # indices
                            ignore_index=True,
                        ),
                    )
                    iob_idx += to_take
                X, obs = result
                if (
                    self.gpu_shuffle
                    and self.gpu_shuffle_mode == "minibatch"
                    and not self.return_sparse_X
                ):
                    mb_n = X.shape[0]
                    perm = self._gpu_perm(mb_n)
                    perm_cpu = perm.to("cpu", non_blocking=False).numpy()
                    X_t = torch.from_numpy(X)
                    if (
                        self.device is not None
                        and getattr(self.device, "type", None) == "cuda"
                    ):
                        if not X_t.is_pinned():
                            X_t = X_t.pin_memory()
                        X_t = X_t.to(self.device, non_blocking=True)
                    X_t = X_t.index_select(0, perm).contiguous()
                    X = X_t.to("cpu", non_blocking=False).numpy()
                    obs = obs.iloc[perm_cpu].reset_index(drop=True)
                assert X.shape[0] == obs.shape[0]
                if X.shape[0] == batch_size:
                    yield result
                    result = None
        else:
            # yield the remnant, if any
            if result is not None:
                yield result
    def __iter__(self) -> Iterator[MiniBatch]:
        it = map(self.maybe_squeeze, self._iter())
        return EagerIterator(it) if self.use_eager_fetch else it
[docs]
    def maybe_squeeze(self, mini_batch: MiniBatch) -> MiniBatch:
        X, obs = mini_batch
        if self.batch_size == 1:
            # This is a no-op for `csr_matrix`s
            return X[0], obs
        else:
            return mini_batch