Shrink and accelerate Sana: x2 smaller and x2 faster (Pro)

Open In Colab

This tutorial demonstrates how to use the pruna_pro package to optimize (going from 16 bits to 8 bits!) any transformer-based diffusion model. We will use the Sana_600M_512px model as an example. This tutorial was tested on a A10G GPU, but can be adapted to any CPU/GPU, and other (flux and stable diffusion) models that contain a transformer (as long as the model fits on the device). Note that the algorithm presented in this tutorial is a pruna_pro feature, so you will need your token to run this tutorial.

1. Loading the Sana Diffusion Model

[ ]:
import torch
from diffusers import SanaPipeline

# Define the model ID
model_id = "Efficient-Large-Model/Sana_600M_512px_diffusers"

# Load the pre-trained model
pipe = SanaPipeline.from_pretrained(model_id, variant="fp16", torch_dtype=torch.float16)
pipe = pipe.to("cuda")

2. Initializing the Smash Config

Next, initialize the smash_config (we make use, here, of the torchao quantization algorithm). Only 3 lines of code?… Yeah!!!

[ ]:
from pruna_pro import smash, SmashConfig

# Initialize the SmashConfig
smash_config = SmashConfig()
smash_config['quantizer'] = 'torchao_autoquant'

3. Smashing the Model

Now, smash the model. This can take up to 40 seconds. Don’t forget to replace the token by the one provided by PrunaAI.

[ ]:
smashed_pipe = smash(
    model=pipe,
    token="<your_pruna_token>",
    smash_config=smash_config,
)

4. Running the Model

Run one step of the pipeline to warm up the model (we need to optimize the kernels on your machine ;) ). This process can take up to 10 minutes, but is done once and for all.

[ ]:
prompt = "a smiling cat dancing on a table. Miyazaki style"
smashed_pipe(prompt)

Finally, run the model to generate the image with accelerated inference.

[ ]:
smashed_pipe(prompt).images[0]

Wrap Up

Congratulations! You have successfully smashed a Sana model! You can now use pruna_pro package to optimize any custom diffusion model. The only parts that you should modify are step 1 and step 4 to fit your use case.