Skip to content

Commit b5f705e

Browse files
authored
Add: (Sparse)Top K Categorical Accuracy Metric (keras-team#61)
* chore: addingop k categorical accuraracy * chore: adding top k and in top k * chore: fixing tests * chore: y true argmax * chore: adding sparse top k cat metric * review coomments
1 parent e0fde58 commit b5f705e

File tree

6 files changed

+285
-0
lines changed

6 files changed

+285
-0
lines changed

keras_core/backend/jax/math.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,10 @@ def top_k(x, k, sorted=True):
1313
"Jax backend does not support `sorted=False` for `ops.top_k`"
1414
)
1515
return jax.lax.top_k(x, k)
16+
17+
18+
def in_top_k(targets, predictions, k):
19+
topk_indices = top_k(predictions, k)[1]
20+
targets = targets[..., None]
21+
mask = targets == topk_indices
22+
return jax.numpy.any(mask, axis=1)

keras_core/backend/tensorflow/math.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,7 @@ def segment_sum(data, segment_ids, num_segments=None, sorted=False):
1010

1111
def top_k(x, k, sorted=True):
1212
return tf.math.top_k(x, k, sorted=sorted)
13+
14+
15+
def in_top_k(targets, predictions, k):
16+
return tf.math.in_top_k(targets, predictions, k)

keras_core/metrics/accuracy_metrics.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,3 +279,154 @@ def __init__(self, name="sparse_categorical_accuracy", dtype=None):
279279

280280
def get_config(self):
281281
return {"name": self.name, "dtype": self.dtype}
282+
283+
284+
def top_k_categorical_accuracy(y_true, y_pred, k=5):
285+
reshape_matches = False
286+
y_pred = ops.convert_to_tensor(y_pred)
287+
y_true = ops.convert_to_tensor(y_true, dtype=y_true.dtype)
288+
y_true = ops.argmax(y_true, axis=-1)
289+
y_true_rank = len(y_true.shape)
290+
y_pred_rank = len(y_pred.shape)
291+
y_true_org_shape = ops.shape(y_true)
292+
293+
# Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,)
294+
if (y_true_rank is not None) and (y_pred_rank is not None):
295+
if y_pred_rank > 2:
296+
y_pred = ops.reshape(y_pred, [-1, y_pred.shape[-1]])
297+
if y_true_rank > 1:
298+
reshape_matches = True
299+
y_true = ops.reshape(y_true, [-1])
300+
301+
matches = ops.cast(
302+
ops.in_top_k(ops.cast(y_true, "int32"), y_pred, k=k),
303+
dtype=backend.floatx(),
304+
)
305+
306+
# returned matches is expected to have same shape as y_true input
307+
if reshape_matches:
308+
matches = ops.reshape(matches, new_shape=y_true_org_shape)
309+
310+
return matches
311+
312+
313+
@keras_core_export("keras_core.metrics.TopKCategoricalAccuracy")
314+
class TopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
315+
"""Computes how often targets are in the top `K` predictions.
316+
317+
Args:
318+
k: (Optional) Number of top elements to look at for computing accuracy.
319+
Defaults to 5.
320+
name: (Optional) string name of the metric instance.
321+
dtype: (Optional) data type of the metric result.
322+
323+
Standalone usage:
324+
325+
>>> m = keras_core.metrics.TopKCategoricalAccuracy(k=1)
326+
>>> m.update_state([[0, 0, 1], [0, 1, 0]],
327+
... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
328+
>>> m.result()
329+
0.5
330+
331+
>>> m.reset_state()
332+
>>> m.update_state([[0, 0, 1], [0, 1, 0]],
333+
... [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],
334+
... sample_weight=[0.7, 0.3])
335+
>>> m.result()
336+
0.3
337+
338+
Usage with `compile()` API:
339+
340+
```python
341+
model.compile(optimizer='sgd',
342+
loss='mse',
343+
metrics=[keras_core.metrics.TopKCategoricalAccuracy()])
344+
```
345+
"""
346+
347+
def __init__(self, k=5, name="top_k_categorical_accuracy", dtype=None):
348+
super().__init__(
349+
fn=top_k_categorical_accuracy,
350+
name=name,
351+
dtype=dtype,
352+
k=k,
353+
)
354+
self.k = k
355+
356+
def get_config(self):
357+
return {"name": self.name, "dtype": self.dtype, "k": self.k}
358+
359+
360+
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
361+
reshape_matches = False
362+
y_pred = ops.convert_to_tensor(y_pred)
363+
y_true = ops.convert_to_tensor(y_true, dtype=y_true.dtype)
364+
y_true_rank = len(y_true.shape)
365+
y_pred_rank = len(y_pred.shape)
366+
y_true_org_shape = ops.shape(y_true)
367+
368+
# Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,)
369+
if (y_true_rank is not None) and (y_pred_rank is not None):
370+
if y_pred_rank > 2:
371+
y_pred = ops.reshape(y_pred, [-1, y_pred.shape[-1]])
372+
if y_true_rank > 1:
373+
reshape_matches = True
374+
y_true = ops.reshape(y_true, [-1])
375+
376+
matches = ops.cast(
377+
ops.in_top_k(ops.cast(y_true, "int32"), y_pred, k=k),
378+
dtype=backend.floatx(),
379+
)
380+
381+
# returned matches is expected to have same shape as y_true input
382+
if reshape_matches:
383+
matches = ops.reshape(matches, new_shape=y_true_org_shape)
384+
385+
return matches
386+
387+
388+
@keras_core_export("keras_core.metrics.SparseTopKCategoricalAccuracy")
389+
class SparseTopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
390+
"""Computes how often integer targets are in the top `K` predictions.
391+
392+
Args:
393+
k: (Optional) Number of top elements to look at for computing accuracy.
394+
Defaults to 5.
395+
name: (Optional) string name of the metric instance.
396+
dtype: (Optional) data type of the metric result.
397+
398+
Standalone usage:
399+
400+
>>> m = keras_core.metrics.SparseTopKCategoricalAccuracy(k=1)
401+
>>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])
402+
>>> m.result()
403+
0.5
404+
405+
>>> m.reset_state()
406+
>>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],
407+
... sample_weight=[0.7, 0.3])
408+
>>> m.result()
409+
0.3
410+
411+
Usage with `compile()` API:
412+
413+
```python
414+
model.compile(optimizer='sgd',
415+
loss='mse',
416+
metrics=[keras_core.metrics.SparseTopKCategoricalAccuracy()])
417+
```
418+
"""
419+
420+
def __init__(
421+
self, k=5, name="sparse_top_k_categorical_accuracy", dtype=None
422+
):
423+
super().__init__(
424+
fn=sparse_top_k_categorical_accuracy,
425+
name=name,
426+
dtype=dtype,
427+
k=k,
428+
)
429+
self.k = k
430+
431+
def get_config(self):
432+
return {"name": self.name, "dtype": self.dtype, "k": self.k}

