PrunaModel

The PrunaModel is the main class in pruna. It is used to load your model and apply the algorithms as well as running inference.

Usage examples PrunaModel

This manual explains how to define and use a PrunaModel.

from pruna import smash, SmashConfig
from diffusers import StableDiffusionPipeline
from pruna import PrunaModel
# prepare the base model
base_model = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4"
)

# Create and smash your model
smash_config = SmashConfig()
smash_config["cacher"] = "deepcache"
smash_config["compiler"] = "diffusers2"
smashed_model = smash(model=base_model, smash_config=smash_config)

# Save the model
smashed_model.save_pretrained("saved_model/")

# Load the model
smashed_model = PrunaModel.from_pretrained("saved_model/")

# Run inference
prompt = "a fruit basket"
smashed_model(prompt).images[0]

Usage examples InferenceHandler

The InferenceHandler is a helper class that is used to run inference on a model and is attached to a PrunaModel.

from pruna import PrunaModel

smashed_model = PrunaModel.from_pretrained("saved_model/")
inference_handler = smashed_model.model_args.update(
    {"arg_1": 1, "arg_2": 0.0}
)

Class API PrunaModel

class PrunaModel(model, smash_config=None)

A pruna class wrapping any model.

Parameters:
  • model (Any) – The model to be held by this class.

  • smash_config (SmashConfig | None) – Smash configuration.

__call__(*args, **kwargs)

Call the smashed model.

Parameters:
  • *args (Any) – Arguments to pass to the model.

  • **kwargs (Any) – Additional keyword arguments to pass to the model.

Returns:

The output of the model’s prediction.

Return type:

Any

__init__(model, smash_config=None)
Parameters:
Return type:

None

destroy()

Destroy model.

Return type:

None

static from_pretrained(pretrained_model_name_or_path=None, *, model_path=None, verbose=False, revision=None, cache_dir=None, local_dir=None, library_name=None, library_version=None, user_agent=None, proxies=None, etag_timeout=10, force_download=False, token=None, local_files_only=False, allow_patterns=None, ignore_patterns=None, max_workers=8, tqdm_class=None, headers=None, endpoint=None, local_dir_use_symlinks='auto', resume_download=None, **kwargs)

Load a PrunaModel from a local path or from the Hugging Face Hub.

Parameters:
  • pretrained_model_name_or_path (str, optional) – The path to the model directory or the repository ID on the Hugging Face Hub.

  • model_path (str, optional) – Deprecated. Use pretrained_model_name_or_path instead.

  • verbose (bool, optional) – Whether to apply warning filters to suppress warnings. Defaults to False.

  • revision (str | None, optional) – The revision of the model to load.

  • cache_dir (str | Path | None, optional) – The directory to cache the model in.

  • local_dir (str | Path | None, optional) – The local directory to save the model in.

  • library_name (str | None, optional) – The name of the library to use to load the model.

  • library_version (str | None, optional) – The version of the library to use to load the model.

  • user_agent (str | Dict | None, optional) – The user agent to use to load the model.

  • proxies (Dict | None, optional) – The proxies to use to load the model.

  • etag_timeout (float, optional) – The timeout for the etag.

  • force_download (bool, optional) – Whether to force the download of the model.

  • token (str | bool | None, optional) – The token to use to access the repository.

  • local_files_only (bool, optional) – Whether to only load the model from the local files.

  • allow_patterns (List[str] | str | None, optional) – The patterns to allow to load the model.

  • ignore_patterns (List[str] | str | None, optional) – The patterns to ignore to load the model.

  • max_workers (int, optional) – The maximum number of workers to use to load the model.

  • tqdm_class (tqdm | None, optional) – The tqdm class to use to load the model.

  • headers (Dict[str, str] | None, optional) – The headers to use to load the model.

  • endpoint (str | None, optional) – The endpoint to use to load the model.

  • local_dir_use_symlinks (bool | Literal["auto"], optional) – Whether to use symlinks to load the model.

  • resume_download (bool | None, optional) – Whether to resume the download of the model.

  • **kwargs (Any, optional) – Additional keyword arguments to pass to the model loading function.

Returns:

The loaded PrunaModel instance.

Return type:

PrunaModel

get_nn_modules()

Get the nn.Module instances in the model.

Returns:

A dictionary of the nn.Module instances in the model.

Return type:

dict[str | None, torch.nn.Module]

move_to_device(device)

Move the model to a specific device.

Parameters:

device (str) – The device to move the model to.

Return type:

None

run_inference(batch, device=None)

Run inference on the model.

Parameters:
  • batch (Any) – The batch to run inference on.

  • device (torch.device | str | None) – The device to run inference on. If None, the best available device will be used.

Returns:

The processed output.

Return type:

Any

save_pretrained(model_path)

Save the smashed model to the specified model path.

Parameters:

model_path (str) – The path to the directory where the model will be saved.

Return type:

None

set_to_eval()

Set the model to evaluation mode.

Return type:

None

Class API InferenceHandler

class StandardHandler(model_args=None)

Handle inference arguments, inputs and outputs for unhandled model types.

Standard handler expectations: - The model should accept β€˜x’ as input, where β€˜x’ is the first element of a two-element data batch. - Invoke the model using model(x) without additional parameters. - Outputs should be directly processable without further modification.

Parameters:

model_args (Dict[str, Any]) – The arguments to pass to the model.

class DiffuserHandler(call_signature, model_args=None)

Handle inference arguments, inputs and outputs for diffusers models.

A generator with a fixed seed (42) is passed as an argument to the model for reproducibility. The first element of the batch is passed as input to the model. The generated outputs are expected to have .images attribute.

Parameters:
  • call_signature (inspect.Signature) – The signature of the call to the model.

  • model_args (Dict[str, Any]) – The arguments to pass to the model.

class TransformerHandler(model_args=None)

Handle inference arguments, inputs and outputs for transformer models.

The first element of the batch is passed as input to the model. The generated outputs are expected to have .logits attribute.

Parameters:

model_args (Dict[str, Any]) – The arguments to pass to the model.