@@ -756,20 +756,28 @@ def _extract_impl(ary, ary_mask, axis=0):
756
756
raise TypeError (
757
757
f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
758
758
)
759
- if not isinstance (ary_mask , dpt .usm_ndarray ):
760
- raise TypeError (
761
- f"Expecting type dpctl.tensor.usm_ndarray, got { type ( ary_mask ) } "
759
+ if isinstance (ary_mask , dpt .usm_ndarray ):
760
+ dst_usm_type = dpctl . utils . get_coerced_usm_type (
761
+ ( ary . usm_type , ary_mask . usm_type )
762
762
)
763
- dst_usm_type = dpctl .utils .get_coerced_usm_type (
764
- (ary .usm_type , ary_mask .usm_type )
765
- )
766
- exec_q = dpctl .utils .get_execution_queue (
767
- (ary .sycl_queue , ary_mask .sycl_queue )
768
- )
769
- if exec_q is None :
770
- raise dpctl .utils .ExecutionPlacementError (
771
- "arrays have different associated queues. "
772
- "Use `y.to_device(x.device)` to migrate."
763
+ exec_q = dpctl .utils .get_execution_queue (
764
+ (ary .sycl_queue , ary_mask .sycl_queue )
765
+ )
766
+ if exec_q is None :
767
+ raise dpctl .utils .ExecutionPlacementError (
768
+ "arrays have different associated queues. "
769
+ "Use `y.to_device(x.device)` to migrate."
770
+ )
771
+ elif isinstance (ary_mask , np .ndarray ):
772
+ dst_usm_type = ary .usm_type
773
+ exec_q = ary .sycl_queue
774
+ ary_mask = dpt .asarray (
775
+ ary_mask , usm_type = dst_usm_type , sycl_queue = exec_q
776
+ )
777
+ else :
778
+ raise TypeError (
779
+ "Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
780
+ f"{ type (ary_mask )} "
773
781
)
774
782
ary_nd = ary .ndim
775
783
pp = normalize_axis_index (operator .index (axis ), ary_nd )
@@ -837,35 +845,40 @@ def _nonzero_impl(ary):
837
845
return res
838
846
839
847
840
- def _validate_indices (inds , queue_list , usm_type_list ):
848
+ def _get_indices_queue_usm_type (inds , queue , usm_type ):
841
849
"""
842
- Utility for validating indices are usm_ndarray of integral dtype or Python
843
- integers. At least one must be an array.
850
+ Utility for validating indices are NumPy ndarray or usm_ndarray of integral
851
+ dtype or Python integers. At least one must be an array.
844
852
845
853
For each array, the queue and usm type are appended to `queue_list` and
846
854
`usm_type_list`, respectively.
847
855
"""
848
- any_usmarray = False
856
+ queues = [queue ]
857
+ usm_types = [usm_type ]
858
+ any_array = False
849
859
for ind in inds :
850
- if isinstance (ind , dpt .usm_ndarray ):
851
- any_usmarray = True
860
+ if isinstance (ind , ( np . ndarray , dpt .usm_ndarray ) ):
861
+ any_array = True
852
862
if ind .dtype .kind not in "ui" :
853
863
raise IndexError (
854
864
"arrays used as indices must be of integer (or boolean) "
855
865
"type"
856
866
)
857
- queue_list .append (ind .sycl_queue )
858
- usm_type_list .append (ind .usm_type )
867
+ if isinstance (ind , dpt .usm_ndarray ):
868
+ queues .append (ind .sycl_queue )
869
+ usm_types .append (ind .usm_type )
859
870
elif not isinstance (ind , Integral ):
860
871
raise TypeError (
861
- "all elements of `ind` expected to be usm_ndarrays "
862
- f"or integers, found { type (ind )} "
872
+ "all elements of `ind` expected to be usm_ndarrays, "
873
+ f"NumPy arrays, or integers, found { type (ind )} "
863
874
)
864
- if not any_usmarray :
875
+ if not any_array :
865
876
raise TypeError (
866
- "at least one element of `inds` expected to be a usm_ndarray "
877
+ "at least one element of `inds` expected to be an array "
867
878
)
868
- return inds
879
+ usm_type = dpctl .utils .get_coerced_usm_type (usm_types )
880
+ q = dpctl .utils .get_execution_queue (queues )
881
+ return q , usm_type
869
882
870
883
871
884
def _prepare_indices_arrays (inds , q , usm_type ):
@@ -922,18 +935,12 @@ def _take_multi_index(ary, inds, p, mode=0):
922
935
raise ValueError (
923
936
"Invalid value for mode keyword, only 0 or 1 is supported"
924
937
)
925
- queues_ = [
926
- ary .sycl_queue ,
927
- ]
928
- usm_types_ = [
929
- ary .usm_type ,
930
- ]
931
938
if not isinstance (inds , (list , tuple )):
932
939
inds = (inds ,)
933
940
934
- _validate_indices ( inds , queues_ , usm_types_ )
935
- res_usm_type = dpctl . utils . get_coerced_usm_type ( usm_types_ )
936
- exec_q = dpctl . utils . get_execution_queue ( queues_ )
941
+ exec_q , res_usm_type = _get_indices_queue_usm_type (
942
+ inds , ary . sycl_queue , ary . usm_type
943
+ )
937
944
if exec_q is None :
938
945
raise dpctl .utils .ExecutionPlacementError (
939
946
"Can not automatically determine where to allocate the "
@@ -942,8 +949,7 @@ def _take_multi_index(ary, inds, p, mode=0):
942
949
"be associated with the same queue."
943
950
)
944
951
945
- if len (inds ) > 1 :
946
- inds = _prepare_indices_arrays (inds , exec_q , res_usm_type )
952
+ inds = _prepare_indices_arrays (inds , exec_q , res_usm_type )
947
953
948
954
ind0 = inds [0 ]
949
955
ary_sh = ary .shape
@@ -976,21 +982,51 @@ def _place_impl(ary, ary_mask, vals, axis=0):
976
982
raise TypeError (
977
983
f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary )} "
978
984
)
979
- if not isinstance (ary_mask , dpt .usm_ndarray ):
980
- raise TypeError (
981
- f"Expecting type dpctl.tensor.usm_ndarray, got { type (ary_mask )} "
985
+ if isinstance (ary_mask , dpt .usm_ndarray ):
986
+ exec_q = dpctl .utils .get_execution_queue (
987
+ (
988
+ ary .sycl_queue ,
989
+ ary_mask .sycl_queue ,
990
+ )
982
991
)
983
- exec_q = dpctl .utils .get_execution_queue (
984
- (
985
- ary .sycl_queue ,
986
- ary_mask .sycl_queue ,
992
+ coerced_usm_type = dpctl .utils .get_coerced_usm_type (
993
+ (
994
+ ary .usm_type ,
995
+ ary_mask .usm_type ,
996
+ )
997
+ )
998
+ if exec_q is None :
999
+ raise dpctl .utils .ExecutionPlacementError (
1000
+ "arrays have different associated queues. "
1001
+ "Use `y.to_device(x.device)` to migrate."
1002
+ )
1003
+ elif isinstance (ary_mask , np .ndarray ):
1004
+ exec_q = ary .sycl_queue
1005
+ coerced_usm_type = ary .usm_type
1006
+ ary_mask = dpt .asarray (
1007
+ ary_mask , usm_type = coerced_usm_type , sycl_queue = exec_q
1008
+ )
1009
+ else :
1010
+ raise TypeError (
1011
+ "Expecting type dpctl.tensor.usm_ndarray or numpy.ndarray, got "
1012
+ f"{ type (ary_mask )} "
987
1013
)
988
- )
989
1014
if exec_q is not None :
990
1015
if not isinstance (vals , dpt .usm_ndarray ):
991
- vals = dpt .asarray (vals , dtype = ary .dtype , sycl_queue = exec_q )
1016
+ vals = dpt .asarray (
1017
+ vals ,
1018
+ dtype = ary .dtype ,
1019
+ usm_type = coerced_usm_type ,
1020
+ sycl_queue = exec_q ,
1021
+ )
992
1022
else :
993
1023
exec_q = dpctl .utils .get_execution_queue ((exec_q , vals .sycl_queue ))
1024
+ coerced_usm_type = dpctl .utils .get_coerced_usm_type (
1025
+ (
1026
+ coerced_usm_type ,
1027
+ vals .usm_type ,
1028
+ )
1029
+ )
994
1030
if exec_q is None :
995
1031
raise dpctl .utils .ExecutionPlacementError (
996
1032
"arrays have different associated queues. "
@@ -1005,7 +1041,12 @@ def _place_impl(ary, ary_mask, vals, axis=0):
1005
1041
)
1006
1042
mask_nelems = ary_mask .size
1007
1043
cumsum_dt = dpt .int32 if mask_nelems < int32_t_max else dpt .int64
1008
- cumsum = dpt .empty (mask_nelems , dtype = cumsum_dt , device = ary_mask .device )
1044
+ cumsum = dpt .empty (
1045
+ mask_nelems ,
1046
+ dtype = cumsum_dt ,
1047
+ usm_type = coerced_usm_type ,
1048
+ device = ary_mask .device ,
1049
+ )
1009
1050
exec_q = cumsum .sycl_queue
1010
1051
_manager = dpctl .utils .SequentialOrderManager [exec_q ]
1011
1052
dep_ev = _manager .submitted_events
@@ -1048,30 +1089,29 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
1048
1089
raise ValueError (
1049
1090
"Invalid value for mode keyword, only 0 or 1 is supported"
1050
1091
)
1051
- if isinstance (vals , dpt .usm_ndarray ):
1052
- queues_ = [ary .sycl_queue , vals .sycl_queue ]
1053
- usm_types_ = [ary .usm_type , vals .usm_type ]
1054
- else :
1055
- queues_ = [
1056
- ary .sycl_queue ,
1057
- ]
1058
- usm_types_ = [
1059
- ary .usm_type ,
1060
- ]
1061
1092
if not isinstance (inds , (list , tuple )):
1062
1093
inds = (inds ,)
1063
1094
1064
- _validate_indices (inds , queues_ , usm_types_ )
1095
+ exec_q , coerced_usm_type = _get_indices_queue_usm_type (
1096
+ inds , ary .sycl_queue , ary .usm_type
1097
+ )
1065
1098
1066
- vals_usm_type = dpctl .utils .get_coerced_usm_type (usm_types_ )
1067
- exec_q = dpctl .utils .get_execution_queue (queues_ )
1068
1099
if exec_q is not None :
1069
1100
if not isinstance (vals , dpt .usm_ndarray ):
1070
1101
vals = dpt .asarray (
1071
- vals , dtype = ary .dtype , usm_type = vals_usm_type , sycl_queue = exec_q
1102
+ vals ,
1103
+ dtype = ary .dtype ,
1104
+ usm_type = coerced_usm_type ,
1105
+ sycl_queue = exec_q ,
1072
1106
)
1073
1107
else :
1074
1108
exec_q = dpctl .utils .get_execution_queue ((exec_q , vals .sycl_queue ))
1109
+ coerced_usm_type = dpctl .utils .get_coerced_usm_type (
1110
+ (
1111
+ coerced_usm_type ,
1112
+ vals .usm_type ,
1113
+ )
1114
+ )
1075
1115
if exec_q is None :
1076
1116
raise dpctl .utils .ExecutionPlacementError (
1077
1117
"Can not automatically determine where to allocate the "
@@ -1080,8 +1120,7 @@ def _put_multi_index(ary, inds, p, vals, mode=0):
1080
1120
"be associated with the same queue."
1081
1121
)
1082
1122
1083
- if len (inds ) > 1 :
1084
- inds = _prepare_indices_arrays (inds , exec_q , vals_usm_type )
1123
+ inds = _prepare_indices_arrays (inds , exec_q , coerced_usm_type )
1085
1124
1086
1125
ind0 = inds [0 ]
1087
1126
ary_sh = ary .shape
0 commit comments