@@ -53,6 +53,65 @@ def select(
53
53
return layer .get_output (0 )
54
54
55
55
56
+ def is_boolean_tensor (tensor : Union [TRTTensor , np .ndarray , torch .Tensor ]) -> bool :
57
+ if isinstance (tensor , (TRTTensor )):
58
+ val = tensor .meta .get ("val" )
59
+ if val is not None and val .dtype is torch .bool :
60
+ return True
61
+ return isinstance (tensor , (torch .Tensor , np .ndarray )) and tensor .dtype == torch .bool
62
+
63
+
64
+ def expand_boolean_indices (
65
+ ctx : ConversionContext ,
66
+ target : Target ,
67
+ source_ir : Optional [SourceIR ],
68
+ name : str ,
69
+ input : TRTTensor ,
70
+ indices : Sequence [Union [TRTTensor , np .ndarray , torch .Tensor ]],
71
+ ) -> Sequence [Union [TRTTensor , np .ndarray , torch .Tensor ]]:
72
+ for i , ind in enumerate (indices ):
73
+ if ind is not None and is_boolean_tensor (ind ):
74
+ _LOGGER .debug (
75
+ f"Boolean index detected at position { i } , converting with nonzero()"
76
+ )
77
+
78
+ mask_tensor = get_trt_tensor (ctx , ind , name + f"_bool_mask_{ i } " )
79
+
80
+ nonzero_layer = ctx .net .add_non_zero (mask_tensor )
81
+ set_layer_name (
82
+ nonzero_layer , target , name + f"_bool_nonzero_{ i } " , source_ir
83
+ )
84
+ nonzero_indices = nonzero_layer .get_output (0 )
85
+
86
+ # nonzero returns shape [N, dims], we need to extract dim i
87
+ if len (indices ) == 1 :
88
+ # x[mask] — 1D mask
89
+ squeeze_layer = ctx .net .add_shuffle (nonzero_indices )
90
+ squeeze_layer .reshape_dims = (- 1 ,)
91
+ set_layer_name (
92
+ squeeze_layer ,
93
+ target ,
94
+ name + f"_bool_nonzero_squeeze_{ i } " ,
95
+ source_ir ,
96
+ )
97
+ squeezed_index = squeeze_layer .get_output (0 )
98
+ ind = squeezed_index
99
+ else :
100
+ # Advanced multi-axis mask: extract index i from shape [N, D]
101
+ gather_axis = 1 # dim index
102
+ gather_layer = ctx .net .add_gather (
103
+ nonzero_indices ,
104
+ get_trt_tensor (ctx , i , name + f"_dim_index_{ i } " ),
105
+ gather_axis ,
106
+ )
107
+ set_layer_name (
108
+ gather_layer , target , name + f"_bool_nonzero_extract_{ i } " , source_ir
109
+ )
110
+ extracted_index = gather_layer .get_output (0 )
111
+ ind = extracted_index
112
+ return indices
113
+
114
+
56
115
def index (
57
116
ctx : ConversionContext ,
58
117
target : Target ,
@@ -63,8 +122,6 @@ def index(
63
122
) -> TRTTensor :
64
123
adv_indx_indices = []
65
124
tensor_indices = []
66
- # check if the input is dynamic
67
- dynamic_shape = has_dynamic_shape (input .shape )
68
125
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
69
126
# If any is not this flag will be set to False
70
127
_LOGGER .debug (
@@ -78,6 +135,7 @@ def index(
78
135
# here we need to check if all the index are broadcastable
79
136
# if no, then we need to broadcast
80
137
last_index = None
138
+ indices = expand_boolean_indices (ctx , target , source_ir , name , input , indices )
81
139
for i , ind in enumerate (indices ):
82
140
if ind is not None :
83
141
_LOGGER .debug (f"Shape of { i } index is { ind .shape } " )
0 commit comments