keras_core/metrics/accuracy_metrics_test.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ def test_config(self):
1010
self.assertEqual(acc_obj.name, "accuracy")
1111
self.assertEqual(len(acc_obj.variables), 2)
1212
self.assertEqual(acc_obj._dtype, "float32")
13+
14+
# Test get_config
15+
acc_obj_config = acc_obj.get_config()
16+
self.assertEqual(acc_obj_config["name"], "accuracy")
17+
self.assertEqual(acc_obj_config["dtype"], "float32")
1318
# TODO: Check save and restore config
1419

1520
def test_unweighted(self):
@@ -38,6 +43,11 @@ def test_config(self):
3843
self.assertEqual(bin_acc_obj.name, "binary_accuracy")
3944
self.assertEqual(len(bin_acc_obj.variables), 2)
4045
self.assertEqual(bin_acc_obj._dtype, "float32")
46+
47+
# Test get_config
48+
bin_acc_obj_config = bin_acc_obj.get_config()
49+
self.assertEqual(bin_acc_obj_config["name"], "binary_accuracy")
50+
self.assertEqual(bin_acc_obj_config["dtype"], "float32")
4151
# TODO: Check save and restore config
4252

4353
def test_unweighted(self):
@@ -70,6 +80,11 @@ def test_config(self):
7080
self.assertEqual(cat_acc_obj.name, "categorical_accuracy")
7181
self.assertEqual(len(cat_acc_obj.variables), 2)
7282
self.assertEqual(cat_acc_obj._dtype, "float32")
83+
84+
# Test get_config
85+
cat_acc_obj_config = cat_acc_obj.get_config()
86+
self.assertEqual(cat_acc_obj_config["name"], "categorical_accuracy")
87+
self.assertEqual(cat_acc_obj_config["dtype"], "float32")
7388
# TODO: Check save and restore config
7489

7590
def test_unweighted(self):
@@ -102,6 +117,13 @@ def test_config(self):
102117
self.assertEqual(sp_cat_acc_obj.name, "sparse_categorical_accuracy")
103118
self.assertEqual(len(sp_cat_acc_obj.variables), 2)
104119
self.assertEqual(sp_cat_acc_obj._dtype, "float32")
120+
121+
# Test get_config
122+
sp_cat_acc_obj_config = sp_cat_acc_obj.get_config()
123+
self.assertEqual(
124+
sp_cat_acc_obj_config["name"], "sparse_categorical_accuracy"
125+
)
126+
self.assertEqual(sp_cat_acc_obj_config["dtype"], "float32")
105127
# TODO: Check save and restore config
106128

