Datasets

pruna provides a variety of pre-configured datasets for different tasks. This guide will help you understand how to use datasets in your pruna workflow.

Available Datasets

pruna currently supports the following datasets categorized by task:

Text Generation

WikiText: Wikipedia text dataset for language modeling
SmolTalk: Everyday conversation dataset
SmolSmolTalk: Lightweight version of SmolTalk
PubChem: Chemical compound dataset in SELFIES format
OpenAssistant: Instruction-following dataset
C4: Large-scale web text dataset

Image Classification

ImageNet: Large-scale image classification dataset
MNIST: Handwritten digit classification dataset

Text-to-Image

COCO: Image captioning dataset
LAION256: Subset of LAION artwork dataset
OpenImage: Image quality preferences dataset

Audio Processing

CommonVoice: Multi-language speech dataset
AIPodcast: AI-focused podcast audio dataset

Question Answering

Polyglot: Fact completion dataset

Using Datasets

There are two main ways to use datasets in pruna:

1. Using String Identifier

What makes using the already implemented datasets so easy is that you can simply use the dataset’s string identifier to add it to your SmashConfig:

from pruna import SmashConfig

smash_config = SmashConfig()
smash_config.add_tokenizer("bert-base-uncased")
smash_config.add_data("WikiText")

2. Using Custom Datasets

You can also pass your own datasets as a tuple of (train, validation, test) datasets:

from pruna import SmashConfig
from pruna.data.utils import split_train_into_train_val_test
from datasets import load_dataset

# Load custom datasets
train_ds = load_dataset("SamuelYang/bookcorpus")["train"]
train_ds, val_ds, test_ds = split_train_into_train_val_test(train_ds, seed=42)

# Add to SmashConfig
smash_config = SmashConfig()
smash_config.add_tokenizer("bert-base-uncased")
smash_config.add_data(
    (train_ds, val_ds, test_ds),
    collate_fn="text_generation_collate"
)

In this case, you need to specify the collate_fn to use for the dataset. The collate_fn is a function that takes a list of individual data samples and returns a batch of data in a unified format. Your dataset will have to adhere to the formats expected by the collate_fn and this will be checked during a quick compatibility check when adding the dataset to the smash_config.

pruna.data.collate.text_generation_collate(data, max_seq_len, tokenizer)

Custom collation function for text generation datasets.

Expects a text column containing clear-text samples in the dataset.

Parameters:
  • data (Any) – The data to collate.

  • max_seq_len (int | None) – The maximum sequence length.

  • tokenizer (AutoTokenizer) – The tokenizer to use.

Returns:

The collated data.

Return type:

Tuple[torch.Tensor, torch.Tensor]

pruna.data.collate.image_generation_collate(data, img_size, output_format='int')

Custom collation function for text-to-image generation datasets.

Expects a image column containing PIL images and a text column containing the clear-textprompt for the image generation in the dataset.

Parameters:
  • data (Any) – The data to collate.

  • img_size (int) – The size of the image to resize to.

  • output_format (str) – The output format, in [“int”, “float”, “normalized”]. With “int”, output tensors have integer values between 0 and 255. With “float”, they have float values between 0 and 1. With “normalized”, they have float values between -1 and 1.

Returns:

The collated data with size img_size and normalized to [0, 1].

Return type:

Tuple[torch.Tensor, Any]

pruna.data.collate.image_classification_collate(data, img_size, output_format='int')

Custom collation function for image classification datasets.

Expects a image column containing PIL images and a label column containing the class label in the dataset.

Parameters:
  • data (Any) – The data to collate.

  • img_size (int) – The size of the image to resize to.

  • output_format (str) – The output format, in [“int”, “float”, “normalized”]. With “int”, output tensors have integer values between 0 and 255. With “float”, they have float values between 0 and 1. With “normalized”, they have float values between -1 and 1.

Returns:

The collated data with size img_size and normalized to [0, 1].

Return type:

Tuple[List[str], torch.Tensor]

pruna.data.collate.audio_collate(data)

Custom collation function for audio datasets.

Expects a audio/path column containing the path to the audio samples and a sentence column containing the clear-text transcription of the audio samples in the dataset.

Parameters:

data (Any) – The data to collate.

Returns:

The collated data.

Return type:

List[str]

pruna.data.collate.question_answering_collate(data, max_seq_len, tokenizer)

Custom collation function for question answering datasets.

Expects a question and answer column containing the clear-text question and answer in the dataset.

Parameters:
  • data (Any) – The data to collate.

  • max_seq_len (int) – The maximum sequence length.

  • tokenizer (AutoTokenizer) – The tokenizer to use.

Returns:

The collated data.

Return type:

Tuple[torch.Tensor, torch.Tensor]

Accessing the PrunaDataModule directly

You can also create and access the PrunaDataModule directly and use it in your workflow, e.g., if you want to pass it to the evaluation agent.

class pruna.data.pruna_datamodule.PrunaDataModule(train_ds, val_ds, test_ds, collate_fn, dataloader_args)

A PrunaDataModule is a wrapper around a PyTorch Lightning DataModule that allows for easy loading of datasets.

Parameters:
  • train_ds (Union[IterableDataset, Dataset, TorchDataset]) – The training dataset.

  • val_ds (Union[IterableDataset, Dataset, TorchDataset]) – The validation dataset.

  • test_ds (Union[IterableDataset, Dataset, TorchDataset]) – The test dataset.

  • collate_fn (Callable) – The collate function to use.

  • dataloader_args (dict) – The arguments for the dataloader.

classmethod from_datasets(datasets, collate_fn, tokenizer=None, collate_fn_args={}, dataloader_args={})

Create a PrunaDataModule from the individual datasets.

Parameters:
  • datasets (tuple | list) – The datasets.

  • collate_fn (str) – The Pruna collate function to use.

  • tokenizer (AutoTokenizer | None) – The tokenizer to use.

  • collate_fn_args (dict) – Any additional arguments for the collate function.

  • dataloader_args (dict) – Any additional arguments for the dataloader.

Returns:

The PrunaDataModule.

Return type:

PrunaDataModule

classmethod from_string(dataset_name, tokenizer=None, collate_fn_args={}, dataloader_args={}, seed=42)

Create a PrunaDataModule from the dataset name with preimplemented dataset loading.

Parameters:
  • dataset_name (str) – The name of the dataset.

  • tokenizer (AutoTokenizer | None) – The tokenizer to use.

  • collate_fn_args (dict) – Any additional arguments for the collate function.

  • dataloader_args (dict) – Any additional arguments for the dataloader.

  • seed (int) – The seed to use.

Returns:

The PrunaDataModule.

Return type:

PrunaDataModule