Skip to content

Commit 0de2fff

Browse files
committed
add usm_type coercion to advanced indexing __setitem__ routines
1 parent b90fc90 commit 0de2fff

File tree

1 file changed

+38
-6
lines changed

1 file changed

+38
-6
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -989,15 +989,22 @@ def _place_impl(ary, ary_mask, vals, axis=0):
989989
ary_mask.sycl_queue,
990990
)
991991
)
992+
coerced_usm_type = dpctl.utils.get_coerced_usm_type(
993+
(
994+
ary.usm_type,
995+
ary_mask.usm_type,
996+
)
997+
)
992998
if exec_q is None:
993999
raise dpctl.utils.ExecutionPlacementError(
9941000
"arrays have different associated queues. "
9951001
"Use `y.to_device(x.device)` to migrate."
9961002
)
9971003
elif isinstance(ary_mask, np.ndarray):
9981004
exec_q = ary.sycl_queue
1005+
coerced_usm_type = ary.usm_type
9991006
ary_mask = dpt.asarray(
1000-
ary_mask, usm_type=ary.usm_type, sycl_queue=exec_q
1007+
ary_mask, usm_type=coerced_usm_type, sycl_queue=exec_q
10011008
)
10021009
else:
10031010
raise TypeError(
@@ -1006,9 +1013,20 @@ def _place_impl(ary, ary_mask, vals, axis=0):
10061013
)
10071014
if exec_q is not None:
10081015
if not isinstance(vals, dpt.usm_ndarray):
1009-
vals = dpt.asarray(vals, dtype=ary.dtype, sycl_queue=exec_q)
1016+
vals = dpt.asarray(
1017+
vals,
1018+
dtype=ary.dtype,
1019+
usm_type=coerced_usm_type,
1020+
sycl_queue=exec_q,
1021+
)
10101022
else:
10111023
exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue))
1024+
coerced_usm_type = dpctl.utils.get_coerced_usm_type(
1025+
(
1026+
coerced_usm_type,
1027+
vals.usm_type,
1028+
)
1029+
)
10121030
if exec_q is None:
10131031
raise dpctl.utils.ExecutionPlacementError(
10141032
"arrays have different associated queues. "
@@ -1023,7 +1041,12 @@ def _place_impl(ary, ary_mask, vals, axis=0):
10231041
)
10241042
mask_nelems = ary_mask.size
10251043
cumsum_dt = dpt.int32 if mask_nelems < int32_t_max else dpt.int64
1026-
cumsum = dpt.empty(mask_nelems, dtype=cumsum_dt, device=ary_mask.device)
1044+
cumsum = dpt.empty(
1045+
mask_nelems,
1046+
dtype=cumsum_dt,
1047+
usm_type=coerced_usm_type,
1048+
device=ary_mask.device,
1049+
)
10271050
exec_q = cumsum.sycl_queue
10281051
_manager = dpctl.utils.SequentialOrderManager[exec_q]
10291052
dep_ev = _manager.submitted_events
@@ -1069,17 +1092,26 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
10691092
if not isinstance(inds, (list, tuple)):
10701093
inds = (inds,)
10711094

1072-
exec_q, vals_usm_type = _get_indices_queue_usm_type(
1095+
exec_q, coerced_usm_type = _get_indices_queue_usm_type(
10731096
inds, ary.sycl_queue, ary.usm_type
10741097
)
10751098

10761099
if exec_q is not None:
10771100
if not isinstance(vals, dpt.usm_ndarray):
10781101
vals = dpt.asarray(
1079-
vals, dtype=ary.dtype, usm_type=vals_usm_type, sycl_queue=exec_q
1102+
vals,
1103+
dtype=ary.dtype,
1104+
usm_type=coerced_usm_type,
1105+
sycl_queue=exec_q,
10801106
)
10811107
else:
10821108
exec_q = dpctl.utils.get_execution_queue((exec_q, vals.sycl_queue))
1109+
coerced_usm_type = dpctl.utils.get_coerced_usm_type(
1110+
(
1111+
coerced_usm_type,
1112+
vals.usm_type,
1113+
)
1114+
)
10831115
if exec_q is None:
10841116
raise dpctl.utils.ExecutionPlacementError(
10851117
"Can not automatically determine where to allocate the "
@@ -1088,7 +1120,7 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
10881120
"be associated with the same queue."
10891121
)
10901122

1091-
inds = _prepare_indices_arrays(inds, exec_q, vals_usm_type)
1123+
inds = _prepare_indices_arrays(inds, exec_q, coerced_usm_type)
10921124

10931125
ind0 = inds[0]
10941126
ary_sh = ary.shape

0 commit comments

Comments
 (0)