Skip to content

Commit e5aae73

Browse files
committed
Index converter dynamic cases fix
1 parent 75b7774 commit e5aae73

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,9 @@ def index_dtype_validator(
392392

393393

394394
@dynamo_tensorrt_converter(
395-
torch.ops.aten.index.Tensor, capability_validator=index_dtype_validator
395+
torch.ops.aten.index.Tensor,
396+
capability_validator=index_dtype_validator,
397+
supports_dynamic_shapes=True,
396398
)
397399
@enforce_tensor_types(
398400
{

tests/py/dynamo/conversion/test_index_aten.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,31 @@ def forward(self, input):
168168
dtype=torch.float32,
169169
),
170170
]
171-
self.run_test_with_dynamic_shape(TestModule(), input_specs)
171+
self.run_test_with_dynamic_shape(
172+
TestModule(), input_specs, use_dynamo_tracer=True
173+
)
174+
175+
176+
class TestIndexDynamicInputNonDynamicIndexConverter(DispatchTestCase):
177+
def test_index_input_non_dynamic_index_dynamic(self):
178+
class TestIndexWithRuntimeIndex(torch.nn.Module):
179+
def forward(self, x):
180+
mask = x > 0
181+
idx = torch.nonzero(mask, as_tuple=True)
182+
return torch.ops.aten.index.Tensor(x, idx)
183+
184+
input_specs = [
185+
Input(
186+
min_shape=(2, 2),
187+
opt_shape=(2, 2),
188+
max_shape=(8, 8),
189+
dtype=torch.float32,
190+
),
191+
]
192+
# In this case the index args[1] gets itself converted to a List of TRTTensors with use_dynamo_tracer=True
193+
self.run_test_with_dynamic_shape(
194+
TestIndexWithRuntimeIndex(), input_specs, use_dynamo_tracer=True
195+
)
172196

173197

174198
if __name__ == "__main__":

0 commit comments

Comments
 (0)