107129
def test_unweighted(self):
@@ -124,3 +146,90 @@ def test_weighted(self):
124146
sp_cat_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)
125147
result = sp_cat_acc_obj.result()
126148
self.assertAllClose(result, 0.3, atol=1e-3)
149+
150+
151+
class TopKCategoricalAccuracyTest(testing.TestCase):
152+
def test_config(self):
153+
top_k_cat_acc_obj = accuracy_metrics.TopKCategoricalAccuracy(
154+
k=1, name="top_k_categorical_accuracy", dtype="float32"
155+
)
156+
self.assertEqual(top_k_cat_acc_obj.name, "top_k_categorical_accuracy")
157+
self.assertEqual(len(top_k_cat_acc_obj.variables), 2)
158+
self.assertEqual(top_k_cat_acc_obj._dtype, "float32")
159+
160+
# Test get_config
161+
top_k_cat_acc_obj_config = top_k_cat_acc_obj.get_config()
162+
self.assertEqual(
163+
top_k_cat_acc_obj_config["name"], "top_k_categorical_accuracy"
164+
)
165+
self.assertEqual(top_k_cat_acc_obj_config["dtype"], "float32")
166+
self.assertEqual(top_k_cat_acc_obj_config["k"], 1)
167+
# TODO: Check save and restore config
168+
169+
def test_unweighted(self):
170+
top_k_cat_acc_obj = accuracy_metrics.TopKCategoricalAccuracy(
171+
k=1, name="top_k_categorical_accuracy", dtype="float32"
172+
)
173+
y_true = np.array([[0, 0, 1], [0, 1, 0]])
174+
y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]], dtype="float32")
175+
top_k_cat_acc_obj.update_state(y_true, y_pred)
176+
result = top_k_cat_acc_obj.result()
177+
self.assertAllClose(result, 0.5, atol=1e-3)
178+
179+
def test_weighted(self):
180+
top_k_cat_acc_obj = accuracy_metrics.TopKCategoricalAccuracy(
181+
k=1, name="top_k_categorical_accuracy", dtype="float32"
182+
)
183+
y_true = np.array([[0, 0, 1], [0, 1, 0]])
184+
y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]], dtype="float32")
185+
sample_weight = np.array([0.7, 0.3])
186+
top_k_cat_acc_obj.update_state(
187+
y_true, y_pred, sample_weight=sample_weight
188+
)
189+
result = top_k_cat_acc_obj.result()
190+
self.assertAllClose(result, 0.3, atol=1e-3)
191+
192+
193+
class SparseTopKCategoricalAccuracyTest(testing.TestCase):
194+
def test_config(self):
195+
sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(
196+
k=1, name="sparse_top_k_categorical_accuracy", dtype="float32"
197+
)
198+
self.assertEqual(
199+
sp_top_k_cat_acc_obj.name, "sparse_top_k_categorical_accuracy"
200+
)
201+
self.assertEqual(len(sp_top_k_cat_acc_obj.variables), 2)
202+
self.assertEqual(sp_top_k_cat_acc_obj._dtype, "float32")
203+
204+
# Test get_config
205+
sp_top_k_cat_acc_obj_config = sp_top_k_cat_acc_obj.get_config()
206+
self.assertEqual(
207+
sp_top_k_cat_acc_obj_config["name"],
208+
"sparse_top_k_categorical_accuracy",
209+
)
210+
self.assertEqual(sp_top_k_cat_acc_obj_config["dtype"], "float32")
211+
self.assertEqual(sp_top_k_cat_acc_obj_config["k"], 1)
212+
# TODO: Check save and restore config
213+
214+
def test_unweighted(self):
215+
sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(
216+
k=1, name="sparse_top_k_categorical_accuracy", dtype="float32"
217+
)
218+
y_true = np.array([2, 1])
219+
y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]], dtype="float32")
220+
sp_top_k_cat_acc_obj.update_state(y_true, y_pred)
221+
result = sp_top_k_cat_acc_obj.result()
222+
self.assertAllClose(result, 0.5, atol=1e-3)
223+
224+
def test_weighted(self):
225+
sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(
226+
k=1, name="sparse_top_k_categorical_accuracy", dtype="float32"
227+
)
228+
y_true = np.array([2, 1])
229+
y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]], dtype="float32")
230+
sample_weight = np.array([0.7, 0.3])
231+
sp_top_k_cat_acc_obj.update_state(
232+
y_true, y_pred, sample_weight=sample_weight
233+
)
234+
result = sp_top_k_cat_acc_obj.result()
235+
self.assertAllClose(result, 0.3, atol=1e-3)

keras_core/operations/math.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
segment_sum
33
top_k
4+
in_top_k
45
"""
56

67
from keras_core import backend
@@ -28,3 +29,14 @@ def top_k(x, k, sorted=True):
2829
if any_symbolic_tensors((x,)):
2930
return TopK().symbolic_call(x, k, sorted)
3031
return backend.math.top_k(x, k, sorted)
32+
33+
34+
class InTopK(Operation):
35+
def call(self, targets, predictions, k):
36+
return backend.math.in_top_k(targets, predictions, k)
37+
38+
39+
def in_top_k(targets, predictions, k):
40+
if any_symbolic_tensors((targets, predictions)):
41+
return InTopK().symbolic_call(targets, predictions, k)
42+
return backend.math.in_top_k(targets, predictions, k)

keras_core/operations/nn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
conv_transpose
2424
2525
one_hot
26+
top_k
27+
in_top_k
2628
2729
ctc ??
2830
"""

0 commit comments

Comments
 (0)