PrunaModel
The PrunaModel
is the main class in pruna. It is used to load your model and apply the algorithms as well as running inference.
Usage examples PrunaModel
This manual explains how to define and use a PrunaModel
.
from pruna import smash, SmashConfig
from diffusers import StableDiffusionPipeline
from pruna import PrunaModel
# prepare the base model
base_model = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4"
)
# Create and smash your model
smash_config = SmashConfig()
smash_config["cacher"] = "deepcache"
smash_config["compiler"] = "diffusers2"
smashed_model = smash(model=base_model, smash_config=smash_config)
# Save the model
smashed_model.save_pretrained("saved_model/")
# Load the model
smashed_model = PrunaModel.from_pretrained("saved_model/")
# Run inference
prompt = "a fruit basket"
smashed_model(prompt).images[0]
Usage examples InferenceHandler
The InferenceHandler
is a helper class that is used to run inference on a model and is attached to a PrunaModel
.
from pruna import PrunaModel
smashed_model = PrunaModel.from_pretrained("saved_model/")
inference_handler = smashed_model.model_args.update(
{"arg_1": 1, "arg_2": 0.0}
)
Class API PrunaModel
- class PrunaModel(model, smash_config=None)
A pruna class wrapping any model.
- Parameters:
model (Any) β The model to be held by this class.
smash_config (SmashConfig | None) β Smash configuration.
- __call__(*args, **kwargs)
Call the smashed model.
- Parameters:
*args (Any) β Arguments to pass to the model.
**kwargs (Any) β Additional keyword arguments to pass to the model.
- Returns:
The output of the modelβs prediction.
- Return type:
Any
- __init__(model, smash_config=None)
- Parameters:
model (Any)
smash_config (SmashConfig | None)
- Return type:
None
- static from_pretrained(pretrained_model_name_or_path=None, *, model_path=None, verbose=False, revision=None, cache_dir=None, local_dir=None, library_name=None, library_version=None, user_agent=None, proxies=None, etag_timeout=10, force_download=False, token=None, local_files_only=False, allow_patterns=None, ignore_patterns=None, max_workers=8, tqdm_class=None, headers=None, endpoint=None, local_dir_use_symlinks='auto', resume_download=None, **kwargs)
Load a PrunaModel from a local path or from the Hugging Face Hub.
- Parameters:
pretrained_model_name_or_path (str, optional) β The path to the model directory or the repository ID on the Hugging Face Hub.
model_path (str, optional) β Deprecated. Use pretrained_model_name_or_path instead.
verbose (bool, optional) β Whether to apply warning filters to suppress warnings. Defaults to False.
revision (str | None, optional) β The revision of the model to load.
cache_dir (str | Path | None, optional) β The directory to cache the model in.
local_dir (str | Path | None, optional) β The local directory to save the model in.
library_name (str | None, optional) β The name of the library to use to load the model.
library_version (str | None, optional) β The version of the library to use to load the model.
user_agent (str | Dict | None, optional) β The user agent to use to load the model.
proxies (Dict | None, optional) β The proxies to use to load the model.
etag_timeout (float, optional) β The timeout for the etag.
force_download (bool, optional) β Whether to force the download of the model.
token (str | bool | None, optional) β The token to use to access the repository.
local_files_only (bool, optional) β Whether to only load the model from the local files.
allow_patterns (List[str] | str | None, optional) β The patterns to allow to load the model.
ignore_patterns (List[str] | str | None, optional) β The patterns to ignore to load the model.
max_workers (int, optional) β The maximum number of workers to use to load the model.
tqdm_class (tqdm | None, optional) β The tqdm class to use to load the model.
headers (Dict[str, str] | None, optional) β The headers to use to load the model.
endpoint (str | None, optional) β The endpoint to use to load the model.
local_dir_use_symlinks (bool | Literal["auto"], optional) β Whether to use symlinks to load the model.
resume_download (bool | None, optional) β Whether to resume the download of the model.
**kwargs (Any, optional) β Additional keyword arguments to pass to the model loading function.
- Returns:
The loaded PrunaModel instance.
- Return type:
- get_nn_modules()
Get the nn.Module instances in the model.
- Returns:
A dictionary of the nn.Module instances in the model.
- Return type:
dict[str | None, torch.nn.Module]
- move_to_device(device)
Move the model to a specific device.
- Parameters:
device (str) β The device to move the model to.
- Return type:
None
- run_inference(batch, device=None)
Run inference on the model.
- Parameters:
batch (Any) β The batch to run inference on.
device (torch.device | str | None) β The device to run inference on. If None, the best available device will be used.
- Returns:
The processed output.
- Return type:
Any
Class API InferenceHandler
- class StandardHandler(model_args=None)
Handle inference arguments, inputs and outputs for unhandled model types.
Standard handler expectations: - The model should accept βxβ as input, where βxβ is the first element of a two-element data batch. - Invoke the model using model(x) without additional parameters. - Outputs should be directly processable without further modification.
- Parameters:
model_args (Dict[str, Any]) β The arguments to pass to the model.
- class DiffuserHandler(call_signature, model_args=None)
Handle inference arguments, inputs and outputs for diffusers models.
A generator with a fixed seed (42) is passed as an argument to the model for reproducibility. The first element of the batch is passed as input to the model. The generated outputs are expected to have .images attribute.
- Parameters:
call_signature (inspect.Signature) β The signature of the call to the model.
model_args (Dict[str, Any]) β The arguments to pass to the model.
- class TransformerHandler(model_args=None)
Handle inference arguments, inputs and outputs for transformer models.
The first element of the batch is passed as input to the model. The generated outputs are expected to have .logits attribute.
- Parameters:
model_args (Dict[str, Any]) β The arguments to pass to the model.