Run your Flux model without an A100

This tutorial demonstrates how to use the pruna package to optimize your Flux model for memory consumption.

This tutorial smashes the Flux model on CPU, which will require around 28GB of memory. As the example inference is run on GPU with the smashed model, a GPU with around 18 GB VRAM is sufficient (15GB for 4bit quantization, 11GB for 4bit quantization with additional memory savings).

1. Loading the Flux Model

First, load your Flux model.

[ ]:
import torch

model_id = "black-forest-labs/FLUX.1-schnell"
model_revision = "refs/pr/1"
text_model_id = "openai/clip-vit-large-patch14"
model_data_type = torch.bfloat16
[ ]:
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL

tokenizer = CLIPTokenizer.from_pretrained(
    text_model_id, torch_dtype=model_data_type)
text_encoder = CLIPTextModel.from_pretrained(
    text_model_id, torch_dtype=model_data_type)

# 2
tokenizer_2 = T5TokenizerFast.from_pretrained(
    model_id, subfolder="tokenizer_2", torch_dtype=model_data_type,
    revision=model_revision)
text_encoder_2 = T5EncoderModel.from_pretrained(
    model_id, subfolder="text_encoder_2", torch_dtype=model_data_type,
    revision=model_revision)

# Transformers
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
    model_id, subfolder="scheduler", revision=model_revision)
transformer = FluxTransformer2DModel.from_pretrained(
    model_id, subfolder="transformer", torch_dtype=model_data_type,
    revision=model_revision)

# VAE
vae = AutoencoderKL.from_pretrained(
    model_id, subfolder="vae", torch_dtype=model_data_type,
    revision=model_revision)

2. Initializing the Smash Config

Next, initialize the smash_config.

[ ]:
from pruna import SmashConfig

# Initialize the SmashConfig
smash_config = SmashConfig()
smash_config['quantizers'] = ['quanto']
smash_config['quant_quanto_calibrate'] = False
smash_config['quant_quanto_weight_bits'] = 'qfloat8' # or "qint2", "qint4", "qint8"

3. Smashing the Model

Now, you can smash the model, which will take around 4 minutes. Don’t forget to replace the token by the one provided by PrunaAI.

[ ]:
from pruna import smash

transformer = smash(
    model=transformer,
    token="<your_token>",  # replace <your_token> with your actual token or set to None if you do not have one yet
    smash_config=smash_config,
)
text_encoder_2 = smash(
    model=text_encoder_2,
    token="<your_token>",  # replace <your_token> with your actual token or set to None if you do not have one yet
    smash_config=smash_config,
)

4. Running the Model

Finally, run the model to generate the image. Note that moving the modules to the GPU can take some time.

[ ]:
pipe = FluxPipeline(
    scheduler=scheduler,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    text_encoder_2=text_encoder_2,
    tokenizer_2=tokenizer_2,
    vae=vae,
    transformer=transformer
)
[ ]:
pipe.text_encoder.to('cuda')
pipe.vae.to('cuda')
pipe.transformer.to('cuda')
pipe.text_encoder_2.to('cuda')
[ ]:
# For added memory savings run this block, there is however a trade-off with speed.
vae.enable_tiling()
vae.enable_slicing()
pipe.enable_sequential_cpu_offload()
[ ]:
prompt = "A cat holding a sign that says hello world"
pipe(
    prompt,
    guidance_scale=0.0,
    num_inference_steps=4,
    max_sequence_length=256,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]

Wrap Up

Congratulations! You have successfully smashed a Flux model. Enjoy the smaller memory footprint!