diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 2194c975b89f..5e19b0654228 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -209,6 +209,7 @@ from keras.src.ops.numpy import isnan as isnan from keras.src.ops.numpy import isneginf as isneginf from keras.src.ops.numpy import isposinf as isposinf +from keras.src.ops.numpy import isreal as isreal from keras.src.ops.numpy import kaiser as kaiser from keras.src.ops.numpy import kron as kron from keras.src.ops.numpy import lcm as lcm diff --git a/keras/api/_tf_keras/keras/ops/numpy/__init__.py b/keras/api/_tf_keras/keras/ops/numpy/__init__.py index ebeb384c181c..d8a24a28873b 100644 --- a/keras/api/_tf_keras/keras/ops/numpy/__init__.py +++ b/keras/api/_tf_keras/keras/ops/numpy/__init__.py @@ -95,6 +95,7 @@ from keras.src.ops.numpy import isnan as isnan from keras.src.ops.numpy import isneginf as isneginf from keras.src.ops.numpy import isposinf as isposinf +from keras.src.ops.numpy import isreal as isreal from keras.src.ops.numpy import kaiser as kaiser from keras.src.ops.numpy import kron as kron from keras.src.ops.numpy import lcm as lcm diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 2194c975b89f..5e19b0654228 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -209,6 +209,7 @@ from keras.src.ops.numpy import isnan as isnan from keras.src.ops.numpy import isneginf as isneginf from keras.src.ops.numpy import isposinf as isposinf +from keras.src.ops.numpy import isreal as isreal from keras.src.ops.numpy import kaiser as kaiser from keras.src.ops.numpy import kron as kron from keras.src.ops.numpy import lcm as lcm diff --git a/keras/api/ops/numpy/__init__.py b/keras/api/ops/numpy/__init__.py index ebeb384c181c..d8a24a28873b 100644 --- a/keras/api/ops/numpy/__init__.py +++ b/keras/api/ops/numpy/__init__.py @@ -95,6 +95,7 @@ from keras.src.ops.numpy import isnan as isnan from keras.src.ops.numpy import isneginf as isneginf from keras.src.ops.numpy import isposinf as isposinf +from keras.src.ops.numpy import isreal as isreal from keras.src.ops.numpy import kaiser as kaiser from keras.src.ops.numpy import kron as kron from keras.src.ops.numpy import lcm as lcm diff --git a/keras/src/backend/jax/numpy.py b/keras/src/backend/jax/numpy.py index e9def4b255c9..e7fce13382fa 100644 --- a/keras/src/backend/jax/numpy.py +++ b/keras/src/backend/jax/numpy.py @@ -819,6 +819,11 @@ def isposinf(x): return jnp.isposinf(x) +def isreal(x): + x = convert_to_tensor(x) + return jnp.isreal(x) + + def kron(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) diff --git a/keras/src/backend/numpy/numpy.py b/keras/src/backend/numpy/numpy.py index d8d4b8930341..f37aac255aeb 100644 --- a/keras/src/backend/numpy/numpy.py +++ b/keras/src/backend/numpy/numpy.py @@ -745,6 +745,11 @@ def isposinf(x): return np.isposinf(x) +def isreal(x): + x = convert_to_tensor(x) + return np.isreal(x) + + def kron(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) diff --git a/keras/src/backend/openvino/excluded_concrete_tests.txt b/keras/src/backend/openvino/excluded_concrete_tests.txt index 13bae27343d5..df9470f9977a 100644 --- a/keras/src/backend/openvino/excluded_concrete_tests.txt +++ b/keras/src/backend/openvino/excluded_concrete_tests.txt @@ -37,6 +37,7 @@ NumpyDtypeTest::test_isin NumpyDtypeTest::test_isinf NumpyDtypeTest::test_isnan NumpyDtypeTest::test_isposinf +NumpyDtypeTest::test_isreal NumpyDtypeTest::test_kron NumpyDtypeTest::test_lcm NumpyDtypeTest::test_logaddexp2 @@ -92,6 +93,7 @@ NumpyOneInputOpsCorrectnessTest::test_imag NumpyOneInputOpsCorrectnessTest::test_isfinite NumpyOneInputOpsCorrectnessTest::test_isinf NumpyOneInputOpsCorrectnessTest::test_isposinf +NumpyOneInputOpsCorrectnessTest::test_isreal NumpyOneInputOpsCorrectnessTest::test_logaddexp2 NumpyOneInputOpsCorrectnessTest::test_max NumpyOneInputOpsCorrectnessTest::test_mean @@ -151,10 +153,12 @@ NumpyOneInputOpsDynamicShapeTest::test_corrcoef NumpyOneInputOpsDynamicShapeTest::test_hamming NumpyOneInputOpsDynamicShapeTest::test_hanning NumpyOneInputOpsDynamicShapeTest::test_isposinf +NumpyOneInputOpsDynamicShapeTest::test_isreal NumpyOneInputOpsDynamicShapeTest::test_kaiser NumpyOneInputOpsStaticShapeTest::test_angle NumpyOneInputOpsStaticShapeTest::test_cbrt NumpyOneInputOpsStaticShapeTest::test_isposinf +NumpyOneInputOpsStaticShapeTest::test_isreal NumpyTwoInputOpsDynamicShapeTest::test_gcd NumpyTwoInputOpsDynamicShapeTest::test_heaviside NumpyTwoInputOpsDynamicShapeTest::test_hypot diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index c750d409d4a0..d95e0e9f4f6c 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -1003,6 +1003,10 @@ def isposinf(x): ) +def isreal(x): + raise NotImplementedError("`isreal` is not supported with openvino backend") + + def kron(x1, x2): raise NotImplementedError("`kron` is not supported with openvino backend") diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index ff146a41253b..9cda2d0d8679 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -1712,6 +1712,14 @@ def isposinf(x): return tf.math.equal(x, tf.constant(float("inf"), dtype=x.dtype)) +def isreal(x): + x = convert_to_tensor(x) + if x.dtype.is_complex: + return tf.equal(tf.math.imag(x), 0) + else: + return tf.ones_like(x, dtype=tf.bool) + + def kron(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) diff --git a/keras/src/backend/torch/numpy.py b/keras/src/backend/torch/numpy.py index 553faea4fd40..fa7ee53e3bd4 100644 --- a/keras/src/backend/torch/numpy.py +++ b/keras/src/backend/torch/numpy.py @@ -946,6 +946,11 @@ def isposinf(x): return torch.isposinf(x) +def isreal(x): + x = convert_to_tensor(x) + return torch.isreal(x) + + def kron(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index cbc07c9c3e3c..2569492cef7f 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -3846,6 +3846,35 @@ def isposinf(x): return backend.numpy.isposinf(x) +class Isreal(Operation): + def call(self, x): + return backend.numpy.isreal(x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype="bool") + + +@keras_export(["keras.ops.isreal", "keras.ops.numpy.isreal"]) +def isreal(x): + """Test element-wise for real numbers. + + Args: + x: Input tensor. + + Returns: + Output boolean tensor. + + Example: + >>> from keras import ops + >>> x = ops.array([1+1j, 1+0j, 4.5, 3, 2, 2j], dtype="complex64") + >>> ops.isreal(x) + array([False, True, True, True, True, False]) + """ + if any_symbolic_tensors((x,)): + return Isreal().symbolic_call(x) + return backend.numpy.isreal(x) + + class Kron(Operation): def call(self, x1, x2): return backend.numpy.kron(x1, x2) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 998d18bd4b73..413607d16b5b 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -1588,6 +1588,10 @@ def test_isposinf(self): x = KerasTensor((None, 3)) self.assertEqual(knp.isposinf(x).shape, (None, 3)) + def test_isreal(self): + x = KerasTensor((None, 3)) + self.assertEqual(knp.isreal(x).shape, (None, 3)) + def test_log(self): x = KerasTensor((None, 3)) self.assertEqual(knp.log(x).shape, (None, 3)) @@ -2193,6 +2197,10 @@ def test_isposinf(self): x = KerasTensor((2, 3)) self.assertEqual(knp.isposinf(x).shape, (2, 3)) + def test_isreal(self): + x = KerasTensor((2, 3)) + self.assertEqual(knp.isreal(x).shape, (2, 3)) + def test_log(self): x = KerasTensor((2, 3)) self.assertEqual(knp.log(x).shape, (2, 3)) @@ -4389,6 +4397,15 @@ def test_isposinf(self): self.assertAllClose(knp.isposinf(x), np.isposinf(x)) self.assertAllClose(knp.Isposinf()(x), np.isposinf(x)) + def test_isreal(self): + x = np.array([1 + 1j, 1 + 0j, 4.5, 3, 2, 2j], dtype=complex) + self.assertAllClose(knp.isreal(x), np.isreal(x)) + self.assertAllClose(knp.Isreal()(x), np.isreal(x)) + + x = np.array([1.0, 2.0, 3.0]) + self.assertAllClose(knp.isreal(x), np.isreal(x)) + self.assertAllClose(knp.Isreal()(x), np.isreal(x)) + def test_log(self): x = np.array([[1, 2, 3], [3, 2, 1]]) self.assertAllClose(knp.log(x), np.log(x)) @@ -7766,6 +7783,20 @@ def test_isposinf(self, dtype): expected_dtype, ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_isreal(self, dtype): + import jax.numpy as jnp + + x = knp.ones((), dtype=dtype) + x_jax = jnp.ones((), dtype=dtype) + expected_dtype = standardize_dtype(jnp.isreal(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.isreal(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Isreal().symbolic_call(x).dtype), + expected_dtype, + ) + @parameterized.named_parameters( named_product(dtypes=itertools.combinations(INT_DTYPES, 2)) )