smash
The smash
function is the main function in pruna for optimizing models. In the following sections we will show you how to use it.
Usage example smash
In preparation to using smash
, we have to load a model and define a SmashConfig.
import torchvision
from pruna import SmashConfig
# Load the model
base_model = torchvision.models.vit_b_16(
weights="ViT_B_16_Weights.DEFAULT"
).cuda()
# Define the SmashConfig
smash_config = SmashConfig()
smash_config['compiler'] = 'torch_compile'
We are now ready to call the smash
function!
We can pass the model and the SmashConfig
to the smash
function as follows:
from pruna import smash
smashed_model = smash(
model=base_model,
smash_config=smash_config,
)
The resulting smashed model can be used in the same way as the original one.
Function API smash
- smash(model, smash_config, verbose=False, experimental=False)
Smash an arbitrary model for inference.
- Parameters:
model (Any) – Base model to be smashed.
smash_config (SmashConfig) – Configuration settings for quantization, and compilation.
verbose (bool) – Whether to print the progress of the smashing process.
experimental (bool) – Whether to use experimental algorithms, e.g. to avoid checking model compatibility. This can lead to undefined behavior or difficult-to-debug errors.
- Returns:
Smashed model wrapped in a PrunaModel object.
- Return type: