Evaluate quality with the Evaluation Agent
This guide provides an introduction to evaluating models with pruna.
Evaluation helps you understand how compression affects your models across different dimensions - from output quality to resource requirements. This knowledge is essential for making informed decisions about which compression techniques work best for your specific needs.
Haven’t smashed a model yet? Check out the optimize guide to learn how to do that.
Basic Evaluation Workflow
pruna follows a simple workflow for evaluating model optimizations. You can use either the direct parameters approach or the Task-based approach:
Direct Parameters Workflow:
graph LR User -->|configures| Metrics User -->|configures| PrunaDataModule PrunaModel -->|provides predictions| EvaluationAgent EvaluationAgent -->|evaluates| PrunaModel EvaluationAgent -->|returns| D["Evaluation Results"] subgraph E["Evaluation Configuration"] PrunaDataModule Metrics end Metrics-->|is used by| EvaluationAgent PrunaDataModule -->|is used by| EvaluationAgent User -->|creates| EvaluationAgent style User fill:#bbf,stroke:#333,stroke-width:2px style EvaluationAgent fill:#bbf,stroke:#333,stroke-width:2px style PrunaDataModule fill:#bbf,stroke:#333,stroke-width:2px style PrunaModel fill:#bbf,stroke:#333,stroke-width:2px style D fill:#bbf,stroke:#333,stroke-width:2px style Metrics fill:#bbf,stroke:#333,stroke-width:2px
Task-based Workflow:
flowchart LR User -->|creates| Task User -->|creates| EvaluationAgent Task -->|defines| PrunaDataModule Task -->|defines| Metrics Task -->|is used by| EvaluationAgent Metrics -->|includes| B["Base Metrics"] Metrics -->|includes| C["Stateful Metrics"] PrunaModel -->|provides predictions| EvaluationAgent EvaluationAgent -->|evaluates| PrunaModel EvaluationAgent -->|returns| D["Evaluation Results"] User -->|configures| EvaluationAgent subgraph A["Metric Types"] B C end subgraph E["Task Definition"] Task PrunaDataModule Metrics A end style User fill:#bbf,stroke:#333,stroke-width:2px style Task fill:#bbf,stroke:#333,stroke-width:2px style EvaluationAgent fill:#bbf,stroke:#333,stroke-width:2px style PrunaDataModule fill:#bbf,stroke:#333,stroke-width:2px style PrunaModel fill:#bbf,stroke:#333,stroke-width:2px style D fill:#bbf,stroke:#333,stroke-width:2px style Metrics fill:#bbf,stroke:#333,stroke-width:2px style B fill:#f9f,stroke:#333,stroke-width:2px style C fill:#f9f,stroke:#333,stroke-width:2px
The implementation details and initialization options are covered in the sections below.
Evaluation Components
The pruna package provides a variety of evaluation metrics to assess your models. In this section, we’ll introduce the evaluation metrics you can use.
EvaluationAgent Initialization
The EvaluationAgent
is the main class for evaluating model performance. It can be initialized using two approaches:
Pass request, datamodule, and device directly to the constructor:
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.data.pruna_datamodule import PrunaDataModule
eval_agent = EvaluationAgent(
request=["cmmd", "ssim"],
datamodule=PrunaDataModule.from_string('LAION256'),
device="cpu"
)
Create a Task object that encapsulates the configuration:
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.task import Task
from pruna.data.pruna_datamodule import PrunaDataModule
task = Task(
request=["cmmd", "ssim"],
datamodule=PrunaDataModule.from_string('LAION256'),
device="cpu"
)
eval_agent = EvaluationAgent(task)
Parameters
request:
str | List[str | BaseMetric | StatefulMetric]
- The metrics to evaluatedatamodule:
PrunaDataModule
- The data module containing the evaluation datasetdevice:
str | torch.device | None
- The device to use for evaluation (defaults to best available)
Task
The Task
class provides an alternative way to define evaluation configurations. It encapsulates the evaluation parameters and can be passed directly to the EvaluationAgent
constructor.
from pruna.evaluation.task import Task
from pruna.data.pruna_datamodule import PrunaDataModule
task = Task(
request=["cmmd", "ssim"],
datamodule=PrunaDataModule.from_string('LAION256'),
device="cpu"
)
Metrics
Metrics are the core components that calculate specific performance indicators. There are two main types of metrics:
Base Metrics: These metrics compute values directly from inputs without maintaining state across batches.
Stateful Metrics: Metrics that maintain internal state and accumulate information across multiple batches. These are typically used for quality assessment.
The EvaluationAgent
accepts Metrics
in three ways:
As a plain text request from predefined options (e.g., image_generation_quality
)
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.data.pruna_datamodule import PrunaDataModule
eval_agent = EvaluationAgent(
request ="image_generation_quality",
datamodule=PrunaDataModule.from_string('LAION256'),
device="cpu"
)
As a list of metric names (e.g., ["clip_score"
, "psnr"
])
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.data.pruna_datamodule import PrunaDataModule
task = Task(
request=["clip_score", "psnr"],
datamodule=PrunaDataModule.from_string('LAION256'),
device="cpu"
)
As a list of metric instances (e.g., CMMD()
), which provides more flexibility in configuring the metrics.
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.evaluation.metrics import CMMD, TorchMetricWrapper
task = Task(
request=[CMMD(call_type="pairwise"), TorchMetricWrapper(metric_name="clip_score")],
datamodule=PrunaDataModule.from_string('LAION256'),
device="cpu"
)
Note
You can find the full list of available metrics in the Metric Overview section.
Metric Call Types
pruna metrics can operate in both single-model and pairwise modes.
Single-Model mode: Each evaluation produces independent scores for the model being evaluated.
Pairwise mode: Metrics compare a subsequent model against the first model evaluated by the agent and produce a single comparison score.
Underneath the hood, the StatefulMetric
class uses the call_type
parameter to determine the order of the inputs.
Each metric has a default call_type
but you can switch the mode of the metric despite your default call_type
.
from pruna.evaluation.metrics import CMMD
metric = CMMD(call_type="single") # or [CMMD() since single is the default call type]
from pruna.evaluation.metrics import CMMD
metric = CMMD(call_type="pairwise")
These high-level modes abstract away the underlying input ordering. Internally, each metric uses a more specific call_type to determine the exact order of inputs passed to the metric function.
Internal Call Types
The following table lists the supported internal call types and examples of metrics using them. The following table lists the supported internal call types and examples of metrics using them.
This is what’s happening under the hood when you pass call_type="single"
or call_type="pairwise"
to a metric.
Call Type |
Description |
Example Metrics |
---|---|---|
|
Model’s output first, then ground truth |
|
|
Ground truth first, then model’s output |
|
|
Input data first, then ground truth |
|
|
Ground truth first, then input data |
|
|
Pairwise mode to default to |
|
|
Base model’s output first, then subsequent model’s output |
|
|
Subsequent model’s output first, then base model’s output |
|
Metric Results
The MetricResult
is a class that contains the result of a metric evaluation.
Each metric returns a MetricResult
instance, which contains the outcome of a single evaluation.
The MetricResult
class stores the metric’s name, any associated parameters, and the computed result value:
from pruna.evaluation.metrics.result import MetricResult
# Example output
MetricResult(
name="clip_score",
params={"param1": "value1", "param2": "value2"},
result=28.0828
)
PrunaDataModule
The PrunaDataModule
is a class that defines the data you want to evaluate your model on.
Data modules are a core component of the evaluation framework, providing standardized access to datasets for evaluating model performance before and after optimization.
A more detailed overview of the PrunaDataModule
, its datasets and their corresponding collate functions can be found in the Data Module Overview section.
The EvaluationAgent
accepts PrunaDataModule
in two different ways:
As a plain text request from predefined options (e.g., WikiText
)
from transformers import AutoTokenizer
from pruna.data.pruna_datamodule import PrunaDataModule
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token
# Create the data Module
datamodule = PrunaDataModule.from_string(
dataset_name="WikiText",
tokenizer=tokenizer,
collate_fn_args={"max_seq_len": 512},
dataloader_args={"batch_size": 16, "num_workers": 4},
)
As a list of datasets, which provides more flexibility in configuring the data module.
from datasets import load_dataset
from transformers import AutoTokenizer
from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.data.utils import split_train_into_train_val_test
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token
# Load custom datasets
train_ds = load_dataset("SamuelYang/bookcorpus")["train"]
train_ds, val_ds, test_ds = split_train_into_train_val_test(train_ds, seed=42)
# Create the data module
datamodule = PrunaDataModule.from_datasets(
datasets=(train_ds, val_ds, test_ds),
collate_fn="text_generation_collate",
tokenizer=tokenizer,
collate_fn_args={"max_seq_len": 512},
dataloader_args={"batch_size": 16, "num_workers": 4},
)
Tip
You can find the full list of available datasets in the Dataset Overview section.
Lastly, you can limit the number of samples in the dataset by using the PrunaDataModule.limit_samples
method.
from transformers import AutoTokenizer
from pruna.data.pruna_datamodule import PrunaDataModule
# Create the data module
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token
datamodule = PrunaDataModule.from_string("WikiText", tokenizer=tokenizer)
# Limit all splits to 100 samples
datamodule.limit_datasets(100)
# Use different limits for each split
datamodule.limit_datasets([50, 10, 20]) # train, val, test
Evaluation Examples
The EvaluationAgent
evaluates model performance and can work in both single-model and pairwise modes.
Single-Model mode: each model is evaluated independently, producing metrics that only pertain to that model’s performance. The metrics are computed from the model’s outputs without reference to any other model.
Pairwise mode: metrics compare the outputs of the current model against the first model evaluated by the agent. The first model’s outputs are cached by the EvaluationAgent and used as a reference for subsequent evaluations.
Let’s see how this works in code.
from diffusers import DiffusionPipeline
from pruna import SmashConfig, smash
from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.metrics import CMMD
from pruna.evaluation.task import Task
# Load data and set up smash config
smash_config = SmashConfig()
smash_config["quantizer"] = "hqq_diffusers"
# Load the base model
model_path = "segmind/Segmind-Vega"
pipe = DiffusionPipeline.from_pretrained(model_path)
# Smash the model
smashed_pipe = smash(pipe, smash_config)
# Define the task and the evaluation agent
metrics = [CMMD()]
datamodule = PrunaDataModule.from_string("LAION256")
datamodule.limit_datasets(5)
task = Task(metrics, datamodule=datamodule)
eval_agent = EvaluationAgent(task)
# Optional: tweak model generation parameters for benchmarking
smashed_pipe.inference_handler.model_args.update(
{"num_inference_steps": 1, "guidance_scale": 0.0}
)
# Evaluate base model, all models need to be wrapped in a PrunaModel before passing them to the EvaluationAgent
first_results = eval_agent.evaluate(pipe)
import copy
from diffusers import DiffusionPipeline
from pruna import SmashConfig, smash
from pruna.data.pruna_datamodule import PrunaDataModule
from pruna.evaluation.evaluation_agent import EvaluationAgent
from pruna.evaluation.metrics import CMMD
from pruna.evaluation.task import Task
# Load data and set up smash config
smash_config = SmashConfig()
smash_config["quantizer"] = "hqq_diffusers"
# Load the base model
model_path = "segmind/Segmind-Vega"
pipe = DiffusionPipeline.from_pretrained(model_path)
# Smash the model
copy_pipe = copy.deepcopy(pipe)
smashed_pipe = smash(copy_pipe, smash_config)
# Define the task and the evaluation agent
metrics = [CMMD(call_type="pairwise")]
datamodule = PrunaDataModule.from_string("LAION256")
datamodule.limit_datasets(5)
task = Task(metrics, datamodule=datamodule)
eval_agent = EvaluationAgent(task)
# wrap the model in a PrunaModel to use the EvaluationAgent
wrapped_pipe = PrunaModel(pipe, None)
# Optional: tweak model generation parameters for benchmarking
inference_arguments = {"num_inference_steps": 1, "guidance_scale": 0.0}
wrapped_pipe.inference_handler.model_args.update(inference_arguments)
wrapped_pipe.inference_handler.update_model(wrapped_pipe)
# Evaluate base model first (cached for comparison)
first_results = eval_agent.evaluate(pipe)
# Evaluate smashed model (compared against base model)
smashed_results = eval_agent.evaluate(smashed_pipe)
print(smashed_results)
EvaluationAgent Initialization Options
You can choose between the two initialization approaches shown above based on your preference and project requirements. Both approaches provide identical functionality and can be used interchangeably.
Best Practices
Start with a small dataset
When first setting up evaluation, limit the dataset size with datamodule.limit_datasets(n)
to make debugging faster.
Use pairwise metrics for comparison
When comparing an optimized model against the baseline, use pairwise metrics to get direct comparison scores.
Choose your initialization style
Both direct parameters and Task-based initialization are valid approaches. Choose the one that best fits your project’s coding patterns and requirements.