Skip to content

Commit 3381d14

Browse files
authored
Merge pull request #2128 from IntelPython/allow-numpy-arrays-indexing
Accept NumPy arrays in advanced indexing
2 parents a207271 + 532c11e commit 3381d14

File tree

3 files changed

+131
-69
lines changed

3 files changed

+131
-69
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 102 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -756,20 +756,28 @@ def _extract_impl(ary, ary_mask, axis=0):
756756
raise TypeError(
757757
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
758758
)
759-
if not isinstance(ary_mask, dpt.usm_ndarray):
760-
raise TypeError(
761-
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}"
759+
if isinstance(ary_mask, dpt.usm_ndarray):
760+
dst_usm_type = dpctl.utils.get_coerced_usm_type(
761+
(ary.usm_type, ary_mask.usm_type)
762762
)
763-
dst_usm_type = dpctl.utils.get_coerced_usm_type(
764-
(ary.usm_type, ary_mask.usm_type)
765-
)
766-
exec_q = dpctl.utils.get_execution_queue(
767-
(ary.sycl_queue, ary_mask.sycl_queue)
768-
)
769-
if exec_q is None:
770-
raise dpctl.utils.ExecutionPlacementError(
771-
"arrays have different associated queues. "
772-
"Use `y.to_device(x.device)` to migrate."
763+
exec_q = dpctl.utils.get_execution_queue(
764+
(ary.sycl_queue, ary_mask.sycl_queue)
765+
)
766+
if exec_q is None:
767+
raise dpctl.utils.ExecutionPlacementError(
768+
"arrays have different associated queues. "
769+
"Use `y.to_device(x.device)` to migrate."
770+
)
771+
elif isinstance(ary_mask, np.ndarray):
772+
dst_usm_type = ary.usm_type
773+
exec_q = ary.sycl_queue
774+
ary_mask = dpt.asarray(
775+
ary_mask, usm_type=dst_usm_type, sycl_queue=exec_q
776+
)
777+
else:
778+
raise TypeError(
779+
"Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
780+
f"{type(ary_mask)}"
773781
)
774782
ary_nd = ary.ndim
775783
pp = normalize_axis_index(operator.index(axis), ary_nd)
@@ -837,35 +845,40 @@ def _nonzero_impl(ary):
837845
return res
838846

839847

840-
def _validate_indices(inds, queue_list, usm_type_list):
848+
def _get_indices_queue_usm_type(inds, queue, usm_type):
841849
"""
842-
Utility for validating indices are usm_ndarray of integral dtype or Python
843-
integers. At least one must be an array.
850+
Utility for validating indices are NumPy ndarray or usm_ndarray of integral
851+
dtype or Python integers. At least one must be an array.
844852
845853
For each array, the queue and usm type are appended to `queue_list` and
846854
`usm_type_list`, respectively.
847855
"""
848-
any_usmarray = False
856+
queues = [queue]
857+
usm_types = [usm_type]
858+
any_array = False
849859
for ind in inds:
850-
if isinstance(ind, dpt.usm_ndarray):
851-
any_usmarray = True
860+
if isinstance(ind, (np.ndarray, dpt.usm_ndarray)):
861+
any_array = True
852862
if ind.dtype.kind not in "ui":
853863
raise IndexError(
854864
"arrays used as indices must be of integer (or boolean) "
855865
"type"
856866
)
857-
queue_list.append(ind.sycl_queue)
858-
usm_type_list.append(ind.usm_type)
867+
if isinstance(ind, dpt.usm_ndarray):
868+
queues.append(ind.sycl_queue)
869+
usm_types.append(ind.usm_type)
859870
elif not isinstance(ind, Integral):
860871
raise TypeError(
861-
"all elements of `ind` expected to be usm_ndarrays "
862-
f"or integers, found {type(ind)}"
872+
"all elements of `ind` expected to be usm_ndarrays, "
873+
f"NumPy arrays, or integers, found {type(ind)}"
863874
)
864-
if not any_usmarray:
875+
if not any_array:
865876
raise TypeError(
866-
"at least one element of `inds` expected to be a usm_ndarray"
877+
"at least one element of `inds` expected to be an array"
867878
)
868-
return inds
879+
usm_type = dpctl.utils.get_coerced_usm_type(usm_types)
880+
q = dpctl.utils.get_execution_queue(queues)
881+
return q, usm_type
869882

870883

871884
def _prepare_indices_arrays(inds, q, usm_type):
@@ -922,18 +935,12 @@ def _take_multi_index(ary, inds, p, mode=0):
922935
raise ValueError(
923936
"Invalid value for mode keyword, only 0 or 1 is supported"
924937
)
925-
queues_ = [
926-
ary.sycl_queue,
927-
]
928-
usm_types_ = [
929-
ary.usm_type,
930-
]
931938
if not isinstance(inds, (list, tuple)):
932939
inds = (inds,)
933940

934-
_validate_indices(inds, queues_, usm_types_)
935-
res_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
936-
exec_q = dpctl.utils.get_execution_queue(queues_)
941+
exec_q, res_usm_type = _get_indices_queue_usm_type(
942+
inds, ary.sycl_queue, ary.usm_type
943+
)
937944
if exec_q is None:
938945
raise dpctl.utils.ExecutionPlacementError(
939946
"Can not automatically determine where to allocate the "
@@ -942,8 +949,7 @@ def _take_multi_index(ary, inds, p, mode=0):
942949
"be associated with the same queue."
943950
)
944951

945-
if len(inds) > 1:
946-
inds = _prepare_indices_arrays(inds, exec_q, res_usm_type)
952+
inds = _prepare_indices_arrays(inds, exec_q, res_usm_type)
947953

948954
ind0 = inds[0]
949955
ary_sh = ary.shape
@@ -976,21 +982,51 @@ def _place_impl(ary, ary_mask, vals, axis=0):
976982
raise TypeError(
977983
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary)}"
978984
)
979-
if not isinstance(ary_mask, dpt.usm_ndarray):
980-
raise TypeError(
981-
f"Expecting type dpctl.tensor.usm_ndarray, got {type(ary_mask)}"
985+
if isinstance(ary_mask, dpt.usm_ndarray):
986+
exec_q = dpctl.utils.get_execution_queue(
987+
(
988+
ary.sycl_queue,
989+
ary_mask.sycl_queue,
990+
)
982991
)
983-
exec_q = dpctl.utils.get_execution_queue(
984-
(
985-
ary.sycl_queue,
986-
ary_mask.sycl_queue,
992+
coerced_usm_type = dpctl.utils.get_coerced_usm_type(
993+
(
994+
ary.usm_type,
995+
ary_mask.usm_type,
996+
)
997+
)
998+
if exec_q is None:
999+
raise dpctl.utils.ExecutionPlacementError(
1000+
"arrays have different associated queues. "
1001+
"Use `y.to_device(x.device)` to migrate."
1002+
)
1003+
elif isinstance(ary_mask, np.ndarray):
1004+
exec_q = ary.sycl_queue
1005+
coerced_usm_type = ary.usm_type
1006+
ary_mask = dpt.asarray(
1007+
ary_mask, usm_type=coerced_usm_type, sycl_queue=exec_q
1008+
)
1009+
else:
1010+
raise TypeError(
1011+
"Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
1012+
f"{type(ary_mask)}"
9871013
)
988-
)
9891014
if exec_q is not None:
9901015
if not isinstance(vals, dpt.usm_ndarray):
991-
vals = dpt.asarray(vals, dtype=ary.dtype, sycl_queue=exec_q)
1016+
vals = dpt.asarray(
1017+
vals,
1018+
dtype=ary.dtype,
1019+
usm_type=coerced_usm_type,
1020+
sycl_queue=exec_q,
1021+
)
9921022
else:
9931023
exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue))
1024+
coerced_usm_type = dpctl.utils.get_coerced_usm_type(
1025+
(
1026+
coerced_usm_type,
1027+
vals.usm_type,
1028+
)
1029+
)
9941030
if exec_q is None:
9951031
raise dpctl.utils.ExecutionPlacementError(
9961032
"arrays have different associated queues. "
@@ -1005,7 +1041,12 @@ def _place_impl(ary, ary_mask, vals, axis=0):
10051041
)
10061042
mask_nelems = ary_mask.size
10071043
cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64
1008-
cumsum = dpt.empty(mask_nelems, dtype=cumsum_dt, device=ary_mask.device)
1044+
cumsum = dpt.empty(
1045+
mask_nelems,
1046+
dtype=cumsum_dt,
1047+
usm_type=coerced_usm_type,
1048+
device=ary_mask.device,
1049+
)
10091050
exec_q = cumsum.sycl_queue
10101051
_manager = dpctl.utils.SequentialOrderManager[exec_q]
10111052
dep_ev = _manager.submitted_events
@@ -1048,30 +1089,29 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
10481089
raise ValueError(
10491090
"Invalid value for mode keyword, only 0 or 1 is supported"
10501091
)
1051-
if isinstance(vals, dpt.usm_ndarray):
1052-
queues_ = [ary.sycl_queue, vals.sycl_queue]
1053-
usm_types_ = [ary.usm_type, vals.usm_type]
1054-
else:
1055-
queues_ = [
1056-
ary.sycl_queue,
1057-
]
1058-
usm_types_ = [
1059-
ary.usm_type,
1060-
]
10611092
if not isinstance(inds, (list, tuple)):
10621093
inds = (inds,)
10631094

