Saving and Loading Pruna Models
After smashing a model using pruna, you can save it to disk and load it later using the built-in save and load functionality.
Saving and Loading Models
To save a smashed model, use the PrunaModel.save_pretrained()
or PrunaModel.save_to_hub()
method. This method saves all necessary model files and as well as the smash configuration to the specified directory:
from pruna import smash, SmashConfig
from diffusers import StableDiffusionPipeline
# prepare the base model
base_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
# Create and smash your model
smash_config = SmashConfig()
smash_config["cacher"] = "deepcache"
smash_config["compiler"] = "diffusers2"
smashed_model = smash(model=base_model, smash_config=smash_config)
# Save the model
smashed_model.save_pretrained("saved_model/")
from pruna import smash, SmashConfig
from diffusers import StableDiffusionPipeline
# prepare the base model
base_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
# Create and smash your model
smash_config = SmashConfig()
smash_config["cacher"] = "deepcache"
smash_config["compiler"] = "diffusers2"
smashed_model = smash(model=base_model, smash_config=smash_config)
# Save the model
smashed_model.save_to_hub("PrunaAI/smashed-stable-diffusion-v1-4")
The save operation will:
Save the model weights and architecture, including information on how to load the model later on
Save the
smash_config
(including tokenizer and processor if present, data will be detached and not reloaded)
To load a previously saved PrunaModel
, use the PrunaModel.from_pretrained()
or PrunaModel.from_hub()
class method:
from pruna import PrunaModel
loaded_model = PrunaModel.from_pretrained("saved_model/")
from pruna import PrunaModel
loaded_model = PrunaModel.from_hub("PrunaAI/smashed-stable-diffusion-v1-4")
The load operation will: 1. Load the model architecture and weights and cast them to the device specified in the SmashConfig 2. Restore the smash configuration
Special Considerations
Loading Keyword Arguments
We generally recommend to load the smashed model in the same configuration as the base model, in particular if the two should be compared in terms of efficiency and quality. So, when the base model was loaded with e.g. a specific precision:
import torch
from diffusers import StableDiffusionPipeline
base_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
you should also load the smashed model as follows:
from pruna import PrunaModel
loaded_model = PrunaModel.from_pretrained("saved_model/", torch_dtype=torch.float16)
Depending on the saving function of the algorithm combination not all keyword arguments are required for loading (e.g. some are set by the algorithm combination itself). In that case, we discard and log a warning about unused keyword arguments.
Algorithm Reapplication
Some algorithms, particularly compilers and certain quantization methods, need to be reapplied after loading, as, for example, a compiled model can be rarely saved in its compiled state. This happens automatically during the loading process based on the saved configuration and does not add a significant time overhead.
Warning Suppression
Set verbose=True
when loading if you want to see warning messages as well as logs (in particular about reapplication of algorithms) that are by default suppressed:
from pruna import PrunaModel
loaded_model = PrunaModel.from_pretrained("saved_model/", verbose=True)
PrunaModel
Function Documentation
- class pruna.engine.pruna_model.PrunaModel(model, smash_config=None)
A pruna class wrapping any model.
- Parameters:
model (Any) – The model to be held by this class.
smash_config (SmashConfig | None) – Smash configuration.
- static from_hub(repo_id, revision=None, cache_dir=None, local_dir=None, library_name=None, library_version=None, user_agent=None, proxies=None, etag_timeout=10, force_download=False, token=None, local_files_only=False, allow_patterns=None, ignore_patterns=None, max_workers=8, tqdm_class=None, headers=None, endpoint=None, local_dir_use_symlinks='auto', resume_download=None, **kwargs)
Load a PrunaModel from the specified repository.
- Parameters:
repo_id (str) – The repository ID to load the model from.
revision (str | None, optional) – The revision of the model to load.
cache_dir (str | Path | None, optional) – The directory to cache the model in.
local_dir (str | Path | None, optional) – The local directory to save the model in.
library_name (str | None, optional) – The name of the library to use to load the model.
library_version (str | None, optional) – The version of the library to use to load the model.
user_agent (str | Dict | None, optional) – The user agent to use to load the model.
proxies (Dict | None, optional) – The proxies to use to load the model.
etag_timeout (float, optional) – The timeout for the etag.
force_download (bool, optional) – Whether to force the download of the model.
token (str | bool | None, optional) – The token to use to access the repository.
local_files_only (bool, optional) – Whether to only load the model from the local files.
allow_patterns (List[str] | str | None, optional) – The patterns to allow to load the model.
ignore_patterns (List[str] | str | None, optional) – The patterns to ignore to load the model.
max_workers (int, optional) – The maximum number of workers to use to load the model.
tqdm_class (tqdm | None, optional) – The tqdm class to use to load the model.
headers (Dict[str, str] | None, optional) – The headers to use to load the model.
endpoint (str | None, optional) – The endpoint to use to load the model.
local_dir_use_symlinks (bool | Literal["auto"], optional) – Whether to use symlinks to load the model.
resume_download (bool | None, optional) – Whether to resume the download of the model.
**kwargs (Any, optional) – Additional keyword arguments to pass to the model loading function.
- Returns:
The loaded PrunaModel instance.
- Return type:
- static from_pretrained(model_path, verbose=False, **kwargs)
Load a PrunaModel from the specified model path.
- Parameters:
model_path (str) – The path to the model directory containing necessary configuration and model files.
verbose (bool, optional) – Whether to apply warning filters to suppress warnings. Defaults to False.
**kwargs (dict) – Additional keyword arguments to pass to the model loading function, such as specific settings or parameters.
- Returns:
The loaded PrunaModel instance.
- Return type:
- save_pretrained(model_path)
Save the smashed model to the specified model path.
- Parameters:
model_path (str) – The path to the directory where the model will be saved.
- Return type:
None
- save_to_hub(repo_id, *, model_path=None, revision=None, private=False, allow_patterns=None, ignore_patterns=None, num_workers=None, print_report=False, print_report_every=0)
Save the model to the specified repository.
- Parameters:
repo_id (str) – The repository ID to save the model to.
model_path (str | None) – The path to the directory where the model will be saved. If None, the model will only be saved to the the Hugging Face Hub.
revision (str | None) – The revision to save the model to.
private (bool) – Whether to save the model as a private repository.
allow_patterns (List[str] | str | None) – The patterns to allow to save the model.
ignore_patterns (List[str] | str | None) – The patterns to ignore to save the model.
num_workers (int | None) – The number of workers to use to save the model.
print_report (bool) – Whether to print the report of the saved model.
print_report_every (int) – The number of steps to print the report of the saved model.
- Return type:
None