Skip to content

Commit 2f796e0

Browse files
authored
Align strides with numpy (#2747)
The PR changes implementation of `strides` property in `dpnp.ndarray` to align with NumPy and CuPy and to return bytes displacement in memory (previously and in dpctl it returns elements displacement).
1 parent f4591e1 commit 2f796e0

File tree

18 files changed

+202
-112
lines changed

18 files changed

+202
-112
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
4747
* Clarified behavior on repeated `axes` in `dpnp.tensordot` and `dpnp.linalg.tensordot` functions [#2733](https://github.com/IntelPython/dpnp/pull/2733)
4848
* Improved documentation of `file` argument in `dpnp.fromfile` [#2745](https://github.com/IntelPython/dpnp/pull/2745)
4949
* Aligned `dpnp.trim_zeros` with NumPy 2.4 to support a tuple of integers passed with `axis` keyword [#2746](https://github.com/IntelPython/dpnp/pull/2746)
50+
* Aligned `strides` property of `dpnp.ndarray` with NumPy and CuPy implementations [#2747](https://github.com/IntelPython/dpnp/pull/2747)
5051

5152
### Deprecated
5253

dpnp/dpnp_array.py

Lines changed: 66 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,16 @@ def __init__(
105105
else:
106106
buffer = usm_type
107107

108+
if strides is not None:
109+
# dpctl expects strides as elements displacement in memory,
110+
# while dpnp (and numpy as well) relies on bytes displacement
111+
if dtype is None:
112+
dtype = dpnp.default_float_type(
113+
device=device, sycl_queue=sycl_queue
114+
)
115+
it_sz = dpnp.dtype(dtype).itemsize
116+
strides = tuple(el // it_sz for el in strides)
117+
108118
sycl_queue_normalized = dpnp.get_normalized_queue_device(
109119
device=device, sycl_queue=sycl_queue
110120
)
@@ -1855,16 +1865,53 @@ def std(
18551865
@property
18561866
def strides(self):
18571867
"""
1858-
Return memory displacement in array elements, upon unit
1859-
change of respective index.
1868+
Tuple of bytes to step in each dimension when traversing an array.
18601869
1861-
For example, for strides ``(s1, s2, s3)`` and multi-index
1862-
``(i1, i2, i3)`` position of the respective element relative
1863-
to zero multi-index element is ``s1*s1 + s2*i2 + s3*i3``.
1870+
The byte offset of element ``(i[0], i[1], ..., i[n])`` in an array `a`
1871+
is::
18641872
1865-
"""
1873+
offset = sum(dpnp.array(i) * a.strides)
18661874
1867-
return self._array_obj.strides
1875+
For full documentation refer to :obj:`numpy.ndarray.strides`.
1876+
1877+
See Also
1878+
--------
1879+
:obj:`dpnp.lib.stride_tricks.as_strided` : Return a view into the array
1880+
with given shape and strides.
1881+
1882+
Examples
1883+
--------
1884+
>>> import dpnp as np
1885+
>>> y = np.reshape(np.arange(2 * 3 * 4, dtype=np.int32), (2, 3, 4))
1886+
>>> y
1887+
array([[[ 0, 1, 2, 3],
1888+
[ 4, 5, 6, 7],
1889+
[ 8, 9, 10, 11]],
1890+
[[12, 13, 14, 15],
1891+
[16, 17, 18, 19],
1892+
[20, 21, 22, 23]]], dtype=np.int32)
1893+
>>> y.strides
1894+
(48, 16, 4)
1895+
>>> y[1, 1, 1]
1896+
array(17, dtype=int32)
1897+
>>> offset = sum(i * s for i, s in zip((1, 1, 1), y.strides))
1898+
>>> offset // y.itemsize
1899+
17
1900+
1901+
>>> x = np.reshape(np.arange(5*6*7*8, dtype=np.int32), (5, 6, 7, 8))
1902+
>>> x = x.transpose(2, 3, 1, 0)
1903+
>>> x.strides
1904+
(32, 4, 224, 1344)
1905+
>>> offset = sum(i * s for i, s in zip((3, 5, 2, 2), x.strides))
1906+
>>> x[3, 5, 2, 2]
1907+
array(813, dtype=int32)
1908+
>>> offset // x.itemsize
1909+
813
1910+
1911+
"""
1912+
1913+
it_sz = self.itemsize
1914+
return tuple(el * it_sz for el in self._array_obj.strides)
18681915

18691916
def sum(
18701917
self,
@@ -2335,23 +2382,20 @@ def view(self, /, dtype=None, *, type=None):
23352382

23362383
# resize on last axis only
23372384
axis = ndim - 1
2338-
if old_sh[axis] != 1 and self.size != 0 and old_strides[axis] != 1:
2385+
if (
2386+
old_sh[axis] != 1
2387+
and self.size != 0
2388+
and old_strides[axis] != old_itemsz
2389+
):
23392390
raise ValueError(
23402391
"To change to a dtype of a different size, "
23412392
"the last axis must be contiguous"
23422393
)
23432394

23442395
# normalize strides whenever itemsize changes
2345-
if old_itemsz > new_itemsz:
2346-
new_strides = list(
2347-
el * (old_itemsz // new_itemsz) for el in old_strides
2348-
)
2349-
else:
2350-
new_strides = list(
2351-
el // (new_itemsz // old_itemsz) for el in old_strides
2352-
)
2353-
new_strides[axis] = 1
2354-
new_strides = tuple(new_strides)
2396+
new_strides = tuple(
2397+
old_strides[i] if i != axis else new_itemsz for i in range(ndim)
2398+
)
23552399

23562400
new_dim = old_sh[axis] * old_itemsz
23572401
if new_dim % new_itemsz != 0:
@@ -2361,9 +2405,10 @@ def view(self, /, dtype=None, *, type=None):
23612405
)
23622406

23632407
# normalize shape whenever itemsize changes
2364-
new_sh = list(old_sh)
2365-
new_sh[axis] = new_dim // new_itemsz
2366-
new_sh = tuple(new_sh)
2408+
new_sh = tuple(
2409+
old_sh[i] if i != axis else new_dim // new_itemsz
2410+
for i in range(ndim)
2411+
)
23672412

23682413
return dpnp_array(
23692414
new_sh,

dpnp/dpnp_iface_arraycreation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _get_empty_array(
105105
elif a.flags.c_contiguous:
106106
order = "C"
107107
else:
108-
strides = _get_strides_for_order_k(a, _shape)
108+
strides = _get_strides_for_order_k(a, _dtype, shape=_shape)
109109
order = "C"
110110
elif order not in "cfCF":
111111
raise ValueError(
@@ -122,15 +122,15 @@ def _get_empty_array(
122122
)
123123

124124

125-
def _get_strides_for_order_k(x, shape=None):
125+
def _get_strides_for_order_k(x, dtype, shape=None):
126126
"""
127127
Calculate strides when order='K' for empty_like, ones_like, zeros_like,
128128
and full_like where `shape` is ``None`` or len(shape) == x.ndim.
129129
130130
"""
131131
stride_and_index = sorted([(abs(s), -i) for i, s in enumerate(x.strides)])
132132
strides = [0] * x.ndim
133-
stride = 1
133+
stride = dpnp.dtype(dtype).itemsize
134134
for _, i in stride_and_index:
135135
strides[-i] = stride
136136
stride *= shape[-i] if shape else x.shape[-i]

dpnp/dpnp_iface_indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -731,10 +731,10 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
731731
elif 0 < offset < m:
732732
out_shape = a_shape[:-2] + (min(n, m - offset),)
733733
out_strides = a_straides[:-2] + (st_n + st_m,)
734-
out_offset = st_m * offset
734+
out_offset = st_m // a.itemsize * offset
735735
else:
736736
out_shape = a_shape[:-2] + (0,)
737-
out_strides = a_straides[:-2] + (1,)
737+
out_strides = a_straides[:-2] + (a.itemsize,)
738738
out_offset = 0
739739

740740
return dpnp_array(

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def _define_contig_flag(x):
185185
"""
186186

187187
flag = False
188-
x_strides = x.strides
188+
x_strides = dpnp.get_usm_ndarray(x).strides
189189
x_shape = x.shape
190190
if x.ndim < 2:
191191
return True, True, True

dpnp/fft/dpnp_utils_fft.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,13 @@ def _compute_result(dsc, a, out, forward, c2c, out_strides):
193193
)
194194
result = a
195195
else:
196+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
196197
if (
197198
out is not None
198-
and out.strides == tuple(out_strides)
199-
and not ti._array_overlap(a_usm, dpnp.get_usm_ndarray(out))
199+
and out_usm.strides == tuple(out_strides)
200+
and not ti._array_overlap(a_usm, out_usm)
200201
):
201-
res_usm = dpnp.get_usm_ndarray(out)
202+
res_usm = out_usm
202203
result = out
203204
else:
204205
# Result array that is used in oneMKL must have the exact same
@@ -223,6 +224,10 @@ def _compute_result(dsc, a, out, forward, c2c, out_strides):
223224
if a.dtype == dpnp.complex64
224225
else dpnp.float64
225226
)
227+
# cast to expected strides format
228+
out_strides = tuple(
229+
el * dpnp.dtype(out_dtype).itemsize for el in out_strides
230+
)
226231
result = dpnp_array(
227232
out_shape,
228233
dtype=out_dtype,
@@ -419,7 +424,8 @@ def _fft(a, norm, out, forward, in_place, c2c, axes, batch_fft=True):
419424
if cufft_wa: # pragma: no cover
420425
a = dpnp.moveaxis(a, -1, -2)
421426

422-
a_strides = _standardize_strides_to_nonzero(a.strides, a.shape)
427+
strides = dpnp.get_usm_ndarray(a).strides
428+
a_strides = _standardize_strides_to_nonzero(strides, a.shape)
423429
dsc, out_strides = _commit_descriptor(
424430
a, forward, in_place, c2c, a_strides, index, batch_fft
425431
)

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def _batched_inv(a, res_type):
215215
_manager.add_event_pair(ht_ev, copy_ev)
216216

217217
ipiv_stride = n
218-
a_stride = a_h.strides[0]
218+
a_stride = a_h.strides[0] // a_h.itemsize
219219

220220
# Call the LAPACK extension function _getrf_batch
221221
# to perform LU decomposition of a batch of general matrices
@@ -298,7 +298,7 @@ def _batched_lu_factor(a, res_type):
298298
dev_info_h = [0] * batch_size
299299

300300
ipiv_stride = n
301-
a_stride = a_h.strides[0]
301+
a_stride = a_h.strides[0] // a_h.itemsize
302302

303303
# Call the LAPACK extension function _getrf_batch
304304
# to perform LU decomposition of a batch of general matrices
@@ -471,8 +471,8 @@ def _batched_qr(a, mode="reduced"):
471471
dtype=res_type,
472472
)
473473

474-
a_stride = a_t.strides[0]
475-
tau_stride = tau_h.strides[0]
474+
a_stride = a_t.strides[0] // a_t.itemsize
475+
tau_stride = tau_h.strides[0] // tau_h.itemsize
476476

477477
# Call the LAPACK extension function _geqrf_batch to compute
478478
# the QR factorization of a general m x n matrix.
@@ -535,8 +535,8 @@ def _batched_qr(a, mode="reduced"):
535535
)
536536
_manager.add_event_pair(ht_ev, copy_ev)
537537

538-
q_stride = q.strides[0]
539-
tau_stride = tau_h.strides[0]
538+
q_stride = q.strides[0] // q.itemsize
539+
tau_stride = tau_h.strides[0] // tau_h.itemsize
540540

541541
# Get LAPACK function (_orgqr_batch for real or _ungqf_batch for complex
542542
# data types) for QR factorization
@@ -1818,7 +1818,7 @@ def dpnp_cholesky_batch(a, upper_lower, res_type):
18181818
)
18191819
_manager.add_event_pair(ht_ev, copy_ev)
18201820

1821-
a_stride = a_h.strides[0]
1821+
a_stride = a_h.strides[0] // a_h.itemsize
18221822

18231823
# Call the LAPACK extension function _potrf_batch
18241824
# to computes the Cholesky decomposition of a batch of

dpnp/scipy/linalg/_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
3838
"""
3939

40+
# pylint: disable=duplicate-code
4041
# pylint: disable=no-name-in-module
4142
# pylint: disable=protected-access
4243

@@ -144,7 +145,7 @@ def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
144145
dev_info_h = [0] * batch_size
145146

146147
ipiv_stride = k
147-
a_stride = a_h.strides[-1]
148+
a_stride = a_h.strides[-1] // a_h.itemsize
148149

149150
# Call the LAPACK extension function _getrf_batch
150151
# to perform LU decomposition of a batch of general matrices

dpnp/tests/test_arraycreation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -861,12 +861,12 @@ def test_full_order(order1, order2):
861861
def test_full_strides():
862862
a = numpy.full((3, 3), numpy.arange(3, dtype="i4"))
863863
ia = dpnp.full((3, 3), dpnp.arange(3, dtype="i4"))
864-
assert ia.strides == tuple(el // a.itemsize for el in a.strides)
864+
assert ia.strides == a.strides
865865
assert_array_equal(ia, a)
866866

867867
a = numpy.full((3, 3), numpy.arange(6, dtype="i4")[::2])
868868
ia = dpnp.full((3, 3), dpnp.arange(6, dtype="i4")[::2])
869-
assert ia.strides == tuple(el // a.itemsize for el in a.strides)
869+
assert ia.strides == a.strides
870870
assert_array_equal(ia, a)
871871

872872

dpnp/tests/test_ndarray.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,10 @@ def test_attributes(self):
6060
assert_equal(self.three.shape, (10, 3, 2))
6161
self.three.shape = (2, 5, 6)
6262

63-
assert_equal(self.one.strides, (self.one.itemsize / self.one.itemsize,))
64-
num = self.two.itemsize / self.two.itemsize
63+
assert_equal(self.one.strides, (self.one.itemsize,))
64+
num = self.two.itemsize
6565
assert_equal(self.two.strides, (5 * num, num))
66-
num = self.three.itemsize / self.three.itemsize
66+
num = self.three.itemsize
6767
assert_equal(self.three.strides, (30 * num, 6 * num, num))
6868

6969
assert_equal(self.one.ndim, 1)
@@ -290,7 +290,7 @@ def test_flags_strides(dtype, order, strides):
290290
(4, 4), dtype=dtype, order=order, strides=strides
291291
)
292292
a = numpy.ndarray((4, 4), dtype=dtype, order=order, strides=numpy_strides)
293-
ia = dpnp.ndarray((4, 4), dtype=dtype, order=order, strides=strides)
293+
ia = dpnp.ndarray((4, 4), dtype=dtype, order=order, strides=numpy_strides)
294294
assert usm_array.flags == ia.flags
295295
assert a.flags.c_contiguous == ia.flags.c_contiguous
296296
assert a.flags.f_contiguous == ia.flags.f_contiguous

0 commit comments

Comments
 (0)