Run your Flux model with half the memory

This tutorial demonstrates how to use the pruna package to optimize your Flux model for memory consumption.

This tutorial smashes the Flux model on CPU, which will require around 28GB of memory. As the example inference is run on GPU with the smashed model, a GPU with around 24 GB VRAM is sufficient when using 4bit quantization.

1. Loading the Flux Model

First, load your Flux model.

[ ]:
import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)

2. Initializing the Smash Config

Next, initialize the smash_config. You can uncomment the torch_compile line to additionally enable 50% speed up.

[ ]:
from pruna import SmashConfig

# Initialize the SmashConfig
smash_config = SmashConfig()
# smash_config['compiler'] = 'torch_compile'
smash_config['quantizer'] = 'hqq_diffusers'
smash_config['hqq_diffusers_weight_bits'] = 4  # or 2, 4, 8

3. Smashing the Model

Now, you can smash the model

[ ]:
from pruna import smash

pipe = smash(
    model=pipe,
    smash_config=smash_config,
).to("cuda")

4. Running the Model

Finally, run the model to generate the image.

[ ]:
prompt = "A cat holding a sign that says hello world"
pipe(
    prompt,
    guidance_scale=0.0,
    num_inference_steps=4,
    max_sequence_length=256,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]

Wrap Up

Congratulations! You have successfully smashed a Flux model. Enjoy the smaller memory footprint!