Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/user_manual/customize_metric.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,13 @@ For example, if you are implementing a metric that compares two models, you shou

If you are implementing an alignment metric comparing model's output with the input, you should use the ``x_gt`` or ``gt_x`` call type. Examples from |pruna| include ``clip_score``.

If you are implementing a metric that compares the model's output with the ground truth, you should use the ``y_gt`` or ``gt_y`` call type. Examples from |pruna| include ``fid``, ``cmmd``, ``accuracy``, ``recall``, ``precision``.
If you are implementing a metric that compares the model's output with the ground truth, you should use the ``y_gt`` or ``gt_y`` call type. Examples from |pruna| include ``fid``, ``kid``, ``cmmd``, ``accuracy``, ``recall``, ``precision``.

If you are wrapping an Image Quality Assessment (IQA) metric, that has an internal dataset, you should use the ``y`` call type. Examples from |pruna| include ``arniqa``.

You may want to switch the mode of the metric despite your default ``call_type``. For instance you may want to use ``fid`` in pairwise mode to get a single comparison score for two models.
You may want to switch the mode of the metric despite your default ``call_type``. For instance you may want to use ``fid`` or ``kid`` in pairwise mode to get a single comparison score for two models.

In this case, you can pass ``pairwise`` to the ``call_type`` parameter of the ``StatefulMetric`` constructor.
In this case, you can pass ``pairwise`` to the ``call_type`` parameter of the ``StatefulMetric`` constructor`

.. container:: hidden_code

Expand Down
6 changes: 3 additions & 3 deletions docs/user_manual/evaluate.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Evaluation helps you understand how compression affects your models across diffe
This knowledge is essential for making informed decisions about which compression techniques work best for your specific needs using two types of metrics:

- **Efficiency Metrics:** Measure speed (total time, latency, throughput), memory (disk, inference, training), and energy usage (consumption, CO2 emissions).
- **Quality Metrics:** Assess fidelity (FID, CMMD), alignment (Clip Score), diversity (PSNR, SSIM), accuracy (accuracy, precision, perplexity), and more. Custom metrics are supported.
- **Quality Metrics:** Assess fidelity (FID, KID, CMMD), alignment (Clip Score), diversity (PSNR, SSIM), accuracy (accuracy, precision, perplexity), and more. Custom metrics are supported.

.. image:: /_static/assets/images/evaluation_agent.png
:alt: Evaluation Agent
Expand Down Expand Up @@ -275,11 +275,11 @@ This is what's happening under the hood when you pass ``call_type="single"`` or

* - ``y_gt``
- Model's output first, then ground truth
- ``fid``, ``cmmd``, ``accuracy``, ``recall``, ``precision``
- ``fid``, ``kid``, ``cmmd``, ``accuracy``, ``recall``, ``precision``

* - ``gt_y``
- Ground truth first, then model's output
- ``fid``, ``cmmd``, ``accuracy``, ``recall``, ``precision``
- ``fid``, ``kid``, ``cmmd``, ``accuracy``, ``recall``, ``precision``

* - ``x_gt``
- Input data first, then ground truth
Expand Down
51 changes: 49 additions & 2 deletions src/pruna/evaluation/metrics/metric_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from torchmetrics.classification import Accuracy, Precision, Recall
from torchmetrics.image import (
FrechetInceptionDistance,
KernelInceptionDistance,
LearnedPerceptualImagePatchSimilarity,
MultiScaleStructuralSimilarityIndexMeasure,
PeakSignalNoiseRatio,
Expand Down Expand Up @@ -85,6 +86,46 @@ def fid_update(metric: FrechetInceptionDistance, reals: Any, fakes: Any) -> None
metric.update(fakes, real=False)


def kid_update(metric: KernelInceptionDistance, reals: Any, fakes: Any) -> None:
"""
Update handler for KID metric.

Parameters
----------
metric : KernelInceptionDistance instance
The KID metric instance.
reals : Any
The ground truth images tensor.
fakes : Any
The generated images tensor.
"""
metric.update(reals, real=True)
metric.update(fakes, real=False)


def kid_compute(metric: KernelInceptionDistance) -> Any:
"""
Compute handler for KID metric.

KID normally returns (mean, std) but we only need the mean value.
The defensive check handles edge cases (e.g., insufficient data, version differences).

Parameters
----------
metric : KernelInceptionDistance
The KID metric instance.

Returns
-------
Any
The computed metric value (mean from tuple, or result as-is if not a tuple).
"""
result = metric.compute() # type: ignore
if isinstance(result, tuple) and len(result) == 2:
return result[0] # Extract mean from tuple (KID returns normally (mean, std))
return result


def lpips_update(metric: LearnedPerceptualImagePatchSimilarity, preds: Any, target: Any) -> None:
"""
Update handler for LPIPS metric.
Expand Down Expand Up @@ -152,7 +193,8 @@ class TorchMetrics(Enum):
"""
Enum for available torchmetrics.

The enum contains triplets of the metric class, the update function and the call type.
The enum contains tuples of (metric class, update function, call type) with an optional
compute function as a 4th element for metrics that need special compute handling.

Parameters
----------
Expand All @@ -171,6 +213,7 @@ class TorchMetrics(Enum):
"""

fid = (partial(FrechetInceptionDistance), fid_update, "gt_y")
kid = (partial(KernelInceptionDistance), kid_update, "gt_y", kid_compute)
accuracy = (partial(Accuracy), None, "y_gt")
perplexity = (partial(Perplexity), None, "y_gt")
clip_score = (partial(CLIPScore), None, "y_x")
Expand All @@ -187,6 +230,7 @@ def __init__(self, *args, **kwargs) -> None:
self.tm = self.value[0]
self.update_fn = self.value[1] or default_update
self.call_type = self.value[2]
self.compute_fn = self.value[3] if len(self.value) > 3 else None

def __call__(self, **kwargs) -> Metric:
"""
Expand Down Expand Up @@ -260,6 +304,8 @@ def __init__(self, metric_name: str, call_type: str = "", **kwargs) -> None:

# Get the specific update function for the metric, or use the default if not found.
self.update_fn = TorchMetrics[metric_name].update_fn
# Get the compute function if available (e.g., for KID), otherwise None
self.compute_fn = TorchMetrics[metric_name].compute_fn
except KeyError:
raise ValueError(f"Metric {metric_name} is not supported.")

Expand Down Expand Up @@ -339,7 +385,8 @@ def compute(self) -> Any:
Any
The computed metric value.
"""
result = self.metric.compute()
# Use metric-specific compute function if available (e.g., KID), otherwise use default
result = self.compute_fn(self.metric) if self.compute_fn is not None else self.metric.compute()

# Normally we have a single score for each metric for the entire dataset.
# For IQA metrics we have a single score per image, so we need to convert the tensor to a list.
Expand Down
40 changes: 40 additions & 0 deletions tests/evaluation/test_torch_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,41 @@ def test_fid(datamodule_fixture: PrunaDataModule, device: str) -> None:
assert metric.compute().result == pytest.approx(0.0, abs=1e-2)


@pytest.mark.parametrize(
"datamodule_fixture, device",
[
pytest.param("LAION256", "cpu", marks=pytest.mark.cpu),
pytest.param("LAION256", "cuda", marks=pytest.mark.cuda),
],
indirect=["datamodule_fixture"],
)
def test_kid(datamodule_fixture: PrunaDataModule, device: str) -> None:
"""Test the kid."""
dataloader = datamodule_fixture.val_dataloader()
dataloader_iter = iter(dataloader)

# Get multiple batches to ensure enough samples
batches = []
for _ in range(4): # Get 4 batches
_, batch = next(dataloader_iter)
batches.append(batch)
gt = torch.cat(batches, dim=0)

total_samples = gt.shape[0]
# subset_size must be strictly smaller than number of samples
# Use a subset_size that's safely smaller (at least 1 less)
subset_size = min(50, max(2, total_samples - 1))

metric = TorchMetricWrapper("kid", device=device, subset_size=subset_size)
metric.update(gt, gt, gt)
result = metric.compute()

# KID should be close to 0 when comparing identical images
# Use absolute value and a slightly larger tolerance due to numerical precision
assert not torch.isnan(torch.tensor(result.result)), f"KID returned NaN"
assert abs(result.result) < 0.25, f"KID should be close to 0 for identical images, got {result.result}"


@pytest.mark.parametrize(
"datamodule_fixture, device",
[
Expand All @@ -77,6 +112,7 @@ def test_clip_score(datamodule_fixture: PrunaDataModule, device: str) -> None:
score = metric.compute()
assert score.result > 0.0 and score.result < 100.0


@pytest.mark.cpu
@pytest.mark.parametrize("datamodule_fixture", ["LAION256"], indirect=True)
def test_clipiqa(datamodule_fixture: PrunaDataModule) -> None:
Expand Down Expand Up @@ -113,6 +149,7 @@ def test_torch_metrics(datamodule_fixture: PrunaDataModule, device: str, metric:
metric.update(gt, gt, gt)
assert metric.compute().result == 1.0


@pytest.mark.cpu
@pytest.mark.parametrize("datamodule_fixture", ["LAION256"], indirect=True)
def test_arniqa(datamodule_fixture: PrunaDataModule) -> None:
Expand All @@ -123,6 +160,7 @@ def test_arniqa(datamodule_fixture: PrunaDataModule) -> None:
x, gt = next(dataloader_iter)
metric.update(x, gt, gt)


@pytest.mark.cpu
@pytest.mark.parametrize("metric", TorchMetrics.__members__.keys())
@pytest.mark.parametrize("call_type", ["single", "pairwise"])
Expand All @@ -143,6 +181,7 @@ def test_check_call_type(metric: str, call_type: str):
else:
assert not metric.call_type.startswith("pairwise")


@pytest.mark.cpu
@pytest.mark.parametrize(
'metric_name,metric_type',
Expand All @@ -155,6 +194,7 @@ def test_ssim_generalization_metric_type(metric_name, metric_type):
wrapper = TorchMetricWrapper(metric_name=metric_name)
assert isinstance(wrapper.metric, metric_type)


@pytest.mark.cpu
@pytest.mark.parametrize(
'metric_name,invalid_param_args',
Expand Down
Loading