Smashing Stable Diffusion Models
This tutorial demonstrates how to use the pruna package to optimize any custom stable diffusion model. We will use the stable diffusion 1.5 model as an example.
Loading the Stable Diffusion Model
First, load your stable diffusion model.
from diffusers import StableDiffusionPipeline
import torch
# Define the model ID
model_id = "CompVis/stable-diffusion-v1-4"
# Load the pre-trained model
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
# Move the model to GPU
pipe = pipe.to("cuda")
# Define the prompt
prompt = "a photo of an astronaut riding a horse on mars"
Initializing the Smash Config
Next, initialize the smash_config.
from pruna_engine.SmashConfig import SmashConfig
# Initialize the SmashConfig
smash_config = SmashConfig()
smash_config['task'] = 'text_image_generation'
smash_config['compilers'] = ['diffusers2']
Smashing the Model
Now, smash the model.
from pruna.smash import smash
# Smash the model
smashed_model = smash(
model=pipe,
api_key='<your-api-key>', # replace <your-api-key> with your actual API key
smash_config=smash_config,
)
Don’t forget to replace the api_key by the one provided by PrunaAI.
Running the Model
Finally, run the model to generate the image.
# Display the result
smashed_model(prompt).images[0].show()
Wrap Up
Congratulations! You have successfully smashed a stable diffusion model. You can now use the pruna package to optimize any custom stable diffusion model. The only parts that you should modify are step 1 and step 4 to fit your use case.