Datasets

The PrunaDataModule is a class that allows you to load and process data for your model. It can be used directly, through the collate functions or through the SmashConfig.

Usage examples

Usage example of PrunaDataModule

from pruna.data import PrunaDataModule

datamodule = PrunaDataModule(
    datasets=["dataset1", "dataset2"],
    collate_fn="image_generation_collate", # expect a specific format for the datasets
)

Usage example of collate functions

from pruna import SmashConfig
from pruna.data.utils import split_train_into_train_val_test
from datasets import load_dataset

# Load custom datasets
train_ds = load_dataset("SamuelYang/bookcorpus")["train"]
train_ds, val_ds, test_ds = split_train_into_train_val_test(train_ds, seed=42)

# Add to SmashConfig
smash_config = SmashConfig()
smash_config.add_tokenizer("bert-base-uncased")
smash_config.add_data(
    (train_ds, val_ds, test_ds),
    collate_fn="text_generation_collate"
)

Usage example of SmashConfig

from pruna import SmashConfig

smash_config = SmashConfig()
smash_config.add_data("dataset1")

Function API collate

image_generation_collate(data, img_size, output_format='int')

Custom collation function for text-to-image generation datasets.

Expects a image column containing PIL images and a text column containing the clear-textprompt for the image generation in the dataset.

Parameters:
  • data (Any) – The data to collate.

  • img_size (int) – The size of the image to resize to.

  • output_format (str) – The output format, in [“int”, “float”, “normalized”]. With “int”, output tensors have integer values between 0 and 255. With “float”, they have float values between 0 and 1. With “normalized”, they have float values between -1 and 1.

Returns:

The collated data with size img_size and normalized to [0, 1].

Return type:

Tuple[torch.Tensor, Any]

audio_collate(data)

Custom collation function for audio datasets.

Expects a audio/path column containing the path to the audio samples and a sentence column containing the clear-text transcription of the audio samples in the dataset.

Parameters:

data (Any) – The data to collate.

Returns:

The collated data.

Return type:

List[str]

image_classification_collate(data, img_size, output_format='int')

Custom collation function for image classification datasets.

Expects a image column containing PIL images and a label column containing the class label in the dataset.

Parameters:
  • data (Any) – The data to collate.

  • img_size (int) – The size of the image to resize to.

  • output_format (str) – The output format, in [“int”, “float”, “normalized”]. With “int”, output tensors have integer values between 0 and 255. With “float”, they have float values between 0 and 1. With “normalized”, they have float values between -1 and 1.

Returns:

The collated data with size img_size and normalized to [0, 1].

Return type:

Tuple[List[str], torch.Tensor]

text_generation_collate(data, max_seq_len, tokenizer)

Custom collation function for text generation datasets.

Expects a text column containing clear-text samples in the dataset.

Parameters:
  • data (Any) – The data to collate.

  • max_seq_len (int | None) – The maximum sequence length.

  • tokenizer (AutoTokenizer) – The tokenizer to use.

Returns:

The collated data.

Return type:

Tuple[torch.Tensor, torch.Tensor]

question_answering_collate(data, max_seq_len, tokenizer)

Custom collation function for question answering datasets.

Expects a question and answer column containing the clear-text question and answer in the dataset.

Parameters:
  • data (Any) – The data to collate.

  • max_seq_len (int) – The maximum sequence length.

  • tokenizer (AutoTokenizer) – The tokenizer to use.

Returns:

The collated data.

Return type:

Tuple[torch.Tensor, torch.Tensor]

Class API PrunaDataModule

class PrunaDataModule(train_ds, val_ds, test_ds, collate_fn, dataloader_args)

A PrunaDataModule is a wrapper around a PyTorch Lightning DataModule that allows for easy loading of datasets.

Parameters:
  • train_ds (Union[IterableDataset, Dataset, TorchDataset]) – The training dataset.

  • val_ds (Union[IterableDataset, Dataset, TorchDataset]) – The validation dataset.

  • test_ds (Union[IterableDataset, Dataset, TorchDataset]) – The test dataset.

  • collate_fn (Callable) – The collate function to use.

  • dataloader_args (dict) – The arguments for the dataloader.

__init__(train_ds, val_ds, test_ds, collate_fn, dataloader_args)
Parameters:
  • train_ds (IterableDataset | Dataset | Dataset)

  • val_ds (IterableDataset | Dataset | Dataset)

  • test_ds (IterableDataset | Dataset | Dataset)

  • collate_fn (Callable)

  • dataloader_args (dict)

Return type:

None

prepare_data_per_node

If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.

allow_zero_length_dataloader_with_multiple_devices

If True, dataloader with zero length within local rank is allowed. Default value is False.

classmethod from_datasets(datasets, collate_fn, tokenizer=None, collate_fn_args={}, dataloader_args={})

Create a PrunaDataModule from the individual datasets.

Parameters:
  • datasets (tuple | list) – The datasets.

  • collate_fn (str) – The Pruna collate function to use.

  • tokenizer (AutoTokenizer | None) – The tokenizer to use.

  • collate_fn_args (dict) – Any additional arguments for the collate function.

  • dataloader_args (dict) – Any additional arguments for the dataloader.

Returns:

The PrunaDataModule.

Return type:

PrunaDataModule

classmethod from_string(dataset_name, tokenizer=None, collate_fn_args={}, dataloader_args={}, seed=42)

Create a PrunaDataModule from the dataset name with preimplemented dataset loading.

Parameters:
  • dataset_name (str) – The name of the dataset.

  • tokenizer (AutoTokenizer | None) – The tokenizer to use.

  • collate_fn_args (dict) – Any additional arguments for the collate function.

  • dataloader_args (dict) – Any additional arguments for the dataloader.

  • seed (int) – The seed to use.

Returns:

The PrunaDataModule.

Return type:

PrunaDataModule

limit_datasets(limit)

Limit the dataset to the given number of samples.

Parameters:

limit (int | list[int] | tuple[int, int, int]) – The number of samples to limit the dataset to.

Return type:

None

test_dataloader(**kwargs)

Return the test data loader.

Parameters:

**kwargs (dict) – Any additional arguments used when loading data, overriding the default values provided in the constructor. Examples: img_size: int would override the collate function default for image generation, while batch_size: int, shuffle: bool, pin_memory: bool, … would override the dataloader defaults.

Returns:

The test data loader.

Return type:

DataLoader

train_dataloader(**kwargs)

Return the training data loader.

Parameters:

**kwargs (dict) – Any additional arguments used when loading data, overriding the default values provided in the constructor. Examples: img_size: int would override the collate function default for image generation, while batch_size: int, shuffle: bool, pin_memory: bool, … would override the dataloader defaults.

Returns:

The training data loader.

Return type:

DataLoader

val_dataloader(**kwargs)

Return the validation data loader.

Parameters:

**kwargs (dict) – Any additional arguments used when loading data, overriding the default values provided in the constructor. Examples: img_size: int would override the collate function default for image generation, while batch_size: int, shuffle: bool, pin_memory: bool, … would override the dataloader defaults.

Returns:

The validation data loader.

Return type:

DataLoader