1064-
_validate_indices(inds, queues_, usm_types_)
1095+
exec_q, coerced_usm_type = _get_indices_queue_usm_type(
1096+
inds, ary.sycl_queue, ary.usm_type
1097+
)
10651098

1066-
vals_usm_type = dpctl.utils.get_coerced_usm_type(usm_types_)
1067-
exec_q = dpctl.utils.get_execution_queue(queues_)
10681099
if exec_q is not None:
10691100
if not isinstance(vals, dpt.usm_ndarray):
10701101
vals = dpt.asarray(
1071-
vals, dtype=ary.dtype, usm_type=vals_usm_type, sycl_queue=exec_q
1102+
vals,
1103+
dtype=ary.dtype,
1104+
usm_type=coerced_usm_type,
1105+
sycl_queue=exec_q,
10721106
)
10731107
else:
10741108
exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue))
1109+
coerced_usm_type = dpctl.utils.get_coerced_usm_type(
1110+
(
1111+
coerced_usm_type,
1112+
vals.usm_type,
1113+
)
1114+
)
10751115
if exec_q is None:
10761116
raise dpctl.utils.ExecutionPlacementError(
10771117
"Can not automatically determine where to allocate the "
@@ -1080,8 +1120,7 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
10801120
"be associated with the same queue."
10811121
)
10821122

