PrunaModel

The PrunaModel is the main class in pruna. It is used to load your model and apply the algorithms as well as running inference.

Usage examples PrunaModel

This manual explains how to define and use a PrunaModel.

from pruna import smash, SmashConfig
from diffusers import StableDiffusionPipeline
from pruna import PrunaModel
# prepare the base model
base_model = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4"
)

# Create and smash your model
smash_config = SmashConfig()
smash_config["cacher"] = "deepcache"
smash_config["compiler"] = "diffusers2"
smashed_model = smash(model=base_model, smash_config=smash_config)

# Save the model
smashed_model.save_pretrained("saved_model/")

# Load the model
smashed_model = PrunaModel.from_pretrained("saved_model/")

# Run inference
prompt = "a fruit basket"
smashed_model(prompt).images[0]

Usage examples InferenceHandler

The InferenceHandler is a helper class that is used to run inference on a model and is attached to a PrunaModel.

from pruna import PrunaModel

smashed_model = PrunaModel.from_pretrained("saved_model/")
inference_handler = smashed_model.model_args.update(
    {"arg_1": 1, "arg_2": 0.0}
)

Class API PrunaModel

class PrunaModel(model, smash_config=None)

Bases: object

A pruna class wrapping any model.

Parameters:
  • model (Any) – The model to be held by this class.

  • smash_config (SmashConfig | None) – Smash configuration.

__init__(model, smash_config=None)
Parameters:
Return type:

None

__call__(*args, **kwargs)

Call the smashed model.

Parameters:
  • *args (Any) – Arguments to pass to the model.

  • **kwargs (Any) – Additional keyword arguments to pass to the model.

Returns:

The output of the model’s prediction.

Return type:

Any

run_inference(batch)

Run inference on the model.

Parameters:

batch (Any) – The batch to run inference on.

Returns:

The processed output.

Return type:

Any

is_instance(instance_type)

Compare the model to the given instance type.

Parameters:

instance_type (Any) – The type to compare the model to.

Returns:

True if the model is an instance of the given type, False otherwise.

Return type:

bool

get_nn_modules()

Get the nn.Module instances in the model.

Returns:

A dictionary of the nn.Module instances in the model.

Return type:

dict[str | None, torch.nn.Module]

set_to_eval()

Set the model to evaluation mode.

Return type:

None

save_pretrained(model_path)

Save the smashed model to the specified model path.

Parameters:

model_path (str) – The path to the directory where the model will be saved.

Return type:

None

push_to_hub(repo_id, *, model_path=None, revision=None, private=False, allow_patterns=None, ignore_patterns=None, num_workers=None, print_report=False, print_report_every=0, hf_token=None)

Push the model to the specified repository.

Parameters:
  • repo_id (str) – The repository ID to push the model to.

  • model_path (str | None) – The path to the directory where the model will be saved. If None, the model will only be saved to the the Hugging Face Hub.

  • revision (str | None) – The revision to push the model to.

  • private (bool) – Whether to push the model as a private repository.

  • allow_patterns (List[str] | str | None) – The patterns to allow to push the model.

  • ignore_patterns (List[str] | str | None) – The patterns to ignore to push the model.

  • num_workers (int | None) – The number of workers to use to push the model.

  • print_report (bool) – Whether to print the report of the pushed model.

  • print_report_every (int) – The number of steps to print the report of the pushed model.

  • hf_token (str | None) – The Hugging Face token to use for authentication to push models to the Hub.

Return type:

None

static from_pretrained(pretrained_model_name_or_path=None, *, model_path=None, verbose=False, revision=None, cache_dir=None, local_dir=None, library_name=None, library_version=None, user_agent=None, proxies=None, etag_timeout=10, force_download=False, token=None, local_files_only=False, allow_patterns=None, ignore_patterns=None, max_workers=8, tqdm_class=None, headers=None, endpoint=None, local_dir_use_symlinks='auto', resume_download=None, **kwargs)

Load a PrunaModel from a local path or from the Hugging Face Hub.

