Flux generation in a heartbeat, literally
This tutorial demonstrates how to use the pruna
package to optimize your Flux model for faster inference. Any execution times given below are measured on an A100 GPU.
This tutorial smashes the Flux model on GPU for faster inference, which will require an A100 or comparable GPUs to run.
1. Loading the Flux Model
First, load your Flux model.
[ ]:
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
# pipe.enable_model_cpu_offload() # save some VRAM by offloading the model to CPU. Remove this if you have enough GPU memory
pipe.to('cuda')
2. Initializing the Smash Config
Next, initialize the smash_config.
[ ]:
from pruna import SmashConfig
# Initialize the SmashConfig
smash_config = SmashConfig()
smash_config['compilers'] = ['onediff']
3. Smashing the Model
Now, you can smash the model, which can take up to 2 minutes. Don’t forget to replace the token by the one provided by PrunaAI.
[ ]:
from pruna import smash
pipe.transformer = smash(
model=pipe.transformer,
token='<your_token>', # replace <your-token> with your actual token or set to None if you do not have one yet
smash_config=smash_config,
)
4. Running the Model
After the model has been compiled, we run inference for a few iterations as warm-up. This can take up to a minute.
[ ]:
prompt = "A cat holding a sign that says hello world"
# run some warm-up iterations
for _ in range(5):
pipe(
prompt,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
Finally, run the model to generate images with accelerated inference.
[ ]:
pipe(prompt,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
Wrap Up
Congratulations! You have successfully smashed a Flux model. Enjoy the speed-up!