smash

The smash function is the main function in pruna for optimizing and running models. In the following sections we will show you how to use it.

Calling the smash Function

In preparation to using smash, we have to load our model and define a SmashConfig. As an example, we will take a simple model by loading the ViT-B/16 model from torchvision.

import torchvision

base_model = torchvision.models.vit_b_16(weights="ViT_B_16_Weights.DEFAULT").cuda()

Next, we will define a SmashConfig and activate the x-fast compiler.

from pruna import SmashConfig
smash_config = SmashConfig()
smash_config['compilers'] = ['x-fast']

We are now ready to call the smash function!

We can pass the model and the SmashConfig to the smash function as follows:

from pruna import smash

smashed_model = smash(
    model=base_model,
    token='<your-token>',  # replace <your-token> with your actual token or set to None if you do not have one yet
    smash_config=smash_config,
)

Dont forget to replace the placeholder token with your Pruna Token! The resulting smashed model can be used in the same way as the original one.

We perform compatibility checks to ensure that the model is compatible with the algorithms that you have selected at the beginning of the smash process. If you wish to skip these checks, you can set the experimental flag to True:

smashed_model = smash(
    model=base_model,
    token='<your-token>',  # replace <your-token> with your actual token or set to None if you do not have one yet
    smash_config=smash_config,
    experimental=True,
)

Please note that this can lead to undefined behavior or difficult-to-debug errors.

Importantly, the returned model offers save and load functionality that allows you to save the model and load it in its smashed state.

from pruna.engine.PrunaModel import PrunaModel

smashed_model.save_model("saved_model/")
smashed_model_loaded = PrunaModel.load_model("saved_model/")

smash Function Documentation

pruna.smash.smash(model: Any, smash_config: SmashConfig, token: str | None = None, verbose: bool = False, experimental: bool = False) pruna.engine.PrunaModel.PrunaModel

Smashes an arbitrary model for inference.

Parameters:
  • model (Any) – Base model to be smashed.

  • smash_config (SmashConfig) – Configuration settings for quantization, and compilation.

  • token (str | None) – The API key used to log the request.

  • verbose (bool) – Whether to print the progress of the smashing process.

  • experimental (bool) – Whether to use experimental algorithms, e.g. to avoid checking model compatibility. This can lead to undefined behavior or difficult-to-debug errors.

Returns:

Smashed model wrapped in a PrunaModel object.

Return type:

PrunaModel