Skip to content

Commit 06a74cd

Browse files
authored
Extend nan_to_num with broadcast support of nan, posinf, and neginf (#2754)
The PR extends implementation of `dpnp.nan_to_num` function to align with NumPy and CuPy which supports `nan`, `posinf`, and `neginf` keywords as any array through broadcasting. This PR adds handling for a common path where at least one of the keywords has non-scalar value. The path does not assume a dedicated SYCL kernel, instead proposes to rely on implementation through existing python functions. That can be improved in the future if required.
1 parent 2f796e0 commit 06a74cd

File tree

4 files changed

+144
-77
lines changed

4 files changed

+144
-77
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
4848
* Improved documentation of `file` argument in `dpnp.fromfile` [#2745](https://github.com/IntelPython/dpnp/pull/2745)
4949
* Aligned `dpnp.trim_zeros` with NumPy 2.4 to support a tuple of integers passed with `axis` keyword [#2746](https://github.com/IntelPython/dpnp/pull/2746)
5050
* Aligned `strides` property of `dpnp.ndarray` with NumPy and CuPy implementations [#2747](https://github.com/IntelPython/dpnp/pull/2747)
51+
* Extended `dpnp.nan_to_num` to support broadcasting of `nan`, `posinf`, and `neginf` keywords [#2754](https://github.com/IntelPython/dpnp/pull/2754)
5152

5253
### Deprecated
5354

dpnp/dpnp_iface_mathematical.py

Lines changed: 75 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3646,20 +3646,24 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
36463646
an array does not require a copy.
36473647
36483648
Default: ``True``.
3649-
nan : {int, float, bool}, optional
3650-
Value to be used to fill ``NaN`` values.
3649+
nan : {scalar, array_like}, optional
3650+
Values to be used to fill ``NaN`` values. If no values are passed then
3651+
``NaN`` values will be replaced with ``0.0``.
3652+
Expected to have a real-valued data type for the values.
36513653
36523654
Default: ``0.0``.
3653-
posinf : {int, float, bool, None}, optional
3654-
Value to be used to fill positive infinity values. If no value is
3655+
posinf : {None, scalar, array_like}, optional
3656+
Values to be used to fill positive infinity values. If no values are
36553657
passed then positive infinity values will be replaced with a very
36563658
large number.
3659+
Expected to have a real-valued data type for the values.
36573660
36583661
Default: ``None``.
3659-
neginf : {int, float, bool, None} optional
3660-
Value to be used to fill negative infinity values. If no value is
3662+
neginf : {None, scalar, array_like}, optional
3663+
Values to be used to fill negative infinity values. If no values are
36613664
passed then negative infinity values will be replaced with a very
36623665
small (or negative) number.
3666+
Expected to have a real-valued data type for the values.
36633667
36643668
Default: ``None``.
36653669
@@ -3687,13 +3691,22 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
36873691
array(-1.79769313e+308)
36883692
>>> np.nan_to_num(np.array(np.nan))
36893693
array(0.)
3694+
36903695
>>> x = np.array([np.inf, -np.inf, np.nan, -128, 128])
36913696
>>> np.nan_to_num(x)
36923697
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000,
36933698
-1.28000000e+002, 1.28000000e+002])
36943699
>>> np.nan_to_num(x, nan=-9999, posinf=33333333, neginf=33333333)
36953700
array([ 3.3333333e+07, 3.3333333e+07, -9.9990000e+03, -1.2800000e+02,
36963701
1.2800000e+02])
3702+
3703+
>>> nan = np.array([11, 12, -9999, 13, 14])
3704+
>>> posinf = np.array([33333333, 11, 12, 13, 14])
3705+
>>> neginf = np.array([11, 33333333, 12, 13, 14])
3706+
>>> np.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
3707+
array([ 3.3333333e+07, 3.3333333e+07, -9.9990000e+03, -1.2800000e+02,
3708+
1.2800000e+02])
3709+
36973710
>>> y = np.array([complex(np.inf, np.nan), np.nan, complex(np.nan, np.inf)])
36983711
>>> np.nan_to_num(y)
36993712
array([1.79769313e+308 +0.00000000e+000j, # may vary
@@ -3706,33 +3719,32 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
37063719

37073720
dpnp.check_supported_arrays_type(x)
37083721

3709-
# Python boolean is a subtype of an integer
3710-
# so additional check for bool is not needed.
3711-
if not isinstance(nan, (int, float)):
3712-
raise TypeError(
3713-
"nan must be a scalar of an integer, float, bool, "
3714-
f"but got {type(nan)}"
3715-
)
3716-
x_type = x.dtype.type
3722+
def _check_nan_inf(val, val_dt):
3723+
# Python boolean is a subtype of an integer
3724+
if not isinstance(val, (int, float)):
3725+
val = dpnp.asarray(
3726+
val, dtype=val_dt, sycl_queue=x.sycl_queue, usm_type=x.usm_type
3727+
)
3728+
return val
37173729

3718-
if not issubclass(x_type, dpnp.inexact):
3730+
x_type = x.dtype.type
3731+
if not dpnp.issubdtype(x_type, dpnp.inexact):
37193732
return dpnp.copy(x) if copy else dpnp.get_result_array(x)
37203733

37213734
max_f, min_f = _get_max_min(x.real.dtype)
3735+
3736+
# get dtype of nan and infs values if casting required
3737+
is_complex = dpnp.issubdtype(x_type, dpnp.complexfloating)
3738+
if is_complex:
3739+
val_dt = x.real.dtype
3740+
else:
3741+
val_dt = x.dtype
3742+
3743+
nan = _check_nan_inf(nan, val_dt)
37223744
if posinf is not None:
3723-
if not isinstance(posinf, (int, float)):
3724-
raise TypeError(
3725-
"posinf must be a scalar of an integer, float, bool, "
3726-
f"or be None, but got {type(posinf)}"
3727-
)
3728-
max_f = posinf
3745+
max_f = _check_nan_inf(posinf, val_dt)
37293746
if neginf is not None:
3730-
if not isinstance(neginf, (int, float)):
3731-
raise TypeError(
3732-
"neginf must be a scalar of an integer, float, bool, "
3733-
f"or be None, but got {type(neginf)}"
3734-
)
3735-
min_f = neginf
3747+
min_f = _check_nan_inf(neginf, val_dt)
37363748

37373749
if copy:
37383750
out = dpnp.empty_like(x)
@@ -3741,19 +3753,45 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None):
37413753
raise ValueError("copy is required for read-only array `x`")
37423754
out = x
37433755

3744-
x_ary = dpnp.get_usm_ndarray(x)
3745-
out_ary = dpnp.get_usm_ndarray(out)
3756+
# handle a special case when nan and infs are all scalars
3757+
if all(dpnp.isscalar(el) for el in (nan, max_f, min_f)):
3758+
x_ary = dpnp.get_usm_ndarray(x)
3759+
out_ary = dpnp.get_usm_ndarray(out)
3760+
3761+
q = x.sycl_queue
3762+
_manager = dpu.SequentialOrderManager[q]
3763+
3764+
h_ev, comp_ev = ufi._nan_to_num(
3765+
x_ary,
3766+
nan,
3767+
max_f,
3768+
min_f,
3769+
out_ary,
3770+
q,
3771+
depends=_manager.submitted_events,
3772+
)
37463773

3747-
q = x.sycl_queue
3748-
_manager = dpu.SequentialOrderManager[q]
3774+
_manager.add_event_pair(h_ev, comp_ev)
37493775

3750-
h_ev, comp_ev = ufi._nan_to_num(
3751-
x_ary, nan, max_f, min_f, out_ary, q, depends=_manager.submitted_events
3752-
)
3776+
return dpnp.get_result_array(out)
37533777

3754-
_manager.add_event_pair(h_ev, comp_ev)
3755-
3756-
return dpnp.get_result_array(out)
3778+
# handle a common case with broadcasting of input nan and infs
3779+
if is_complex:
3780+
parts = (x.real, x.imag)
3781+
parts_out = (out.real, out.imag)
3782+
else:
3783+
parts = (x,)
3784+
parts_out = (out,)
3785+
3786+
for part, part_out in zip(parts, parts_out):
3787+
nan_mask = dpnp.isnan(part)
3788+
posinf_mask = dpnp.isposinf(part)
3789+
neginf_mask = dpnp.isneginf(part)
3790+
3791+
part = dpnp.where(nan_mask, nan, part, out=part_out)
3792+
part = dpnp.where(posinf_mask, max_f, part, out=part_out)
3793+
part = dpnp.where(neginf_mask, min_f, part, out=part_out)
3794+
return out
37573795

37583796

37593797
_NEGATIVE_DOCSTRING = """

dpnp/tests/test_mathematical.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,37 +1480,35 @@ def test_boolean_array(self):
14801480
expected = numpy.nan_to_num(a)
14811481
assert_allclose(result, expected)
14821482

1483-
def test_errors(self):
1484-
ia = dpnp.array([0, 1, dpnp.nan, dpnp.inf, -dpnp.inf])
1485-
1486-
# unsupported type `a`
1487-
a = dpnp.asnumpy(ia)
1488-
assert_raises(TypeError, dpnp.nan_to_num, a)
1489-
1490-
# unsupported type `nan`
1491-
i_nan = dpnp.array(1)
1492-
assert_raises(TypeError, dpnp.nan_to_num, ia, nan=i_nan)
1483+
@pytest.mark.parametrize("dt", get_float_complex_dtypes())
1484+
@pytest.mark.parametrize("kw_name", ["nan", "posinf", "neginf"])
1485+
@pytest.mark.parametrize("val", [[1, 2, -1, -2, 7], (7.0,), numpy.array(1)])
1486+
def test_nan_infs_array_like(self, dt, kw_name, val):
1487+
a = numpy.array([0, 1, dpnp.nan, dpnp.inf, -dpnp.inf], dtype=dt)
1488+
ia = dpnp.array(a)
14931489

1494-
# unsupported type `posinf`
1495-
i_posinf = dpnp.array(1)
1496-
assert_raises(TypeError, dpnp.nan_to_num, ia, posinf=i_posinf)
1490+
result = dpnp.nan_to_num(ia, **{kw_name: val})
1491+
expected = numpy.nan_to_num(a, **{kw_name: val})
1492+
assert_allclose(result, expected)
14971493

1498-
# unsupported type `neginf`
1499-
i_neginf = dpnp.array(1)
1500-
assert_raises(TypeError, dpnp.nan_to_num, ia, neginf=i_neginf)
1494+
@pytest.mark.parametrize("xp", [dpnp, numpy])
1495+
@pytest.mark.parametrize("kw_name", ["nan", "posinf", "neginf"])
1496+
def test_nan_infs_complex_dtype(self, xp, kw_name):
1497+
ia = xp.array([0, 1, xp.nan, xp.inf, -xp.inf])
1498+
with pytest.raises(TypeError, match="complex"):
1499+
xp.nan_to_num(ia, **{kw_name: 1j})
15011500

1502-
@pytest.mark.parametrize("kwarg", ["nan", "posinf", "neginf"])
1503-
@pytest.mark.parametrize("value", [1 - 0j, [1, 2], (1,)])
1504-
def test_errors_diff_types(self, kwarg, value):
1505-
ia = dpnp.array([0, 1, dpnp.nan, dpnp.inf, -dpnp.inf])
1506-
with pytest.raises(TypeError):
1507-
dpnp.nan_to_num(ia, **{kwarg: value})
1501+
def test_numpy_input_array(self):
1502+
a = numpy.array([0, 1, dpnp.nan, dpnp.inf, -dpnp.inf])
1503+
with pytest.raises(TypeError, match="must be any of supported type"):
1504+
dpnp.nan_to_num(a)
15081505

1509-
def test_error_readonly(self):
1510-
a = dpnp.array([0, 1, dpnp.nan, dpnp.inf, -dpnp.inf])
1511-
a.flags.writable = False
1512-
with pytest.raises(ValueError):
1513-
dpnp.nan_to_num(a, copy=False)
1506+
@pytest.mark.parametrize("xp", [dpnp, numpy])
1507+
def test_error_readonly(self, xp):
1508+
a = xp.array([0, 1, xp.nan, xp.inf, -xp.inf])
1509+
a.flags["W"] = False
1510+
with pytest.raises(ValueError, match="read-only"):
1511+
xp.nan_to_num(a, copy=False)
15141512

15151513
@pytest.mark.parametrize("copy", [True, False])
15161514
@pytest.mark.parametrize("dt", get_all_dtypes(no_bool=True, no_none=True))
@@ -1522,9 +1520,9 @@ def test_strided(self, copy, dt):
15221520
if dt.kind in "fc":
15231521
a[::4] = numpy.nan
15241522
ia[::4] = dpnp.nan
1523+
15251524
result = dpnp.nan_to_num(ia[::-2], copy=copy, nan=57.0)
15261525
expected = numpy.nan_to_num(a[::-2], copy=copy, nan=57.0)
1527-
15281526
assert_dtype_allclose(result, expected)
15291527

15301528

dpnp/tests/third_party/cupy/math_tests/test_misc.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from __future__ import annotations
2+
13
import numpy
24
import pytest
35

46
import dpnp as cupy
5-
from dpnp.tests.helper import has_support_aspect64
7+
from dpnp.tests.helper import has_support_aspect64, numpy_version
68
from dpnp.tests.third_party.cupy import testing
79

810

@@ -155,10 +157,7 @@ def test_external_clip4(self, dtype):
155157
# (min or max) as a keyword argument according to Python Array API.
156158
# In older versions of numpy, both arguments must be positional;
157159
# passing only one raises a TypeError.
158-
if (
159-
xp is numpy
160-
and numpy.lib.NumpyVersion(numpy.__version__) < "2.1.0"
161-
):
160+
if xp is numpy and numpy_version() < "2.1.0":
162161
with pytest.raises(TypeError):
163162
xp.clip(a, 3)
164163
else:
@@ -257,9 +256,10 @@ def test_nan_to_num_inf(self):
257256
def test_nan_to_num_nan(self):
258257
self.check_unary_nan("nan_to_num")
259258

260-
@testing.numpy_cupy_allclose(atol=1e-5, type_check=has_support_aspect64())
259+
@pytest.mark.skip("no scalar support")
260+
@testing.numpy_cupy_allclose(atol=1e-5)
261261
def test_nan_to_num_scalar_nan(self, xp):
262-
return xp.nan_to_num(xp.array(xp.nan))
262+
return xp.nan_to_num(xp.nan)
263263

264264
@pytest.mark.filterwarnings("ignore::RuntimeWarning")
265265
def test_nan_to_num_inf_nan(self):
@@ -286,14 +286,44 @@ def test_nan_to_num_inplace(self, xp):
286286
return y
287287

288288
@pytest.mark.parametrize("kwarg", ["nan", "posinf", "neginf"])
289-
def test_nan_to_num_broadcast(self, kwarg):
289+
@testing.numpy_cupy_array_equal()
290+
def test_nan_to_num_broadcast_same_shapes(self, xp, kwarg):
291+
x = xp.asarray(
292+
[[0, 1, xp.nan, 4], [11, xp.inf, 12, 13]],
293+
dtype=cupy.default_float_type(),
294+
)
295+
y = xp.zeros((2, 4), dtype=x.dtype)
296+
return xp.nan_to_num(x, **{kwarg: y})
297+
298+
@pytest.mark.parametrize("kwarg", ["nan", "posinf", "neginf"])
299+
@testing.numpy_cupy_array_equal()
300+
def test_nan_to_num_broadcast_different_columns(self, xp, kwarg):
301+
x = xp.asarray(
302+
[[0, 1, xp.nan, 4], [11, xp.inf, 12, 13]],
303+
dtype=cupy.default_float_type(),
304+
)
305+
y = xp.zeros((2, 1), dtype=x.dtype)
306+
return xp.nan_to_num(x, **{kwarg: y})
307+
308+
@pytest.mark.parametrize("kwarg", ["nan", "posinf", "neginf"])
309+
@testing.numpy_cupy_array_equal()
310+
def test_nan_to_num_broadcast_different_rows(self, xp, kwarg):
311+
x = xp.asarray(
312+
[[0, 1, xp.nan, 4], [11, -xp.inf, 12, 13]],
313+
dtype=cupy.default_float_type(),
314+
)
315+
y = xp.zeros((1, 4), dtype=x.dtype)
316+
return xp.nan_to_num(x, **{kwarg: y})
317+
318+
@pytest.mark.parametrize("kwarg", ["nan", "posinf", "neginf"])
319+
def test_nan_to_num_broadcast_invalid_shapes(self, kwarg):
290320
for xp in (numpy, cupy):
291321
x = xp.asarray([0, 1, xp.nan, 4], dtype=cupy.default_float_type())
292-
y = xp.zeros((2, 4), dtype=cupy.default_float_type())
293-
with pytest.raises((ValueError, TypeError)):
322+
y = xp.zeros((2, 4), dtype=x.dtype)
323+
with pytest.raises(ValueError):
294324
xp.nan_to_num(x, **{kwarg: y})
295-
with pytest.raises((ValueError, TypeError)):
296-
xp.nan_to_num(0.0, **{kwarg: y})
325+
with pytest.raises(ValueError):
326+
xp.nan_to_num(xp.array(0.0), **{kwarg: y})
297327

298328
@testing.for_all_dtypes(no_bool=True, no_complex=True)
299329
@testing.numpy_cupy_array_equal()

0 commit comments

Comments
 (0)