Distributing Flux on Multiple GPUs
In this tutorial, we will walk you through how to use the pruna package to optimize your Flux model for faster inference on multiple GPUs. Any execution times below are measured on a set of 2 H100 PCIes. Note that the pruna distributers are also compatible with torchrun, simply convert this tutorial to a script and run with torchrun --nproc_per_node=2 flux_tutorial.py.
1. Loading the Flux Model
First, load your Flux model.
[ ]:
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.to("cuda")
2. Initializing the Smash Config
Next, initialize the smash_config. For this tutorial, we will select our ring_attn distributer and torch_compile. If this is not enough for you, you can play around with additionally activating e.g. the quantizer, factorizer and pruner below!
[ ]:
from pruna import SmashConfig, smash
# Initialize the SmashConfig and configure the algorithms
smash_config = SmashConfig(["ring_attn", "torch_compile"])
# Additionally configure suitable hyperparameters
smash_config.add({
"torch_compile_target": "module_list"
})
# You can choose to activate further algorithms compatible with the ring_attn distributer!
# smash_config.add(["qkv_diffusers", "padding_pruning"])
3. Smashing the Model
Now, you can smash the model, which can take up to one minute.
[ ]:
pipe = smash(
model=pipe,
smash_config=smash_config,
)
4. Running the Model
After the model has been distributed and compiled, we run inference for a few iterations as warm-up. The initial inference time of 10.4 seconds has now been reduced to around 2.7 seconds!
[ ]:
prompt = (
"An anime illustration of Sydney Opera House sitting next to Eiffel tower, under a blue night sky of "
"roiling energy, exploding yellow stars, and radiating swirls of blue."
)
for _ in range(5):
output = pipe(prompt, num_inference_steps=50).images[0]
output
5. Clean-Up
To properly clean up the distributed model, make sure to call the destroy method.
[ ]:
pipe.destroy()
Wrap Up
Congratulations! You have successfully distributed a Flux model on multiple GPUs and combined it with other pruna algorithms - it is that easy.