PrunaModel

The PrunaModel is the main class in pruna. It is used to load your model and apply the algorithms as well as running inference.

Usage examples PrunaModel

This manual explains how to define and use a PrunaModel.

from pruna import smash, SmashConfig
from diffusers import StableDiffusionPipeline
from pruna import PrunaModel
# prepare the base model
base_model = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4"
)

# Create and smash your model
smash_config = SmashConfig()
smash_config["cacher"] = "deepcache"
smash_config["compiler"] = "diffusers2"
smashed_model = smash(model=base_model, smash_config=smash_config)

# Save the model
smashed_model.save_pretrained("saved_model/")

# Load the model
smashed_model = PrunaModel.from_pretrained("saved_model/")

# Run inference
prompt = "a fruit basket"
smashed_model(prompt).images[0]

Usage examples InferenceHandler

The InferenceHandler is a helper class that is used to run inference on a model and is attached to a PrunaModel.

from pruna import PrunaModel

smashed_model = PrunaModel.from_pretrained("saved_model/")
inference_handler = smashed_model.model_args.update(
    {"arg_1": 1, "arg_2": 0.0}
)

Class API PrunaModel

class PrunaModel(model, smash_config=None)

A pruna class wrapping any model.

Parameters:
  • model (Any) – The model to be held by this class.

  • smash_config (SmashConfig | None) – Smash configuration.

__call__(*args, **kwargs)

Call the smashed model.

Parameters:
  • *args (Any) – Arguments to pass to the model.

  • **kwargs (Any) – Additional keyword arguments to pass to the model.

Returns:

The output of the model’s prediction.

Return type:

Any

__init__(model, smash_config=None)
Parameters:
Return type:

None

destroy()

Destroy model.

Return type:

None

static from_pretrained(model_path, verbose=False, **kwargs)

Load a PrunaModel from the specified model path.

Parameters:
  • model_path (str) – The path to the model directory containing necessary configuration and model files.

  • verbose (bool, optional) – Whether to apply warning filters to suppress warnings. Defaults to False.

  • **kwargs (dict) – Additional keyword arguments to pass to the model loading function, such as specific settings or parameters.

Returns:

The loaded PrunaModel instance.

Return type:

PrunaModel

get_nn_modules()

Get the nn.Module instances in the model.

Returns:

A dictionary of the nn.Module instances in the model.

Return type:

dict[str | None, torch.nn.Module]

move_to_device(device)

Move the model to a specific device.

Parameters:

device (str | torch.device) – The device to move the model to.

Return type:

None

run_inference(batch, device)

Run inference on the model.

Parameters:
  • batch (Tuple[List[str] | torch.Tensor, ...]) – The batch to run inference on.

  • device (torch.device | str) – The device to run inference on.

Returns:

The processed output.

Return type:

Any

save_pretrained(model_path)

Save the smashed model to the specified model path.

Parameters:

model_path (str) – The path to the directory where the model will be saved.

Return type:

None

set_to_eval()

Set the model to evaluation mode.

Return type:

None

Class API InferenceHandler

class StandardHandler(model_args=None)

Handle inference arguments, inputs and outputs for unhandled model types.

Standard handler expectations: - The model should accept β€˜x’ as input, where β€˜x’ is the first element of a two-element data batch. - Invoke the model using model(x) without additional parameters. - Outputs should be directly processable without further modification.

Parameters:

model_args (Dict[str, Any]) – The arguments to pass to the model.

class DiffuserHandler(call_signature, model_args=None)

Handle inference arguments, inputs and outputs for diffusers models.

A generator with a fixed seed (42) is passed as an argument to the model for reproducibility. The first element of the batch is passed as input to the model. The generated outputs are expected to have .images attribute.

Parameters:
  • call_signature (inspect.Signature) – The signature of the call to the model.

  • model_args (Dict[str, Any]) – The arguments to pass to the model.

class TransformerHandler(model_args=None)

Handle inference arguments, inputs and outputs for transformer models.

The first element of the batch is passed as input to the model. The generated outputs are expected to have .logits attribute.

Parameters:

model_args (Dict[str, Any]) – The arguments to pass to the model.