rna_code.utils package

Submodules

rna_code.utils.dataset_merger module

Merge datasets

class rna_code.utils.dataset_merger.DatasetMerger

Bases: object

Merge pandas datasets using either feature intersection or union.

Returns:

New Merged dataset

Return type:

pd.DataFrame

static intersect(dataset1: DataFrame, dataset2: DataFrame) DataFrame

Merge two datasets by intersecting feature in common. Missing features from one dataset are discarded.

Parameters:
  • dataset1 (pd.DataFrame) – First dataset

  • dataset2 (pd.DataFrame) – Second dataset

Returns:

Merged dataset

Return type:

pd.DataFrame

static union(dataset1: DataFrame, dataset2: DataFrame) DataFrame

Merge datasets by union. Missing features from one dataset are 0 padded.

Parameters:
  • dataset1 (pd.DataFrame) – First dataset

  • dataset2 (pd.DataFrame) – Second Dataset

Returns:

Merged dataset

Return type:

pd.DataFrame

rna_code.utils.experiment module

rna_code.utils.helpers module

rna_code.utils.helpers.encode_recon_dataset(dataloader, model, DEVICE)

Encodes and reconstructs a dataset using a provided model.

Parameters:
  • dataloader – The DataLoader containing the dataset to be processed.

  • model – The model to be used for encoding and reconstruction.

  • DEVICE – The device (CPU/GPU) on which the model is running.

Returns:

A tuple containing the encoded and reconstructed outputs of the dataset.

Return type:

Tuple

rna_code.utils.helpers.generate_config(static_params, dynamic_params)

Generate a list of configurations for ML experiments, combining static and dynamic parameters.

Args: static_params (dict): Parameters that stay the same for each configuration. dynamic_params (dict): Dictionary of parameter names to lists of possible values.

This includes both single-value parameters and coupled parameters (as tuples).

Returns: list: A list of configuration dictionaries.

rna_code.utils.monitor_callback module

class rna_code.utils.monitor_callback.MetricsComputer

Bases: object

static compute_hopkins(X: ndarray) float
static compute_metrics(encoded_data: ndarray, true_labels: List[int], n_clusters: int) Dict[str, float]
class rna_code.utils.monitor_callback.MonitorCallback(dataloader: DataLoader, labels: List[int], n_clusters: int, evaluation_intervals: List[int] | None = None, compute_on: str = 'epoch', verbose: int = 0)

Bases: Callback

get_encoded_data(trainer, pl_module)
on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)

Called when the train batch ends.

Note

The value outputs["loss"] here will be the normalized value w.r.t accumulate_grad_batches of the loss returned from training_step.

on_train_epoch_end(trainer, pl_module)

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the pytorch_lightning.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()

rna_code.utils.search module

rna_code.utils.transfert_leanring_experiment module

rna_code.utils.visualization module

rna_code.utils.visualization.dataset_plot(data)

Plots the entire dataset as a heatmap and provides a density plot of the total gene expression.

Parameters:

data (numpy.ndarray) – Dataset to be visualized.

Displays:

A figure with two subplots: a heatmap of gene expression across cells and a KDE plot showing the density of total gene expression.

rna_code.utils.visualization.post_training_animation(monitor, metadata)

Creates a smooth animation showing the evolution of PCA results over training epochs using Plotly.

Parameters:
  • monitor (Monitor) – Monitor object containing PCA results for each epoch.

  • metadata (dict) – Metadata containing labels for the data points.

rna_code.utils.visualization.post_training_viz(data, dataloader, model, DEVICE, loss_hist, labels)

Generates 2x3 visualizations after model training, including PCA, heatmaps, and loss plots.

Parameters:
  • data (numpy.ndarray) – Original dataset.

  • dataloader (DataLoader) – DataLoader object for the dataset.

  • model (torch.nn.Module) – Trained model for encoding and reconstruction.

  • DEVICE (torch.device) – Device on which the model is running.

  • loss_hist (list) – History of training loss values.

  • labels (list) – Labels for data points, used in PCA scatter plot.

Displays:

A figure with six subplots: two rows with training loss plot, heatmaps of the original dataset, encoded space, reconstruction, and a PCA scatter plot.

Module contents