from __future__ import annotations
import os
from enum import Enum
from typing import Any, Sequence
import pandas as pd
import torch
from lightning import LightningDataModule
from sklearn.preprocessing import LabelEncoder
from tiledbsoma import ExperimentAxisQuery
from torch.utils.data import DataLoader
from tiledbsoma_ml import ExperimentDataset, experiment_dataloader
from tiledbsoma_ml._common import MiniBatch
from tiledbsoma_ml._query_ids import QueryIDs
from tiledbsoma_ml.x_locator import XLocator
DEFAULT_DATALOADER_KWARGS: dict[str, Any] = {
    "pin_memory": torch.cuda.is_available(),
    "persistent_workers": True,
    "num_workers": max(((os.cpu_count() or 1) // 2), 1),
}
class DatasetSplit(Enum):
    """Enum for dataset splits."""
    TRAIN = "train"
    VAL = "val"
[docs]
class SCVIDataModule(LightningDataModule):  # type: ignore[misc]
    """PyTorch Lightning DataModule for training scVI models from SOMA data.
    Wraps a |ExperimentDataset| to stream the results of a SOMA |ExperimentAxisQuery|,
    exposing a |DataLoader| to generate tensors ready for scVI model training. Also handles deriving
    the scVI batch label as a tuple of obs columns.
    Lifecycle:
        Experimental.
    """
    def __init__(
        self,
        query: ExperimentAxisQuery,
        *args: Any,
        batch_column_names: Sequence[str] | None = None,
        batch_labels: Sequence[str] | None = None,
        dataloader_kwargs: dict[str, Any] | None = None,
        train_size: float = 1.0,
        seed: int = 42,
        **kwargs: Any,
    ):
        """Args:
        query: |ExperimentAxisQuery|
        Defines the desired result set from a SOMA |Experiment|.
        *args, **kwargs:
        Additional arguments passed through to |ExperimentDataset|.
        batch_column_names: Sequence[str], optional
        List of obs column names, the tuple of which defines the scVI batch label (not to to be confused with
                        a batch of training data). Defaults to
                        `["dataset_id", "assay", "suspension_type", "donor_id"]`.
        batch_labels: Sequence[str], optional
        List of possible values of the batch label, for mapping to label tensors. By default,
                        this will be derived from the unique labels in the given query results (given
                        `batch_column_names`), making the label mapping depend on the query. The `batch_labels`
                        attribute in the `SCVIDataModule` used for training may be saved and here restored in
                        another instance for a different query. That ensures the label mapping will be correct
                        for the trained model, even if the second query doesn't return examples of every
                        training batch label.
        dataloader_kwargs: dict, optional
        Keyword arguments passed to `tiledbsoma_ml.experiment_dataloader()`, e.g. `num_workers`.
        train_size: float, optional
        Fraction of data to use for training (between 0 and 1). Default is 1.0 (use all data for training).
        If less than 1.0, the remaining data will be used for validation.
        seed: int, optional
        Random seed for deterministic train/validation split. Default is 42.
        """
        super().__init__()
        self.query = query
        self.dataset_args = args
        self.dataset_kwargs = kwargs
        self.dataloader_kwargs = {
            **DEFAULT_DATALOADER_KWARGS,
            **(dataloader_kwargs or {}),
        }
        self.batch_column_names = (
            batch_column_names
            if batch_column_names is not None
            else ["dataset_id", "assay", "suspension_type", "donor_id"]
        )
        self.batch_colsep = "//"
        self.batch_colname = "scvi_batch"
        # prepare LabelEncoder for the scVI batch label:
        #   1. read obs DataFrame for the whole query result set
        #   2. add scvi_batch column
        #   3. fit LabelEncoder to the scvi_batch column's unique values
        if batch_labels is None:
            obs_df = (
                self.query.obs(column_names=self.batch_column_names)
                .concat()
                .to_pandas()
            )
            self._add_batch_col(obs_df, inplace=True)
            batch_labels = obs_df[self.batch_colname].unique()
        self.batch_labels = batch_labels
        self.batch_encoder = LabelEncoder().fit(self.batch_labels)
        self.train_size = train_size
        self.seed = seed
        self.train_query_ids: QueryIDs | None = None
        self.val_query_ids: QueryIDs | None = None
        self.x_locator: XLocator | None = None
        self.layer_name = kwargs.get("layer_name", "raw")
    def setup(self, stage: str | None = None) -> None:
        # Create QueryIDs and XLocator from the query
        query_ids = QueryIDs.create(self.query)
        self.x_locator = XLocator.create(
            self.query.experiment,
            measurement_name=self.query.measurement_name,
            layer_name=self.layer_name,
        )
        # Split data into train and validation sets if train_size < 1.0
        if self.train_size < 1.0:
            # Use QueryIDs.random_split() for efficient splitting
            val_size = 1.0 - self.train_size
            train_ids, val_ids = query_ids.random_split(
                self.train_size, val_size, seed=self.seed
            )
            self.train_query_ids = train_ids
            self.val_query_ids = val_ids
        else:
            # Use all data for training
            self.train_query_ids = query_ids
            self.val_query_ids = None
    def _create_dataloader(self, split: DatasetSplit) -> DataLoader | None:
        """Create a dataloader for the specified dataset split.
        Args:
            split: The dataset split (TRAIN or VAL)
        Returns:
            DataLoader for the specified split, or None if the split doesn't exist
        """
        # Get the appropriate query_ids based on split
        query_ids_map = {
            DatasetSplit.TRAIN: self.train_query_ids,
            DatasetSplit.VAL: self.val_query_ids,
        }
        query_ids = query_ids_map.get(split)
        if query_ids is None or self.x_locator is None:
            return None
        # Filter out query and layer_name from dataset_kwargs since we're using x_locator and query_ids
        filtered_kwargs = {
            k: v
            for k, v in self.dataset_kwargs.items()
            if k not in ("query", "layer_name")
        }
        # Create dataset with appropriate query_ids
        dataset = ExperimentDataset(
            x_locator=self.x_locator,
            query_ids=query_ids,
            obs_column_names=list(self.batch_column_names),
            **filtered_kwargs,
        )
        return experiment_dataloader(
            dataset,
            **self.dataloader_kwargs,
        )
    def train_dataloader(self) -> DataLoader:
        """Create the training dataloader.
        Returns:
            DataLoader for training data
        Raises:
            AssertionError: If setup() hasn't been called
        """
        loader = self._create_dataloader(DatasetSplit.TRAIN)
        assert loader is not None, "setup() must be called before train_dataloader()"
        return loader
    def val_dataloader(self) -> DataLoader | None:
        """Create the validation dataloader.
        Returns:
            DataLoader for validation data, or None if no validation split exists
        """
        return self._create_dataloader(DatasetSplit.VAL)
    def _add_batch_col(
        self, obs_df: pd.DataFrame, inplace: bool = False
    ) -> pd.DataFrame:
        # synthesize a new column for obs_df by concatenating the self.batch_column_names columns
        if not inplace:
            obs_df = obs_df.copy()
        obs_df[self.batch_colname] = (
            obs_df[self.batch_column_names]
            .astype(str)
            .agg(self.batch_colsep.join, axis=1)
        )
        return obs_df
    def on_before_batch_transfer(
        self,
        batch: MiniBatch,
        dataloader_idx: int,
    ) -> dict[str, torch.Tensor | None]:
        # DataModule hook: transform the ExperimentDataset data batch (X: ndarray, obs_df: DataFrame)
        # into X & batch variable tensors for scVI (using batch_encoder on scvi_batch)
        batch_X, batch_obs = batch
        self._add_batch_col(batch_obs, inplace=True)
        return {
            "X": torch.from_numpy(batch_X).float(),
            "batch": torch.from_numpy(
                self.batch_encoder.transform(batch_obs[self.batch_colname])
            ).unsqueeze(1),
            "labels": torch.empty(0),
        }
    # scVI expects these properties on the DataModule:
    @property
    def n_obs(self) -> int:
        return len(self.query.obs_joinids())
    @property
    def n_vars(self) -> int:
        return len(self.query.var_joinids())
    @property
    def n_batch(self) -> int:
        return len(self.batch_encoder.classes_)