diff --git a/keras/api/_tf_keras/keras/ops/__init__.py b/keras/api/_tf_keras/keras/ops/__init__.py index 01e7d9f806b3..e78c48a785e4 100644 --- a/keras/api/_tf_keras/keras/ops/__init__.py +++ b/keras/api/_tf_keras/keras/ops/__init__.py @@ -296,6 +296,7 @@ from keras.src.ops.numpy import true_divide as true_divide from keras.src.ops.numpy import trunc as trunc from keras.src.ops.numpy import unravel_index as unravel_index +from keras.src.ops.numpy import vander as vander from keras.src.ops.numpy import var as var from keras.src.ops.numpy import vdot as vdot from keras.src.ops.numpy import vectorize as vectorize diff --git a/keras/api/_tf_keras/keras/ops/numpy/__init__.py b/keras/api/_tf_keras/keras/ops/numpy/__init__.py index ad912016ee44..ab845b4d6074 100644 --- a/keras/api/_tf_keras/keras/ops/numpy/__init__.py +++ b/keras/api/_tf_keras/keras/ops/numpy/__init__.py @@ -182,6 +182,7 @@ from keras.src.ops.numpy import true_divide as true_divide from keras.src.ops.numpy import trunc as trunc from keras.src.ops.numpy import unravel_index as unravel_index +from keras.src.ops.numpy import vander as vander from keras.src.ops.numpy import var as var from keras.src.ops.numpy import vdot as vdot from keras.src.ops.numpy import vectorize as vectorize diff --git a/keras/api/ops/__init__.py b/keras/api/ops/__init__.py index 01e7d9f806b3..e78c48a785e4 100644 --- a/keras/api/ops/__init__.py +++ b/keras/api/ops/__init__.py @@ -296,6 +296,7 @@ from keras.src.ops.numpy import true_divide as true_divide from keras.src.ops.numpy import trunc as trunc from keras.src.ops.numpy import unravel_index as unravel_index +from keras.src.ops.numpy import vander as vander from keras.src.ops.numpy import var as var from keras.src.ops.numpy import vdot as vdot from keras.src.ops.numpy import vectorize as vectorize diff --git a/keras/api/ops/numpy/__init__.py b/keras/api/ops/numpy/__init__.py index ad912016ee44..ab845b4d6074 100644 --- a/keras/api/ops/numpy/__init__.py +++ b/keras/api/ops/numpy/__init__.py @@ -182,6 +182,7 @@ from keras.src.ops.numpy import true_divide as true_divide from keras.src.ops.numpy import trunc as trunc from keras.src.ops.numpy import unravel_index as unravel_index +from keras.src.ops.numpy import vander as vander from keras.src.ops.numpy import var as var from keras.src.ops.numpy import vdot as vdot from keras.src.ops.numpy import vectorize as vectorize diff --git a/keras/src/backend/jax/numpy.py b/keras/src/backend/jax/numpy.py index 24eea3b402a7..774ab64c9bb4 100644 --- a/keras/src/backend/jax/numpy.py +++ b/keras/src/backend/jax/numpy.py @@ -1398,6 +1398,11 @@ def trapezoid(y, x=None, dx=1.0, axis=-1): return jnp.trapezoid(y, x, dx=dx, axis=axis) +def vander(x, N=None, increasing=False): + x = convert_to_tensor(x) + return jnp.vander(x, N=N, increasing=increasing) + + def var(x, axis=None, keepdims=False): x = convert_to_tensor(x) # `jnp.var` does not handle low precision (e.g., float16) overflow diff --git a/keras/src/backend/numpy/numpy.py b/keras/src/backend/numpy/numpy.py index cdff50137588..31d2dd44359d 100644 --- a/keras/src/backend/numpy/numpy.py +++ b/keras/src/backend/numpy/numpy.py @@ -1372,6 +1372,14 @@ def trapezoid(y, x=None, dx=1.0, axis=-1): return np.trapezoid(y, x, dx=dx, axis=axis).astype(result_dtype) +def vander(x, N=None, increasing=False): + x = convert_to_tensor(x) + result_dtype = dtypes.result_type(x.dtype) + compute_dtype = dtypes.result_type(x.dtype, config.floatx()) + x = x.astype(compute_dtype) + return np.vander(x, N=N, increasing=increasing).astype(result_dtype) + + def var(x, axis=None, keepdims=False): axis = standardize_axis_for_numpy(axis) x = convert_to_tensor(x) diff --git a/keras/src/backend/openvino/excluded_concrete_tests.txt b/keras/src/backend/openvino/excluded_concrete_tests.txt index b26482a13ba2..cbd6f4a0471b 100644 --- a/keras/src/backend/openvino/excluded_concrete_tests.txt +++ b/keras/src/backend/openvino/excluded_concrete_tests.txt @@ -50,6 +50,7 @@ NumpyDtypeTest::test_trace NumpyDtypeTest::test_trapezoid NumpyDtypeTest::test_trunc NumpyDtypeTest::test_unravel +NumpyDtypeTest::test_vander NumpyDtypeTest::test_var NumpyDtypeTest::test_vdot NumpyDtypeTest::test_view @@ -92,6 +93,7 @@ NumpyOneInputOpsCorrectnessTest::test_trace NumpyOneInputOpsCorrectnessTest::test_trapezoid NumpyOneInputOpsCorrectnessTest::test_trunc NumpyOneInputOpsCorrectnessTest::test_unravel_index +NumpyOneInputOpsCorrectnessTest::test_vander NumpyOneInputOpsCorrectnessTest::test_vectorize NumpyOneInputOpsCorrectnessTest::test_vstack NumpyOneInputOpsCorrectnessTest::test_view diff --git a/keras/src/backend/openvino/numpy.py b/keras/src/backend/openvino/numpy.py index 5e73bb3c840a..77af88266c21 100644 --- a/keras/src/backend/openvino/numpy.py +++ b/keras/src/backend/openvino/numpy.py @@ -2461,6 +2461,10 @@ def trapezoid(y, x=None, dx=1.0, axis=-1): ) +def vander(x, N=None, increasing=False): + raise NotImplementedError("`vander` is not supported with openvino backend") + + def var(x, axis=None, keepdims=False): x = get_ov_output(x) x_type = x.get_element_type() diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index c35be3d35bc8..05faf7cc6058 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -3097,6 +3097,27 @@ def _move_axis_to_last(tensor, axis): return result +def vander(x, N=None, increasing=False): + x = convert_to_tensor(x) + result_dtype = dtypes.result_type(x.dtype) + + if N is None: + N = tf.shape(x)[0] + + if increasing: + powers = tf.range(N) + else: + powers = tf.range(N - 1, -1, -1) + + x_exp = tf.expand_dims(x, axis=-1) + + compute_dtype = dtypes.result_type(x.dtype, "float32") + vander = tf.math.pow( + tf.cast(x_exp, compute_dtype), tf.cast(powers, compute_dtype) + ) + return tf.cast(vander, result_dtype) + + def var(x, axis=None, keepdims=False): x = convert_to_tensor(x) compute_dtype = dtypes.result_type(x.dtype, "float32") diff --git a/keras/src/backend/torch/numpy.py b/keras/src/backend/torch/numpy.py index feb1ac181dbe..540eb8e15a5b 100644 --- a/keras/src/backend/torch/numpy.py +++ b/keras/src/backend/torch/numpy.py @@ -1833,6 +1833,12 @@ def trapezoid(y, x=None, dx=1.0, axis=-1): return torch.trapz(y, dx=dx, dim=axis) +def vander(x, N=None, increasing=False): + x = convert_to_tensor(x) + result_dtype = dtypes.result_type(x.dtype) + return cast(torch.vander(x, N=N, increasing=increasing), result_dtype) + + def var(x, axis=None, keepdims=False): x = convert_to_tensor(x) compute_dtype = dtypes.result_type(x.dtype, "float32") diff --git a/keras/src/ops/numpy.py b/keras/src/ops/numpy.py index 2231fee3dd4b..dd2518c3dc43 100644 --- a/keras/src/ops/numpy.py +++ b/keras/src/ops/numpy.py @@ -301,33 +301,6 @@ def all(x, axis=None, keepdims=False): return backend.numpy.all(x, axis=axis, keepdims=keepdims) -class Any(Operation): - def __init__(self, axis=None, keepdims=False, *, name=None): - super().__init__(name=name) - if isinstance(axis, int): - self.axis = [axis] - else: - self.axis = axis - self.keepdims = keepdims - - def call(self, x): - return backend.numpy.any( - x, - axis=self.axis, - keepdims=self.keepdims, - ) - - def compute_output_spec(self, x): - return KerasTensor( - reduce_shape( - x.shape, - axis=self.axis, - keepdims=self.keepdims, - ), - dtype="bool", - ) - - class Angle(Operation): def call(self, x): return backend.numpy.angle(x) @@ -363,6 +336,33 @@ def angle(x): return backend.numpy.angle(x) +class Any(Operation): + def __init__(self, axis=None, keepdims=False, *, name=None): + super().__init__(name=name) + if isinstance(axis, int): + self.axis = [axis] + else: + self.axis = axis + self.keepdims = keepdims + + def call(self, x): + return backend.numpy.any( + x, + axis=self.axis, + keepdims=self.keepdims, + ) + + def compute_output_spec(self, x): + return KerasTensor( + reduce_shape( + x.shape, + axis=self.axis, + keepdims=self.keepdims, + ), + dtype="bool", + ) + + @keras_export(["keras.ops.any", "keras.ops.numpy.any"]) def any(x, axis=None, keepdims=False): """Test whether any array element along a given axis evaluates to `True`. @@ -7319,6 +7319,77 @@ def mean(x, axis=None, keepdims=False): return backend.numpy.mean(x, axis=axis, keepdims=keepdims) +class Vander(Operation): + def __init__(self, N=None, increasing=False, *, name=None): + super().__init__(name=name) + self.N = N + self.increasing = increasing + + def call(self, x): + return backend.numpy.vander(x, self.N, self.increasing) + + def compute_output_spec(self, x): + if self.N is None: + N = x.shape[0] + else: + N = self.N + + out_shape = list(x.shape) + out_shape.append(N) + return KerasTensor(tuple(out_shape), dtype=x.dtype) + + +@keras_export(["keras.ops.vander", "keras.ops.numpy.vander"]) +def vander(x, N=None, increasing=False): + """Generate a Vandermonde matrix. + + Args: + x: 1D input tensor. + N: Number of columns. If None, `N` = `len(x)`. + increasing: Order of powers. If True, powers increase left to right. + + Returns: + Output tensor, vandermonde matrix of shape `(len(x), N)`. + + Example: + >>> import numpy as np + >>> import keras + >>> x = np.array([1, 2, 3, 5]) + >>> keras.ops.vander(x) + array([[ 1, 1, 1, 1], + [ 8, 4, 2, 1], + [ 27, 9, 3, 1], + [125, 25, 5, 1]]) + """ + + if len(x.shape) != 1: + raise ValueError( + "Input tensor must be 1-dimensional. " + f"Received: input.shape={x.shape}" + ) + + if N is not None: + if N < 0: + raise ValueError( + f"Argument 'N' must be nonnegative. Received: N={N}" + ) + + if not isinstance(N, int): + raise TypeError( + f"Argument `N` must be of type `int`. Received: dtype={type(N)}" + ) + + if not isinstance(increasing, bool): + raise TypeError( + f"Argument `increasing` must be of type `bool`. " + f"Received: dtype={type(increasing)}" + ) + + if any_symbolic_tensors((x,)): + return Vander(N=N, increasing=increasing).symbolic_call(x) + return backend.numpy.vander(x, N=N, increasing=increasing) + + class Var(Operation): def __init__(self, axis=None, keepdims=False, *, name=None): super().__init__(name=name) diff --git a/keras/src/ops/numpy_test.py b/keras/src/ops/numpy_test.py index 0bc35c36c9a4..20b27a4c37a4 100644 --- a/keras/src/ops/numpy_test.py +++ b/keras/src/ops/numpy_test.py @@ -1154,6 +1154,10 @@ def test_trapezoid(self): x = KerasTensor((None, 3, 3)) self.assertEqual(knp.trapezoid(x, axis=1).shape, (None, 3)) + def test_vander(self): + x = KerasTensor((None,)) + self.assertEqual(knp.vander(x).shape, (None, None)) + def test_var(self): x = KerasTensor((None, 3)) self.assertEqual(knp.var(x).shape, ()) @@ -1913,6 +1917,10 @@ def test_trapezoid(self): x = KerasTensor((2, 3)) self.assertEqual(knp.trapezoid(x).shape, (2,)) + def test_vander(self): + x = KerasTensor((2,)) + self.assertEqual(knp.vander(x).shape, (2, 2)) + def test_var(self): x = KerasTensor((2, 3)) self.assertEqual(knp.var(x).shape, ()) @@ -3737,6 +3745,25 @@ def test_trapezoid(self): np.trapezoid(y, x=x, axis=1), ) + def test_vander(self): + x = np.random.random((3,)) + N = 6 + increasing = True + + self.assertAllClose(knp.vander(x), np.vander(x)) + self.assertAllClose(knp.vander(x, N=N), np.vander(x, N=N)) + self.assertAllClose( + knp.vander(x, N=N, increasing=increasing), + np.vander(x, N=N, increasing=increasing), + ) + + self.assertAllClose(knp.Vander().call(x), np.vander(x)) + self.assertAllClose(knp.Vander(N=N).call(x), np.vander(x, N=N)) + self.assertAllClose( + knp.Vander(N=N, increasing=increasing).call(x), + np.vander(x, N=N, increasing=increasing), + ) + def test_var(self): x = np.array([[1, 2, 3], [3, 2, 1]]) self.assertAllClose(knp.var(x), np.var(x)) @@ -9203,6 +9230,24 @@ def test_trapezoid(self, dtype): expected_dtype, ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_vander(self, dtype): + import jax.numpy as jnp + + x = knp.ones((2,), dtype=dtype) + x_jax = jnp.ones((2,), dtype=dtype) + + if dtype == "bool": + self.skipTest("vander does not support bool") + + expected_dtype = standardize_dtype(jnp.vander(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.vander(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Vander().symbolic_call(x).dtype), + expected_dtype, + ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) def test_var(self, dtype): import jax.numpy as jnp