Speedup and Quantize any Diffusion Model
This tutorial demonstrates how to use the pruna
package to optimize both the latency and the memory footprint of any diffusion model from the diffusers package.
We will use the Flux Dev
model as an example, but this tutorial is working on any stable diffusion or flux model.
1. Loading the Diffusion Model
First, load your diffusion model.
[ ]:
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to("cuda")
2. Initializing the Smash Config
Next, initialize the smash_config (we make use, here, of the hqq-diffusers and torch-compile algorithms).
[ ]:
from pruna import SmashConfig
smash_config = SmashConfig()
smash_config['compiler'] = 'torch_compile'
smash_config['quantizer'] = 'hqq_diffusers'
# smash_config['torch_compile_mode'] = 'max-autotune' # Uncomment to enable extra speedups
3. Smashing the Model
Now, smash the model. This can take up to 30 seconds.
[ ]:
from pruna import smash
# Smash the model
pipe = smash(
model=pipe,
smash_config=smash_config,
)
4. Running the Model
Finally, run the model to generate the image, note there will be a warmup the first time you run it.
[ ]:
# Define the prompt
prompt = "a smiling cat dancing on a table. Miyazaki style"
# Display the result
pipe(prompt).images[0]
Wrap Up
Congratulations! You’ve optimized a diffusion model using HQQ quantization and TorchCompile! The quantized model uses less memory and runs faster while maintaining good quality. You can try different settings like weight bits and group size to find the best balance between size and quality.
Want more optimization techniques? Check out our other tutorials!