Customize a Metric

This guide will walk you through the process of adding a new metric to Pruna’s evaluation system.

If anything is unclear or you want to discuss your contribution before opening a PR, please reach out on Discord anytime!

If this is your first time contributing to pruna, please refer to the Setup guide for more information.

1. Choosing the right type of metric

pruna’s evaluation system supports two types of metrics, located under pruna/evaluation/metrics: BaseMetric and StatefulMetric.

These two types are designed to accommodate different use cases.

  • BaseMetric: Inherit from BaseMetric and compute values directly without maintaining state.
    • Used when isolated inference is required (e.g., latency, disk_memory, etc.)

  • StatefulMetric: Inherit from StatefulMetric and accumulate state across multiple batches.
    • Best suited for quality evaluations (e.g, accuracy, clip_score, etc.)

Note

In most cases, you should implement a StatefulMetric. BaseMetric is reserved for specialized performance measurements where shared inference would distort results.

2. Implement the metric class

Create a new file in pruna/evaluation/metrics with a descriptive name for your metric. (e.g, your_new_metric.py)

We use snake_case for the file names (e.g., your_new_metric.py), PascalCase for the class names (e.g, YourNewMetric) and NumPy style docstrings for documentation.

Both BaseMetric and StatefulMetric return a MetricResult object, which contains the metric name, result value and other metadata.

Implementing a BaseMetric

Create a new class that inherits from BaseMetric and implements the compute() method.

Your metric should have a metric_name attribute and a higher_is_better attribute. Higher is better is a boolean value that indicates if a higher metric value is better.

compute() takes two parameters: model and dataloader.

Inside compute(), you are responsible for running inference manually.

Your method should return a MetricResult object with the metric name, result value and other metadata. The result value should be a float or int.

from pruna.evaluation.metrics.metric_base import BaseMetric
from pruna.evaluation.metrics.metric_result import MetricResult

class YourNewMetric(BaseMetric):
    '''Your metric description'''

    metric_name = "your_metric_name"
    higher_is_better = True # or False

    def __init__(self):
        super().__init__()
        # Initialize any parameters your metric needs

    def compute(self, model, dataloader):
        '''Run inference on the model and compute the metric value.'''

        outputs = run_inference(model, dataloader)
        result = some_calculation(outputs)
        params = self.__dict__.copy() # or any metadata you prefer
        return MetricResult(self.metric_name, params, result)

Implementing a StatefulMetric

To implement a StatefulMetric, create a class that inherits from StatefulMetric. These metrics are designed to accumulate state across multiple batches and can share inference with other metrics.

Your metric should have a metric_name attribute and a higher_is_better attribute. Higher is better is a boolean value that indicates if a higher metric value is better.

Use add_state() method to define internal state variables that will accumulate data across batches. For example, you might track totals and counts to compute an average.

The update() method processes each batch of data, updating the state variables based on the current batch. It takes three parameters: inputs, ground_truths and predictions.

The compute() method is called after all batches are processed and returns a MetricResult object, which contains the final metric value calculated from the accumulated state.

Metrics can operate in both single-model and pairwise modes, determined by the call_type parameter. Common call_types include y_gt, gt_y, x_gt, gt_x, pairwise_y_gt, and pairwise_gt_y. For more details, see the Understanding Call Types section.

Once you have implemented your metric, you can switch the mode of the metric despite your default call_type simply by passing single or pairwise to the call_type parameter of the StatefulMetric constructor.

Here’s a complete example implementing a StatefulMetric with a single call_type showing all required methods:

from pruna.evaluation.metrics.metric_stateful import StatefulMetric
from pruna.evaluation.metrics.result import MetricResult
from pruna.evaluation.metrics.utils import SINGLE, get_call_type_for_single_metric, metric_data_processor # for pairwise metrics, you would need to change the imports to pairwise
import torch

class YourNewStatefulMetric(StatefulMetric):
    '''Your metric description'''

    default_call_type = "y_gt"
    metric_name = "your_metric_name"
    higher_is_better = True # or False

    def __init__(self, param1='default1', param2='default2', call_type=SINGLE): # Since we picked a single call_type for default, we can use it as a default value
        super().__init__()
        self.param1 = param1
        self.param2 = param2
        self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) # Call the correct helper function to get the correct call_type

        # Initialize state variables
        self.add_state("total", torch.zeros(1))
        self.add_state("count", torch.zeros(1))

    def update(self, inputs, ground_truths, predictions):
        # Update the state variables based on the current batch
        # Pass the inputs, ground_truths and predictions and the call_type to the metric_data_processor to get the data in the correct format
        metric_data = metric_data_processor(inputs, ground_truths, predictions, self.call_type)
        batch_result = some_calculation(*metric_data)
        self.total += batch_result
        self.count += 1

    def compute(self):
        # Compute the final metric value using the accumulated state
        if self.count == 0:
            return 0
        return MetricResult(self.metric_name, self.__dict__.copy(), self.total / self.count)

