From f4e5ea9d7aacf95b82d36193af47cf698c826e9b Mon Sep 17 00:00:00 2001 From: maronuu Date: Sun, 22 Aug 2021 07:33:15 +0000 Subject: [PATCH] add converter for torch.Tensor.expand_as() method --- torch2trt/converters/__init__.py | 1 + torch2trt/converters/expand_as.py | 52 +++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) create mode 100644 torch2trt/converters/expand_as.py diff --git a/torch2trt/converters/__init__.py b/torch2trt/converters/__init__.py index 83465152..10f7b60f 100644 --- a/torch2trt/converters/__init__.py +++ b/torch2trt/converters/__init__.py @@ -29,6 +29,7 @@ from .compare import * from .div import * from .expand import * +from .expand_as import * from .floordiv import * from .gelu import * from .getitem import * diff --git a/torch2trt/converters/expand_as.py b/torch2trt/converters/expand_as.py new file mode 100644 index 00000000..fea5bd06 --- /dev/null +++ b/torch2trt/converters/expand_as.py @@ -0,0 +1,52 @@ +from torch2trt.torch2trt import * +from torch2trt.module_test import add_module_test + + +@tensorrt_converter('torch.Tensor.expand_as') +def convert_expand_as(ctx): + input = ctx.method_args[0] + output = ctx.method_return + + inshape = tuple(input.shape)[1:] # exclude batch + shape = tuple(output.shape)[1:] + ndim = len(shape) + start = tuple([0]*ndim) + stride = tuple([int(i == o) for i, o in zip(inshape, shape)]) # stride == 1 if dimensions match, 0 otherwise + + layer = ctx.network.add_slice(input._trt, start, shape, stride) + + output._trt = layer.get_output(0) + + +class ExpandAsModule(torch.nn.Module): + def __init__(self, other: torch.Tensor): + super(ExpandAsModule, self).__init__() + self.other = other + + def forward(self, x: torch.Tensor): + return x.expand_as(self.other) + + +@add_module_test(torch.float32, torch.device('cuda'), [(1),]) +def test_tensor_expand_as_scalar(): + return ExpandAsModule(torch.randn(3)) + + +@add_module_test(torch.float32, torch.device('cuda'), [(1, 1, 3, 3),]) +def test_tensor_expand_as_singledim(): + return ExpandAsModule(torch.randn((1, 3, 3, 3))) + + +@add_module_test(torch.float32, torch.device('cuda'), [(1, 1, 1, 3),]) +def test_tensor_expand_as_multidim(): + return ExpandAsModule(torch.randn((1, 3, 3, 3))) + + +@add_module_test(torch.float16, torch.device('cuda'), [(1, 1, 3, 3),]) +def test_tensor_expand_as_singledim_half(): + return ExpandAsModule(torch.randn((1, 3, 3, 3))) + + +@add_module_test(torch.float16, torch.device('cuda'), [(1, 1, 1, 3),]) +def test_tensor_expand_as_multidim_half(): + return ExpandAsModule(torch.randn((1, 3, 3, 3)))