Smash your first model

This guide provides a quick introduction to optimizing AI models with pruna.

You’ll learn how to use Pruna’s core functionality to make your models faster, smaller, cheaper, and greener. For installation instructions, see Installation.

Basic Optimization Workflow

pruna follows a simple workflow for optimizing models:

        graph LR
   A[Load Model] --> B[Define SmashConfig]
   B --> C[Smash Model]
   C --> D[Evaluate Model]
   D --> E[Run Inference]
   style A fill:#bbf,stroke:#333,stroke-width:2px
   style B fill:#bbf,stroke:#333,stroke-width:2px
   style C fill:#bbf,stroke:#333,stroke-width:2px
   style D fill:#bbf,stroke:#333,stroke-width:2px
   style E fill:#bbf,stroke:#333,stroke-width:2px
    

Let’s see what that looks like in code.

from diffusers import DiffusionPipeline

from pruna import SmashConfig, smash
from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.task import Task

# Load the model
model = DiffusionPipeline.from_pretrained("segmind/Segmind-Vega")

# Create and configure SmashConfig
smash_config = SmashConfig()
smash_config["quantizer"] = "hqq_diffusers"

# Smash the model
optimized_model = smash(model=model, smash_config=smash_config)

# Evaluate the model
metrics = ["clip_score", "psnr"]
datamodule = PrunaDataModule.from_string("LAION256")
datamodule.limit_datasets(10) # You can limit the number of samples.
task = Task(metrics, datamodule=datamodule)
eval_agent = EvaluationAgent(task)
eval_agent.evaluate(optimized_model)

# Run inference
optimized_model.set_progress_bar_config(disable=True)
optimized_model.to("cuda")
optimized_model("A serene landscape with mountains").images[0].save("output.png")

Step-by-Step Optimisation Workflow

Step 1: Load a pretrained model

First, load any model using its original library, like transformers or diffusers:

from diffusers import DiffusionPipeline

base_model = DiffusionPipeline.from_pretrained("segmind/Segmind-Vega")

Step 2: Define optimizations with a SmashConfig

After loading the model, we can define a SmashConfig to customize the optimizations we want to apply. This SmashConfig is a dictionary-like object that configures which optimizations to apply to your model. You can specify multiple optimization algorithms from different categories like batching, caching and quantization.

For now, let’s just use a quantizer to accelerate the model during inference.

from pruna import SmashConfig

smash_config = SmashConfig()
smash_config["quantizer"] = "hqq_diffusers"  # Accelerate the model with caching

Pruna support a wide range of algorithms for specific optimizations, all with different trade-offs. To understand how to configure the right one for your scenario, see Define a SmashConfig.

Step 3: Apply optimizations with smash

The smash() function is the core of Pruna. It takes your model and SmashConfig, applies the specified optimizations. Let’s use the smash() function to apply the configured optimizations:

from pruna import SmashConfig, smash

from diffusers import DiffusionPipeline

# Load the model
base_model = DiffusionPipeline.from_pretrained("segmind/Segmind-Vega")

# Create and configure SmashConfig
smash_config = SmashConfig()
smash_config["quantizer"] = "hqq_diffusers"

# Smash the model
optimized_model = smash(model=base_model, smash_config=smash_config)

# Save the optimized model
optimized_model.save_to_hub("PrunaAI/Segmind-Vega-smashed")

The smash() function returns a PrunaModel - a wrapper that provides a standardized interface for the optimized model. So, we can still use the model as we would use the original one.

Step 4: Evaluate the optimized model with the EvaluationAgent

To evaluate the optimized model, we can use the same interface as the original model.

from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.engine.pruna_model import PrunaModel
from pruna.evaluation.task import Task

# Load the optimized model
optimized_model = PrunaModel.from_hub("PrunaAI/Segmind-Vega-smashed")

# Define metrics
metrics = ['clip_score', 'psnr']

# Define task
task = Task(metrics, datamodule=PrunaDataModule.from_string('LAION256'))

# Evaluate the model
eval_agent = EvaluationAgent(task)
results = eval_agent.evaluate(optimized_model)
for result in results:
    print(result)

To understand how to run more complex evaluation workflows, see Evaluate a model.

Step 5: Run inference with the optimized model

To run inference with the optimized model, we can use the same interface as the original model.

from pruna.engine.pruna_model import PrunaModel

# Load the optimized model
optimized_model = PrunaModel.from_hub("PrunaAI/Segmind-Vega-smashed")

optimized_model.set_progress_bar_config(disable=True)

prompt = "A serene landscape with mountains"
optimized_model(prompt).images[0].save("output.png")

Example use cases

Let’s look at some specific examples for different model types.

Example 1: Diffusion Model Optimization

from diffusers import DiffusionPipeline

from pruna import SmashConfig, smash

# Load the model
model = DiffusionPipeline.from_pretrained("segmind/Segmind-Vega")

# Create and configure SmashConfig
smash_config = SmashConfig()
smash_config["quantizer"] = "hqq_diffusers"

# Optimize the model
optimized_model = smash(model=model, smash_config=smash_config)

# Generate an image
prompt = "A serene landscape with mountains"
optimized_model(prompt).images[0].save("output.png")

Example 2: Large Language Model Optimization

from transformers import pipeline

from pruna import SmashConfig, smash

# Load the model
model_id = "NousResearch/Llama-3.2-1B"
pipe = pipeline("text-generation", model=model_id)

# Create and configure SmashConfig
smash_config = SmashConfig()
smash_config["compiler"] = "torch_compile"
smash_config["quantizer"] = "hqq"

# Optimize the model
optimized_model = smash(model=pipe.model, smash_config=smash_config)

# Use the model for generation
pipe("The best way to learn programming is", max_new_tokens=100)

Example 3: Speech Recognition Optimization

import requests
import torch
from transformers import AutoModelForSpeechSeq2Seq

from pruna import SmashConfig, smash

# Load the model
model_id = "openai/whisper-tiny"
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True).to("cuda")

# Create and configure SmashConfig
smash_config = SmashConfig()
smash_config.add_processor(model_id)  # Required for Whisper
smash_config.add_tokenizer(model_id)
smash_config["compiler"] = "c_whisper"
smash_config["batcher"] = "whisper_s2t"

# Optimize the model
optimized_model = smash(model=model, smash_config=smash_config)

# Download and transcribe audio sample
audio_url = "https://huggingface.co/datasets/reach-vb/random-audios/resolve/main/sam_altman_lex_podcast_367.flac"
audio_file = "sam_altman_lex_podcast_367.flac"

# Download audio file
response = requests.get(audio_url)
response.raise_for_status()  # Raise exception for bad status codes

# Save audio file
with open(audio_file, "wb") as f:
    f.write(response.content)

# Transcribe audio
transcription = optimized_model(audio_file)