Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@
from keras.src.ops.numpy import isnan as isnan
from keras.src.ops.numpy import isneginf as isneginf
from keras.src.ops.numpy import isposinf as isposinf
from keras.src.ops.numpy import isreal as isreal
from keras.src.ops.numpy import kaiser as kaiser
from keras.src.ops.numpy import kron as kron
from keras.src.ops.numpy import lcm as lcm
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
from keras.src.ops.numpy import isnan as isnan
from keras.src.ops.numpy import isneginf as isneginf
from keras.src.ops.numpy import isposinf as isposinf
from keras.src.ops.numpy import isreal as isreal
from keras.src.ops.numpy import kaiser as kaiser
from keras.src.ops.numpy import kron as kron
from keras.src.ops.numpy import lcm as lcm
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@
from keras.src.ops.numpy import isnan as isnan
from keras.src.ops.numpy import isneginf as isneginf
from keras.src.ops.numpy import isposinf as isposinf
from keras.src.ops.numpy import isreal as isreal
from keras.src.ops.numpy import kaiser as kaiser
from keras.src.ops.numpy import kron as kron
from keras.src.ops.numpy import lcm as lcm
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
from keras.src.ops.numpy import isnan as isnan
from keras.src.ops.numpy import isneginf as isneginf
from keras.src.ops.numpy import isposinf as isposinf
from keras.src.ops.numpy import isreal as isreal
from keras.src.ops.numpy import kaiser as kaiser
from keras.src.ops.numpy import kron as kron
from keras.src.ops.numpy import lcm as lcm
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,11 @@ def isposinf(x):
return jnp.isposinf(x)


def isreal(x):
x = convert_to_tensor(x)
return jnp.isreal(x)


