# Copyright (c) TileDB, Inc. and The Chan Zuckerberg Initiative Foundation
#
# Licensed under the MIT License.
import gc
import logging
import time
from typing import Iterable, Iterator, Optional, Sequence, Tuple
import attrs
import numpy as np
import pandas as pd
import pyarrow as pa
from tiledbsoma import DataFrame, IntIndexer, SparseNDArray
from tiledbsoma_ml._common import NDArrayJoinId
from tiledbsoma_ml._csr import CSR_IO_Buffer
from tiledbsoma_ml._eager_iter import EagerIterator
from tiledbsoma_ml._query_ids import Chunks
from tiledbsoma_ml._utils import batched
logger = logging.getLogger("tiledbsoma_ml._io_batch_iterable")
IOBatch = Tuple[CSR_IO_Buffer, pd.DataFrame]
"""Tuple type emitted by |IOBatchIterable|, containing ``X`` rows (as a |CSR_IO_Buffer|) and ``obs`` rows
(|pd.DataFrame|)."""
[docs]
@attrs.define(frozen=True)
class IOBatchIterable(Iterable[IOBatch]):
"""Given a list of ``obs_joinid`` |Chunks|, re-chunk them into (optionally shuffled) |IOBatch|'s".
An |IOBatch| is a tuple consisting of a batch of rows from the ``X`` |SparseNDArray|, as well as the corresponding
rows from the ``obs`` |DataFrame|. The ``X`` rows are returned in an optimized |CSR_IO_Buffer|.
"""
chunks: Chunks
io_batch_size: int
obs: DataFrame
var_joinids: NDArrayJoinId
X: SparseNDArray
obs_column_names: Sequence[str] = ("soma_joinid",)
seed: Optional[int] = None
shuffle: bool = True
use_eager_fetch: bool = True
@property
def io_batch_ids(self) -> Iterable[Tuple[int, ...]]:
"""Re-chunk ``obs_joinids`` according to the desired ``io_batch_size``."""
return batched(
(joinid for chunk in self.chunks for joinid in chunk),
self.io_batch_size,
)
def __iter__(self) -> Iterator[IOBatch]:
"""Emit |IOBatch|'s."""
# Because obs/var IDs have been partitioned/split/shuffled upstream of this class, this RNG does not need to be
# identical across sub-processes, but seeding is supported anyway, for testing/reproducibility.
shuffle_rng = np.random.default_rng(self.seed)
X = self.X
context = X.context
obs_column_names = (
list(self.obs_column_names)
if "soma_joinid" in self.obs_column_names
else ["soma_joinid", *self.obs_column_names]
)
# NOTE: `.astype("int64")` works around the `np.int64` singleton failing reference-equality after cross-process
# SerDes.
var_joinids = self.var_joinids.astype("int64")
var_indexer = IntIndexer(var_joinids, context=context)
for obs_coords in self.io_batch_ids:
st_time = time.perf_counter()
obs_shuffled_coords = (
np.array(obs_coords)
if not self.shuffle
else shuffle_rng.permuted(obs_coords)
)
obs_indexer = IntIndexer(obs_shuffled_coords, context=context)
logger.debug(
f"Retrieving next SOMA IO batch of length {len(obs_coords)}..."
)
# To maximize opportunities for concurrency, when in eager_fetch mode,
# create the X read iterator first, as the eager iterator will begin
# the read-ahead immediately. Then proceed to fetch obs DataFrame.
# This matters most on latent backing stores, e.g., S3.
X_tbl_iter: Iterator[pa.Table] = X.read(
coords=(obs_coords, self.var_joinids)
).tables()
def make_io_buffer(
X_tbl: pa.Table,
obs_coords: NDArrayJoinId,
var_coords: NDArrayJoinId,
obs_indexer: IntIndexer,
) -> CSR_IO_Buffer:
"""This function provides a GC after we throw off (large) garbage."""
m = CSR_IO_Buffer.from_ijd(
obs_indexer.get_indexer(X_tbl["soma_dim_0"]),
var_indexer.get_indexer(X_tbl["soma_dim_1"]),
X_tbl["soma_data"].to_numpy(),
shape=(len(obs_coords), len(var_coords)),
)
gc.collect(generation=0)
return m
_io_buf_iter: Iterator[CSR_IO_Buffer] = (
make_io_buffer(
X_tbl=X_tbl,
obs_coords=np.array(obs_coords),
var_coords=self.var_joinids,
obs_indexer=obs_indexer,
)
for X_tbl in X_tbl_iter
)
if self.use_eager_fetch:
_io_buf_iter = EagerIterator(_io_buf_iter, pool=X.context.threadpool)
# Now that X read is potentially in progress (in eager mode), go fetch obs data
# fmt: off
obs_io_batch = (
self.obs.read(coords=(obs_coords,), column_names=obs_column_names)
.concat()
.to_pandas()
.set_index("soma_joinid")
.reindex(obs_shuffled_coords, copy=False)
.reset_index() # demote "soma_joinid" to a column
[self.obs_column_names]
) # fmt: on
X_io_batch = CSR_IO_Buffer.merge(tuple(_io_buf_iter))
del obs_indexer, obs_coords, obs_shuffled_coords, _io_buf_iter
gc.collect()
tm = time.perf_counter() - st_time
logger.debug(
f"Retrieved SOMA IO batch, took {tm:.2f}sec, {X_io_batch.shape[0]/tm:0.1f} samples/sec"
)
yield X_io_batch, obs_io_batch