Evaluate quality with the Evaluation Agent
This guide provides an introduction to evaluating models with pruna.
Evaluation helps you understand how compression affects your models across different dimensions - from output quality to resource requirements. This knowledge is essential for making informed decisions about which compression techniques work best for your specific needs.
Haven’t smashed a model yet? Check out the optimize guide to learn how to do that.
Basic Evaluation Workflow
pruna follows a simple workflow for evaluating model optimizations. You can use either the direct parameters approach or the Task-based approach:
Direct Parameters Workflow:
graph LR
User -->|configures| Metrics
User -->|configures| PrunaDataModule
PrunaModel -->|provides predictions| EvaluationAgent
EvaluationAgent -->|evaluates| PrunaModel
EvaluationAgent -->|returns| D["Evaluation Results"]
subgraph E["Evaluation Configuration"]
PrunaDataModule
Metrics
end
Metrics-->|is used by| EvaluationAgent
PrunaDataModule -->|is used by| EvaluationAgent
User -->|creates| EvaluationAgent
style User fill:#bbf,stroke:#333,stroke-width:2px
style EvaluationAgent fill:#bbf,stroke:#333,stroke-width:2px
style PrunaDataModule fill:#bbf,stroke:#333,stroke-width:2px
style PrunaModel fill:#bbf,stroke:#333,stroke-width:2px
style D fill:#bbf,stroke:#333,stroke-width:2px
style Metrics fill:#bbf,stroke:#333,stroke-width:2px
Task-based Workflow:
flowchart LR
User -->|creates| Task
User -->|creates| EvaluationAgent
Task -->|defines| PrunaDataModule
Task -->|defines| Metrics
Task -->|is used by| EvaluationAgent
Metrics -->|includes| B["Base Metrics"]
Metrics -->|includes| C["Stateful Metrics"]
PrunaModel -->|provides predictions| EvaluationAgent
EvaluationAgent -->|evaluates| PrunaModel
EvaluationAgent -->|returns| D["Evaluation Results"]
User -->|configures| EvaluationAgent
subgraph A["Metric Types"]
B
C
end
subgraph E["Task Definition"]
Task
PrunaDataModule
Metrics
A
end
style User fill:#bbf,stroke:#333,stroke-width:2px
style Task fill:#bbf,stroke:#333,stroke-width:2px
style EvaluationAgent fill:#bbf,stroke:#333,stroke-width:2px
style PrunaDataModule fill:#bbf,stroke:#333,stroke-width:2px
style PrunaModel fill:#bbf,stroke:#333,stroke-width:2px
style D fill:#bbf,stroke:#333,stroke-width:2px
style Metrics fill:#bbf,stroke:#333,stroke-width:2px
style B fill:#f9f,stroke:#333,stroke-width:2px
style C fill:#f9f,stroke:#333,stroke-width:2px
The implementation details and initialization options are covered in the sections below.
Evaluation Components
The pruna package provides a variety of evaluation metrics to assess your models. In this section, we’ll introduce the evaluation metrics you can use.
EvaluationAgent Initialization
The EvaluationAgent is the main class for evaluating model performance. It can be initialized using two approaches:
Pass request, datamodule, and device directly to the constructor:
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.data.pruna_datamodule import PrunaDataModule
eval_agent = EvaluationAgent(
request=["cmmd", "ssim"],
datamodule=PrunaDataModule.from_string('LAION256'),
device="cpu"
)
Create a Task object that encapsulates the configuration:
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.task import Task
from pruna.data.pruna_datamodule import PrunaDataModule
task = Task(
request=["cmmd", "ssim"],
datamodule=PrunaDataModule.from_string('LAION256'),
device="cpu"
)
eval_agent = EvaluationAgent(task)
Parameters
request:
str | List[str | BaseMetric | StatefulMetric]- The metrics to evaluatedatamodule:
PrunaDataModule- The data module containing the evaluation datasetdevice:
str | torch.device | None- The device to use for evaluation (defaults to best available)
Task
The Task class provides an alternative way to define evaluation configurations. It encapsulates the evaluation parameters and can be passed directly to the EvaluationAgent constructor.
from pruna.evaluation.task import Task
from pruna.data.pruna_datamodule import PrunaDataModule
task = Task(
request=["cmmd", "ssim"],
datamodule=PrunaDataModule.from_string('LAION256'),
device="cpu"
)
Metrics
Metrics are the core components that calculate specific performance indicators. There are two main types of metrics:
Base Metrics: These metrics compute values directly from inputs without maintaining state across batches.
Stateful Metrics: Metrics that maintain internal state and accumulate information across multiple batches. These are typically used for quality assessment.
The EvaluationAgent accepts Metrics in three ways:
As a plain text request from predefined options (e.g., image_generation_quality)
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.data.pruna_datamodule import PrunaDataModule
eval_agent = EvaluationAgent(
request ="image_generation_quality",
datamodule=PrunaDataModule.from_string('LAION256'),
device="cpu"
)
As a list of metric names (e.g., ["clip_score", "psnr"])
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.data.pruna_datamodule import PrunaDataModule
task = Task(
request=["clip_score", "psnr"],
datamodule=PrunaDataModule.from_string('LAION256'),
device="cpu"
)
As a list of metric instances (e.g., CMMD()), which provides more flexibility in configuring the metrics.
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.evaluation.metrics import CMMD, TorchMetricWrapper
task = Task(
request=[CMMD(call_type="pairwise"), TorchMetricWrapper(metric_name="clip_score")],
datamodule=PrunaDataModule.from_string('LAION256'),
device="cpu"
)
Note
You can find the full list of available metrics in the Metric Overview section.
Metric Call Types
Stateful metrics can generally be evaluated in single-model and pariwise modes. Single-mode allows to compare a mode against ground-truth data, while pairwise mode allows to compare the fidelity of model against another model.
Single-Model mode: Each evaluation produces independent scores for the model being evaluated. IQA metrics are only supported in single-model mode.
Pairwise mode: Metrics compare a subsequent model against the first model evaluated by the agent and produce a single comparison score.
Underneath the hood, the StatefulMetric class uses the call_type parameter to determine the order of the inputs.
Each metric has a default call_type but you can switch the mode of the metric despite your default call_type.
from pruna.evaluation.metrics import CMMD
metric = CMMD(call_type="single") # or [CMMD() since single is the default call type]
from pruna.evaluation.metrics import CMMD
metric = CMMD(call_type="pairwise")
These high-level modes abstract away the underlying input ordering. Internally, each metric uses a more specific call_type to determine the exact order of inputs passed to the metric function.
Internal Call Types
The following table lists the supported internal call types and examples of metrics using them. The following table lists the supported internal call types and examples of metrics using them.
This is what’s happening under the hood when you pass call_type="single" or call_type="pairwise" to a metric.
Call Type |
Description |
Example Metrics |
|---|---|---|
|
Model’s output first, then ground truth |
|
|
Ground truth first, then model’s output |
|
|
Input data first, then ground truth |
|
|
Ground truth first, then input data |
|
|
Pairwise mode to default to |
|
|
Base model’s output first, then subsequent model’s output |
|
|
Subsequent model’s output first, then base model’s output |
|
|
Only the output is used, the metric has an internal dataset |
|
Metric Results
The MetricResult is a class that contains the result of a metric evaluation.
Each metric returns a MetricResult instance, which contains the outcome of a single evaluation.
The MetricResult class stores the metric’s name, any associated parameters, and the computed result value:
from pruna.evaluation.metrics.result import MetricResult
# Example output
MetricResult(
name="clip_score",
params={"param1": "value1", "param2": "value2"},
result=28.0828
)
PrunaDataModule
The PrunaDataModule is a class that defines the data you want to evaluate your model on.
Data modules are a core component of the evaluation framework, providing standardized access to datasets for evaluating model performance before and after optimization.
A more detailed overview of the PrunaDataModule, its datasets and their corresponding collate functions can be found in the Data Module Overview section.
The EvaluationAgent accepts PrunaDataModule in two different ways:
As a plain text request from predefined options (e.g., WikiText)
from transformers import AutoTokenizer
from pruna.data.pruna_datamodule import PrunaDataModule
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token
# Create the data Module
datamodule = PrunaDataModule.from_string(
dataset_name="WikiText",
tokenizer=tokenizer,
collate_fn_args={"max_seq_len": 512},
dataloader_args={"batch_size": 16, "num_workers": 4},
)
As a list of datasets, which provides more flexibility in configuring the data module.
from datasets import load_dataset
from transformers import AutoTokenizer
from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.data.utils import split_train_into_train_val_test
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token
# Load custom datasets
train_ds = load_dataset("SamuelYang/bookcorpus")["train"]
train_ds, val_ds, test_ds = split_train_into_train_val_test(train_ds, seed=42)
# Create the data module
datamodule = PrunaDataModule.from_datasets(
datasets=(train_ds, val_ds, test_ds),
collate_fn="text_generation_collate",
tokenizer=tokenizer,
collate_fn_args={"max_seq_len": 512},
dataloader_args={"batch_size": 16, "num_workers": 4},
)
Tip
You can find the full list of available datasets in the Dataset Overview section.
Lastly, you can limit the number of samples in the dataset by using the PrunaDataModule.limit_samples method.
from transformers import AutoTokenizer
from pruna.data.pruna_datamodule import PrunaDataModule
# Create the data module
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token
datamodule = PrunaDataModule.from_string("WikiText", tokenizer=tokenizer)
# Limit all splits to 100 samples
datamodule.limit_datasets(100)
# Use different limits for each split
datamodule.limit_datasets([50, 10, 20]) # train, val, test
Evaluation Examples
The EvaluationAgent evaluates model performance and can work in both single-model and pairwise modes.
Single-Model mode: each model is evaluated independently, producing metrics that only pertain to that model’s performance. The metrics are computed from the model’s outputs without reference to any other model.
Pairwise mode: metrics compare the outputs of the current model against the first model evaluated by the agent. The first model’s outputs are cached by the EvaluationAgent and used as a reference for subsequent evaluations.
Let’s see how this works in code.
from diffusers import DiffusionPipeline
from pruna import SmashConfig, smash
from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.metrics import CMMD
from pruna.evaluation.task import Task
# Load data and set up smash config
smash_config = SmashConfig()
smash_config["quantizer"] = "hqq_diffusers"
# Load the base model
model_path = "segmind/Segmind-Vega"
pipe = DiffusionPipeline.from_pretrained(model_path)
# Smash the model
smashed_pipe = smash(pipe, smash_config)
# Define the task and the evaluation agent
metrics = [CMMD()]
datamodule = PrunaDataModule.from_string("LAION256")
datamodule.limit_datasets(5)
task = Task(metrics, datamodule=datamodule)
eval_agent = EvaluationAgent(task)
# Optional: tweak model generation parameters for benchmarking
smashed_pipe.inference_handler.model_args.update(
{"num_inference_steps": 1, "guidance_scale": 0.0}
)
# Evaluate base model, all models need to be wrapped in a PrunaModel before passing them to the EvaluationAgent
first_results = eval_agent.evaluate(pipe)
import copy
from diffusers import DiffusionPipeline
from pruna import SmashConfig, smash
from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.metrics import CMMD
from pruna.evaluation.task import Task
# Load data and set up smash config
smash_config = SmashConfig()
smash_config["quantizer"] = "hqq_diffusers"
# Load the base model
model_path = "segmind/Segmind-Vega"
pipe = DiffusionPipeline.from_pretrained(model_path)
# Smash the model
copy_pipe = copy.deepcopy(pipe)
smashed_pipe = smash(copy_pipe, smash_config)
# Define the task and the evaluation agent
metrics = [CMMD(call_type="pairwise")]
datamodule = PrunaDataModule.from_string("LAION256")
datamodule.limit_datasets(5)
task = Task(metrics, datamodule=datamodule)
eval_agent = EvaluationAgent(task)
# wrap the model in a PrunaModel to use the EvaluationAgent
wrapped_pipe = PrunaModel(pipe, None)
# Optional: tweak model generation parameters for benchmarking
inference_arguments = {"num_inference_steps": 1, "guidance_scale": 0.0}
wrapped_pipe.inference_handler.model_args.update(inference_arguments)
# Evaluate base model first (cached for comparison)
first_results = eval_agent.evaluate(pipe)
# Evaluate smashed model (compared against base model)
smashed_results = eval_agent.evaluate(smashed_pipe)
print(smashed_results)
EvaluationAgent Initialization Options
You can choose between the two initialization approaches shown above based on your preference and project requirements. Both approaches provide identical functionality and can be used interchangeably.
Best Practices
Start with a small dataset
When first setting up evaluation, limit the dataset size with datamodule.limit_datasets(n) to make debugging faster.
Use pairwise metrics for comparison
When comparing an optimized model against the baseline, use pairwise metrics to get direct comparison scores.
Choose your initialization style
Both direct parameters and Task-based initialization are valid approaches. Choose the one that best fits your project’s coding patterns and requirements.