Skip to content

Extend dpnp.pad to support pad_width as a dictionary #2535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
* Removed the use of class template argument deduction for alias template to conform to the C++17 standard [#2517](https://github.com/IntelPython/dpnp/pull/2517)
* Changed th order of individual FFTs over `axes` for `dpnp.fft.irfftn` to be in forward order [#2524](https://github.com/IntelPython/dpnp/pull/2524)
* Replaced the use of `numpy.testing.suppress_warnings` with appropriate calls from the warnings module [#2529](https://github.com/IntelPython/dpnp/pull/2529)
* Extended `dpnp.pad` to support `pad_width` keyword as a dictionary [#2535](https://github.com/IntelPython/dpnp/pull/2535)

### Deprecated

Expand Down
25 changes: 24 additions & 1 deletion dpnp/dpnp_iface_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2504,14 +2504,17 @@ def pad(array, pad_width, mode="constant", **kwargs):
----------
array : {dpnp.ndarray, usm_ndarray}
The array of rank ``N`` to pad.
pad_width : {sequence, array_like, int}
pad_width : {sequence, array_like, int, dict}
Number of values padded to the edges of each axis.
``((before_1, after_1), ... (before_N, after_N))`` unique pad widths
for each axis.
``(before, after)`` or ``((before, after),)`` yields same before
and after pad for each axis.
``(pad,)`` or ``int`` is a shortcut for ``before = after = pad`` width
for all axes.
If a dictionary, each key is an axis and its corresponding value is an
integer or a pair of integers describing the padding ``(before, after)``
or ``pad`` width for that axis.
mode : {str, function}, optional
One of the following string values or a user supplied function.

Expand Down Expand Up @@ -2694,6 +2697,26 @@ def pad(array, pad_width, mode="constant", **kwargs):
[100, 100, 100, 100, 100, 100, 100],
[100, 100, 100, 100, 100, 100, 100]])

>>> a = np.arange(1, 7).reshape(2, 3)
>>> np.pad(a, {1: (1, 2)})
array([[0, 1, 2, 3, 0, 0],
[0, 4, 5, 6, 0, 0]])
>>> np.pad(a, {-1: 2})
array([[0, 0, 1, 2, 3, 0, 0],
[0, 0, 4, 5, 6, 0, 0]])
>>> np.pad(a, {0: (3, 0)})
array([[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[1, 2, 3],
[4, 5, 6]])
>>> np.pad(a, {0: (3, 0), 1: 2})
array([[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 2, 3, 0, 0],
[0, 0, 4, 5, 6, 0, 0]])

"""

dpnp.check_supported_arrays_type(array)
Expand Down
53 changes: 49 additions & 4 deletions dpnp/dpnp_utils/dpnp_utils_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,47 @@ def _get_stats(padded, axis, width_pair, length_pair, stat_func):
return left_stat, right_stat


def _pad_normalize_dict_width(pad_width, ndim):
"""
Normalize pad width passed as a dictionary.

Parameters
----------
pad_width : dict
Padding specification. The keys must be integer axis indices, and
the values must be either:
- a single int (same padding before and after),
- a tuple of two ints (before, after).
ndim : int
Number of dimensions in the input array.

Returns
-------
seq : list
A (ndim, 2) list of padding widths for each axis.

Raises
------
TypeError
If the padding format for any axis is invalid.

"""

seq = [(0, 0)] * ndim
for axis, width in pad_width.items():
if isinstance(width, int):
seq[axis] = (width, width)
elif (
isinstance(width, tuple)
and len(width) == 2
and all(isinstance(w, int) for w in width)
):
seq[axis] = width
else:
raise TypeError(f"Invalid pad width for axis {axis}: {width}")
return seq


def _pad_simple(array, pad_width, fill_value=None):
"""
Copied from numpy/lib/_arraypad_impl.py
Expand Down Expand Up @@ -616,21 +657,25 @@ def _view_roi(array, original_area_slice, axis):
def dpnp_pad(array, pad_width, mode="constant", **kwargs):
"""Pad an array."""

nd = array.ndim

if isinstance(pad_width, int):
if pad_width < 0:
raise ValueError("index can't contain negative values")
pad_width = ((pad_width, pad_width),) * array.ndim
pad_width = ((pad_width, pad_width),) * nd
else:
if dpnp.is_supported_array_type(pad_width):
pad_width = dpnp.asnumpy(pad_width)
else:
if isinstance(pad_width, dict):
pad_width = _pad_normalize_dict_width(pad_width, nd)
pad_width = numpy.asarray(pad_width)

if not pad_width.dtype.kind == "i":
raise TypeError("`pad_width` must be of integral type.")

# Broadcast to shape (array.ndim, 2)
pad_width = _as_pairs(pad_width, array.ndim, as_index=True)
# Broadcast to shape (nd, 2)
pad_width = _as_pairs(pad_width, nd, as_index=True)

if callable(mode):
function = mode
Expand Down Expand Up @@ -683,7 +728,7 @@ def dpnp_pad(array, pad_width, mode="constant", **kwargs):
if (
dpnp.isscalar(values)
and values == 0
and (array.ndim == 1 or array.size < 3e7)
and (nd == 1 or array.size < 3e7)
):
# faster path for 1d arrays or small n-dimensional arrays
return _pad_simple(array, pad_width, 0)[0]
Expand Down
18 changes: 18 additions & 0 deletions dpnp/tests/test_arraypad.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,21 @@ def test_as_pairs_exceptions(self):
dpnp_as_pairs([[1, 2], [3, 4]], 3)
with pytest.raises(ValueError, match="could not be broadcast"):
dpnp_as_pairs(dpnp.ones((2, 3)), 3)

@testing.with_requires("numpy>=2.4")
@pytest.mark.parametrize(
"sh, pad_width",
[
((3, 4, 5), {-2: (1, 3)}),
((3, 4, 5), {0: (5, 2)}),
((3, 4, 5), {0: (5, 2), -1: (3, 4)}),
((3, 4, 5), {1: 5}),
],
)
def test_dict_pad_width(self, sh, pad_width):
a = numpy.zeros(sh)
ia = dpnp.array(a)

result = dpnp.pad(ia, pad_width)
expected = numpy.pad(a, pad_width)
assert_equal(result, expected)
Loading