Skip to content

Commit 7d9bdaf

Browse files
minettekaumMarius-Graml
authored andcommitted
Kid metric added (#435)
* feat/kid-metric and updated docs * fixed typo * fixed linting error * moving KID logic to enum * fixing linting error * fixed another linting error * fixing linting error
1 parent 3ae00ee commit 7d9bdaf

File tree

5 files changed

+1322
-16
lines changed

5 files changed

+1322
-16
lines changed

docs/user_manual/customize_metric.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,13 +187,13 @@ For example, if you are implementing a metric that compares two models, you shou
187187

188188
If you are implementing an alignment metric comparing model's output with the input, you should use the ``x_gt`` or ``gt_x`` call type. Examples from |pruna| include ``clip_score``.
189189

190-
If you are implementing a metric that compares the model's output with the ground truth, you should use the ``y_gt`` or ``gt_y`` call type. Examples from |pruna| include ``fid``, ``cmmd``, ``accuracy``, ``recall``, ``precision``.
190+
If you are implementing a metric that compares the model's output with the ground truth, you should use the ``y_gt`` or ``gt_y`` call type. Examples from |pruna| include ``fid``, ``kid``, ``cmmd``, ``accuracy``, ``recall``, ``precision``.
191191

192192
If you are wrapping an Image Quality Assessment (IQA) metric, that has an internal dataset, you should use the ``y`` call type. Examples from |pruna| include ``arniqa``.
193193

194-
You may want to switch the mode of the metric despite your default ``call_type``. For instance you may want to use ``fid`` in pairwise mode to get a single comparison score for two models.
194+
You may want to switch the mode of the metric despite your default ``call_type``. For instance you may want to use ``fid`` or ``kid`` in pairwise mode to get a single comparison score for two models.
195195

196-
In this case, you can pass ``pairwise`` to the ``call_type`` parameter of the ``StatefulMetric`` constructor.
196+
In this case, you can pass ``pairwise`` to the ``call_type`` parameter of the ``StatefulMetric`` constructor`
197197

198198
.. container:: hidden_code
199199

docs/user_manual/evaluate.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Evaluation helps you understand how compression affects your models across diffe
1010
This knowledge is essential for making informed decisions about which compression techniques work best for your specific needs using two types of metrics:
1111

1212
- **Efficiency Metrics:** Measure speed (total time, latency, throughput), memory (disk, inference, training), and energy usage (consumption, CO2 emissions).
13-
- **Quality Metrics:** Assess fidelity (FID, CMMD), alignment (Clip Score), diversity (PSNR, SSIM), accuracy (accuracy, precision, perplexity), and more. Custom metrics are supported.
13+
- **Quality Metrics:** Assess fidelity (FID, KID, CMMD), alignment (Clip Score), diversity (PSNR, SSIM), accuracy (accuracy, precision, perplexity), and more. Custom metrics are supported.
1414

1515
.. image:: /_static/assets/images/evaluation_agent.png
1616
:alt: Evaluation Agent
@@ -275,11 +275,11 @@ This is what's happening under the hood when you pass ``call_type="single"`` or
275275

276276
* - ``y_gt``
277277
- Model's output first, then ground truth
278-
- ``fid``, ``cmmd``, ``accuracy``, ``recall``, ``precision``
278+
- ``fid``, ``kid``, ``cmmd``, ``accuracy``, ``recall``, ``precision``
279279

280280
* - ``gt_y``
281281
- Ground truth first, then model's output
282-
- ``fid``, ``cmmd``, ``accuracy``, ``recall``, ``precision``
282+
- ``fid``, ``kid``, ``cmmd``, ``accuracy``, ``recall``, ``precision``
283283

284284
* - ``x_gt``
285285
- Input data first, then ground truth

src/pruna/evaluation/metrics/metric_torch.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torchmetrics.classification import Accuracy, Precision, Recall
2626
from torchmetrics.image import (
2727
FrechetInceptionDistance,
28+
KernelInceptionDistance,
2829
LearnedPerceptualImagePatchSimilarity,
2930
MultiScaleStructuralSimilarityIndexMeasure,
3031
PeakSignalNoiseRatio,
@@ -88,6 +89,46 @@ def fid_update(metric: FrechetInceptionDistance, reals: Any, fakes: Any) -> None
8889
metric.update(fakes, real=False)
8990

9091

92+
def kid_update(metric: KernelInceptionDistance, reals: Any, fakes: Any) -> None:
93+
"""
94+
Update handler for KID metric.
95+
96+
Parameters
97+
----------
98+
metric : KernelInceptionDistance instance
99+
The KID metric instance.
100+
reals : Any
101+
The ground truth images tensor.
102+
fakes : Any
103+
The generated images tensor.
104+
"""
105+
metric.update(reals, real=True)
106+
metric.update(fakes, real=False)
107+
108+
109+
def kid_compute(metric: KernelInceptionDistance) -> Any:
110+
"""
111+
Compute handler for KID metric.
112+
113+
KID normally returns (mean, std) but we only need the mean value.
114+
The defensive check handles edge cases (e.g., insufficient data, version differences).
115+
116+
Parameters
117+
----------
118+
metric : KernelInceptionDistance
119+
The KID metric instance.
120+
121+
Returns
122+
-------
123+
Any
124+
The computed metric value (mean from tuple, or result as-is if not a tuple).
125+
"""
126+
result = metric.compute() # type: ignore
127+
if isinstance(result, tuple) and len(result) == 2:
128+
return result[0] # Extract mean from tuple (KID returns normally (mean, std))
129+
return result
130+
131+
91132
def lpips_update(metric: LearnedPerceptualImagePatchSimilarity, preds: Any, target: Any) -> None:
92133
"""
93134
Update handler for LPIPS metric.
@@ -153,7 +194,8 @@ class TorchMetrics(Enum):
153194
"""
154195
Enum for available torchmetrics.
155196
156-
The enum contains triplets of the metric class, the update function and the call type.
197+
The enum contains tuples of (metric class, update function, call type) with an optional
198+
compute function as a 4th element for metrics that need special compute handling.
157199
158200
Parameters
159201
----------
@@ -171,6 +213,7 @@ class TorchMetrics(Enum):
171213
The starting value for the enum.
172214
"""
173215

216+
<<<<<<< HEAD
174217
fid = (partial(FrechetInceptionDistance), fid_update, "gt_y", {IMAGE})
175218
accuracy = (partial(Accuracy), None, "y_gt", MODALITIES)
176219
perplexity = (partial(Perplexity), None, "y_gt", {TEXT})
@@ -183,12 +226,31 @@ class TorchMetrics(Enum):
183226
lpips = (partial(LearnedPerceptualImagePatchSimilarity), lpips_update, "pairwise_y_gt", {IMAGE})
184227
arniqa = (partial(ARNIQA), arniqa_update, "y", {IMAGE})
185228
clipiqa = (partial(CLIPImageQualityAssessment), None, "y", {IMAGE})
229+
=======
230+
fid = (partial(FrechetInceptionDistance), fid_update, "gt_y")
231+
kid = (partial(KernelInceptionDistance), kid_update, "gt_y", kid_compute)
232+
accuracy = (partial(Accuracy), None, "y_gt")
233+
perplexity = (partial(Perplexity), None, "y_gt")
234+
clip_score = (partial(CLIPScore), None, "y_x")
235+
precision = (partial(Precision), None, "y_gt")
236+
recall = (partial(Recall), None, "y_gt")
237+
psnr = (partial(PeakSignalNoiseRatio), None, "pairwise_y_gt")
238+
ssim = (partial(StructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt")
239+
msssim = (partial(MultiScaleStructuralSimilarityIndexMeasure), ssim_update, "pairwise_y_gt")
240+
lpips = (partial(LearnedPerceptualImagePatchSimilarity), lpips_update, "pairwise_y_gt")
241+
arniqa = (partial(ARNIQA), arniqa_update, "y")
242+
clipiqa = (partial(CLIPImageQualityAssessment), None, "y")
243+
>>>>>>> 7d11666 (Kid metric added (#435))
186244

187245
def __init__(self, *args, **kwargs) -> None:
188246
self.tm = self.value[0]
189247
self.update_fn = self.value[1] or default_update
190248
self.call_type = self.value[2]
249+
<<<<<<< HEAD
191250
self.modality = self.value[3]
251+
=======
252+
self.compute_fn = self.value[3] if len(self.value) > 3 else None
253+
>>>>>>> 7d11666 (Kid metric added (#435))
192254

193255
def __call__(self, **kwargs) -> Metric:
194256
"""
@@ -262,7 +324,12 @@ def __init__(self, metric_name: str, call_type: str = "", **kwargs) -> None:
262324

263325
# Get the specific update function for the metric, or use the default if not found.
264326
self.update_fn = TorchMetrics[metric_name].update_fn
327+
<<<<<<< HEAD
265328
self.modality = TorchMetrics[metric_name].modality
329+
=======
330+
# Get the compute function if available (e.g., for KID), otherwise None
331+
self.compute_fn = TorchMetrics[metric_name].compute_fn
332+
>>>>>>> 7d11666 (Kid metric added (#435))
266333
except KeyError:
267334
raise ValueError(f"Metric {metric_name} is not supported.")
268335

@@ -342,7 +409,8 @@ def compute(self) -> Any:
342409
Any
343410
The computed metric value.
344411
"""
345-
result = self.metric.compute()
412+
# Use metric-specific compute function if available (e.g., KID), otherwise use default
413+
result = self.compute_fn(self.metric) if self.compute_fn is not None else self.metric.compute()
346414

347415
# Normally we have a single score for each metric for the entire dataset.
348416
# For IQA metrics we have a single score per image, so we need to convert the tensor to a list.

tests/evaluation/test_torch_metrics.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,41 @@ def test_fid(datamodule_fixture: PrunaDataModule, device: str) -> None:
5858
assert metric.compute().result == pytest.approx(0.0, abs=1e-2)
5959

6060

61+
@pytest.mark.parametrize(
62+
"datamodule_fixture, device",
63+
[
64+
pytest.param("LAION256", "cpu", marks=pytest.mark.cpu),
65+
pytest.param("LAION256", "cuda", marks=pytest.mark.cuda),
66+
],
67+
indirect=["datamodule_fixture"],
68+
)
69+
def test_kid(datamodule_fixture: PrunaDataModule, device: str) -> None:
70+
"""Test the kid."""
71+
dataloader = datamodule_fixture.val_dataloader()
72+
dataloader_iter = iter(dataloader)
73+
74+
# Get multiple batches to ensure enough samples
75+
batches = []
76+
for _ in range(4): # Get 4 batches
77+
_, batch = next(dataloader_iter)
78+
batches.append(batch)
79+
gt = torch.cat(batches, dim=0)
80+
81+
total_samples = gt.shape[0]
82+
# subset_size must be strictly smaller than number of samples
83+
# Use a subset_size that's safely smaller (at least 1 less)
84+
subset_size = min(50, max(2, total_samples - 1))
85+
86+
metric = TorchMetricWrapper("kid", device=device, subset_size=subset_size)
87+
metric.update(gt, gt, gt)
88+
result = metric.compute()
89+
90+
# KID should be close to 0 when comparing identical images
91+
# Use absolute value and a slightly larger tolerance due to numerical precision
92+
assert not torch.isnan(torch.tensor(result.result)), f"KID returned NaN"
93+
assert abs(result.result) < 0.25, f"KID should be close to 0 for identical images, got {result.result}"
94+
95+
6196
@pytest.mark.parametrize(
6297
"datamodule_fixture, device",
6398
[
@@ -77,6 +112,7 @@ def test_clip_score(datamodule_fixture: PrunaDataModule, device: str) -> None:
77112
score = metric.compute()
78113
assert score.result > 0.0 and score.result < 100.0
79114

115+
80116
@pytest.mark.cpu
81117
@pytest.mark.parametrize("datamodule_fixture", ["LAION256"], indirect=True)
82118
def test_clipiqa(datamodule_fixture: PrunaDataModule) -> None:
@@ -113,6 +149,7 @@ def test_torch_metrics(datamodule_fixture: PrunaDataModule, device: str, metric:
113149
metric.update(gt, gt, gt)
114150
assert metric.compute().result == 1.0
115151

152+
116153
@pytest.mark.cpu
117154
@pytest.mark.parametrize("datamodule_fixture", ["LAION256"], indirect=True)
118155
def test_arniqa(datamodule_fixture: PrunaDataModule) -> None:
@@ -123,6 +160,7 @@ def test_arniqa(datamodule_fixture: PrunaDataModule) -> None:
123160
x, gt = next(dataloader_iter)
124161
metric.update(x, gt, gt)
125162

163+
126164
@pytest.mark.cpu
127165
@pytest.mark.parametrize("metric", TorchMetrics.__members__.keys())
128166
@pytest.mark.parametrize("call_type", ["single", "pairwise"])
@@ -143,6 +181,7 @@ def test_check_call_type(metric: str, call_type: str):
143181
else:
144182
assert not metric.call_type.startswith("pairwise")
145183

184+
146185
@pytest.mark.cpu
147186
@pytest.mark.parametrize(
148187
'metric_name,metric_type',
@@ -155,6 +194,7 @@ def test_ssim_generalization_metric_type(metric_name, metric_type):
155194
wrapper = TorchMetricWrapper(metric_name=metric_name)
156195
assert isinstance(wrapper.metric, metric_type)
157196

197+
158198
@pytest.mark.cpu
159199
@pytest.mark.parametrize(
160200
'metric_name,invalid_param_args',

0 commit comments

Comments
 (0)