Skip to content

Commit 663fbe3

Browse files
committed
support for boolean indices
1 parent a14fc4c commit 663fbe3

File tree

2 files changed

+65
-3
lines changed

2 files changed

+65
-3
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,11 @@ def index_dtype_validator(
386386
for ind in index:
387387
if ind is not None:
388388
val = ind.meta.get("val")
389-
if val is not None and val.dtype not in (torch.int32, torch.int64):
389+
if val is not None and val.dtype not in (
390+
torch.int32,
391+
torch.int64,
392+
torch.bool,
393+
):
390394
return False
391395
return True
392396

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,65 @@ def select(
5353
return layer.get_output(0)
5454

5555

56+
def is_boolean_tensor(tensor: Union[TRTTensor, np.ndarray, torch.Tensor]) -> bool:
57+
if isinstance(tensor, (TRTTensor)):
58+
val = tensor.meta.get("val")
59+
if val is not None and val.dtype is torch.bool:
60+
return True
61+
return isinstance(tensor, (torch.Tensor, np.ndarray)) and tensor.dtype == torch.bool
62+
63+
64+
def expand_boolean_indices(
65+
ctx: ConversionContext,
66+
target: Target,
67+
source_ir: Optional[SourceIR],
68+
name: str,
69+
input: TRTTensor,
70+
indices: Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]],
71+
) -> Sequence[Union[TRTTensor, np.ndarray, torch.Tensor]]:
72+
for i, ind in enumerate(indices):
73+
if ind is not None and is_boolean_tensor(ind):
74+
_LOGGER.debug(
75+
f"Boolean index detected at position {i}, converting with nonzero()"
76+
)
77+
78+
mask_tensor = get_trt_tensor(ctx, ind, name + f"_bool_mask_{i}")
79+
80+
nonzero_layer = ctx.net.add_non_zero(mask_tensor)
81+
set_layer_name(
82+
nonzero_layer, target, name + f"_bool_nonzero_{i}", source_ir
83+
)
84+
nonzero_indices = nonzero_layer.get_output(0)
85+
86+
# nonzero returns shape [N, dims], we need to extract dim i
87+
if len(indices) == 1:
88+
# x[mask] — 1D mask
89+
squeeze_layer = ctx.net.add_shuffle(nonzero_indices)
90+
squeeze_layer.reshape_dims = (-1,)
91+
set_layer_name(
92+
squeeze_layer,
93+
target,
94+
name + f"_bool_nonzero_squeeze_{i}",
95+
source_ir,
96+
)
97+
squeezed_index = squeeze_layer.get_output(0)
98+
ind = squeezed_index
99+
else:
100+
# Advanced multi-axis mask: extract index i from shape [N, D]
101+
gather_axis = 1 # dim index
102+
gather_layer = ctx.net.add_gather(
103+
nonzero_indices,
104+
get_trt_tensor(ctx, i, name + f"_dim_index_{i}"),
105+
gather_axis,
106+
)
107+
set_layer_name(
108+
gather_layer, target, name + f"_bool_nonzero_extract_{i}", source_ir
109+
)
110+
extracted_index = gather_layer.get_output(0)
111+
ind = extracted_index
112+
return indices
113+
114+
56115
def index(
57116
ctx: ConversionContext,
58117
target: Target,
@@ -63,8 +122,6 @@ def index(
63122
) -> TRTTensor:
64123
adv_indx_indices = []
65124
tensor_indices = []
66-
# check if the input is dynamic
67-
dynamic_shape = has_dynamic_shape(input.shape)
68125
# is_numpy is a flag to specify if all the indices are numpy or torchTensor.
69126
# If any is not this flag will be set to False
70127
_LOGGER.debug(
@@ -78,6 +135,7 @@ def index(
78135
# here we need to check if all the index are broadcastable
79136
# if no, then we need to broadcast
80137
last_index = None
138+
indices = expand_boolean_indices(ctx, target, source_ir, name, input, indices)
81139
for i, ind in enumerate(indices):
82140
if ind is not None:
83141
_LOGGER.debug(f"Shape of {i} index is {ind.shape}")

0 commit comments

Comments
 (0)