File tree Expand file tree Collapse file tree 2 files changed +28
-2
lines changed
py/torch_tensorrt/dynamo/conversion
tests/py/dynamo/conversion Expand file tree Collapse file tree 2 files changed +28
-2
lines changed Original file line number Diff line number Diff line change @@ -392,7 +392,9 @@ def index_dtype_validator(
392
392
393
393
394
394
@dynamo_tensorrt_converter (
395
- torch .ops .aten .index .Tensor , capability_validator = index_dtype_validator
395
+ torch .ops .aten .index .Tensor ,
396
+ capability_validator = index_dtype_validator ,
397
+ supports_dynamic_shapes = True ,
396
398
)
397
399
@enforce_tensor_types (
398
400
{
Original file line number Diff line number Diff line change @@ -168,7 +168,31 @@ def forward(self, input):
168
168
dtype = torch .float32 ,
169
169
),
170
170
]
171
- self .run_test_with_dynamic_shape (TestModule (), input_specs )
171
+ self .run_test_with_dynamic_shape (
172
+ TestModule (), input_specs , use_dynamo_tracer = True
173
+ )
174
+
175
+
176
+ class TestIndexDynamicInputNonDynamicIndexConverter (DispatchTestCase ):
177
+ def test_index_input_non_dynamic_index_dynamic (self ):
178
+ class TestIndexWithRuntimeIndex (torch .nn .Module ):
179
+ def forward (self , x ):
180
+ mask = x > 0
181
+ idx = torch .nonzero (mask , as_tuple = True )
182
+ return torch .ops .aten .index .Tensor (x , idx )
183
+
184
+ input_specs = [
185
+ Input (
186
+ min_shape = (2 , 2 ),
187
+ opt_shape = (2 , 2 ),
188
+ max_shape = (8 , 8 ),
189
+ dtype = torch .float32 ,
190
+ ),
191
+ ]
192
+ # In this case the index args[1] gets itself converted to a List of TRTTensors with use_dynamo_tracer=True
193
+ self .run_test_with_dynamic_shape (
194
+ TestIndexWithRuntimeIndex (), input_specs , use_dynamo_tracer = True
195
+ )
172
196
173
197
174
198
if __name__ == "__main__" :
You can’t perform that action at this time.
0 commit comments