diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index abf721198b..e9be9c9b89 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -414,13 +414,20 @@ def index_dtype_validator( for ind in index: if ind is not None: val = ind.meta.get("val") - if val is not None and val.dtype not in (torch.int32, torch.int64): + if val is not None and val.dtype not in ( + torch.int32, + torch.int64, + torch.bool, + ): return False return True @dynamo_tensorrt_converter( - torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator + torch.ops.aten.index.Tensor, + capability_validator=index_dtype_validator, + supports_dynamic_shapes=True, + requires_output_allocator=True, ) @enforce_tensor_types( { diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py index c4d44a07ea..ded50519ad 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/select.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -14,7 +14,6 @@ cast_trt_tensor, get_positive_dim, get_trt_tensor, - has_dynamic_shape, set_layer_name, to_numpy, ) @@ -51,6 +50,71 @@ def select( return layer.get_output(0) +def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool: + if isinstance(tensor, (torch.Tensor, np.ndarray, TRTTensor)): + return bool(tensor.dtype == torch.bool) + # when index is a node + else: + val = tensor.meta.get("val") + if val is not None and val.dtype is torch.bool: + return True + + return isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.dtype == torch.bool + + +def expand_boolean_indices( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]], +) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]: + new_indices = [] + for i, ind in enumerate(indices): + if ind is not None and is_boolean_tensor(ind): + _LOGGER.debug( + f"Boolean index detected at position {i}, converting with nonzero()" + ) + mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}") + + nonzero_layer = ctx.net.add_non_zero(mask_tensor) + set_layer_name( + nonzero_layer, target, name + f"_bool_nonzero_{i}", source_ir + ) + nonzero_indices = nonzero_layer.get_output(0) + + # nonzero returns shape [N, dims], we need to extract dim i + if len(indices) == 1: + # x[mask] — 1D mask + to_squeeze = nonzero_indices + else: + # Advanced multi-axis mask: extract index i from shape [N, D] + gather_axis = 1 # dim index + gather_layer = ctx.net.add_gather( + nonzero_indices, + get_trt_tensor(ctx, i, name + f"_dim_index_{i}"), + gather_axis, + ) + set_layer_name( + gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir + ) + to_squeeze = gather_layer.get_output(0) + squeeze_layer = ctx.net.add_shuffle(to_squeeze) + squeeze_layer.reshape_dims = (-1,) + set_layer_name( + squeeze_layer, + target, + name + f"_bool_mask_squeeze_{i}", + source_ir, + ) + squeezed_index = squeeze_layer.get_output(0) + new_indices.append(squeezed_index) + else: + new_indices.append(ind) + return new_indices + + def index( ctx: ConversionContext, target: Target, @@ -61,13 +125,12 @@ def index( ) -> TRTTensor: adv_indx_indices = [] tensor_indices = [] - # check if the input is dynamic - dynamic_shape = has_dynamic_shape(input.shape) # is_numpy is a flag to specify if all the indices are numpy or torchTensor. # If any is not this flag will be set to False _LOGGER.debug( "Determining whether aten.index constant-index optimization can be invoked" ) + indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices) is_numpy = all( isinstance(ind, (torch.Tensor, np.ndarray)) for ind in indices diff --git a/tests/py/dynamo/conversion/test_index_aten.py b/tests/py/dynamo/conversion/test_index_aten.py index 8e21f945dc..f7278f84a6 100644 --- a/tests/py/dynamo/conversion/test_index_aten.py +++ b/tests/py/dynamo/conversion/test_index_aten.py @@ -71,6 +71,27 @@ class TestIndexConstantConverter(DispatchTestCase): [None, torch.tensor([0, 0, 1, 1]), None, torch.tensor([0, 0, 1, 1])], torch.randn(2, 4, 4, 2), ), + ( + "mask_index_three_dim", + [None, torch.tensor([True, False]), None], + torch.randn(2, 2, 2), + ), + ( + "mask_index_two_dim", + [torch.tensor([True, False])], + torch.randn(2, 2), + ), + ( + # covers multi axis and discontinuous indices + "mask_index_multi_axis", + [ + None, + torch.tensor([[True, False, False, True]]), # axis 1 + None, + torch.tensor([True, False]), # axis 3 + ], + torch.randn(2, 4, 4, 2), + ), ] ) def test_index_constant(self, _, index, input): @@ -168,7 +189,31 @@ def forward(self, input): dtype=torch.float32, ), ] - self.run_test_with_dynamic_shape(TestModule(), input_specs) + self.run_test_with_dynamic_shape( + TestModule(), input_specs, use_dynamo_tracer=True + ) + + +class TestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase): + def test_index_input_non_dynamic_index_dynamic(self): + class TestIndexWithRuntimeIndex(torch.nn.Module): + def forward(self, x): + mask = x > 0 + idx = torch.nonzero(mask, as_tuple=True) + return torch.ops.aten.index.Tensor(x, idx) + + input_specs = [ + Input( + min_shape=(2, 2), + opt_shape=(2, 2), + max_shape=(8, 8), + dtype=torch.float32, + ), + ] + # In this case the index args[1] gets itself converted to a List of TRTTensors with use_dynamo_tracer=True + self.run_test_with_dynamic_shape( + TestIndexWithRuntimeIndex(), input_specs, use_dynamo_tracer=True + ) if __name__ == "__main__":