Flux generation in a heartbeat, literally (Pro)

Open In Colab

This tutorial demonstrates how to use the pruna package to optimize your Flux model for faster inference. Any execution times given below are measured on an A100 GPU.

This tutorial smashes the Flux model on GPU for faster inference, which will require an A100 or comparable GPUs to run.

1. Loading the Flux Model

First, load your Flux model.

[ ]:
import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", cache_dir="/efs/hf_cache", torch_dtype=torch.bfloat16).to("cuda")
# pipe.enable_model_cpu_offload() # save some VRAM by offloading the model to CPU. Remove this if you have enough GPU memory

2. Initializing the Smash Config

Next, initialize the smash_config.

[ ]:
from pruna_pro import smash, SmashConfig

# Initialize the SmashConfig
smash_config = SmashConfig()
smash_config['cacher'] = 'periodic'
smash_config['periodic_cache_interval'] = 2
smash_config['periodic_start_step'] = 4
smash_config['compiler'] = 'torch_compile'

3. Smashing the Model

Now, you can smash the model, which can take up to 2 minutes. Don’t forget to replace the token by the one provided by PrunaAI.

[ ]:
pipe = smash(
    model=pipe,
    token="<your_pruna_token>",
    smash_config=smash_config,
)

4. Running the Model

After the model has been compiled, we run inference for a few iterations as warm-up. You can remove torch_compile from the compiler argument in the smash_config if you prefer instant speed-up without warm-up iterations.

[ ]:
prompt = "An anime illustration of Sydney Opera House sitting next to Eiffel tower, under a blue night sky of roiling energy, exploding yellow stars, and radiating swirls of blue."

for _ in range(5):
    pipe(prompt, num_inference_steps=50).images[0]

Run the model to generate images with accelerated inference.

[ ]:
pipe(prompt, num_inference_steps=50).images[0]

Wrap Up

Congratulations! You have successfully smashed a Flux model. Enjoy the speed-up!