Understanding Call Types

pruna metrics can operate in both single-model and pairwise modes:

  • Single-model mode: Each evaluation produces independent scores for the model being evaluated.

  • Pairwise mode: Metrics compare a subsequent model against the first model evaluated by the agent and produce a single comparison score.

Call Type

Description

y_gt

Model’s output first, then ground truth

gt_y

Ground truth first, then model’s output

x_gt

Input data first, then ground truth

gt_x

Ground truth first, then input data

pairwise_y_gt

Base model’s output first, then subsequent model’s output

pairwise_gt_y

Subsequent model’s output first, then base model’s output

You need to decide on the default call_type based on the metric you are implementing.

For example, if you are implementing a metric that compares two models, you should use the pairwise_y_gt call type. Examples from pruna include psnr, ssim, lpips.

If you are implementing an alignment metric comparing model’s output with the input, you should use the x_gt or gt_x call type. Examples from pruna include clip_score.

If you are implementing a metric that compares the model’s output with the ground truth, you should use the y_gt or gt_y call type. Examples from pruna include fid, cmmd, accuracy, recall, precision.

You may want to switch the mode of the metric despite your default call_type. For instance you may want to use fid in pairwise mode to get a single comparison score for two models.

In this case, you can pass pairwise to the call_type parameter of the StatefulMetric constructor.

import sys
import types

dummy_your_metric = types.ModuleType("pruna.evaluation.metrics.your_metric_file")
dummy_your_metric.YourNewStatefulMetric = "dummy_your_metric"
sys.modules["pruna.evaluation.metrics.your_metric_file"] = dummy_your_metric
from pruna.evaluation.metrics.your_metric_file import YourNewStatefulMetric

# Initialize your metric from the instance
YourNewStatefulMetric(param1='value1', param2='value2', call_type="pairwise")

If you have implemented your metric using the correct get_call_type_for_metric function and metric_data_processor function, this will work as expected.

3. Register the metric

After implementing your metric, you need to register it with Pruna’s MetricRegistry system.

The simplest way to do this is with the @MetricRegistry.register decorator:

from pruna.evaluation.metrics.registry import MetricRegistry
from pruna.evaluation.metrics.metric_stateful import StatefulMetric

@MetricRegistry.register("your_metric_name")
class YourNewMetric(StatefulMetric):
    def __init__(self, param1='default1', param2='default2'): # Don't forget to add default values for your parameters!
        super().__init__()
        self.param1 = param1
        self.param2 = param2
        self.metric_name = "your_metric_name"

Thanks to this registry system, everyone using pruna can now refer to your metric by name without having to create instances directly!

# mock certain imports to make the code block runnable
import sys
import types

dummy_your_metric = types.ModuleType("pruna.evaluation.metrics.your_metric_file")
dummy_your_metric.YourNewMetric = "dummy_your_metric"
sys.modules["pruna.evaluation.metrics.your_metric_file"] = dummy_your_metric
from pruna.evaluation.metrics.your_metric_file import YourNewMetric

# Classic way: Initialize your metric from the instance
YourNewMetric(param1='value1', param2='value2')
from pruna.evaluation.task import Task

metrics = [
    'your_metric_name'
]

# Now you can create a task with your metric from the metric name.
task = Task(request=metrics, data_module=pruna.data.pruna_datamodule.PrunaDataModule.from_string('LAION256'))

One important thing: the registration happens when your module is imported. To ensure your metric is always available, we suggest importing it in pruna/evaluation/metrics/__init__.py file.

4. Add tests and update the documentation

Create tests in pruna/tests/evaluation for your metric to ensure it works correctly.

Add documentation for your new metric in the user manual docs/user_manual/evaluation.rst, including examples of how to use it.

By following these steps, you’ll help expand Pruna’s capabilities and contribute to the project’s success.

Using your new metric

Once you’ve implemented your metric, everyone can use it in Pruna’s evaluation pipeline! Here’s how:

# mock certain imports to make the code block runnable
import sys
import types
from diffusers import StableDiffusionPipeline

dummy_your_metric = types.ModuleType("pruna.evaluation.metrics.your_metric_file")
dummy_your_metric.YourNewMetric = "dummy_your_metric"
sys.modules["pruna.evaluation.metrics.your_metric_file"] = dummy_your_metric

model_path = "CompVis/stable-diffusion-v1-4"
model = StableDiffusionPipeline.from_pretrained(model_path)
from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper
from pruna.evaluation.metrics.your_metric_file import YourNewMetric

metrics = [
    'clip_score',
    'your_new_metric_name'
]

task = Task(request=metrics, data_module=pruna.data.pruna_datamodule.PrunaDataModule.from_string('LAION256'))

eval_agent = EvaluationAgent(task=task)

results = eval_agent.evaluate(model)