From 912efb1dd8246dd473940b006e46f01a8e13798b Mon Sep 17 00:00:00 2001 From: "daniel.cahall" Date: Sun, 12 Oct 2025 22:13:45 -0400 Subject: [PATCH 1/8] ensure eye behavior is consistent across backends --- keras/src/ops/numpy.py | 7 +++++++ keras/src/ops/numpy_test.py | 9 +++++++++ 2 files changed, 16 insertions(+) diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index cbc07c9c3e3c..5c7434eff135 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -7228,6 +7228,13 @@ def eye(N, M=None, k=0, dtype=None): Returns: Tensor with ones on the k-th diagonal and zeros elsewhere. """ + def is_float(v): + if isinstance(v, float) or getattr(v, "dtype", None) in ("float16", "float32", "float64"): + return True + if is_float(N): + raise ValueError("Argument `N` must be an integer or an integer tensor.") + if is_float(M): + raise ValueError("Argument `M` must be an integer, an integer tensor, or `None`.") return backend.numpy.eye(N, M=M, k=k, dtype=dtype) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 998d18bd4b73..47c4f00e115a 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -5230,6 +5230,15 @@ def test_eye(self): # Test k < 0 and M < N and M - k > N self.assertAllClose(knp.eye(4, 3, k=-2), np.eye(4, 3, k=-2)) + def test_eye_raises_error_with_floats(self): + with self.assertRaises(ValueError): + knp.eye(3.0) + with self.assertRaises(ValueError): + knp.eye(3.0, 2.0) + with self.assertRaises(ValueError): + v = knp.max(knp.arange(4.0)) + knp.eye(v) + def test_arange(self): self.assertAllClose(knp.arange(3), np.arange(3)) self.assertAllClose(knp.arange(3, 7), np.arange(3, 7)) From edf3612b3c062720b019a1056d2e47fcb1e0a42f Mon Sep 17 00:00:00 2001 From: Danny <33044223+danielenricocahall@users.noreply.github.com> Date: Sun, 12 Oct 2025 22:24:35 -0400 Subject: [PATCH 2/8] Update keras/src/ops/numpy_test.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras/src/ops/numpy_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 47c4f00e115a..233d9f682791 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -5238,6 +5238,9 @@ def test_eye_raises_error_with_floats(self): with self.assertRaises(ValueError): v = knp.max(knp.arange(4.0)) knp.eye(v) + if backend.backend() != "numpy": + with self.assertRaises(ValueError): + knp.eye(knp.array(3, dtype="bfloat16")) def test_arange(self): self.assertAllClose(knp.arange(3), np.arange(3)) From d8eac1281e170baaf7c79f3aa8a4285603ab523d Mon Sep 17 00:00:00 2001 From: "daniel.cahall" Date: Sun, 12 Oct 2025 22:25:46 -0400 Subject: [PATCH 3/8] simplify per pr review --- keras/src/ops/numpy.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index 5c7434eff135..3d418dd32e86 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -7229,8 +7229,7 @@ def eye(N, M=None, k=0, dtype=None): Tensor with ones on the k-th diagonal and zeros elsewhere. """ def is_float(v): - if isinstance(v, float) or getattr(v, "dtype", None) in ("float16", "float32", "float64"): - return True + return isinstance(v, float) or getattr(v, "dtype", None) in dtypes.FLOAT_TYPES if is_float(N): raise ValueError("Argument `N` must be an integer or an integer tensor.") if is_float(M): From d4369f45dd55c45406bf1416e99b4955741bc4d0 Mon Sep 17 00:00:00 2001 From: "daniel.cahall" Date: Sun, 12 Oct 2025 22:50:57 -0400 Subject: [PATCH 4/8] pre-commit --- keras/src/ops/numpy.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index 3d418dd32e86..9906a0a28eb7 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -7228,12 +7228,21 @@ def eye(N, M=None, k=0, dtype=None): Returns: Tensor with ones on the k-th diagonal and zeros elsewhere. """ + def is_float(v): - return isinstance(v, float) or getattr(v, "dtype", None) in dtypes.FLOAT_TYPES + return ( + isinstance(v, float) + or getattr(v, "dtype", None) in dtypes.FLOAT_TYPES + ) + if is_float(N): - raise ValueError("Argument `N` must be an integer or an integer tensor.") + raise ValueError( + "Argument `N` must be an integer or an integer tensor." + ) if is_float(M): - raise ValueError("Argument `M` must be an integer, an integer tensor, or `None`.") + raise ValueError( + "Argument `M` must be an integer, an integer tensor, or `None`." + ) return backend.numpy.eye(N, M=M, k=k, dtype=dtype) From e4922c1b6c8792bd5795f119f66a7a4ea62f44f1 Mon Sep 17 00:00:00 2001 From: "daniel.cahall" Date: Sun, 12 Oct 2025 23:03:19 -0400 Subject: [PATCH 5/8] fix test for torch backend + add comments --- keras/src/ops/numpy_test.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 233d9f682791..33541ae13510 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -5235,10 +5235,16 @@ def test_eye_raises_error_with_floats(self): knp.eye(3.0) with self.assertRaises(ValueError): knp.eye(3.0, 2.0) - with self.assertRaises(ValueError): - v = knp.max(knp.arange(4.0)) - knp.eye(v) - if backend.backend() != "numpy": + + # Note: Torch raises a TypeError here, as it does not permit Tensor args + # with torch.eye. However, np.eye and tf.eye do support these, and + # per the thread in https://github.com/keras-team/keras/issues/20616, + # we will use np.eye as the guide + if backend.backend() != "torch": + with self.assertRaises((ValueError, TypeError)): + v = knp.max(knp.arange(4.0)) + knp.eye(v) + if backend.backend() not in ("numpy", "torch"): with self.assertRaises(ValueError): knp.eye(knp.array(3, dtype="bfloat16")) From afab04113491924918f9a04b2476a8b40e560377 Mon Sep 17 00:00:00 2001 From: "daniel.cahall" Date: Mon, 13 Oct 2025 07:51:28 -0400 Subject: [PATCH 6/8] update implementation to raise TypeError for consistency --- keras/src/ops/numpy.py | 6 ++---- keras/src/ops/numpy_test.py | 21 +++++++-------------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index 9906a0a28eb7..2b8c446ccb9b 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -7236,11 +7236,9 @@ def is_float(v): ) if is_float(N): - raise ValueError( - "Argument `N` must be an integer or an integer tensor." - ) + raise TypeError("Argument `N` must be an integer or an integer tensor.") if is_float(M): - raise ValueError( + raise TypeError( "Argument `M` must be an integer, an integer tensor, or `None`." ) return backend.numpy.eye(N, M=M, k=k, dtype=dtype) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 33541ae13510..e23e4bdc4054 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -5231,22 +5231,15 @@ def test_eye(self): self.assertAllClose(knp.eye(4, 3, k=-2), np.eye(4, 3, k=-2)) def test_eye_raises_error_with_floats(self): - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): knp.eye(3.0) - with self.assertRaises(ValueError): + with self.assertRaises(TypeError): knp.eye(3.0, 2.0) - - # Note: Torch raises a TypeError here, as it does not permit Tensor args - # with torch.eye. However, np.eye and tf.eye do support these, and - # per the thread in https://github.com/keras-team/keras/issues/20616, - # we will use np.eye as the guide - if backend.backend() != "torch": - with self.assertRaises((ValueError, TypeError)): - v = knp.max(knp.arange(4.0)) - knp.eye(v) - if backend.backend() not in ("numpy", "torch"): - with self.assertRaises(ValueError): - knp.eye(knp.array(3, dtype="bfloat16")) + with self.assertRaises(TypeError): + v = knp.max(knp.arange(4.0)) + knp.eye(v) + with self.assertRaises(TypeError): + knp.eye(knp.array(3, dtype="bfloat16")) def test_arange(self): self.assertAllClose(knp.arange(3), np.arange(3)) From c7453bdd07e140beb0870f7a282ad87d6abed108 Mon Sep 17 00:00:00 2001 From: "daniel.cahall" Date: Mon, 13 Oct 2025 07:54:24 -0400 Subject: [PATCH 7/8] add case for M being the onl float --- keras/src/ops/numpy_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index e23e4bdc4054..dee805a85361 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -5235,6 +5235,8 @@ def test_eye_raises_error_with_floats(self): knp.eye(3.0) with self.assertRaises(TypeError): knp.eye(3.0, 2.0) + with self.assertRaises(TypeError): + knp.eye(3, 2.0) with self.assertRaises(TypeError): v = knp.max(knp.arange(4.0)) knp.eye(v) From 392d05cb637a32115d3211675989334bcfd95f7f Mon Sep 17 00:00:00 2001 From: "daniel.cahall" Date: Mon, 13 Oct 2025 07:55:03 -0400 Subject: [PATCH 8/8] improve naming of inner function for type check --- keras/src/ops/numpy.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index 2b8c446ccb9b..f732c1fcb161 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -7229,15 +7229,15 @@ def eye(N, M=None, k=0, dtype=None): Tensor with ones on the k-th diagonal and zeros elsewhere. """ - def is_float(v): + def is_floating_type(v): return ( isinstance(v, float) or getattr(v, "dtype", None) in dtypes.FLOAT_TYPES ) - if is_float(N): + if is_floating_type(N): raise TypeError("Argument `N` must be an integer or an integer tensor.") - if is_float(M): + if is_floating_type(M): raise TypeError( "Argument `M` must be an integer, an integer tensor, or `None`." )