Recovering Quality after Quantizing Models to 4 Bits (Pro)

Open In Colab

This tutorial demonstrates how to use the pruna_pro package to use our experimental “recovery” feature to recover the model quality after quantization. This option allows you to push quantization or other compression techniques to the limit without compromising quality.

We will use PERP on the Sana model as an example, but you can also use Stable Diffusion and Flux models depending on your device. Any execution times given below are measure on an A10G GPU.

Note that recovery is a pruna_pro feature, so you will need your token to run this tutorial.

1. Loading the Sana Model

First, load the Sana model, and generate an image for quality reference.

[ ]:
import torch
from diffusers import SanaPipeline

pipe = SanaPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
    torch_dtype=torch.float16,
).to("cuda")

We generate an image to have a reference for quality.

[ ]:
prompt = "A crow walking along a river near a foggy cliff, with cute yellow ducklings following it in a line, at sunset."
pipe(prompt).images[0]

2. Initializing the SmashConfig

Next, initialize the SmashConfig. We’ll use bitsandbytes’ quantization to 4-bits, and recover quality by finetuning with PERP on a text-to-image dataset.

[ ]:
from pruna_pro import SmashConfig

smash_config = SmashConfig()
# Attach a text-to-image dataset, used for recovery
smash_config.add_data("COCO")

# Quantizes the model to 4-bits
smash_config["quantizer"] = "diffusers_int8"
smash_config["diffusers_int8_weight_bits"] = 4

# Recover, allowing you to push quantization to lower bit rates without compromising quality
smash_config["recoverer"] = "text_to_image_perp"
smash_config["text_to_image_perp_num_epochs"] = 0.001  # use only 0.1% of the dataset
# you can increase or reduce 'batch_size' depending on your GPU, or use 'gradient_accumulation_steps' with it
smash_config["text_to_image_perp_batch_size"] = 4

3. Smashing the Model

Now, smash the model. This takes about 2 minutes on an A10G GPU, but it depends on how many samples are used for recovery.

[ ]:
from pruna_pro import smash

smashed_model = smash(
    model=pipe,
    token='<your_pruna_token>',  # replace <your_token> with your actual token
    smash_config=smash_config,
)

4. Running the Model

Finally, we run the model which has been quantized and recovered. It has a lower memory footprint than the original because of the quantization.

[ ]:
smashed_model(prompt).images[0]

Wrap up

Congratulations! You have successfully recovered quality on your compressed Sana model. You can now use the pruna_pro package to its limit by using aggressive compression alongside recovery. The only parts you should modify are steps 1 and 4 to fit your use case. You can also use recovery with other quantization algorithms such as the one in this tutorial: Run your Flux model without an A100.