def kron(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,11 @@ def isposinf(x):
return np.isposinf(x)


def isreal(x):
x = convert_to_tensor(x)
return np.isreal(x)


def kron(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ NumpyDtypeTest::test_isin
NumpyDtypeTest::test_isinf
NumpyDtypeTest::test_isnan
NumpyDtypeTest::test_isposinf
NumpyDtypeTest::test_isreal
NumpyDtypeTest::test_kron
NumpyDtypeTest::test_lcm
NumpyDtypeTest::test_logaddexp2
Expand Down Expand Up @@ -92,6 +93,7 @@ NumpyOneInputOpsCorrectnessTest::test_imag
NumpyOneInputOpsCorrectnessTest::test_isfinite
NumpyOneInputOpsCorrectnessTest::test_isinf
NumpyOneInputOpsCorrectnessTest::test_isposinf
NumpyOneInputOpsCorrectnessTest::test_isreal
NumpyOneInputOpsCorrectnessTest::test_logaddexp2
NumpyOneInputOpsCorrectnessTest::test_max
NumpyOneInputOpsCorrectnessTest::test_mean
Expand Down Expand Up @@ -151,10 +153,12 @@ NumpyOneInputOpsDynamicShapeTest::test_corrcoef
NumpyOneInputOpsDynamicShapeTest::test_hamming
NumpyOneInputOpsDynamicShapeTest::test_hanning
NumpyOneInputOpsDynamicShapeTest::test_isposinf
NumpyOneInputOpsDynamicShapeTest::test_isreal
NumpyOneInputOpsDynamicShapeTest::test_kaiser
NumpyOneInputOpsStaticShapeTest::test_angle
NumpyOneInputOpsStaticShapeTest::test_cbrt
NumpyOneInputOpsStaticShapeTest::test_isposinf
NumpyOneInputOpsStaticShapeTest::test_isreal
NumpyTwoInputOpsDynamicShapeTest::test_gcd
NumpyTwoInputOpsDynamicShapeTest::test_heaviside
NumpyTwoInputOpsDynamicShapeTest::test_hypot
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,10 @@ def isposinf(x):
)


def isreal(x):
raise NotImplementedError("`isreal` is not supported with openvino backend")


def kron(x1, x2):
raise NotImplementedError("`kron` is not supported with openvino backend")

Expand Down
53 changes: 53 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1712,6 +1712,14 @@ def isposinf(x):
return tf.math.equal(x, tf.constant(float("inf"), dtype=x.dtype))


def isreal(x):
x = convert_to_tensor(x)
if x.dtype.is_complex:
return tf.equal(tf.math.imag(x), 0)
else:
return tf.ones_like(x, dtype=tf.bool)


def kron(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
Expand Down Expand Up @@ -2898,6 +2906,51 @@ def true_divide(x1, x2):
return divide(x1, x2)


def poly(x):
x = convert_to_tensor(x)
if x.dtype.is_integer or x.dtype.is_bool:
x = tf.cast(x, tf.float32)

# Get rank and shape
rank = tf.rank(x)

# Handle square 2D matrix
def matrix_case():
eigvals, _ = tf.linalg.eig(tf.cast(x, tf.complex64))
return eigvals

# Handle 1D vector
def vector_case():
return tf.reshape(x, [-1])

# Safe check: is x a square 2D matrix?
is_square_matrix = tf.logical_and(
tf.equal(rank, 2), tf.equal(tf.shape(x)[0], tf.shape(x)[1])
)

# Conditionally choose
x_vec = tf.cond(is_square_matrix, true_fn=matrix_case, false_fn=vector_case)

# If empty, return [1]
if tf.size(x_vec) == 0:
return tf.ones((1,), dtype=x_vec.dtype)

# Iteratively build polynomial coefficients via convolution
a = tf.ones((1,), dtype=x_vec.dtype)
for k in range(tf.shape(x_vec)[0]):
root = x_vec[k]
conv_kernel = tf.stack([1.0, -root], axis=0)
# 1D convolution requires 3D tensors
a = tf.nn.convolution(
tf.reshape(a, [1, -1, 1]),
tf.reshape(conv_kernel, [-1, 1, 1]),
padding="VALID",
)
a = tf.reshape(a, [-1])

return a


def power(x1, x2):
if not isinstance(x1, (int, float)):
x1 = convert_to_tensor(x1)
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,11 @@ def isposinf(x):
return torch.isposinf(x)


def isreal(x):
x = convert_to_tensor(x)
return torch.isreal(x)


def kron(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
Expand Down
23 changes: 23 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3846,6 +3846,29 @@ def isposinf(x):
return backend.numpy.isposinf(x)


class Isreal(Operation):
def call(self, x):
return backend.numpy.isreal(x)

def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype="bool")


@keras_export(["keras.ops.isreal", "keras.ops.numpy.isreal"])
def isreal(x):
"""Test element-wise for real numbers.

Args:
x: Input tensor.

Returns:
Output boolean tensor.
"""
if any_symbolic_tensors((x,)):
return Isreal().symbolic_call(x)
return backend.numpy.isreal(x)


class Kron(Operation):
def call(self, x1, x2):
return backend.numpy.kron(x1, x2)
Expand Down
27 changes: 25 additions & 2 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,10 @@ def test_isposinf(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.isposinf(x).shape, (None, 3))

def test_isreal(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.isreal(x).shape, (None, 3))

def test_log(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.log(x).shape, (None, 3))
Expand Down Expand Up @@ -2189,9 +2193,9 @@ def test_isneginf(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.isneginf(x).shape, (2, 3))

def test_isposinf(self):
def test_isreal(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.isposinf(x).shape, (2, 3))
self.assertEqual(knp.isreal(x).shape, (2, 3))

def test_log(self):
x = KerasTensor((2, 3))
Expand Down Expand Up @@ -4389,6 +4393,11 @@ def test_isposinf(self):
self.assertAllClose(knp.isposinf(x), np.isposinf(x))
self.assertAllClose(knp.Isposinf()(x), np.isposinf(x))

def test_isreal(self):
x = np.array([1 + 1j, 1 + 0j, 4.5, 3, 2, 2j], dtype=complex)
self.assertAllClose(knp.isreal(x), np.isreal(x))
self.assertAllClose(knp.Isreal()(x), np.isreal(x))

def test_log(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(knp.log(x), np.log(x))
Expand Down Expand Up @@ -7766,6 +7775,20 @@ def test_isposinf(self, dtype):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_isreal(self, dtype):
import jax.numpy as jnp

x = knp.ones((), dtype=dtype)
x_jax = jnp.ones((), dtype=dtype)
expected_dtype = standardize_dtype(jnp.isreal(x_jax).dtype)

self.assertEqual(standardize_dtype(knp.isreal(x).dtype), expected_dtype)
self.assertEqual(
standardize_dtype(knp.Isreal().symbolic_call(x).dtype),
expected_dtype,
)

@parameterized.named_parameters(
named_product(dtypes=itertools.combinations(INT_DTYPES, 2))
)
Expand Down