Customize a Metric
This guide will walk you through the process of adding a new metric to Pruna’s evaluation system.
If anything is unclear or you want to discuss your contribution before opening a PR, please reach out on Discord anytime!
If this is your first time contributing to pruna, please refer to the Setup guide for more information.
1. Choosing the right type of metric
pruna’s evaluation system supports two types of metrics, located under pruna/evaluation/metrics
: BaseMetric
and StatefulMetric
.
These two types are designed to accommodate different use cases.
- BaseMetric: Inherit from
BaseMetric
and compute values directly without maintaining state. Used when isolated inference is required (e.g.,
latency
,disk_memory
, etc.)
- BaseMetric: Inherit from
- StatefulMetric: Inherit from
StatefulMetric
and accumulate state across multiple batches. Best suited for quality evaluations (e.g,
accuracy
,clip_score
, etc.)
- StatefulMetric: Inherit from
Note
In most cases, you should implement a StatefulMetric
. BaseMetric
is reserved for specialized performance measurements where shared inference would distort results.
2. Implement the metric class
Create a new file in pruna/evaluation/metrics
with a descriptive name for your metric. (e.g, your_new_metric.py
)
We use snake_case for the file names (e.g., your_new_metric.py
), PascalCase for the class names (e.g, YourNewMetric
) and NumPy style docstrings for documentation.
Both BaseMetric
and StatefulMetric
return a MetricResult
object, which contains the metric name, result value and other metadata.
Implementing a BaseMetric
Create a new class that inherits from BaseMetric
and implements the compute()
method.
Your metric should have a metric_name
attribute and a higher_is_better
attribute. Higher is better is a boolean value that indicates if a higher metric value is better.
compute()
takes two parameters: model
and dataloader
.
Inside compute()
, you are responsible for running inference manually.
Your method should return a MetricResult
object with the metric name, result value and other metadata. The result value should be a float or int.
from pruna.evaluation.metrics.metric_base import BaseMetric
from pruna.evaluation.metrics.metric_result import MetricResult
class YourNewMetric(BaseMetric):
'''Your metric description'''
metric_name = "your_metric_name"
higher_is_better = True # or False
def __init__(self):
super().__init__()
# Initialize any parameters your metric needs
def compute(self, model, dataloader):
'''Run inference on the model and compute the metric value.'''
outputs = run_inference(model, dataloader)
result = some_calculation(outputs)
params = self.__dict__.copy() # or any metadata you prefer
return MetricResult(self.metric_name, params, result)
Implementing a StatefulMetric
To implement a StatefulMetric
, create a class that inherits from StatefulMetric
. These metrics are designed to accumulate state across multiple batches and can share inference with other metrics.
Your metric should have a metric_name
attribute and a higher_is_better
attribute. Higher is better is a boolean value that indicates if a higher metric value is better.
Use add_state()
method to define internal state variables that will accumulate data across batches. For example, you might track totals and counts to compute an average.
The update()
method processes each batch of data, updating the state variables based on the current batch. It takes three parameters: inputs
, ground_truths
and predictions
.
The compute()
method is called after all batches are processed and returns a MetricResult
object, which contains the final metric value calculated from the accumulated state.
Metrics can operate in both single-model and pairwise modes, determined by the call_type
parameter. Common call_types
include y_gt
, gt_y
, x_gt
, gt_x
, pairwise_y_gt
, and pairwise_gt_y
. For more details, see the Understanding Call Types section.
Once you have implemented your metric, you can switch the mode of the metric despite your default call_type
simply by passing single
or pairwise
to the call_type
parameter of the StatefulMetric
constructor.
Here’s a complete example implementing a StatefulMetric
with a single call_type
showing all required methods:
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
from pruna.evaluation.metrics.result import MetricResult
from pruna.evaluation.metrics.utils import SINGLE, get_call_type_for_single_metric, metric_data_processor # for pairwise metrics, you would need to change the imports to pairwise
import torch
class YourNewStatefulMetric(StatefulMetric):
'''Your metric description'''
default_call_type = "y_gt"
metric_name = "your_metric_name"
higher_is_better = True # or False
def __init__(self, param1='default1', param2='default2', call_type=SINGLE): # Since we picked a single call_type for default, we can use it as a default value
super().__init__()
self.param1 = param1
self.param2 = param2
self.call_type = get_call_type_for_single_metric(call_type, self.default_call_type) # Call the correct helper function to get the correct call_type
# Initialize state variables
self.add_state("total", torch.zeros(1))
self.add_state("count", torch.zeros(1))
def update(self, inputs, ground_truths, predictions):
# Update the state variables based on the current batch
# Pass the inputs, ground_truths and predictions and the call_type to the metric_data_processor to get the data in the correct format
metric_data = metric_data_processor(inputs, ground_truths, predictions, self.call_type)
batch_result = some_calculation(*metric_data)
self.total += batch_result
self.count += 1
def compute(self):
# Compute the final metric value using the accumulated state
if self.count == 0:
return 0
return MetricResult(self.metric_name, self.__dict__.copy(), self.total / self.count)
Understanding Call Types
pruna metrics can operate in both single-model and pairwise modes:
Single-model mode: Each evaluation produces independent scores for the model being evaluated.
Pairwise mode: Metrics compare a subsequent model against the first model evaluated by the agent and produce a single comparison score.
Call Type |
Description |
---|---|
y_gt |
Model’s output first, then ground truth |
gt_y |
Ground truth first, then model’s output |
x_gt |
Input data first, then ground truth |
gt_x |
Ground truth first, then input data |
pairwise_y_gt |
Base model’s output first, then subsequent model’s output |
pairwise_gt_y |
Subsequent model’s output first, then base model’s output |
You need to decide on the default call_type
based on the metric you are implementing.
For example, if you are implementing a metric that compares two models, you should use the pairwise_y_gt
call type. Examples from pruna include psnr
, ssim
, lpips
.
If you are implementing an alignment metric comparing model’s output with the input, you should use the x_gt
or gt_x
call type. Examples from pruna include clip_score
.
If you are implementing a metric that compares the model’s output with the ground truth, you should use the y_gt
or gt_y
call type. Examples from pruna include fid
, cmmd
, accuracy
, recall
, precision
.
You may want to switch the mode of the metric despite your default call_type
. For instance you may want to use fid
in pairwise mode to get a single comparison score for two models.
In this case, you can pass pairwise
to the call_type
parameter of the StatefulMetric
constructor.
import sys
import types
dummy_your_metric = types.ModuleType("pruna.evaluation.metrics.your_metric_file")
dummy_your_metric.YourNewStatefulMetric = "dummy_your_metric"
sys.modules["pruna.evaluation.metrics.your_metric_file"] = dummy_your_metric
from pruna.evaluation.metrics.your_metric_file import YourNewStatefulMetric
# Initialize your metric from the instance
YourNewStatefulMetric(param1='value1', param2='value2', call_type="pairwise")
If you have implemented your metric using the correct get_call_type_for_metric
function and metric_data_processor
function, this will work as expected.
3. Register the metric
After implementing your metric, you need to register it with Pruna’s MetricRegistry
system.
The simplest way to do this is with the @MetricRegistry.register
decorator:
from pruna.evaluation.metrics.registry import MetricRegistry
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
@MetricRegistry.register("your_metric_name")
class YourNewMetric(StatefulMetric):
def __init__(self, param1='default1', param2='default2'): # Don't forget to add default values for your parameters!
super().__init__()
self.param1 = param1
self.param2 = param2
self.metric_name = "your_metric_name"
Thanks to this registry system, everyone using pruna can now refer to your metric by name without having to create instances directly!
# mock certain imports to make the code block runnable
import sys
import types
dummy_your_metric = types.ModuleType("pruna.evaluation.metrics.your_metric_file")
dummy_your_metric.YourNewMetric = "dummy_your_metric"
sys.modules["pruna.evaluation.metrics.your_metric_file"] = dummy_your_metric
from pruna.evaluation.metrics.your_metric_file import YourNewMetric
# Classic way: Initialize your metric from the instance
YourNewMetric(param1='value1', param2='value2')
from pruna.evaluation.task import Task
metrics = [
'your_metric_name'
]
# Now you can create a task with your metric from the metric name.
task = Task(request=metrics, data_module=pruna.data.pruna_datamodule.PrunaDataModule.from_string('LAION256'))
One important thing: the registration happens when your module is imported. To ensure your metric is always available, we suggest importing it in pruna/evaluation/metrics/__init__.py
file.
4. Add tests and update the documentation
Create tests in pruna/tests/evaluation
for your metric to ensure it works correctly.
Add documentation for your new metric in the user manual docs/user_manual/evaluation.rst
, including examples of how to use it.
By following these steps, you’ll help expand Pruna’s capabilities and contribute to the project’s success.
Using your new metric
Once you’ve implemented your metric, everyone can use it in Pruna’s evaluation pipeline! Here’s how:
# mock certain imports to make the code block runnable
import sys
import types
from diffusers import StableDiffusionPipeline
dummy_your_metric = types.ModuleType("pruna.evaluation.metrics.your_metric_file")
dummy_your_metric.YourNewMetric = "dummy_your_metric"
sys.modules["pruna.evaluation.metrics.your_metric_file"] = dummy_your_metric
model_path = "CompVis/stable-diffusion-v1-4"
model = StableDiffusionPipeline.from_pretrained(model_path)
from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper
from pruna.evaluation.metrics.your_metric_file import YourNewMetric
metrics = [
'clip_score',
'your_new_metric_name'
]
task = Task(request=metrics, data_module=pruna.data.pruna_datamodule.PrunaDataModule.from_string('LAION256'))
eval_agent = EvaluationAgent(task=task)
results = eval_agent.evaluate(model)