Datasets
The PrunaDataModule is a class that allows you to load and process data for your model.
It can be used directly, through the collate functions or through the SmashConfig.
Usage examples
Usage example of PrunaDataModule
from pruna.data import PrunaDataModule
datamodule = PrunaDataModule(
datasets=["dataset1", "dataset2"],
collate_fn="image_generation_collate", # expect a specific format for the datasets
)
Usage example of collate functions
from pruna import SmashConfig
from pruna.data.utils import split_train_into_train_val_test
from datasets import load_dataset
# Load custom datasets
train_ds = load_dataset("SamuelYang/bookcorpus")["train"]
train_ds, val_ds, test_ds = split_train_into_train_val_test(train_ds, seed=42)
# Add to SmashConfig
smash_config = SmashConfig()
smash_config.add_tokenizer("bert-base-uncased")
smash_config.add_data(
(train_ds, val_ds, test_ds),
collate_fn="text_generation_collate"
)
Usage example of SmashConfig
from pruna import SmashConfig
smash_config = SmashConfig()
smash_config.add_data("dataset1")
Function API collate
- image_generation_collate(data, img_size, output_format='int')
Custom collation function for text-to-image generation datasets.
Expects a
imagecolumn containing PIL images and atextcolumn containing the clear-textprompt for the image generation in the dataset.- Parameters:
data (Any) – The data to collate.
img_size (int) – The size of the image to resize to.
output_format (str) – The output format, in [“int”, “float”, “normalized”]. With “int”, output tensors have integer values between 0 and 255. With “float”, they have float values between 0 and 1. With “normalized”, they have float values between -1 and 1.
- Returns:
The collated data with size img_size and normalized to [0, 1].
- Return type:
Tuple[torch.Tensor, Any]
- audio_collate(data)
Custom collation function for audio datasets.
Expects a
audio/pathcolumn containing the path to the audio samples and asentencecolumn containing the clear-text transcription of the audio samples in the dataset.- Parameters:
data (Any) – The data to collate.
- Returns:
The collated data.
- Return type:
List[str]
- image_classification_collate(data, img_size, output_format='int')
Custom collation function for image classification datasets.
Expects a
imagecolumn containing PIL images and alabelcolumn containing the class label in the dataset.- Parameters:
data (Any) – The data to collate.
img_size (int) – The size of the image to resize to.
output_format (str) – The output format, in [“int”, “float”, “normalized”]. With “int”, output tensors have integer values between 0 and 255. With “float”, they have float values between 0 and 1. With “normalized”, they have float values between -1 and 1.
- Returns:
The collated data with size img_size and normalized to [0, 1].
- Return type:
Tuple[List[str], torch.Tensor]
- text_generation_collate(data, max_seq_len, tokenizer)
Custom collation function for text generation datasets.
Expects a
textcolumn containing clear-text samples in the dataset.- Parameters:
data (Any) – The data to collate.
max_seq_len (int | None) – The maximum sequence length.
tokenizer (AutoTokenizer) – The tokenizer to use.
- Returns:
The collated data.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- question_answering_collate(data, max_seq_len, tokenizer)
Custom collation function for question answering datasets.
Expects a
questionandanswercolumn containing the clear-text question and answer in the dataset.- Parameters:
data (Any) – The data to collate.
max_seq_len (int) – The maximum sequence length.
tokenizer (AutoTokenizer) – The tokenizer to use.
- Returns:
The collated data.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
Class API PrunaDataModule
- class PrunaDataModule(train_ds, val_ds, test_ds, collate_fn, dataloader_args)
A PrunaDataModule is a wrapper around a PyTorch Lightning DataModule that allows for easy loading of datasets.
- Parameters:
train_ds (Union[IterableDataset, Dataset, TorchDataset]) – The training dataset.
val_ds (Union[IterableDataset, Dataset, TorchDataset]) – The validation dataset.
test_ds (Union[IterableDataset, Dataset, TorchDataset]) – The test dataset.
collate_fn (Callable) – The collate function to use.
dataloader_args (dict) – The arguments for the dataloader.
- __init__(train_ds, val_ds, test_ds, collate_fn, dataloader_args)
- Parameters:
train_ds (IterableDataset | Dataset | Dataset)
val_ds (IterableDataset | Dataset | Dataset)
test_ds (IterableDataset | Dataset | Dataset)
collate_fn (Callable)
dataloader_args (dict)
- Return type:
None
- classmethod from_datasets(datasets, collate_fn, tokenizer=None, collate_fn_args={}, dataloader_args={})
Create a PrunaDataModule from the individual datasets.
- Parameters:
datasets (tuple | list) – The datasets.
collate_fn (str) – The Pruna collate function to use.
tokenizer (AutoTokenizer | None) – The tokenizer to use.
collate_fn_args (dict) – Any additional arguments for the collate function.
dataloader_args (dict) – Any additional arguments for the dataloader.
- Returns:
The PrunaDataModule.
- Return type:
- classmethod from_string(dataset_name, tokenizer=None, collate_fn_args={}, dataloader_args={}, seed=42)
Create a PrunaDataModule from the dataset name with preimplemented dataset loading.
- Parameters:
dataset_name (str) – The name of the dataset.
tokenizer (AutoTokenizer | None) – The tokenizer to use.
collate_fn_args (dict) – Any additional arguments for the collate function.
dataloader_args (dict) – Any additional arguments for the dataloader.
seed (int) – The seed to use.
- Returns:
The PrunaDataModule.
- Return type:
- limit_datasets(limit)
Limit the dataset to the given number of samples.
- Parameters:
limit (int | list[int] | tuple[int, int, int]) – The number of samples to limit the dataset to.
- Return type:
None
- test_dataloader(**kwargs)
Return the test data loader.
- Parameters:
**kwargs (dict) – Any additional arguments used when loading data, overriding the default values provided in the constructor. Examples: img_size: int would override the collate function default for image generation, while batch_size: int, shuffle: bool, pin_memory: bool, … would override the dataloader defaults.
- Returns:
The test data loader.
- Return type:
DataLoader
- train_dataloader(**kwargs)
Return the training data loader.
- Parameters:
**kwargs (dict) – Any additional arguments used when loading data, overriding the default values provided in the constructor. Examples: img_size: int would override the collate function default for image generation, while batch_size: int, shuffle: bool, pin_memory: bool, … would override the dataloader defaults.
- Returns:
The training data loader.
- Return type:
DataLoader
- val_dataloader(**kwargs)
Return the validation data loader.
- Parameters:
**kwargs (dict) – Any additional arguments used when loading data, overriding the default values provided in the constructor. Examples: img_size: int would override the collate function default for image generation, while batch_size: int, shuffle: bool, pin_memory: bool, … would override the dataloader defaults.
- Returns:
The validation data loader.
- Return type:
DataLoader