Adding a Metric
This guide will walk you through the process of adding a new metric to Pruna’s evaluation system.
If anything is unclear or you want to discuss your contribution before opening a PR, please reach out on Discord anytime!
If this is your first time contributing to pruna, please refer to the Setup guide for more information.
Understanding Pruna’s Metric System
pruna has two main types of metrics that live under pruna/evaluation/metrics
:
Base Metrics - Inherit from
BaseMetric
and compute values directly without maintaining state. These metrics usually require isolated inference computation. Examples:GPUMemoryMetric
,ElapsedTimeMetric
.Stateful Metrics - Inherit from
StatefulMetric
and maintain internal state across multiple computations. State here refers to the information that is accumulated across multiple batches. Examples: all metrics underTorchMetricWrapper
likeAccuracy
,CLIPScore
.
When adding a new metric to pruna, you should place your implementation in pruna/evaluation/metrics
directory to ensure it’s properly integrated with the rest of the system. Use snake_case for the file name (e.g., your_new_metric.py
).
In pruna, we evaluate metrics by sharing inference runs across multiple metrics whenever possible. This means that pruna runs inference once for all compatible metrics.
Stateful metrics are preferred for most use cases, especially quality metrics, as they can share inference results across multiple metrics
Base metrics are primarily used when isolated inference is required (e.g., for GPU memory metrics where sharing inference would distort results)
Note
If you are confused about which type of metric to implement, you will likely need to implement stateful metrics. Base metrics are typically only used for specialized performance measurements that require isolated inference.
We use PascalCase for the class names (e.g, YourNewMetric
) and NumPy style docstrings for documentation.
Base Metrics
Base metrics inherit from the BaseMetric
class and implement the compute()
method. These are used when a metric requires isolated inference or cannot share computation with other metrics.
pruna EvaluationAgent
(documentation) requires all BaseMetric
s to implement the compute
method with two specific parameters: model
and dataloader
. Please take note that the EvaluationAgent
does not handle inference for base metrics. You will need to handle inference computations yourself.
from pruna.evaluation.metrics.metric_base import BaseMetric
class YourNewMetric(BaseMetric):
def __init__(self):
super().__init__()
# Initialize any parameters your metric needs
def compute(self, model, dataloader):
'''Run inference on the model and compute the metric value.'''
outputs = run_inference(model, dataloader)
result = some_calculation(outputs)
return result
Stateful Metrics
Stateful metrics inherit from the StatefulMetric
class and are the preferred approach for most metrics in pruna. They maintain internal state variables that accumulate information across multiple batches, allowing for efficient sharing of inference across different metrics.
Every stateful metric must implement the following methods:
__init__(self, **kwargs)
: Initialize your metric and its parametersCall
super().__init__()
Set
self.metric_name
Initialize state variables using
add_state()
Define any additional parameters
add_state(self, name, default_value)
: Define persistent state variablesMust be called in
__init__
Creates variables that persist and accumulate across batches
Example states: totals, counts, running sums
update(self, inputs, ground_truths, predictions)
: Process each batchCalled automatically by the evaluation pipeline
Update your state variables based on the current batch. Your implementation can use any combination of these parameters as needed for its specific calculations.
No return value needed
compute(self)
: Calculate final metric valueUse accumulated state to compute final result
Called after all batches are processed
Must return the final metric value
reset(self)
: Reset all state variablesMust reset all states to their initial values
Called automatically between evaluation runs
Here’s a complete example showing all required methods:
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
class YourNewStatefulMetric(StatefulMetric):
def __init__(self, param1='default1', param2='default2'):
super().__init__()
self.param1 = param1
self.param2 = param2
self.metric_name = "your_metric_name"
# Initialize state variables
self.add_state("total", 0)
self.add_state("count", 0)
def add_state(self, name, default_value):
'''Add a state variable to the metric.'''
self.state[name] = default_value
def update(self, inputs, ground_truths, predictions):
# Update the state variables based on the current batch
# Choose the required combination of inputs, ground_truths and predictions
batch_result = some_calculation(predictions, ground_truths)
self.total += batch_result
self.count += 1
def compute(self):
# Compute the final metric value using the accumulated state
if self.count == 0:
return 0
return self.total / self.count
def reset(self):
# Reset state variables to initial values
self.total = 0
self.count = 0
When to Use Each Type
Use Stateful Metrics when: Your metric can share inference with other metrics without affecting results (most quality metrics fall into this category)
Use Basic Metrics when: Your metric requires isolated inference or would produce incorrect results if inference were shared (e.g., performance metrics like GPU memory usage)
By using stateful metrics whenever possible, pruna can efficiently evaluate multiple metrics with just a single inference pass.
Registering Your Metric
After implementing your metric, you need to register it with Pruna’s MetricRegistry
system.
The simplest way to do this is with the @MetricRegistry.register
decorator:
from pruna.evaluation.metrics.registry import MetricRegistry
from pruna.evaluation.metrics.metric_stateful import StatefulMetric
@MetricRegistry.register("your_new_metric_name")
class YourNewMetric(StatefulMetric):
def __init__(self, param1='default1', param2='default2'): # Don't forget to add default values for your parameters!
super().__init__()
self.param1 = param1
self.param2 = param2
self.metric_name = "your_metric_name"
Thanks to this registry system, everyone using pruna can now refer to your metric by name without having to create instances directly!
One important thing: the registration happens when your module is imported. To ensure your metric is always available, we suggest importing it in pruna/evaluation/metrics/__init__.py
file.
Steps to Add a New Metric
Decide on the metric type: Determine whether your metric needs isolated inference (use
BaseMetric
) or can share inference (useStatefulMetric
).Create a new file: Create a new Python file in the
pruna/evaluation/metrics/
directory with a descriptive name for your metric.Implement your metric class: Inherit from the appropriate class and implement the required methods.
Register your metric: Use the
MetricRegistry.register
decorator to make your metric available throughout the system.Add tests: Create tests in
pruna/tests/evaluation
for your metric to ensure it works correctly.Update documentation: Add documentation for your new metric in the user manual
docs/user_manual/evaluation.rst
, including examples of how to use it.Submit a pull request: Follow the standard contribution process to submit your new metric for review.
By following these steps, you’ll help expand Pruna’s capabilities and contribute to the project’s success.
Using your new metric
Once you’ve implemented your metric, everyone can use it in Pruna’s evaluation pipeline! Here’s how:
# mock certain imports to make the code block runnable
import sys
import types
from diffusers import StableDiffusionPipeline
dummy_your_metric = types.ModuleType("pruna.evaluation.metrics.your_metric_file")
dummy_your_metric.YourNewMetric = "dummy_your_metric"
sys.modules["pruna.evaluation.metrics.your_metric_file"] = dummy_your_metric
model_path = "CompVis/stable-diffusion-v1-4"
model = StableDiffusionPipeline.from_pretrained(model_path)
from pruna.evaluation.metrics.metric_torch import TorchMetricWrapper
from pruna.evaluation.metrics.your_metric_file import YourNewMetric
metrics = [
'clip_score',
'your_new_metric_name'
]
data_module = PrunaDataModule.from_string('LAION256')
test_dataloader = data_module.train_dataloader()
task = Task(request=metrics, dataloader=test_dataloader)
eval_agent = EvaluationAgent(task=task)
results = eval_agent.evaluate(model)