PrunaModel
The PrunaModel
is the main class in pruna. It is used to load your model and apply the algorithms as well as running inference.
Usage examples PrunaModel
This manual explains how to define and use a PrunaModel
.
from pruna import smash, SmashConfig
from diffusers import StableDiffusionPipeline
from pruna import PrunaModel
# prepare the base model
base_model = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4"
)
# Create and smash your model
smash_config = SmashConfig()
smash_config["cacher"] = "deepcache"
smash_config["compiler"] = "diffusers2"
smashed_model = smash(model=base_model, smash_config=smash_config)
# Save the model
smashed_model.save_pretrained("saved_model/")
# Load the model
smashed_model = PrunaModel.from_pretrained("saved_model/")
# Run inference
prompt = "a fruit basket"
smashed_model(prompt).images[0]
Usage examples InferenceHandler
The InferenceHandler
is a helper class that is used to run inference on a model and is attached to a PrunaModel
.
from pruna import PrunaModel
smashed_model = PrunaModel.from_pretrained("saved_model/")
inference_handler = smashed_model.model_args.update(
{"arg_1": 1, "arg_2": 0.0}
)
Class API PrunaModel
- class PrunaModel(model, smash_config=None)
A pruna class wrapping any model.
- Parameters:
model (Any) β The model to be held by this class.
smash_config (SmashConfig | None) β Smash configuration.
- __call__(*args, **kwargs)
Call the smashed model.
- Parameters:
*args (Any) β Arguments to pass to the model.
**kwargs (Any) β Additional keyword arguments to pass to the model.
- Returns:
The output of the modelβs prediction.
- Return type:
Any
- __init__(model, smash_config=None)
- Parameters:
model (Any)
smash_config (SmashConfig | None)
- Return type:
None
- static from_pretrained(model_path, verbose=False, **kwargs)
Load a PrunaModel from the specified model path.
- Parameters:
model_path (str) β The path to the model directory containing necessary configuration and model files.
verbose (bool, optional) β Whether to apply warning filters to suppress warnings. Defaults to False.
**kwargs (dict) β Additional keyword arguments to pass to the model loading function, such as specific settings or parameters.
- Returns:
The loaded PrunaModel instance.
- Return type:
- get_nn_modules()
Get the nn.Module instances in the model.
- Returns:
A dictionary of the nn.Module instances in the model.
- Return type:
dict[str | None, torch.nn.Module]
- move_to_device(device)
Move the model to a specific device.
- Parameters:
device (str | torch.device) β The device to move the model to.
- Return type:
None
- run_inference(batch, device)
Run inference on the model.
- Parameters:
batch (Tuple[List[str] | torch.Tensor, ...]) β The batch to run inference on.
device (torch.device | str) β The device to run inference on.
- Returns:
The processed output.
- Return type:
Any
Class API InferenceHandler
- class StandardHandler(model_args=None)
Handle inference arguments, inputs and outputs for unhandled model types.
Standard handler expectations: - The model should accept βxβ as input, where βxβ is the first element of a two-element data batch. - Invoke the model using model(x) without additional parameters. - Outputs should be directly processable without further modification.
- Parameters:
model_args (Dict[str, Any]) β The arguments to pass to the model.
- class DiffuserHandler(call_signature, model_args=None)
Handle inference arguments, inputs and outputs for diffusers models.
A generator with a fixed seed (42) is passed as an argument to the model for reproducibility. The first element of the batch is passed as input to the model. The generated outputs are expected to have .images attribute.
- Parameters:
call_signature (inspect.Signature) β The signature of the call to the model.
model_args (Dict[str, Any]) β The arguments to pass to the model.
- class TransformerHandler(model_args=None)
Handle inference arguments, inputs and outputs for transformer models.
The first element of the batch is passed as input to the model. The generated outputs are expected to have .logits attribute.
- Parameters:
model_args (Dict[str, Any]) β The arguments to pass to the model.