Datasets
pruna provides a variety of pre-configured datasets for different tasks. This guide will help you understand how to use datasets in your pruna workflow.
Available Datasets
pruna currently supports the following datasets categorized by task:
Text Generation
WikiText
: Wikipedia text dataset for language modelingSmolTalk
: Everyday conversation datasetSmolSmolTalk
: Lightweight version of SmolTalkPubChem
: Chemical compound dataset in SELFIES formatOpenAssistant
: Instruction-following datasetC4
: Large-scale web text datasetImage Classification
ImageNet
: Large-scale image classification datasetMNIST
: Handwritten digit classification datasetText-to-Image
COCO
: Image captioning datasetLAION256
: Subset of LAION artwork datasetOpenImage
: Image quality preferences datasetAudio Processing
CommonVoice
: Multi-language speech datasetAIPodcast
: AI-focused podcast audio datasetQuestion Answering
Polyglot
: Fact completion datasetUsing Datasets
There are two main ways to use datasets in pruna:
1. Using String Identifier
What makes using the already implemented datasets so easy is that you can simply use the dataset’s string identifier to add it to your SmashConfig:
from pruna import SmashConfig
smash_config = SmashConfig()
smash_config.add_tokenizer("bert-base-uncased")
smash_config.add_data("WikiText")
2. Using Custom Datasets
You can also pass your own datasets as a tuple of (train, validation, test)
datasets:
from pruna import SmashConfig
from pruna.data.utils import split_train_into_train_val_test
from datasets import load_dataset
# Load custom datasets
train_ds = load_dataset("SamuelYang/bookcorpus")["train"]
train_ds, val_ds, test_ds = split_train_into_train_val_test(train_ds, seed=42)
# Add to SmashConfig
smash_config = SmashConfig()
smash_config.add_tokenizer("bert-base-uncased")
smash_config.add_data(
(train_ds, val_ds, test_ds),
collate_fn="text_generation_collate"
)
In this case, you need to specify the collate_fn
to use for the dataset. The collate_fn
is a function that takes a list of individual data samples and returns a batch of data in a unified format.
Your dataset will have to adhere to the formats expected by the collate_fn
and this will be checked during a quick compatibility check when adding the dataset to the smash_config
.
- pruna.data.collate.text_generation_collate(data, max_seq_len, tokenizer)
Custom collation function for text generation datasets.
Expects a
text
column containing clear-text samples in the dataset.- Parameters:
data (Any) – The data to collate.
max_seq_len (int | None) – The maximum sequence length.
tokenizer (AutoTokenizer) – The tokenizer to use.
- Returns:
The collated data.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- pruna.data.collate.image_generation_collate(data, img_size, output_format='int')
Custom collation function for text-to-image generation datasets.
Expects a
image
column containing PIL images and atext
column containing the clear-textprompt for the image generation in the dataset.- Parameters:
data (Any) – The data to collate.
img_size (int) – The size of the image to resize to.
output_format (str) – The output format, in [“int”, “float”, “normalized”]. With “int”, output tensors have integer values between 0 and 255. With “float”, they have float values between 0 and 1. With “normalized”, they have float values between -1 and 1.
- Returns:
The collated data with size img_size and normalized to [0, 1].
- Return type:
Tuple[torch.Tensor, Any]
- pruna.data.collate.image_classification_collate(data, img_size, output_format='int')
Custom collation function for image classification datasets.
Expects a
image
column containing PIL images and alabel
column containing the class label in the dataset.- Parameters:
data (Any) – The data to collate.
img_size (int) – The size of the image to resize to.
output_format (str) – The output format, in [“int”, “float”, “normalized”]. With “int”, output tensors have integer values between 0 and 255. With “float”, they have float values between 0 and 1. With “normalized”, they have float values between -1 and 1.
- Returns:
The collated data with size img_size and normalized to [0, 1].
- Return type:
Tuple[List[str], torch.Tensor]
- pruna.data.collate.audio_collate(data)
Custom collation function for audio datasets.
Expects a
audio/path
column containing the path to the audio samples and asentence
column containing the clear-text transcription of the audio samples in the dataset.- Parameters:
data (Any) – The data to collate.
- Returns:
The collated data.
- Return type:
List[str]
- pruna.data.collate.question_answering_collate(data, max_seq_len, tokenizer)
Custom collation function for question answering datasets.
Expects a
question
andanswer
column containing the clear-text question and answer in the dataset.- Parameters:
data (Any) – The data to collate.
max_seq_len (int) – The maximum sequence length.
tokenizer (AutoTokenizer) – The tokenizer to use.
- Returns:
The collated data.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
Accessing the PrunaDataModule directly
You can also create and access the PrunaDataModule directly and use it in your workflow, e.g., if you want to pass it to the evaluation agent.
- class pruna.data.pruna_datamodule.PrunaDataModule(train_ds, val_ds, test_ds, collate_fn, dataloader_args)
A PrunaDataModule is a wrapper around a PyTorch Lightning DataModule that allows for easy loading of datasets.
- Parameters:
train_ds (Union[IterableDataset, Dataset, TorchDataset]) – The training dataset.
val_ds (Union[IterableDataset, Dataset, TorchDataset]) – The validation dataset.
test_ds (Union[IterableDataset, Dataset, TorchDataset]) – The test dataset.
collate_fn (Callable) – The collate function to use.
dataloader_args (dict) – The arguments for the dataloader.
- classmethod from_datasets(datasets, collate_fn, tokenizer=None, collate_fn_args={}, dataloader_args={})
Create a PrunaDataModule from the individual datasets.
- Parameters:
datasets (tuple | list) – The datasets.
collate_fn (str) – The Pruna collate function to use.
tokenizer (AutoTokenizer | None) – The tokenizer to use.
collate_fn_args (dict) – Any additional arguments for the collate function.
dataloader_args (dict) – Any additional arguments for the dataloader.
- Returns:
The PrunaDataModule.
- Return type:
- classmethod from_string(dataset_name, tokenizer=None, collate_fn_args={}, dataloader_args={}, seed=42)
Create a PrunaDataModule from the dataset name with preimplemented dataset loading.
- Parameters:
dataset_name (str) – The name of the dataset.
tokenizer (AutoTokenizer | None) – The tokenizer to use.
collate_fn_args (dict) – Any additional arguments for the collate function.
dataloader_args (dict) – Any additional arguments for the dataloader.
seed (int) – The seed to use.
- Returns:
The PrunaDataModule.
- Return type: