Distributing Flux on Multiple GPUs (Pro)
In this tutorial, we will walk you through how to use the pruna_pro
package to optimize your Flux model for faster inference on multiple GPUs. Any execution times below are measured on a set of 2 H100 PCIes. Note that the pruna_pro
distributers are also compatible with torchrun
, simply convert this tutorial to a script and run with torchrun --nproc_per_node=2 flux_tutorial.py
.
1. Loading the Flux Model
First, load your Flux model.
[ ]:
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.to("cuda")
2. Initializing the Smash Config
Next, initialize the smash_config
. For this tutorial, we will select our ring_attn
distributer, torch_compile
and the taylor
cacher. If this is not enough for you, you can play around with additionally activating e.g. the quantizer, factorizer and pruner below!
[ ]:
from pruna_pro import SmashConfig, smash
# Initialize the SmashConfig and configure the algorithms
smash_config = SmashConfig()
smash_config["distributer"] = "ring_attn"
smash_config["cacher"] = "auto"
smash_config["compiler"] = "torch_compile"
# Additionally configure suitable hyperparameters
smash_config["auto_cache_mode"] = "taylor"
smash_config["torch_compile_target"] = "module_list"
# You can choose to activate further algorithms compatible with the ring_attn distributer!
# smash_config["factorizer"] = "qkv_diffusers"
# smash_config["quantizer"] = "fp8"
# smash_config["pruner"] = "padding_pruning"
3. Smashing the Model
Now, you can smash the model, which can take up to one minute. Don’t forget to replace the token by the one provided by PrunaAI.
[ ]:
pipe = smash(
model=pipe,
token="<your_pruna_token>",
smash_config=smash_config,
)
4. Running the Model
After the model has been distributed and compiled, we run inference for a few iterations as warm-up. The initial inference time of 10.4 seconds has now been reduced to around 2.7 seconds!
[ ]:
prompt = (
"An anime illustration of Sydney Opera House sitting next to Eiffel tower, under a blue night sky of "
"roiling energy, exploding yellow stars, and radiating swirls of blue."
)
for _ in range(5):
output = pipe(prompt, num_inference_steps=50).images[0]
output
5. Clean-Up
To properly clean up the distributed model, make sure to call the destroy
method.
[ ]:
pipe.destroy()
Wrap Up
Congratulations! You have successfully distributed a Flux model on multiple GPUs and combined it with other pruna
algorithms - it is that easy.