Source code for tiledbsoma_ml._csr

# Copyright (c) TileDB, Inc. and The Chan Zuckerberg Initiative Foundation
#
# Licensed under the MIT License.
"""CSR sparse matrix implementation, optimized for incrementally building from COO matrices.

Private module.
"""

from math import ceil
from typing import Any, List, Sequence, Tuple, Type

import numba
import numpy as np
import numpy.typing as npt
from scipy import sparse
from typing_extensions import Self

from tiledbsoma_ml._common import NDArrayNumber

_CSRIdxArray = npt.NDArray[np.unsignedinteger[Any]]


[docs] class CSR_IO_Buffer: """Implement a minimal CSR matrix with specific optimizations for use in this package. Operations supported are: - Incrementally build a CSR from COO, allowing overlapped I/O and CSR conversion for I/O batches, and a final "merge" step which combines the result. - Zero intermediate copy conversion of an arbitrary row slice to dense (i.e., mini-batch extraction). - Parallel processing, where possible (construction, merge, etc.). - Minimize memory use for index arrays. Overall is significantly faster, and uses less memory, than the equivalent ``scipy.sparse`` operations. """ __slots__ = ("indptr", "indices", "data", "shape")
[docs] def __init__( self, indptr: _CSRIdxArray, indices: _CSRIdxArray, data: NDArrayNumber, shape: Tuple[int, int], ) -> None: """Construct from PJV format.""" assert len(data) == len(indices) assert len(data) <= np.iinfo(indptr.dtype).max assert shape[1] <= np.iinfo(indices.dtype).max assert indptr[-1] == len(data) and indptr[0] == 0 self.indptr = indptr self.indices = indices self.data = data self.shape = shape
[docs] @staticmethod def from_ijd( i: _CSRIdxArray, j: _CSRIdxArray, d: NDArrayNumber, shape: Tuple[int, int] ) -> "CSR_IO_Buffer": """Build a |CSR_IO_Buffer| from a COO sparse matrix representation.""" nnz = len(d) indptr: _CSRIdxArray = np.zeros((shape[0] + 1), dtype=smallest_uint_dtype(nnz)) indices: _CSRIdxArray = np.empty((nnz,), dtype=smallest_uint_dtype(shape[1])) data = np.empty((nnz,), dtype=d.dtype) _coo_to_csr_inner(shape[0], i, j, d, indptr, indices, data) return CSR_IO_Buffer(indptr, indices, data, shape)
[docs] @staticmethod def from_pjd( p: _CSRIdxArray, j: _CSRIdxArray, d: NDArrayNumber, shape: Tuple[int, int] ) -> "CSR_IO_Buffer": """Build a |CSR_IO_Buffer| from a SCR sparse matrix representation.""" return CSR_IO_Buffer(p, j, d, shape)
@property def nnz(self) -> int: """Number of nonzero elements.""" return len(self.indices) @property def nbytes(self) -> int: """Total bytes used by ``indptr``, ``indices``, and ``data`` arrays.""" return int(self.indptr.nbytes + self.indices.nbytes + self.data.nbytes) @property def dtype(self) -> npt.DTypeLike: """Underlying Numpy dtype.""" return self.data.dtype
[docs] def slice_tonumpy(self, row_index: slice) -> NDArrayNumber: """Extract slice as a dense ndarray. Does not assume any particular ordering of minor axis. """ assert isinstance(row_index, slice) assert row_index.step in (1, None) row_idx_start, row_idx_end, _ = row_index.indices(self.indptr.shape[0] - 1) n_rows = max(row_idx_end - row_idx_start, 0) out = np.zeros((n_rows, self.shape[1]), dtype=self.data.dtype) if n_rows >= 0: _csr_to_dense_inner( row_idx_start, n_rows, self.indptr, self.indices, self.data, out ) return out
[docs] def slice_toscipy(self, row_index: slice) -> sparse.csr_matrix: """Extract slice as a ``sparse.csr_matrix``. Does not assume any particular ordering of minor axis, but will return a canonically ordered scipy sparse object. """ assert isinstance(row_index, slice) assert row_index.step in (1, None) row_idx_start, row_idx_end, _ = row_index.indices(self.indptr.shape[0] - 1) n_rows = max(row_idx_end - row_idx_start, 0) if n_rows == 0: return sparse.csr_matrix((0, self.shape[1]), dtype=self.dtype) indptr = self.indptr[row_idx_start : row_idx_end + 1].copy() indices = self.indices[indptr[0] : indptr[-1]].copy() data = self.data[indptr[0] : indptr[-1]].copy() indptr -= indptr[0] return sparse.csr_matrix((data, indices, indptr), shape=(n_rows, self.shape[1]))
[docs] @staticmethod def merge(mtxs: Sequence["CSR_IO_Buffer"]) -> "CSR_IO_Buffer": r"""Merge |CSR_IO_Buffer|\ s.""" assert len(mtxs) > 0 nnz = sum(m.nnz for m in mtxs) shape = mtxs[0].shape for m in mtxs[1:]: assert m.shape == mtxs[0].shape assert m.indices.dtype == mtxs[0].indices.dtype assert all(m.shape == shape for m in mtxs) indptr = np.sum( [m.indptr for m in mtxs], axis=0, dtype=smallest_uint_dtype(nnz) ) indices = np.empty((nnz,), dtype=mtxs[0].indices.dtype) data = np.empty((nnz,), mtxs[0].data.dtype) _csr_merge_inner( tuple((m.indptr.astype(indptr.dtype), m.indices, m.data) for m in mtxs), indptr, indices, data, ) return CSR_IO_Buffer.from_pjd(indptr, indices, data, shape)
[docs] def sort_indices(self) -> Self: """Sort indices (in place).""" _csr_sort_indices(self.indptr, self.indices, self.data) return self
[docs] def smallest_uint_dtype(max_val: int) -> Type[np.unsignedinteger[Any]]: """Return the smallest unsigned-int dtype that can contain ``max_val``.""" dts: List[Type[np.unsignedinteger[Any]]] = [np.uint16, np.uint32] for dt in dts: if max_val <= np.iinfo(dt).max: return dt else: return np.uint64
@numba.njit(nogil=True, parallel=True) # type: ignore[misc] def _csr_merge_inner( As: Tuple[Tuple[_CSRIdxArray, _CSRIdxArray, NDArrayNumber], ...], # P,J,D Bp: _CSRIdxArray, Bj: _CSRIdxArray, Bd: NDArrayNumber, ) -> None: n_rows = len(Bp) - 1 offsets = Bp.copy() for Ap, Aj, Ad in As: n_elmts = Ap[1:] - Ap[:-1] for n in numba.prange(n_rows): Bj[offsets[n] : offsets[n] + n_elmts[n]] = Aj[Ap[n] : Ap[n] + n_elmts[n]] Bd[offsets[n] : offsets[n] + n_elmts[n]] = Ad[Ap[n] : Ap[n] + n_elmts[n]] offsets[:-1] += n_elmts @numba.njit(nogil=True, parallel=True) # type: ignore[misc] def _csr_to_dense_inner( row_idx_start: int, n_rows: int, indptr: _CSRIdxArray, indices: _CSRIdxArray, data: NDArrayNumber, out: NDArrayNumber, ) -> None: for i in numba.prange(row_idx_start, row_idx_start + n_rows): for j in range(indptr[i], indptr[i + 1]): out[i - row_idx_start, indices[j]] = data[j] @numba.njit(nogil=True, parallel=True, inline="always") # type: ignore[misc] def _count_rows(n_rows: int, Ai: NDArrayNumber, Bp: NDArrayNumber) -> NDArrayNumber: """Private: parallel row count.""" nnz = len(Ai) partition_size = 32 * 1024**2 n_partitions = ceil(nnz / partition_size) if n_partitions > 1: counts = np.zeros((n_partitions, n_rows), dtype=Bp.dtype) for p in numba.prange(n_partitions): for n in range(p * partition_size, min(nnz, (p + 1) * partition_size)): row = Ai[n] counts[p, row] += 1 Bp[:-1] = counts.sum(axis=0) else: for n in range(nnz): row = Ai[n] Bp[row] += 1 return Bp @numba.njit(nogil=True, parallel=True) # type: ignore[misc] def _coo_to_csr_inner( n_rows: int, Ai: _CSRIdxArray, Aj: _CSRIdxArray, Ad: NDArrayNumber, Bp: _CSRIdxArray, Bj: _CSRIdxArray, Bd: NDArrayNumber, ) -> None: nnz = len(Ai) _count_rows(n_rows, Ai, Bp) # cum sum to get the row index pointers (NOTE: starting with zero) cumsum = 0 for n in range(n_rows): tmp = Bp[n] Bp[n] = cumsum cumsum += tmp Bp[n_rows] = nnz # Reorganize all the data. Side effect: pointers shifted (reversed in the # subsequent section). # # Method is concurrent (partitioned by rows) if number of rows is greater # than 2**partition_bits. This partitioning scheme leverages the fact # that reads are much cheaper than writes. # # The code is equivalent to: # for n in range(nnz): # row = Ai[n] # dst_row = Bp[row] # Bj[dst_row] = Aj[n] # Bd[dst_row] = Ad[n] # Bp[row] += 1 partition_bits = 13 n_partitions = (n_rows + 2**partition_bits - 1) >> partition_bits for p in numba.prange(n_partitions): for n in range(nnz): row = Ai[n] if (row >> partition_bits) != p: continue dst_row = Bp[row] Bj[dst_row] = Aj[n] Bd[dst_row] = Ad[n] Bp[row] += 1 # Shift the pointers by one slot (i.e., start at zero) prev_ptr = 0 for n in range(n_rows + 1): tmp = Bp[n] Bp[n] = prev_ptr prev_ptr = tmp @numba.njit(nogil=True, parallel=True) # type: ignore[misc] def _csr_sort_indices(Bp: _CSRIdxArray, Bj: _CSRIdxArray, Bd: NDArrayNumber) -> None: """In-place sort of minor axis indices.""" n_rows = len(Bp) - 1 for r in numba.prange(n_rows): row_start = Bp[r] row_end = Bp[r + 1] order = np.argsort(Bj[row_start:row_end]) Bj[row_start:row_end] = Bj[row_start:row_end][order] Bd[row_start:row_end] = Bd[row_start:row_end][order]