Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@
from keras.src.ops.numpy import true_divide as true_divide
from keras.src.ops.numpy import trunc as trunc
from keras.src.ops.numpy import unravel_index as unravel_index
from keras.src.ops.numpy import vander as vander
from keras.src.ops.numpy import var as var
from keras.src.ops.numpy import vdot as vdot
from keras.src.ops.numpy import vectorize as vectorize
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
from keras.src.ops.numpy import true_divide as true_divide
from keras.src.ops.numpy import trunc as trunc
from keras.src.ops.numpy import unravel_index as unravel_index
from keras.src.ops.numpy import vander as vander
from keras.src.ops.numpy import var as var
from keras.src.ops.numpy import vdot as vdot
from keras.src.ops.numpy import vectorize as vectorize
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@
from keras.src.ops.numpy import true_divide as true_divide
from keras.src.ops.numpy import trunc as trunc
from keras.src.ops.numpy import unravel_index as unravel_index
from keras.src.ops.numpy import vander as vander
from keras.src.ops.numpy import var as var
from keras.src.ops.numpy import vdot as vdot
from keras.src.ops.numpy import vectorize as vectorize
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@
from keras.src.ops.numpy import true_divide as true_divide
from keras.src.ops.numpy import trunc as trunc
from keras.src.ops.numpy import unravel_index as unravel_index
from keras.src.ops.numpy import vander as vander
from keras.src.ops.numpy import var as var
from keras.src.ops.numpy import vdot as vdot
from keras.src.ops.numpy import vectorize as vectorize
Expand Down
5 changes: 5 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,6 +1398,11 @@ def trapezoid(y, x=None, dx=1.0, axis=-1):
return jnp.trapezoid(y, x, dx=dx, axis=axis)


def vander(x, N=None, increasing=False):
x = convert_to_tensor(x)
return jnp.vander(x, N=N, increasing=increasing)


def var(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
# `jnp.var` does not handle low precision (e.g., float16) overflow
Expand Down
8 changes: 8 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,14 @@ def trapezoid(y, x=None, dx=1.0, axis=-1):
return np.trapezoid(y, x, dx=dx, axis=axis).astype(result_dtype)


def vander(x, N=None, increasing=False):
x = convert_to_tensor(x)
result_dtype = dtypes.result_type(x.dtype)
compute_dtype = dtypes.result_type(x.dtype, config.floatx())
x = x.astype(compute_dtype)
return np.vander(x, N=N, increasing=increasing).astype(result_dtype)


def var(x, axis=None, keepdims=False):
axis = standardize_axis_for_numpy(axis)
x = convert_to_tensor(x)
Expand Down
2 changes: 2 additions & 0 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ NumpyDtypeTest::test_trace
NumpyDtypeTest::test_trapezoid
NumpyDtypeTest::test_trunc
NumpyDtypeTest::test_unravel
NumpyDtypeTest::test_vander
NumpyDtypeTest::test_var
NumpyDtypeTest::test_vdot
NumpyDtypeTest::test_view
Expand Down Expand Up @@ -92,6 +93,7 @@ NumpyOneInputOpsCorrectnessTest::test_trace
NumpyOneInputOpsCorrectnessTest::test_trapezoid
NumpyOneInputOpsCorrectnessTest::test_trunc
NumpyOneInputOpsCorrectnessTest::test_unravel_index
NumpyOneInputOpsCorrectnessTest::test_vander
NumpyOneInputOpsCorrectnessTest::test_vectorize
NumpyOneInputOpsCorrectnessTest::test_vstack
NumpyOneInputOpsCorrectnessTest::test_view
Expand Down
4 changes: 4 additions & 0 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2461,6 +2461,10 @@ def trapezoid(y, x=None, dx=1.0, axis=-1):
)


def vander(x, N=None, increasing=False):
raise NotImplementedError("`vander` is not supported with openvino backend")


def var(x, axis=None, keepdims=False):
x = get_ov_output(x)
x_type = x.get_element_type()
Expand Down
21 changes: 21 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3097,6 +3097,27 @@ def _move_axis_to_last(tensor, axis):
return result


def vander(x, N=None, increasing=False):
x = convert_to_tensor(x)
result_dtype = dtypes.result_type(x.dtype)

if N is None:
N = tf.shape(x)[0]

if increasing:
powers = tf.range(N)
else:
powers = tf.range(N - 1, -1, -1)

x_exp = tf.expand_dims(x, axis=-1)

compute_dtype = dtypes.result_type(x.dtype, "float32")
vander = tf.math.pow(
tf.cast(x_exp, compute_dtype), tf.cast(powers, compute_dtype)
)
return tf.cast(vander, result_dtype)


