Evaluate quality with the Evaluation Agent

This guide provides an introduction to evaluating models with pruna.

Evaluation helps you understand how compression affects your models across different dimensions - from output quality to resource requirements. This knowledge is essential for making informed decisions about which compression techniques work best for your specific needs.

Haven’t smashed a model yet? Check out the optimize guide to learn how to do that.

Basic Evaluation Workflow

pruna follows a simple workflow for evaluating model optimizations:

        graph LR
 User -->|creates| Task
 User -->|creates| EvaluationAgent
 Task -->|defines| PrunaDataModule
 Task -->|defines| Metrics
 Task -->|is used by| EvaluationAgent
 Metrics -->|includes| B["Base Metrics"]
 Metrics -->|includes| C["Stateless Metric"]
 PrunaModel -->|provides predictions| EvaluationAgent
 EvaluationAgent -->|evaluates| PrunaModel
 EvaluationAgent -->|returns| D["Evaluation Results"]
 User -->|configures| EvaluationAgent

 subgraph A["Metric Types"]
     B
     C
 end

 subgraph E["Task Definition"]
     Task
     PrunaDataModule
     Metrics
     A
 end

 style User fill:#bbf,stroke:#333,stroke-width:2px
 style Task fill:#bbf,stroke:#333,stroke-width:2px
 style EvaluationAgent fill:#bbf,stroke:#333,stroke-width:2px
 style PrunaDataModule fill:#bbf,stroke:#333,stroke-width:2px
 style PrunaModel fill:#bbf,stroke:#333,stroke-width:2px
 style D fill:#bbf,stroke:#333,stroke-width:2px
 style Metrics fill:#bbf,stroke:#333,stroke-width:2px
 style B fill:#f9f,stroke:#333,stroke-width:2px
 style C fill:#f9f,stroke:#333,stroke-width:2px
    

Let’s see what that looks like in code.

from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.task import Task
from pruna.data.pruna_datamodule import PrunaDataModule

# Load the optimized model
optimized_model = PrunaModel.from_pretrained("PrunaAI/opt-125m-smashed")

# Create and configure Task
task = Task(
    requests=["accuracy"],
    datamodule=PrunaDataModule.from_string('WikiText'),
    device="cpu"
)

# Create and configure EvaluationAgent
eval_agent = EvaluationAgent(task)

# Evaluate the model
eval_agent.evaluate(optimized_model)

Evaluation Components

The pruna package provides a variety of evaluation metrics to assess your models. In this section, we’ll introduce the evaluation metrics you can use.

Task

The Task is a class that defines the task you want to evaluate your model on and it requires a set of Metrics and a PrunaDataModule to perform the evaluation.

Metrics

Metrics are the core components that calculate specific performance indicators. There are two main types of metrics:

  • Base Metrics: These metrics compute values directly from inputs without maintaining state across batches.

  • Stateful Metrics: Metrics that maintain internal state and accumulate information across multiple batches. These are typically used for quality assessment.

The Task accepts Metrics in three ways:

As a plain text request from predefined options (e.g., image_generation_quality)

from pruna.evaluation.task import Task
from pruna.data.pruna_datamodule import PrunaDataModule

task = Task(
    request="image_generation_quality",
    datamodule=PrunaDataModule.from_string('LAION256'),
    device="cpu"
)

Note

You can find the full list of available metrics in the Metric Overview section.

Metric Call Types

pruna metrics can operate in both single-model and pairwise modes.

  • Single-Model mode: Each evaluation produces independent scores for the model being evaluated.

  • Pairwise mode: Metrics compare a subsequent model against the first model evaluated by the agent and produce a single comparison score.

Underneath the hood, the StatefulMetric class uses the call_type parameter to determine the order of the inputs.

The following table shows the different call types supported by pruna metrics and the metrics that support each call type.

Call Type

Description

Example Metrics

y_gt

