Reducing warm-up time for compilation
In this tutorial, we will walk you through how to use the pruna
package to compile your model in a way that reduces warm-up time significantly when re-loading the model on a new machine. Please be aware that as of now, this tutorial will only apply to re-loading the model on a new machine with identical hardware as the machine it was compiled on. The provided inference and compilation times were measured on an
NVIDIA H100.
0. Setup
In a first step, we will do a brief setup to mimic the loading of a compiled model on a new machine. To do so, we will specifically set the torch inductor cache s.t. we can delete it later.
[ ]:
import os
cache_dir = "temp_cache_dir/"
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir
1. Load the model
We are now ready to load the model we want to compile. In this case, we will use a stable diffusion pipeline to both apply caching and compilation to to showcase the support of portable compilation with other algorithms in pruna
. Of course, only compiling with torch_compile
is also supported.
[ ]:
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
2. Smash the model
Next, we define the Smashconfig and smash the model.
[ ]:
from pruna import SmashConfig, smash
smash_config = SmashConfig()
smash_config["compiler"] = "torch_compile"
smash_config["cacher"] = "deepcache"
smash_config["torch_compile_make_portable"] = True
pipe = smash(pipe, smash_config=smash_config)
3. Run and save the compiled model
We now run the model and observe both the time it takes for the first warm-up inference, in this example approximately 50 seconds. In the subsequent runs, we can then see the runtime of the compiled model.
[ ]:
import time
for _ in range(2):
start = time.time()
pipe(prompt)
print(f"Time taken: {time.time() - start} seconds")
pipe.save_pretrained("smashed_model/")
4. Simulate move to a new machine
Next, we will delete the compilation cache directory to mimic moving to a new machine. After that, please restart your kernel or process and continue the tutorial.
[ ]:
import shutil
shutil.rmtree(cache_dir)
5. Load the model
We can now load the model and check that the warm-up time has significantly reduced!
[ ]:
import time
import torch
from pruna import PrunaModel, SmashConfig, smash
pipe = PrunaModel.from_pretrained("smashed_model", torch_dtype=torch.float16)
prompt = "a photo of an astronaut riding a horse on mars"
for _ in range(2):
start = time.time()
pipe(prompt)
print(f"Time taken: {time.time() - start} seconds")
Wrap Up
Congratulations! You have successfully smashed a model with portable compilation. The only parts that you should modify are step 1 and step 2 to fit your use case.