def var(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
compute_dtype = dtypes.result_type(x.dtype, "float32")
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,6 +1833,12 @@ def trapezoid(y, x=None, dx=1.0, axis=-1):
return torch.trapz(y, dx=dx, dim=axis)


def vander(x, N=None, increasing=False):
x = convert_to_tensor(x)
result_dtype = dtypes.result_type(x.dtype)
return cast(torch.vander(x, N=N, increasing=increasing), result_dtype)


def var(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
compute_dtype = dtypes.result_type(x.dtype, "float32")
Expand Down
125 changes: 98 additions & 27 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,33 +301,6 @@ def all(x, axis=None, keepdims=False):
return backend.numpy.all(x, axis=axis, keepdims=keepdims)


class Any(Operation):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the Any class was misplaced, I moved it.

def __init__(self, axis=None, keepdims=False, *, name=None):
super().__init__(name=name)
if isinstance(axis, int):
self.axis = [axis]
else:
self.axis = axis
self.keepdims = keepdims

def call(self, x):
return backend.numpy.any(
x,
axis=self.axis,
keepdims=self.keepdims,
)

def compute_output_spec(self, x):
return KerasTensor(
reduce_shape(
x.shape,
axis=self.axis,
keepdims=self.keepdims,
),
dtype="bool",
)


class Angle(Operation):
def call(self, x):
return backend.numpy.angle(x)
Expand Down Expand Up @@ -363,6 +336,33 @@ def angle(x):
return backend.numpy.angle(x)


class Any(Operation):
def __init__(self, axis=None, keepdims=False, *, name=None):
super().__init__(name=name)
if isinstance(axis, int):
self.axis = [axis]
else:
self.axis = axis
self.keepdims = keepdims

def call(self, x):
return backend.numpy.any(
x,
axis=self.axis,
keepdims=self.keepdims,
)

def compute_output_spec(self, x):
return KerasTensor(
reduce_shape(
x.shape,
axis=self.axis,
keepdims=self.keepdims,
),
dtype="bool",
)


@keras_export(["keras.ops.any", "keras.ops.numpy.any"])
def any(x, axis=None, keepdims=False):
"""Test whether any array element along a given axis evaluates to `True`.
Expand Down Expand Up @@ -7319,6 +7319,77 @@ def mean(x, axis=None, keepdims=False):
return backend.numpy.mean(x, axis=axis, keepdims=keepdims)


class Vander(Operation):
def __init__(self, N=None, increasing=False, *, name=None):
super().__init__(name=name)
self.N = N
self.increasing = increasing

def call(self, x):
return backend.numpy.vander(x, self.N, self.increasing)

def compute_output_spec(self, x):
if self.N is None:
N = x.shape[0]
else:
N = self.N

out_shape = list(x.shape)
out_shape.append(N)
return KerasTensor(tuple(out_shape), dtype=x.dtype)


@keras_export(["keras.ops.vander", "keras.ops.numpy.vander"])
def vander(x, N=None, increasing=False):
"""Generate a Vandermonde matrix.

Args:
x: 1D input tensor.
N: Number of columns. If None, `N` = `len(x)`.
increasing: Order of powers. If True, powers increase left to right.

Returns:
Output tensor, vandermonde matrix of shape `(len(x), N)`.

Example:
>>> import numpy as np
>>> import keras
>>> x = np.array([1, 2, 3, 5])
>>> keras.ops.vander(x)
array([[ 1, 1, 1, 1],
[ 8, 4, 2, 1],
[ 27, 9, 3, 1],
[125, 25, 5, 1]])
"""

if len(x.shape) != 1:
raise ValueError(
"Input tensor must be 1-dimensional. "
f"Received: input.shape={x.shape}"
)

if N is not None:
if N < 0:
raise ValueError(
f"Argument 'N' must be nonnegative. Received: N={N}"
)

if not isinstance(N, int):
raise TypeError(
f"Argument `N` must be of type `int`. Received: dtype={type(N)}"
)

if not isinstance(increasing, bool):
raise TypeError(
f"Argument `increasing` must be of type `bool`. "
f"Received: dtype={type(increasing)}"
)

if any_symbolic_tensors((x,)):
return Vander(N=N, increasing=increasing).symbolic_call(x)
return backend.numpy.vander(x, N=N, increasing=increasing)


class Var(Operation):
def __init__(self, axis=None, keepdims=False, *, name=None):
super().__init__(name=name)
Expand Down
45 changes: 45 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,6 +1154,10 @@ def test_trapezoid(self):
x = KerasTensor((None, 3, 3))
self.assertEqual(knp.trapezoid(x, axis=1).shape, (None, 3))

def test_vander(self):
x = KerasTensor((None,))
self.assertEqual(knp.vander(x).shape, (None, None))

def test_var(self):
x = KerasTensor((None, 3))
self.assertEqual(knp.var(x).shape, ())
Expand Down Expand Up @@ -1913,6 +1917,10 @@ def test_trapezoid(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.trapezoid(x).shape, (2,))

def test_vander(self):
x = KerasTensor((2,))
self.assertEqual(knp.vander(x).shape, (2, 2))

def test_var(self):
x = KerasTensor((2, 3))
self.assertEqual(knp.var(x).shape, ())
Expand Down Expand Up @@ -3737,6 +3745,25 @@ def test_trapezoid(self):
np.trapezoid(y, x=x, axis=1),
)

def test_vander(self):
x = np.random.random((3,))
N = 6
increasing = True

self.assertAllClose(knp.vander(x), np.vander(x))
self.assertAllClose(knp.vander(x, N=N), np.vander(x, N=N))
self.assertAllClose(
knp.vander(x, N=N, increasing=increasing),
np.vander(x, N=N, increasing=increasing),
)

self.assertAllClose(knp.Vander().call(x), np.vander(x))
self.assertAllClose(knp.Vander(N=N).call(x), np.vander(x, N=N))
self.assertAllClose(
knp.Vander(N=N, increasing=increasing).call(x),
np.vander(x, N=N, increasing=increasing),
)

def test_var(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(knp.var(x), np.var(x))
Expand Down Expand Up @@ -9203,6 +9230,24 @@ def test_trapezoid(self, dtype):
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_vander(self, dtype):
import jax.numpy as jnp

x = knp.ones((2,), dtype=dtype)
x_jax = jnp.ones((2,), dtype=dtype)

if dtype == "bool":
self.skipTest("vander does not support bool")

expected_dtype = standardize_dtype(jnp.vander(x_jax).dtype)

self.assertEqual(standardize_dtype(knp.vander(x).dtype), expected_dtype)
self.assertEqual(
standardize_dtype(knp.Vander().symbolic_call(x).dtype),
expected_dtype,
)

@parameterized.named_parameters(named_product(dtype=ALL_DTYPES))
def test_var(self, dtype):
import jax.numpy as jnp
Expand Down