smash

The smash function is the main function in pruna for optimizing models. In the following sections we will show you how to use it.

Usage example smash

In preparation to using smash, we have to load a model and define a SmashConfig.

import torchvision
from pruna import SmashConfig

# Load the model
base_model = torchvision.models.vit_b_16(
    weights="ViT_B_16_Weights.DEFAULT"
).cuda()

# Define the SmashConfig
smash_config = SmashConfig()
smash_config['compiler'] = 'torch_compile'

We are now ready to call the smash function!

We can pass the model and the SmashConfig to the smash function as follows:

from pruna import smash

smashed_model = smash(
    model=base_model,
    smash_config=smash_config,
)

The resulting smashed model can be used in the same way as the original one.

Function API smash

smash(model, smash_config, verbose=False, experimental=False)

Smash an arbitrary model for inference.

Parameters:
  • model (Any) – Base model to be smashed.

  • smash_config (SmashConfig) – Configuration settings for quantization, and compilation.

  • verbose (bool) – Whether to print the progress of the smashing process.

  • experimental (bool) – Whether to use experimental algorithms, e.g. to avoid checking model compatibility. This can lead to undefined behavior or difficult-to-debug errors.

Returns:

Smashed model wrapped in a PrunaModel object.

Return type:

PrunaModel