Distributing Flux on Multiple GPUs (Pro)

Open In Colab

In this tutorial, we will walk you through how to use the pruna_pro package to optimize your Flux model for faster inference on multiple GPUs. Any execution times below are measured on a set of 2 H100 PCIes. Note that the pruna_pro distributers are also compatible with torchrun, simply convert this tutorial to a script and run with torchrun --nproc_per_node=2 flux_tutorial.py.

1. Loading the Flux Model

First, load your Flux model.

[ ]:
import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.to("cuda")

2. Initializing the Smash Config

Next, initialize the smash_config. For this tutorial, we will select our ring_attn distributer, torch_compile and the taylor cacher. If this is not enough for you, you can play around with additionally activating e.g. the quantizer, factorizer and pruner below!

[ ]:
from pruna_pro import SmashConfig, smash

# Initialize the SmashConfig and configure the algorithms
smash_config = SmashConfig()
smash_config["distributer"] = "ring_attn"
smash_config["cacher"] = "auto"
smash_config["compiler"] = "torch_compile"

# Additionally configure suitable hyperparameters
smash_config["auto_cache_mode"] = "taylor"
smash_config["torch_compile_target"] = "module_list"

# You can choose to activate further algorithms compatible with the ring_attn distributer!
# smash_config["factorizer"] = "qkv_diffusers"
# smash_config["quantizer"] = "fp8"
# smash_config["pruner"] = "padding_pruning"

3. Smashing the Model

Now, you can smash the model, which can take up to one minute. Don’t forget to replace the token by the one provided by PrunaAI.

[ ]:
pipe = smash(
    model=pipe,
    token="<your_pruna_token>",
    smash_config=smash_config,
)

4. Running the Model

After the model has been distributed and compiled, we run inference for a few iterations as warm-up. The initial inference time of 10.4 seconds has now been reduced to around 2.7 seconds!

[ ]:
prompt = (
    "An anime illustration of Sydney Opera House sitting next to Eiffel tower, under a blue night sky of "
    "roiling energy, exploding yellow stars, and radiating swirls of blue."
)

for _ in range(5):
    output = pipe(prompt, num_inference_steps=50).images[0]
output

5. Clean-Up

To properly clean up the distributed model, make sure to call the destroy method.

[ ]:
pipe.destroy()

Wrap Up

Congratulations! You have successfully distributed a Flux model on multiple GPUs and combined it with other pruna algorithms - it is that easy.