Source code for tiledbsoma_ml.scvi

from __future__ import annotations

import os
from enum import Enum
from typing import Any, Sequence

import pandas as pd
import torch
from lightning import LightningDataModule
from sklearn.preprocessing import LabelEncoder
from tiledbsoma import ExperimentAxisQuery
from torch.utils.data import DataLoader

from tiledbsoma_ml import ExperimentDataset, experiment_dataloader
from tiledbsoma_ml._common import MiniBatch
from tiledbsoma_ml._query_ids import QueryIDs
from tiledbsoma_ml.x_locator import XLocator

DEFAULT_DATALOADER_KWARGS: dict[str, Any] = {
    "pin_memory": torch.cuda.is_available(),
    "persistent_workers": True,
    "num_workers": max(((os.cpu_count() or 1) // 2), 1),
}


class DatasetSplit(Enum):
    """Enum for dataset splits."""

    TRAIN = "train"
    VAL = "val"


[docs] class SCVIDataModule(LightningDataModule): # type: ignore[misc] """PyTorch Lightning DataModule for training scVI models from SOMA data. Wraps a |ExperimentDataset| to stream the results of a SOMA |ExperimentAxisQuery|, exposing a |DataLoader| to generate tensors ready for scVI model training. Also handles deriving the scVI batch label as a tuple of obs columns. Lifecycle: Experimental. """ def __init__( self, query: ExperimentAxisQuery, *args: Any, batch_column_names: Sequence[str] | None = None, batch_labels: Sequence[str] | None = None, dataloader_kwargs: dict[str, Any] | None = None, train_size: float = 1.0, seed: int = 42, **kwargs: Any, ): """Args: query: |ExperimentAxisQuery| Defines the desired result set from a SOMA |Experiment|. *args, **kwargs: Additional arguments passed through to |ExperimentDataset|. batch_column_names: Sequence[str], optional List of obs column names, the tuple of which defines the scVI batch label (not to to be confused with a batch of training data). Defaults to `["dataset_id", "assay", "suspension_type", "donor_id"]`. batch_labels: Sequence[str], optional List of possible values of the batch label, for mapping to label tensors. By default, this will be derived from the unique labels in the given query results (given `batch_column_names`), making the label mapping depend on the query. The `batch_labels` attribute in the `SCVIDataModule` used for training may be saved and here restored in another instance for a different query. That ensures the label mapping will be correct for the trained model, even if the second query doesn't return examples of every training batch label. dataloader_kwargs: dict, optional Keyword arguments passed to `tiledbsoma_ml.experiment_dataloader()`, e.g. `num_workers`. train_size: float, optional Fraction of data to use for training (between 0 and 1). Default is 1.0 (use all data for training). If less than 1.0, the remaining data will be used for validation. seed: int, optional Random seed for deterministic train/validation split. Default is 42. """ super().__init__() self.query = query self.dataset_args = args self.dataset_kwargs = kwargs self.dataloader_kwargs = { **DEFAULT_DATALOADER_KWARGS, **(dataloader_kwargs or {}), } self.batch_column_names = ( batch_column_names if batch_column_names is not None else ["dataset_id", "assay", "suspension_type", "donor_id"] ) self.batch_colsep = "//" self.batch_colname = "scvi_batch" # prepare LabelEncoder for the scVI batch label: # 1. read obs DataFrame for the whole query result set # 2. add scvi_batch column # 3. fit LabelEncoder to the scvi_batch column's unique values if batch_labels is None: obs_df = ( self.query.obs(column_names=self.batch_column_names) .concat() .to_pandas() ) self._add_batch_col(obs_df, inplace=True) batch_labels = obs_df[self.batch_colname].unique() self.batch_labels = batch_labels self.batch_encoder = LabelEncoder().fit(self.batch_labels) self.train_size = train_size self.seed = seed self.train_query_ids: QueryIDs | None = None self.val_query_ids: QueryIDs | None = None self.x_locator: XLocator | None = None self.layer_name = kwargs.get("layer_name", "raw") def setup(self, stage: str | None = None) -> None: # Create QueryIDs and XLocator from the query query_ids = QueryIDs.create(self.query) self.x_locator = XLocator.create( self.query.experiment, measurement_name=self.query.measurement_name, layer_name=self.layer_name, ) # Split data into train and validation sets if train_size < 1.0 if self.train_size < 1.0: # Use QueryIDs.random_split() for efficient splitting val_size = 1.0 - self.train_size train_ids, val_ids = query_ids.random_split( self.train_size, val_size, seed=self.seed ) self.train_query_ids = train_ids self.val_query_ids = val_ids else: # Use all data for training self.train_query_ids = query_ids self.val_query_ids = None def _create_dataloader(self, split: DatasetSplit) -> DataLoader | None: """Create a dataloader for the specified dataset split. Args: split: The dataset split (TRAIN or VAL) Returns: DataLoader for the specified split, or None if the split doesn't exist """ # Get the appropriate query_ids based on split query_ids_map = { DatasetSplit.TRAIN: self.train_query_ids, DatasetSplit.VAL: self.val_query_ids, } query_ids = query_ids_map.get(split) if query_ids is None or self.x_locator is None: return None # Filter out query and layer_name from dataset_kwargs since we're using x_locator and query_ids filtered_kwargs = { k: v for k, v in self.dataset_kwargs.items() if k not in ("query", "layer_name") } # Create dataset with appropriate query_ids dataset = ExperimentDataset( x_locator=self.x_locator, query_ids=query_ids, obs_column_names=list(self.batch_column_names), **filtered_kwargs, ) return experiment_dataloader( dataset, **self.dataloader_kwargs, ) def train_dataloader(self) -> DataLoader: """Create the training dataloader. Returns: DataLoader for training data Raises: AssertionError: If setup() hasn't been called """ loader = self._create_dataloader(DatasetSplit.TRAIN) assert loader is not None, "setup() must be called before train_dataloader()" return loader def val_dataloader(self) -> DataLoader | None: """Create the validation dataloader. Returns: DataLoader for validation data, or None if no validation split exists """ return self._create_dataloader(DatasetSplit.VAL) def _add_batch_col( self, obs_df: pd.DataFrame, inplace: bool = False ) -> pd.DataFrame: # synthesize a new column for obs_df by concatenating the self.batch_column_names columns if not inplace: obs_df = obs_df.copy() obs_df[self.batch_colname] = ( obs_df[self.batch_column_names] .astype(str) .agg(self.batch_colsep.join, axis=1) ) return obs_df def on_before_batch_transfer( self, batch: MiniBatch, dataloader_idx: int, ) -> dict[str, torch.Tensor | None]: # DataModule hook: transform the ExperimentDataset data batch (X: ndarray, obs_df: DataFrame) # into X & batch variable tensors for scVI (using batch_encoder on scvi_batch) batch_X, batch_obs = batch self._add_batch_col(batch_obs, inplace=True) return { "X": torch.from_numpy(batch_X).float(), "batch": torch.from_numpy( self.batch_encoder.transform(batch_obs[self.batch_colname]) ).unsqueeze(1), "labels": torch.empty(0), } # scVI expects these properties on the DataModule: @property def n_obs(self) -> int: return len(self.query.obs_joinids()) @property def n_vars(self) -> int: return len(self.query.var_joinids()) @property def n_batch(self) -> int: return len(self.batch_encoder.classes_)