From 8d5a0bb9f2cc461a3d1267575c9e4485f8f16d71 Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Mon, 13 Oct 2025 20:27:29 +0900 Subject: [PATCH 1/8] Add initial version --- keras/src/backend/jax/numpy.py | 5 +++ keras/src/backend/numpy/numpy.py | 5 +++ keras/src/backend/tensorflow/numpy.py | 45 +++++++++++++++++++++++++++ keras/src/backend/torch/numpy.py | 5 +++ 4 files changed, 60 insertions(+) diff --git a/keras/src/backend/jax/numpy.py b/keras/src/backend/jax/numpy.py index e9def4b255c9..e7fce13382fa 100644 --- a/keras/src/backend/jax/numpy.py +++ b/keras/src/backend/jax/numpy.py @@ -819,6 +819,11 @@ def isposinf(x): return jnp.isposinf(x) +def isreal(x): + x = convert_to_tensor(x) + return jnp.isreal(x) + + def kron(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) diff --git a/keras/src/backend/numpy/numpy.py b/keras/src/backend/numpy/numpy.py index d8d4b8930341..f37aac255aeb 100644 --- a/keras/src/backend/numpy/numpy.py +++ b/keras/src/backend/numpy/numpy.py @@ -745,6 +745,11 @@ def isposinf(x): return np.isposinf(x) +def isreal(x): + x = convert_to_tensor(x) + return np.isreal(x) + + def kron(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index ff146a41253b..eafa666d1e97 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -2898,6 +2898,51 @@ def true_divide(x1, x2): return divide(x1, x2) +def poly(x): + x = convert_to_tensor(x) + if x.dtype.is_integer or x.dtype.is_bool: + x = tf.cast(x, tf.float32) + + # Get rank and shape + rank = tf.rank(x) + + # Handle square 2D matrix + def matrix_case(): + eigvals, _ = tf.linalg.eig(tf.cast(x, tf.complex64)) + return eigvals + + # Handle 1D vector + def vector_case(): + return tf.reshape(x, [-1]) + + # Safe check: is x a square 2D matrix? + is_square_matrix = tf.logical_and( + tf.equal(rank, 2), tf.equal(tf.shape(x)[0], tf.shape(x)[1]) + ) + + # Conditionally choose + x_vec = tf.cond(is_square_matrix, true_fn=matrix_case, false_fn=vector_case) + + # If empty, return [1] + if tf.size(x_vec) == 0: + return tf.ones((1,), dtype=x_vec.dtype) + + # Iteratively build polynomial coefficients via convolution + a = tf.ones((1,), dtype=x_vec.dtype) + for k in range(tf.shape(x_vec)[0]): + root = x_vec[k] + conv_kernel = tf.stack([1.0, -root], axis=0) + # 1D convolution requires 3D tensors + a = tf.nn.convolution( + tf.reshape(a, [1, -1, 1]), + tf.reshape(conv_kernel, [-1, 1, 1]), + padding="VALID", + ) + a = tf.reshape(a, [-1]) + + return a + + def power(x1, x2): if not isinstance(x1, (int, float)): x1 = convert_to_tensor(x1) diff --git a/keras/src/backend/torch/numpy.py b/keras/src/backend/torch/numpy.py index 553faea4fd40..fa7ee53e3bd4 100644 --- a/keras/src/backend/torch/numpy.py +++ b/keras/src/backend/torch/numpy.py @@ -946,6 +946,11 @@ def isposinf(x): return torch.isposinf(x) +def isreal(x): + x = convert_to_tensor(x) + return torch.isreal(x) + + def kron(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) From f94b4e74c9e862bc3e27de00fbf426ced7141b8b Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Mon, 13 Oct 2025 20:30:08 +0900 Subject: [PATCH 2/8] Add tensorflow version --- keras/src/backend/tensorflow/numpy.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index eafa666d1e97..ffe43ce9d692 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -1712,6 +1712,14 @@ def isposinf(x): return tf.math.equal(x, tf.constant(float("inf"), dtype=x.dtype)) +def isreal(x): + x = convert_to_tensor(x) + if x.dtype.is_complex: + return tf.equal(tf.math.imag(x), 0) + else: + return tf.ones_like(x, dtype=tf.bool) + + def kron(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) From 5e3d2701e2d4de8cc0b7f06cae5628191e8045af Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Mon, 13 Oct 2025 20:32:22 +0900 Subject: [PATCH 3/8] Update numpy.py for ops --- keras/api/_tf_keras/keras/ops/__init__.py | 1 + .../api/_tf_keras/keras/ops/numpy/__init__.py | 1 + keras/api/ops/__init__.py | 1 + keras/api/ops/numpy/__init__.py | 1 + keras/src/ops/numpy.py | 23 +++++++++++++++++++ 5 files changed, 27 insertions(+) diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 2194c975b89f..5e19b0654228 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -209,6 +209,7 @@ from keras.src.ops.numpy import isnan as isnan from keras.src.ops.numpy import isneginf as isneginf from keras.src.ops.numpy import isposinf as isposinf +from keras.src.ops.numpy import isreal as isreal from keras.src.ops.numpy import kaiser as kaiser from keras.src.ops.numpy import kron as kron from keras.src.ops.numpy import lcm as lcm diff --git a/keras/api/_tf_keras/keras/ops/numpy/__init__.py b/keras/api/_tf_keras/keras/ops/numpy/__init__.py index ebeb384c181c..d8a24a28873b 100644 --- a/keras/api/_tf_keras/keras/ops/numpy/__init__.py +++ b/keras/api/_tf_keras/keras/ops/numpy/__init__.py @@ -95,6 +95,7 @@ from keras.src.ops.numpy import isnan as isnan from keras.src.ops.numpy import isneginf as isneginf from keras.src.ops.numpy import isposinf as isposinf +from keras.src.ops.numpy import isreal as isreal from keras.src.ops.numpy import kaiser as kaiser from keras.src.ops.numpy import kron as kron from keras.src.ops.numpy import lcm as lcm diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 2194c975b89f..5e19b0654228 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -209,6 +209,7 @@ from keras.src.ops.numpy import isnan as isnan from keras.src.ops.numpy import isneginf as isneginf from keras.src.ops.numpy import isposinf as isposinf +from keras.src.ops.numpy import isreal as isreal from keras.src.ops.numpy import kaiser as kaiser from keras.src.ops.numpy import kron as kron from keras.src.ops.numpy import lcm as lcm diff --git a/keras/api/ops/numpy/__init__.py b/keras/api/ops/numpy/__init__.py index ebeb384c181c..d8a24a28873b 100644 --- a/keras/api/ops/numpy/__init__.py +++ b/keras/api/ops/numpy/__init__.py @@ -95,6 +95,7 @@ from keras.src.ops.numpy import isnan as isnan from keras.src.ops.numpy import isneginf as isneginf from keras.src.ops.numpy import isposinf as isposinf +from keras.src.ops.numpy import isreal as isreal from keras.src.ops.numpy import kaiser as kaiser from keras.src.ops.numpy import kron as kron from keras.src.ops.numpy import lcm as lcm diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index cbc07c9c3e3c..cb6845bbfb00 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -3846,6 +3846,29 @@ def isposinf(x): return backend.numpy.isposinf(x) +class Isreal(Operation): + def call(self, x): + return backend.numpy.isreal(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype="bool") + + +@keras_export(["keras.ops.isreal", "keras.ops.numpy.isreal"]) +def isreal(x): + """Test element-wise for real numbers. + + Args: + x: Input tensor. + + Returns: + Output boolean tensor. + """ + if any_symbolic_tensors((x,)): + return Isreal().symbolic_call(x) + return backend.numpy.isreal(x) + + class Kron(Operation): def call(self, x1, x2): return backend.numpy.kron(x1, x2) From d3bf402aa367293ce12bc4e1274b55ec7613e9d2 Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Mon, 13 Oct 2025 20:39:30 +0900 Subject: [PATCH 4/8] Update numpy_test.py --- keras/src/ops/numpy_test.py | 27 +++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 998d18bd4b73..079802576bff 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -1588,6 +1588,10 @@ def test_isposinf(self): x = KerasTensor((None, 3)) self.assertEqual(knp.isposinf(x).shape, (None, 3)) + def test_isreal(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.isreal(x).shape, (None, 3)) + def test_log(self): x = KerasTensor((None, 3)) self.assertEqual(knp.log(x).shape, (None, 3)) @@ -2189,9 +2193,9 @@ def test_isneginf(self): x = KerasTensor((2, 3)) self.assertEqual(knp.isneginf(x).shape, (2, 3)) - def test_isposinf(self): + def test_isreal(self): x = KerasTensor((2, 3)) - self.assertEqual(knp.isposinf(x).shape, (2, 3)) + self.assertEqual(knp.isreal(x).shape, (2, 3)) def test_log(self): x = KerasTensor((2, 3)) @@ -4389,6 +4393,11 @@ def test_isposinf(self): self.assertAllClose(knp.isposinf(x), np.isposinf(x)) self.assertAllClose(knp.Isposinf()(x), np.isposinf(x)) + def test_isreal(self): + x = np.array([1 + 1j, 1 + 0j, 4.5, 3, 2, 2j], dtype=complex) + self.assertAllClose(knp.isreal(x), np.isreal(x)) + self.assertAllClose(knp.Isreal()(x), np.isreal(x)) + def test_log(self): x = np.array([[1, 2, 3], [3, 2, 1]]) self.assertAllClose(knp.log(x), np.log(x)) @@ -7766,6 +7775,20 @@ def test_isposinf(self, dtype): expected_dtype, ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_isreal(self, dtype): + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.isreal(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.isreal(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Isreal().symbolic_call(x).dtype), + expected_dtype, + ) + @parameterized.named_parameters( named_product(dtypes=itertools.combinations(INT_DTYPES, 2)) ) From 0ba53b79071677bc8b3e5649e472ea298e4a22ff Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Mon, 13 Oct 2025 20:45:40 +0900 Subject: [PATCH 5/8] Add method for openvino --- keras/src/backend/openvino/excluded_concrete_tests.txt | 4 ++++ keras/src/backend/openvino/numpy.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/keras/src/backend/openvino/excluded_concrete_tests.txt b/keras/src/backend/openvino/excluded_concrete_tests.txt index 13bae27343d5..df9470f9977a 100644 --- a/keras/src/backend/openvino/excluded_concrete_tests.txt +++ b/keras/src/backend/openvino/excluded_concrete_tests.txt @@ -37,6 +37,7 @@ NumpyDtypeTest::test_isin NumpyDtypeTest::test_isinf NumpyDtypeTest::test_isnan NumpyDtypeTest::test_isposinf +NumpyDtypeTest::test_isreal NumpyDtypeTest::test_kron NumpyDtypeTest::test_lcm NumpyDtypeTest::test_logaddexp2 @@ -92,6 +93,7 @@ NumpyOneInputOpsCorrectnessTest::test_imag NumpyOneInputOpsCorrectnessTest::test_isfinite NumpyOneInputOpsCorrectnessTest::test_isinf NumpyOneInputOpsCorrectnessTest::test_isposinf +NumpyOneInputOpsCorrectnessTest::test_isreal NumpyOneInputOpsCorrectnessTest::test_logaddexp2 NumpyOneInputOpsCorrectnessTest::test_max NumpyOneInputOpsCorrectnessTest::test_mean @@ -151,10 +153,12 @@ NumpyOneInputOpsDynamicShapeTest::test_corrcoef NumpyOneInputOpsDynamicShapeTest::test_hamming NumpyOneInputOpsDynamicShapeTest::test_hanning NumpyOneInputOpsDynamicShapeTest::test_isposinf +NumpyOneInputOpsDynamicShapeTest::test_isreal NumpyOneInputOpsDynamicShapeTest::test_kaiser NumpyOneInputOpsStaticShapeTest::test_angle NumpyOneInputOpsStaticShapeTest::test_cbrt NumpyOneInputOpsStaticShapeTest::test_isposinf +NumpyOneInputOpsStaticShapeTest::test_isreal NumpyTwoInputOpsDynamicShapeTest::test_gcd NumpyTwoInputOpsDynamicShapeTest::test_heaviside NumpyTwoInputOpsDynamicShapeTest::test_hypot diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index c750d409d4a0..d95e0e9f4f6c 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -1003,6 +1003,10 @@ def isposinf(x): ) +def isreal(x): + raise NotImplementedError("`isreal` is not supported with openvino backend") + + def kron(x1, x2): raise NotImplementedError("`kron` is not supported with openvino backend") From 0f9ef917a19903c0f590ca981a768cb03cdd578c Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Mon, 13 Oct 2025 20:47:12 +0900 Subject: [PATCH 6/8] clean code --- keras/src/backend/tensorflow/numpy.py | 45 --------------------------- 1 file changed, 45 deletions(-) diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index ffe43ce9d692..9cda2d0d8679 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -2906,51 +2906,6 @@ def true_divide(x1, x2): return divide(x1, x2) -def poly(x): - x = convert_to_tensor(x) - if x.dtype.is_integer or x.dtype.is_bool: - x = tf.cast(x, tf.float32) - - # Get rank and shape - rank = tf.rank(x) - - # Handle square 2D matrix - def matrix_case(): - eigvals, _ = tf.linalg.eig(tf.cast(x, tf.complex64)) - return eigvals - - # Handle 1D vector - def vector_case(): - return tf.reshape(x, [-1]) - - # Safe check: is x a square 2D matrix? - is_square_matrix = tf.logical_and( - tf.equal(rank, 2), tf.equal(tf.shape(x)[0], tf.shape(x)[1]) - ) - - # Conditionally choose - x_vec = tf.cond(is_square_matrix, true_fn=matrix_case, false_fn=vector_case) - - # If empty, return [1] - if tf.size(x_vec) == 0: - return tf.ones((1,), dtype=x_vec.dtype) - - # Iteratively build polynomial coefficients via convolution - a = tf.ones((1,), dtype=x_vec.dtype) - for k in range(tf.shape(x_vec)[0]): - root = x_vec[k] - conv_kernel = tf.stack([1.0, -root], axis=0) - # 1D convolution requires 3D tensors - a = tf.nn.convolution( - tf.reshape(a, [1, -1, 1]), - tf.reshape(conv_kernel, [-1, 1, 1]), - padding="VALID", - ) - a = tf.reshape(a, [-1]) - - return a - - def power(x1, x2): if not isinstance(x1, (int, float)): x1 = convert_to_tensor(x1) From 03dabf16a76ba9e376295906a0c277e9336aca50 Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Mon, 13 Oct 2025 20:51:31 +0900 Subject: [PATCH 7/8] update code by gemini review --- keras/src/ops/numpy.py | 6 ++++++ keras/src/ops/numpy_test.py | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index cb6845bbfb00..2569492cef7f 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -3863,6 +3863,12 @@ def isreal(x): Returns: Output boolean tensor. + + Example: + >>> from keras import ops + >>> x = ops.array([1+1j, 1+0j, 4.5, 3, 2, 2j], dtype="complex64") + >>> ops.isreal(x) + array([False, True, True, True, True, False]) """ if any_symbolic_tensors((x,)): return Isreal().symbolic_call(x) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 079802576bff..6832ae8106e1 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -2193,6 +2193,10 @@ def test_isneginf(self): x = KerasTensor((2, 3)) self.assertEqual(knp.isneginf(x).shape, (2, 3)) + def test_isposinf(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.isposinf(x).shape, (2, 3)) + def test_isreal(self): x = KerasTensor((2, 3)) self.assertEqual(knp.isreal(x).shape, (2, 3)) From 0e5ce378d7de816774f6635774612b295e03872d Mon Sep 17 00:00:00 2001 From: ugeunpark Date: Wed, 15 Oct 2025 18:47:36 +0900 Subject: [PATCH 8/8] update test case for non-complex type --- keras/src/ops/numpy_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 6832ae8106e1..413607d16b5b 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -4402,6 +4402,10 @@ def test_isreal(self): self.assertAllClose(knp.isreal(x), np.isreal(x)) self.assertAllClose(knp.Isreal()(x), np.isreal(x)) + x = np.array([1.0, 2.0, 3.0]) + self.assertAllClose(knp.isreal(x), np.isreal(x)) + self.assertAllClose(knp.Isreal()(x), np.isreal(x)) + def test_log(self): x = np.array([[1, 2, 3], [3, 2, 1]]) self.assertAllClose(knp.log(x), np.log(x))