Datasets
The PrunaDataModule
is a class that allows you to load and process data for your model.
It can be used directly, through the collate
functions or through the SmashConfig
.
Usage examples
Usage example of PrunaDataModule
from pruna.data import PrunaDataModule
datamodule = PrunaDataModule(
datasets=["dataset1", "dataset2"],
collate_fn="image_generation_collate", # expect a specific format for the datasets
)
Usage example of collate
functions
from pruna import SmashConfig
from pruna.data.utils import split_train_into_train_val_test
from datasets import load_dataset
# Load custom datasets
train_ds = load_dataset("SamuelYang/bookcorpus")["train"]
train_ds, val_ds, test_ds = split_train_into_train_val_test(train_ds, seed=42)
# Add to SmashConfig
smash_config = SmashConfig()
smash_config.add_tokenizer("bert-base-uncased")
smash_config.add_data(
(train_ds, val_ds, test_ds),
collate_fn="text_generation_collate"
)
Usage example of SmashConfig
from pruna import SmashConfig
smash_config = SmashConfig()
smash_config.add_data("dataset1")
Function API collate
- image_generation_collate(data, img_size, output_format='int')
Custom collation function for text-to-image generation datasets.
Expects a
image
column containing PIL images and atext
column containing the clear-textprompt for the image generation in the dataset.- Parameters:
data (Any) – The data to collate.
img_size (int) – The size of the image to resize to.
output_format (str) – The output format, in [“int”, “float”, “normalized”]. With “int”, output tensors have integer values between 0 and 255. With “float”, they have float values between 0 and 1. With “normalized”, they have float values between -1 and 1.
- Returns:
The collated data with size img_size and normalized to [0, 1].
- Return type:
Tuple[torch.Tensor, Any]
- audio_collate(data)
Custom collation function for audio datasets.
Expects a
audio/path
column containing the path to the audio samples and asentence
column containing the clear-text transcription of the audio samples in the dataset.- Parameters:
data (Any) – The data to collate.
- Returns:
The collated data.
- Return type:
List[str]
- image_classification_collate(data, img_size, output_format='int')
Custom collation function for image classification datasets.
Expects a
image
column containing PIL images and alabel
column containing the class label in the dataset.- Parameters:
data (Any) – The data to collate.
img_size (int) – The size of the image to resize to.
output_format (str) – The output format, in [“int”, “float”, “normalized”]. With “int”, output tensors have integer values between 0 and 255. With “float”, they have float values between 0 and 1. With “normalized”, they have float values between -1 and 1.
- Returns:
The collated data with size img_size and normalized to [0, 1].
- Return type:
Tuple[List[str], torch.Tensor]
- text_generation_collate(data, max_seq_len, tokenizer)
Custom collation function for text generation datasets.
Expects a
text
column containing clear-text samples in the dataset.- Parameters:
data (Any) – The data to collate.
max_seq_len (int | None) – The maximum sequence length.
tokenizer (AutoTokenizer) – The tokenizer to use.
- Returns:
The collated data.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
- question_answering_collate(data, max_seq_len, tokenizer)
Custom collation function for question answering datasets.
Expects a
question
andanswer
column containing the clear-text question and answer in the dataset.- Parameters:
data (Any) – The data to collate.
max_seq_len (int) – The maximum sequence length.
tokenizer (AutoTokenizer) – The tokenizer to use.
- Returns:
The collated data.
- Return type:
Tuple[torch.Tensor, torch.Tensor]
Class API PrunaDataModule
- class PrunaDataModule(train_ds, val_ds, test_ds, collate_fn, dataloader_args)
A PrunaDataModule is a wrapper around a PyTorch Lightning DataModule that allows for easy loading of datasets.
- Parameters:
train_ds (Union[IterableDataset, Dataset, TorchDataset]) – The training dataset.
val_ds (Union[IterableDataset, Dataset, TorchDataset]) – The validation dataset.
test_ds (Union[IterableDataset, Dataset, TorchDataset]) – The test dataset.
collate_fn (Callable) – The collate function to use.
dataloader_args (dict) – The arguments for the dataloader.
- __init__(train_ds, val_ds, test_ds, collate_fn, dataloader_args)
- Parameters:
train_ds (IterableDataset | Dataset | Dataset)
val_ds (IterableDataset | Dataset | Dataset)
test_ds (IterableDataset | Dataset | Dataset)
collate_fn (Callable)
dataloader_args (dict)
- Return type:
None
- classmethod from_datasets(datasets, collate_fn, tokenizer=None, collate_fn_args={}, dataloader_args={})
Create a PrunaDataModule from the individual datasets.
- Parameters:
datasets (tuple | list) – The datasets.
collate_fn (str) – The Pruna collate function to use.
tokenizer (AutoTokenizer | None) – The tokenizer to use.
collate_fn_args (dict) – Any additional arguments for the collate function.
dataloader_args (dict) – Any additional arguments for the dataloader.
- Returns:
The PrunaDataModule.
- Return type:
- classmethod from_string(dataset_name, tokenizer=None, collate_fn_args={}, dataloader_args={}, seed=42)
Create a PrunaDataModule from the dataset name with preimplemented dataset loading.
- Parameters:
dataset_name (str) – The name of the dataset.
tokenizer (AutoTokenizer | None) – The tokenizer to use.
collate_fn_args (dict) – Any additional arguments for the collate function.
dataloader_args (dict) – Any additional arguments for the dataloader.
seed (int) – The seed to use.
- Returns:
The PrunaDataModule.
- Return type:
- limit_datasets(limit)
Limit the dataset to the given number of samples.
- Parameters:
limit (int | list[int] | tuple[int, int, int]) – The number of samples to limit the dataset to.
- Return type:
None
- test_dataloader(**kwargs)
Return the test data loader.
- Parameters:
**kwargs (dict) – Any additional arguments used when loading data, overriding the default values provided in the constructor. Examples: img_size: int would override the collate function default for image generation, while batch_size: int, shuffle: bool, pin_memory: bool, … would override the dataloader defaults.
- Returns:
The test data loader.
- Return type:
DataLoader
- train_dataloader(**kwargs)
Return the training data loader.
- Parameters:
**kwargs (dict) – Any additional arguments used when loading data, overriding the default values provided in the constructor. Examples: img_size: int would override the collate function default for image generation, while batch_size: int, shuffle: bool, pin_memory: bool, … would override the dataloader defaults.
- Returns:
The training data loader.
- Return type:
DataLoader
- val_dataloader(**kwargs)
Return the validation data loader.
- Parameters:
**kwargs (dict) – Any additional arguments used when loading data, overriding the default values provided in the constructor. Examples: img_size: int would override the collate function default for image generation, while batch_size: int, shuffle: bool, pin_memory: bool, … would override the dataloader defaults.
- Returns:
The validation data loader.
- Return type:
DataLoader