Flux generation in a heartbeat, literally

Open In Colab

This tutorial demonstrates how to use the pruna package to optimize your Flux model for faster inference.

This tutorial smashes the Flux model on GPU for faster inference, which will require an A100 or comparable GPUs to run.

1. Loading the Flux Model

First, load your Flux model.

[ ]:
import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
# pipe.enable_model_cpu_offload() # save some VRAM by offloading the model to CPU. Remove this if you have enough GPU memory
pipe.to('cuda')

2. Initializing the Smash Config

Next, initialize the smash_config.

[ ]:
from pruna import SmashConfig

# Initialize the SmashConfig
smash_config = SmashConfig()
smash_config['compilers'] = ['onediff']

3. Smashing the Model

Now, smash the model. Don’t forget to replace the token by the one provided by PrunaAI.

[ ]:
from pruna import smash

pipe.transformer = smash(
    model=pipe.transformer,
    token='<your_token>',  # replace <your_token> with your actual token
    smash_config=smash_config,
)

4. Running the Model

Finally, run the model to generate the image. The first execution can take up to 10 minutes. Afterwards, the inference is fast.

[ ]:
prompt = "A cat holding a sign that says hello world"
pipe(
    prompt,
    guidance_scale=0.0,
    num_inference_steps=4,
    max_sequence_length=256,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]

Wrap Up

Congratulations! You have successfully smashed a Flux model. Enjoy the speed-up!