Run your Flux model without an A100
This tutorial demonstrates how to use the pruna
package to optimize your Flux model for memory consumption.
This tutorial smashes the Flux model on CPU, which will require around 28GB of memory. As the example inference is run on GPU with the smashed model, a GPU with around 18 GB VRAM is sufficient (15GB for 4bit quantization, 11GB for 4bit quantization with additional memory savings).
1. Loading the Flux Model
First, load your Flux model.
[ ]:
import torch
model_id = "black-forest-labs/FLUX.1-schnell"
model_revision = "refs/pr/1"
text_model_id = "openai/clip-vit-large-patch14"
model_data_type = torch.bfloat16
[ ]:
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
from transformers import CLIPTextModel, CLIPTokenizer,T5EncoderModel, T5TokenizerFast
from diffusers import FlowMatchEulerDiscreteScheduler, AutoencoderKL
tokenizer = CLIPTokenizer.from_pretrained(
text_model_id, torch_dtype=model_data_type)
text_encoder = CLIPTextModel.from_pretrained(
text_model_id, torch_dtype=model_data_type)
# 2
tokenizer_2 = T5TokenizerFast.from_pretrained(
model_id, subfolder="tokenizer_2", torch_dtype=model_data_type,
revision=model_revision)
text_encoder_2 = T5EncoderModel.from_pretrained(
model_id, subfolder="text_encoder_2", torch_dtype=model_data_type,
revision=model_revision)
# Transformers
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
model_id, subfolder="scheduler", revision=model_revision)
transformer = FluxTransformer2DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=model_data_type,
revision=model_revision)
# VAE
vae = AutoencoderKL.from_pretrained(
model_id, subfolder="vae", torch_dtype=model_data_type,
revision=model_revision)
2. Initializing the Smash Config
Next, initialize the smash_config.
[ ]:
from pruna import SmashConfig
# Initialize the SmashConfig
smash_config = SmashConfig()
smash_config['quantizers'] = ['quanto']
smash_config['quant_quanto_calibrate'] = False
smash_config['quant_quanto_weight_bits'] = 'qfloat8' # or "qint2", "qint4", "qint8"
3. Smashing the Model
Now, you can smash the model, which will take around 4 minutes. Don’t forget to replace the token by the one provided by PrunaAI.
[ ]:
from pruna import smash
transformer = smash(
model=transformer,
token="<your_token>", # replace <your_token> with your actual token or set to None if you do not have one yet
smash_config=smash_config,
)
text_encoder_2 = smash(
model=text_encoder_2,
token="<your_token>", # replace <your_token> with your actual token or set to None if you do not have one yet
smash_config=smash_config,
)
4. Running the Model
Finally, run the model to generate the image. Note that moving the modules to the GPU can take some time.
[ ]:
pipe = FluxPipeline(
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_encoder_2=text_encoder_2,
tokenizer_2=tokenizer_2,
vae=vae,
transformer=transformer
)
[ ]:
pipe.text_encoder.to('cuda')
pipe.vae.to('cuda')
pipe.transformer.to('cuda')
pipe.text_encoder_2.to('cuda')
[ ]:
# For added memory savings run this block, there is however a trade-off with speed.
vae.enable_tiling()
vae.enable_slicing()
pipe.enable_sequential_cpu_offload()
[ ]:
prompt = "A cat holding a sign that says hello world"
pipe(
prompt,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
Wrap Up
Congratulations! You have successfully smashed a Flux model. Enjoy the smaller memory footprint!