smash

The smash function is the main function in pruna for optimizing models. In the following sections we will show you how to use it.

Calling the smash Function

In preparation to using smash, we have to load our model and define a SmashConfig. As an example, we will take a simple model by loading the ViT-B/16 model from torchvision.

import torchvision

base_model = torchvision.models.vit_b_16(weights="ViT_B_16_Weights.DEFAULT").cuda()

Next, we will define a SmashConfig and activate the torch_compile compiler.

from pruna import SmashConfig
smash_config = SmashConfig()
smash_config['compiler'] = 'torch_compile'

We are now ready to call the smash function!

We can pass the model and the SmashConfig to the smash function as follows:

from pruna import smash

smashed_model = smash(
    model=base_model,
    smash_config=smash_config,
)

The resulting smashed model can be used in the same way as the original one.

We perform compatibility checks to ensure that the model is compatible with the algorithms that you have selected at the beginning of the smash process. If you wish to skip these checks, you can set the experimental flag to True:

smashed_model = smash(
    model=base_model,
    smash_config=smash_config,
    experimental=True,
)

Please note that this can lead to undefined behavior or difficult-to-debug errors.

Importantly, the returned model offers save and load functionality that allows you to save the model and load it in its smashed state, see Saving and Loading Pruna Models.

smash Function Documentation

pruna.smash.smash(model, smash_config, verbose=False, experimental=False)

Smash an arbitrary model for inference.

Parameters:
  • model (Any) – Base model to be smashed.

  • smash_config (SmashConfig) – Configuration settings for quantization, and compilation.

  • verbose (bool) – Whether to print the progress of the smashing process.

  • experimental (bool) – Whether to use experimental algorithms, e.g. to avoid checking model compatibility. This can lead to undefined behavior or difficult-to-debug errors.

Returns:

Smashed model wrapped in a PrunaModel object.

Return type:

PrunaModel