Model’s output first, then ground truth

fid, cmmd, accuracy, recall, precision

gt_y

Ground truth first, then model’s output

fid, cmmd, accuracy, recall, precision

x_gt

Input data first, then ground truth

clip_score

gt_x

Ground truth first, then input data

clip_score

pairwise

Pairwise mode to default to pairwise_y_gt or pairwise_gt_y

psnr, ssim, lpips, cmmd

pairwise_y_gt

Base model’s output first, then subsequent model’s output

psnr, ssim, lpips, cmmd

pairwise_gt_y

Subsequent model’s output first, then base model’s output

psnr, ssim, lpips, cmmd

Each metric has a default call_type but you can switch the mode of the metric despite your default call_type.

from pruna.evaluation.metrics import CMMD

metric = CMMD() # or ["cmmd"]

PrunaDataModule

The PrunaDataModule is a class that defines the data you want to evaluate your model on. Data modules are a core component of the evaluation framework, providing standardized access to datasets for evaluating model performance before and after optimization.

A more detailed overview of the PrunaDataModule, its datasets and their corresponding collate functions can be found in the Data Module Overview section.

The Task accepts PrunaDataModule in two different ways:

As a plain text request from predefined options (e.g., WikiText)

from pruna.data.pruna_datamodule import PrunaDataModule
from transformers import AutoTokenizer

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")

# Create the data Module
datamodule = PrunaDataModule.from_string(
    dataset_name='WikiText',
    tokenizer=tokenizer,
    collate_fn="text_generation_collate",
    collate_fn_args={"max_seq_len": 512},
    dataloader_args={"batch_size": 16, "num_workers": 4}
)

Tip

You can find the full list of available datasets in the Dataset Overview section.

Lastly, you can limit the number of samples in the dataset by using the PrunaDataModule.limit_samples method.

from pruna.data.pruna_datamodule import PrunaDataModule

# Create the data module
datamodule = PrunaDataModule.from_string('WikiText')

# Limit all splits to 100 samples
datamodule.limit_datasets(100)

# Use different limits for each split
datamodule.limit_datasets([500, 100, 200])  # train, val, test

EvaluationAgent

The EvaluationAgent is a class that evaluates the performance of your model.

To evaluate a model with the EvaluationAgent, you need to create a Task with Metrics and a PrunaDataModule. Then, initialize an EvaluationAgent with that task and call the evaluate() method with your model.

We can then chose to evaluate a single model or a pair of models.

  • Single-Model mode: each model is evaluated independently, producing metrics that only pertain to that model’s performance. The metrics are computed from the model’s outputs without reference to any other model.

  • Pairwise mode: metrics compare the outputs of the current model against the first model evaluated by the agent. The first model’s outputs are cached by the EvaluationAgent and used as a reference for subsequent evaluations.

Let’s see how this works in code.

import copy

from diffusers import StableDiffusionPipeline

from pruna import smash, SmashConfig
from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.task import Task
from pruna.evaluation.metrics import CMMD
# Load data and set up smash config
smash_config = SmashConfig()
smash_config['cacher'] = 'deepcache'

# Load the base model
model_path = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_path)

# Smash the model
copy_pipe = copy.deepcopy(pipe)
smashed_pipe = smash(copy_pipe, smash_config)

# Define the task and the evaluation agent
metrics = [CMMD()]
task = Task(metrics, datamodule=PrunaDataModule.from_string('LAION256'))
eval_agent = EvaluationAgent(task)

# Evaluate base model, all models need to be wrapped in a PrunaModel before passing them to the EvaluationAgent
first_results = eval_agent.evaluate(pipe)
print(first_results)

Best Practices

Start with a small dataset

When first setting up evaluation, limit the dataset size with datamodule.limit_datasets(n) to make debugging faster.

Use pairwise metrics for comparison

When comparing an optimized model against the baseline, use pairwise metrics to get direct comparison scores.