1083-
if len(inds) > 1:
1084-
inds = _prepare_indices_arrays(inds, exec_q, vals_usm_type)
1123+
inds = _prepare_indices_arrays(inds, exec_q, coerced_usm_type)
10851124

10861125
ind0 = inds[0]
10871126
ary_sh = ary.shape

dpctl/tensor/_slicing.pxi

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numbers
1818
from operator import index
1919
from cpython.buffer cimport PyObject_CheckBuffer
20+
from numpy import ndarray
2021

2122

2223
cdef bint _is_buffer(object o):
@@ -46,7 +47,7 @@ cdef Py_ssize_t _slice_len(
4647

4748
cdef bint _is_integral(object x) except *:
4849
"""Gives True if x is an integral slice spec"""
49-
if isinstance(x, usm_ndarray):
50+
if isinstance(x, (ndarray, usm_ndarray)):
5051
if x.ndim > 0:
5152
return False
5253
if x.dtype.kind not in "ui":
@@ -74,7 +75,7 @@ cdef bint _is_integral(object x) except *:
7475

7576
cdef bint _is_boolean(object x) except *:
7677
"""Gives True if x is an integral slice spec"""
77-
if isinstance(x, usm_ndarray):
78+
if isinstance(x, (ndarray, usm_ndarray)):
7879
if x.ndim > 0:
7980
return False
8081
if x.dtype.kind not in "b":
@@ -185,7 +186,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
185186
raise IndexError(
186187
"Index {0} is out of range for axes 0 with "
187188
"size {1}".format(ind, shape[0]))
188-
elif isinstance(ind, usm_ndarray):
189+
elif isinstance(ind, (ndarray, usm_ndarray)):
189190
return (shape, strides, offset, (ind,), 0)
190191
elif isinstance(ind, tuple):
191192
axes_referenced = 0
@@ -216,7 +217,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
216217
axes_referenced += 1
217218
if not array_streak_started and array_streak_interrupted:
218219
explicit_index += 1
219-
elif isinstance(i, usm_ndarray):
220+
elif isinstance(i, (ndarray, usm_ndarray)):
220221
if not seen_arrays_yet:
221222
seen_arrays_yet = True
222223
array_streak_started = True
@@ -302,7 +303,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
302303
array_streak = False
303304
elif _is_integral(ind_i):
304305
if array_streak:
305-
if not isinstance(ind_i, usm_ndarray):
306+
if not isinstance(ind_i, (ndarray, usm_ndarray)):
306307
ind_i = index(ind_i)
307308
# integer will be converted to an array,
308309
# still raise if OOB
@@ -337,7 +338,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
337338
"Index {0} is out of range for axes "
338339
"{1} with size {2}".format(ind_i, k, shape[k])
339340
)
340-
elif isinstance(ind_i, usm_ndarray):
341+
elif isinstance(ind_i, (ndarray, usm_ndarray)):
341342
if not array_streak:
342343
array_streak = True
343344
if not advanced_start_pos_set:

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,28 @@ def test_advanced_slice16():
440440
assert isinstance(y, dpt.usm_ndarray)
441441

442442

443+
def test_integer_indexing_numpy_array():
444+
q = get_queue_or_skip()
445+
ii = np.asarray([1, 2])
446+
x = dpt.arange(10, dtype="i4", sycl_queue=q)
447+
y = x[ii]
448+
assert isinstance(y, dpt.usm_ndarray)
449+
assert y.shape == ii.shape
450+
assert dpt.all(x[1:3] == y)
451+
452+
453+
def test_boolean_indexing_numpy_array():
454+
q = get_queue_or_skip()
455+
ii = np.asarray(
456+
[False, True, True, False, False, False, False, False, False, False]
457+
)
458+
x = dpt.arange(10, dtype="i4", sycl_queue=q)
459+
y = x[ii]
460+
assert isinstance(y, dpt.usm_ndarray)
461+
assert y.shape == (2,)
462+
assert dpt.all(x[1:3] == y)
463+
464+
443465
def test_boolean_indexing_validation():
444466
get_queue_or_skip()
445467
x = dpt.zeros(10, dtype="i4")

0 commit comments

Comments
 (0)