Smashing Stable Diffusion Models

This tutorial demonstrates how to use the pruna package to optimize any custom stable diffusion model. We will use the stable diffusion 1.5 model as an example.

Loading the Stable Diffusion Model

First, load your stable diffusion model.

from diffusers import StableDiffusionPipeline
import torch

# Define the model ID
model_id = "CompVis/stable-diffusion-v1-4"

# Load the pre-trained model
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)

# Move the model to GPU
pipe = pipe.to("cuda")

# Define the prompt
prompt = "a photo of an astronaut riding a horse on mars"

Initializing the Smash Config

Next, initialize the smash_config.

from pruna_engine.SmashConfig import SmashConfig

# Initialize the SmashConfig
smash_config = SmashConfig()
smash_config['task'] = 'text_image_generation'
smash_config['compilers'] = ['diffusers2']

Smashing the Model

Now, smash the model.

from pruna.smash import smash

# Smash the model
smashed_model = smash(
    model=pipe,
    api_key='<your-api-key>',  # replace <your-api-key> with your actual API key
    smash_config=smash_config,
)

Don’t forget to replace the api_key by the one provided by PrunaAI.

Running the Model

Finally, run the model to generate the image.

# Display the result
smashed_model(prompt).images[0].show()

Wrap Up

Congratulations! You have successfully smashed a stable diffusion model. You can now use the pruna package to optimize any custom stable diffusion model. The only parts that you should modify are step 1 and step 4 to fit your use case.