Parameters:
  • pretrained_model_name_or_path (str, optional) – The path to the model directory or the repository ID on the Hugging Face Hub.

  • model_path (str, optional) – Deprecated. Use pretrained_model_name_or_path instead.

  • verbose (bool, optional) – Whether to apply warning filters to suppress warnings. Defaults to False.

  • revision (str | None, optional) – The revision of the model to load.

  • cache_dir (str | Path | None, optional) – The directory to cache the model in.

  • local_dir (str | Path | None, optional) – The local directory to save the model in.

  • library_name (str | None, optional) – The name of the library to use to load the model.

  • library_version (str | None, optional) – The version of the library to use to load the model.

  • user_agent (str | Dict | None, optional) – The user agent to use to load the model.

  • proxies (Dict | None, optional) – The proxies to use to load the model.

  • etag_timeout (float, optional) – The timeout for the etag.

  • force_download (bool, optional) – Whether to force the download of the model.

  • token (str | bool | None, optional) – The token to use to access the repository.

  • local_files_only (bool, optional) – Whether to only load the model from the local files.

  • allow_patterns (List[str] | str | None, optional) – The patterns to allow to load the model.

  • ignore_patterns (List[str] | str | None, optional) – The patterns to ignore to load the model.

  • max_workers (int, optional) – The maximum number of workers to use to load the model.

  • tqdm_class (tqdm | None, optional) – The tqdm class to use to load the model.

  • headers (Dict[str, str] | None, optional) – The headers to use to load the model.

  • endpoint (str | None, optional) – The endpoint to use to load the model.

  • local_dir_use_symlinks (bool | Literal["auto"], optional) – Whether to use symlinks to load the model.

  • resume_download (bool | None, optional) – Whether to resume the download of the model.

  • **kwargs (Any, optional) – Additional keyword arguments to pass to the model loading function.

Returns:

The loaded PrunaModel instance.

Return type:

PrunaModel

destroy()

Destroy model.

Return type:

None

Class API InferenceHandler

class StandardHandler(model_args=None)

Bases: InferenceHandler

Handle inference arguments, inputs and outputs for unhandled model types.

Standard handler expectations: - The model should accept β€˜x’ as input, where β€˜x’ is the first element of a two-element data batch. - Invoke the model using model(x) without additional parameters. - Outputs should be directly processable without further modification.

Parameters:

model_args (Dict[str, Any]) – The arguments to pass to the model.

prepare_inputs(batch)

Prepare the inputs for the model.

Parameters:

batch (List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor | dict[str, Any], ...] | dict[str, Any]) – The batch to prepare the inputs for.

Returns:

The prepared inputs.

Return type:

Any

process_output(output)

Handle the output of the model.

Parameters:

output (Any) – The output to process.

Returns:

The processed output.

Return type:

Any

log_model_info()

Log information about the inference handler.

Return type:

None

class DiffuserHandler(call_signature, model_args=None)

Bases: InferenceHandler

Handle inference arguments, inputs and outputs for diffusers models.

A generator with a fixed seed (42) is passed as an argument to the model for reproducibility. The first element of the batch is passed as input to the model. The generated outputs are expected to have .images attribute.

Parameters:
  • call_signature (inspect.Signature) – The signature of the call to the model.

  • model_args (Dict[str, Any]) – The arguments to pass to the model.

prepare_inputs(batch)

Prepare the inputs for the model.

Parameters:

batch (List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor | dict[str, Any], ...] | dict[str, Any]) – The batch to prepare the inputs for.

Returns:

The prepared inputs.

Return type:

Any

process_output(output)

Handle the output of the model.

Parameters:

output (Any) – The output to process.

Returns:

The processed images.

Return type:

torch.Tensor

log_model_info()

Log information about the inference handler.

Return type:

None

class TransformerHandler(model_args=None)

Bases: InferenceHandler

Handle inference arguments, inputs and outputs for transformer models.

The first element of the batch is passed as input to the model. The generated outputs are expected to have .logits attribute.

Parameters:

model_args (Dict[str, Any]) – The arguments to pass to the model.

prepare_inputs(batch)

Prepare the inputs for the model.

Parameters:

batch (List[str] | torch.Tensor | Tuple[List[str] | torch.Tensor | dict[str, Any], ...] | dict[str, Any]) – The batch to prepare the inputs for.

Returns:

The prepared inputs.

Return type:

Any

process_output(output)

Handle the output of the model.

Parameters:

output (Any) – The output to process.

Returns:

The processed output.

Return type:

Any

log_model_info()

Log information about the inference handler.

Return type:

None