Smash your first model

This guide provides a quick introduction to optimizing AI models with pruna.

You’ll learn how to use Pruna’s core functionality to make your models faster, smaller, cheaper, and greener. For installation instructions, see Installation.

Basic Optimization Workflow

pruna follows a simple workflow for optimizing models:

        graph LR
   A[Load Model] --> B[Define SmashConfig]
   B --> C[Smash Model]
   C --> D[Evaluate Model]
   D --> E[Run Inference]
   style A fill:#bbf,stroke:#333,stroke-width:2px
   style B fill:#bbf,stroke:#333,stroke-width:2px
   style C fill:#bbf,stroke:#333,stroke-width:2px
   style D fill:#bbf,stroke:#333,stroke-width:2px
   style E fill:#bbf,stroke:#333,stroke-width:2px
    

Let’s see what that looks like in code.

from pruna import smash, SmashConfig
from diffusers import StableDiffusionPipeline
from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.task import Task

# Load the model
model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")

# Create and configure SmashConfig
smash_config = SmashConfig()
smash_config["cacher"] = "deepcache"

# Smash the model
optimized_model = smash(model=model, smash_config=smash_config)

# Evaluate the model
metrics = ['clip_score', 'psnr']
task = Task(metrics, datamodule=PrunaDataModule.from_string('LAION256'))
eval_agent = EvaluationAgent(task)
eval_agent.evaluate(optimized_model)

# Run inference
optimized_model.set_progress_bar_config(disable=True)
optimized_model.inference_handler.model_args.update(
    {"num_inference_steps": 1, "guidance_scale": 0.0}
)
optimized_model("A serene landscape with mountains").images[0]

Step-by-Step Optimisation Workflow

Step 1: Load a pretrained model

First, load any model using its original library, like transformers or diffusers:

from diffusers import StableDiffusionPipeline

base_model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")

Step 2: Define optimizations with a SmashConfig

After loading the model, we can define a SmashConfig to customize the optimizations we want to apply. This SmashConfig is a dictionary-like object that configures which optimizations to apply to your model. You can specify multiple optimization algorithms from different categories like batching, caching and quantization.

For now, let’s just use a cacher to accelerate the model during inference.

from pruna import SmashConfig

smash_config = SmashConfig()
smash_config["cacher"] = "deepcache"  # Accelerate the model with caching

Pruna support a wide range of algorithms for specific optimizations, all with different trade-offs. To understand how to configure the right one for your scenario, see Define a SmashConfig.

Step 3: Apply optimizations with smash

The smash() function is the core of Pruna. It takes your model and SmashConfig, applies the specified optimizations. Let’s use the smash() function to apply the configured optimizations:

from pruna import smash

optimized_model = smash(model=base_model, smash_config=smash_config)

The smash() function returns a PrunaModel - a wrapper that provides a standardized interface for the optimized model. So, we can still use the model as we would use the original one.

Step 4: Evaluate the optimized model with the EvaluationAgent

To evaluate the optimized model, we can use the same interface as the original model.

from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.evaluation.evaluation_agent import EvaluationAgent

metrics = ['clip_score', 'psnr']
task = Task(metrics, datamodule=PrunaDataModule.from_string('LAION256'))
eval_agent = EvaluationAgent(task)
eval_agent.evaluate(optimized_model)

To understand how to run more complex evaluation workflows, see Evaluate a model.

Step 5: Run inference with the optimized model

To run inference with the optimized model, we can use the same interface as the original model.

optimized_model.set_progress_bar_config(disable=True)
optimized_model.inference_handler.model_args.update(
    {"num_inference_steps": 1, "guidance_scale": 0.0}
)
optimized_model("A serene landscape with mountains").images[0]

Example use cases

Let’s look at some specific examples for different model types.

Example 1: Diffusion Model Optimization

from diffusers import StableDiffusionPipeline
from pruna import smash, SmashConfig

# Load the model
model = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")

# Create and configure SmashConfig
smash_config = SmashConfig()
smash_config["cacher"] = "deepcache"
smash_config["compiler"] = "stable_fast"

# Optimize the model
optimized_model = smash(model=model, smash_config=smash_config)

# Generate an image
optimized_model("A serene landscape with mountains").images[0]

Example 2: Large Language Model Optimization

from transformers import AutoModelForCausalLM
from pruna import smash, SmashConfig

# Load the model
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")

# Create and configure SmashConfig
smash_config = SmashConfig()
smash_config["quantizer"] = "gptq"  # Apply GPTQ quantization

# Optimize the model
optimized_model = smash(model=model, smash_config=smash_config)

# Use the model for generation
input_text = "The best way to learn programming is"
optimized_model(input_text)

Example 3: Speech Recognition Optimization

from transformers import AutoModelForSpeechSeq2Seq
from pruna import smash, SmashConfig
import torch

# Load the model
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True
).to("cuda")

# Create and configure SmashConfig
smash_config = SmashConfig()
smash_config.add_processor(model_id)  # Required for Whisper
smash_config["compiler"] = "c_whisper"
smash_config["batcher"] = "whisper_s2t"

# Optimize the model
optimized_model = smash(model=model, smash_config=smash_config)

# Use the model for transcription
optimized_model("audio_file.wav")