Speedup and Quantize any Diffusion Model

Open In Colab

This tutorial demonstrates how to use the pruna package to optimize both the latency and the memory footprint of any diffusion model from the diffusers package. We will use the Flux Dev model as an example, but this tutorial is working on any stable diffusion or flux model.

[ ]:
# if you are not running the latest version of this tutorial, make sure to install the matching version of pruna
# the following command will install the latest version of pruna
%pip install pruna

1. Loading the Diffusion Model

First, load your diffusion model.

[ ]:
import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cuda")

2. Initializing the Smash Config

Next, initialize the smash_config (we make use, here, of the hqq-diffusers and torch-compile algorithms).

[2]:
from pruna import SmashConfig

smash_config = SmashConfig()
smash_config["compiler"] = "torch_compile"
smash_config["quantizer"] = "hqq_diffusers"
# smash_config['torch_compile_mode'] = 'max-autotune' # Uncomment to enable extra speedups

3. Smashing the Model

Now, smash the model. This can take up to 30 seconds.

[ ]:
from pruna import smash

# Smash the model
pipe = smash(
    model=pipe,
    smash_config=smash_config,
)

4. Running the Model

Finally, run the model to generate the image, note there will be a warmup the first time you run it.

[ ]:
# Define the prompt
prompt = "a smiling cat dancing on a table. Miyazaki style"

# Display the result
pipe(prompt).images[0]
The Kernel crashed while executing code in the current cell or a previous cell.

Please review the code in the cell(s) to identify a possible cause of the failure.

Click <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info.

View Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details.

Wrap Up

Congratulations! You’ve optimized a diffusion model using HQQ quantization and TorchCompile! The quantized model uses less memory and runs faster while maintaining good quality. You can try different settings like weight bits and group size to find the best balance between size and quality.

Want more optimization techniques? Check out our other tutorials!