Run your Flux model with half the memory
This tutorial demonstrates how to use the pruna
package to optimize your Flux model for memory consumption.
This tutorial smashes the Flux model on CPU, which will require around 28GB of memory. As the example inference is run on GPU with the smashed model, a GPU with around 24 GB VRAM is sufficient when using 4bit quantization.
1. Loading the Flux Model
First, load your Flux model.
[ ]:
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
2. Initializing the Smash Config
Next, initialize the smash_config. You can uncomment the torch_compile
line to additionally enable 50% speed up.
[ ]:
from pruna import SmashConfig
# Initialize the SmashConfig
smash_config = SmashConfig()
# smash_config['compiler'] = 'torch_compile'
smash_config['quantizer'] = 'hqq_diffusers'
smash_config['hqq_diffusers_weight_bits'] = 4 # or 2, 4, 8
3. Smashing the Model
Now, you can smash the model
[ ]:
from pruna import smash
pipe = smash(
model=pipe,
smash_config=smash_config,
).to("cuda")
4. Running the Model
Finally, run the model to generate the image.
[ ]:
prompt = "A cat holding a sign that says hello world"
pipe(
prompt,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
Wrap Up
Congratulations! You have successfully smashed a Flux model. Enjoy the smaller memory footprint!