Reducing warm-up time for compilation

Open In Colab

In this tutorial, we will walk you through how to use the pruna package to compile your model in a way that reduces warm-up time significantly when re-loading the model on a new machine. Please be aware that as of now, this tutorial will only apply to re-loading the model on a new machine with identical hardware as the machine it was compiled on. The provided inference and compilation times were measured on an NVIDIA H100.

0. Setup

In a first step, we will do a brief setup to mimic the loading of a compiled model on a new machine. To do so, we will specifically set the torch inductor cache s.t. we can delete it later.

[ ]:
import os

cache_dir = "temp_cache_dir/"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir

1. Load the model

We are now ready to load the model we want to compile. In this case, we will use a stable diffusion pipeline to both apply caching and compilation to to showcase the support of portable compilation with other algorithms in pruna. Of course, only compiling with torch_compile is also supported.

[ ]:
import torch
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"

2. Smash the model

Next, we define the Smashconfig and smash the model.

[ ]:
from pruna import SmashConfig, smash

smash_config = SmashConfig()
smash_config["compiler"] = "torch_compile"
smash_config["cacher"] = "deepcache"
smash_config["torch_compile_make_portable"] = True

pipe = smash(pipe, smash_config=smash_config)

3. Run and save the compiled model

We now run the model and observe both the time it takes for the first warm-up inference, in this example approximately 50 seconds. In the subsequent runs, we can then see the runtime of the compiled model.

[ ]:
import time

for _ in range(2):
    start = time.time()
    pipe(prompt)
    print(f"Time taken: {time.time() - start} seconds")

pipe.save_pretrained("smashed_model/")

4. Simulate move to a new machine

Next, we will delete the compilation cache directory to mimic moving to a new machine. After that, please restart your kernel or process and continue the tutorial.

[ ]:
import shutil

shutil.rmtree(cache_dir)

5. Load the model

We can now load the model and check that the warm-up time has significantly reduced!

[ ]:
import time

import torch

from pruna import PrunaModel, SmashConfig, smash

pipe = PrunaModel.from_pretrained("smashed_model", torch_dtype=torch.float16)
prompt = "a photo of an astronaut riding a horse on mars"

for _ in range(2):
    start = time.time()
    pipe(prompt)
    print(f"Time taken: {time.time() - start} seconds")

Wrap Up

Congratulations! You have successfully smashed a model with portable compilation. The only parts that you should modify are step 1 and step 2 to fit your use case.