smash
The smash function is the main function in pruna for optimizing and running models. In the following sections we will show you how to use it.
Calling the smash Function
In preparation to using smash, we have to load our model and define a SmashConfig. As an example, we will take a simple model by loading the ViT-B/16 model from torchvision.
import torchvision
base_model = torchvision.models.vit_b_16(weights="ViT_B_16_Weights.DEFAULT").cuda()
Next, we will define a SmashConfig and activate the x-fast compiler.
from pruna import SmashConfig
smash_config = SmashConfig()
smash_config['compilers'] = ['x-fast']
We are now ready to call the smash function!
We can pass the model and the SmashConfig to the smash function as follows:
from pruna import smash
smashed_model = smash(
model=base_model,
token='<your-token>', # replace <your-token> with your actual token or set to None if you do not have one yet
smash_config=smash_config,
)
Dont forget to replace the placeholder token with your Pruna Token! The resulting smashed model can be used in the same way as the original one.
Importantly, the returned model offers save and load functionality that allows you to save the model and load it in its smashed state.
from pruna.engine.PrunaModel import PrunaModel
smashed_model.save_model("saved_model/")
smashed_model_loaded = PrunaModel.load_model("saved_model/")
smash Function Documentation
- pruna.smash.smash(model: Any, smash_config: SmashConfig, token: str | None = None, verbose: bool = False) pruna.engine.PrunaModel.PrunaModel
Smashes an arbitrary model for inference.
- Parameters:
model (Any) – Base model to be smashed.
smash_config (SmashConfig) – Configuration settings for quantization, and compilation.
token (str | None) – The API key used to log the request.
verbose (bool) – Whether to print the progress of the smashing process.
- Returns:
Smashed model wrapped in a PrunaModel object.
- Return type:
PrunaModel