@@ -989,15 +989,22 @@ def _place_impl(ary, ary_mask, vals, axis=0):
989
989
ary_mask .sycl_queue ,
990
990
)
991
991
)
992
+ coerced_usm_type = dpctl .utils .get_coerced_usm_type (
993
+ (
994
+ ary .usm_type ,
995
+ ary_mask .usm_type ,
996
+ )
997
+ )
992
998
if exec_q is None :
993
999
raise dpctl .utils .ExecutionPlacementError (
994
1000
"arrays have different associated queues. "
995
1001
"Use `y.to_device(x.device)` to migrate."
996
1002
)
997
1003
elif isinstance (ary_mask , np .ndarray ):
998
1004
exec_q = ary .sycl_queue
1005
+ coerced_usm_type = ary .usm_type
999
1006
ary_mask = dpt .asarray (
1000
- ary_mask , usm_type = ary . usm_type , sycl_queue = exec_q
1007
+ ary_mask , usm_type = coerced_usm_type , sycl_queue = exec_q
1001
1008
)
1002
1009
else :
1003
1010
raise TypeError (
@@ -1006,9 +1013,20 @@ def _place_impl(ary, ary_mask, vals, axis=0):
1006
1013
)
1007
1014
if exec_q is not None :
1008
1015
if not isinstance (vals , dpt .usm_ndarray ):
1009
- vals = dpt .asarray (vals , dtype = ary .dtype , sycl_queue = exec_q )
1016
+ vals = dpt .asarray (
1017
+ vals ,
1018
+ dtype = ary .dtype ,
1019
+ usm_type = coerced_usm_type ,
1020
+ sycl_queue = exec_q ,
1021
+ )
1010
1022
else :
1011
1023
exec_q = dpctl .utils .get_execution_queue ((exec_q , vals .sycl_queue ))
1024
+ coerced_usm_type = dpctl .utils .get_coerced_usm_type (
1025
+ (
1026
+ coerced_usm_type ,
1027
+ vals .usm_type ,
1028
+ )
1029
+ )
1012
1030
if exec_q is None :
1013
1031
raise dpctl .utils .ExecutionPlacementError (
1014
1032
"arrays have different associated queues. "
@@ -1023,7 +1041,12 @@ def _place_impl(ary, ary_mask, vals, axis=0):
1023
1041
)
1024
1042
mask_nelems = ary_mask .size
1025
1043
cumsum_dt = dpt .int32 if mask_nelems < int32_t_max else dpt .int64
1026
- cumsum = dpt .empty (mask_nelems , dtype = cumsum_dt , device = ary_mask .device )
1044
+ cumsum = dpt .empty (
1045
+ mask_nelems ,
1046
+ dtype = cumsum_dt ,
1047
+ usm_type = coerced_usm_type ,
1048
+ device = ary_mask .device ,
1049
+ )
1027
1050
exec_q = cumsum .sycl_queue
1028
1051
_manager = dpctl .utils .SequentialOrderManager [exec_q ]
1029
1052
dep_ev = _manager .submitted_events
@@ -1069,17 +1092,26 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
1069
1092
if not isinstance (inds , (list , tuple )):
1070
1093
inds = (inds ,)
1071
1094
1072
- exec_q , vals_usm_type = _get_indices_queue_usm_type (
1095
+ exec_q , coerced_usm_type = _get_indices_queue_usm_type (
1073
1096
inds , ary .sycl_queue , ary .usm_type
1074
1097
)
1075
1098
1076
1099
if exec_q is not None :
1077
1100
if not isinstance (vals , dpt .usm_ndarray ):
1078
1101
vals = dpt .asarray (
1079
- vals , dtype = ary .dtype , usm_type = vals_usm_type , sycl_queue = exec_q
1102
+ vals ,
1103
+ dtype = ary .dtype ,
1104
+ usm_type = coerced_usm_type ,
1105
+ sycl_queue = exec_q ,
1080
1106
)
1081
1107
else :
1082
1108
exec_q = dpctl .utils .get_execution_queue ((exec_q , vals .sycl_queue ))
1109
+ coerced_usm_type = dpctl .utils .get_coerced_usm_type (
1110
+ (
1111
+ coerced_usm_type ,
1112
+ vals .usm_type ,
1113
+ )
1114
+ )
1083
1115
if exec_q is None :
1084
1116
raise dpctl .utils .ExecutionPlacementError (
1085
1117
"Can not automatically determine where to allocate the "
@@ -1088,7 +1120,7 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
1088
1120
"be associated with the same queue."
1089
1121
)
1090
1122
1091
- inds = _prepare_indices_arrays (inds , exec_q , vals_usm_type )
1123
+ inds = _prepare_indices_arrays (inds , exec_q , coerced_usm_type )
1092
1124
1093
1125
ind0 = inds [0 ]
1094
1126
ary_sh = ary .shape
0 commit comments