From c955dbf85bce5fba8937c6a1cdaaae37f87a033f Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Sun, 24 Feb 2019 20:37:18 -0500 Subject: [PATCH 1/2] Add broadcast + expand to gather --- namedtensor/core.py | 26 ++++++++++++++++++++++++++ namedtensor/test_core.py | 5 ++++- namedtensor/torch_base.py | 19 ++++++++++++------- 3 files changed, 42 insertions(+), 8 deletions(-) diff --git a/namedtensor/core.py b/namedtensor/core.py index b6b8584..959cf4f 100644 --- a/namedtensor/core.py +++ b/namedtensor/core.py @@ -207,6 +207,32 @@ def _broadcast_order(self, other): order.append(d) return order + def _broadcast_order_shape(self, other, indim, outdim): + """ + Outputs two orders (list) that works for self and other, + as well as the shapes necessary to expand to a shared size. + Assumes update from indim to outdim. + Moves indim and outdim to the front to ensure the most spacing. + """ + self_order = [indim] + other_order = [outdim] + self_shape = [self.shape[indim]] + other_shape = [other.shape[outdim]] + exclude = {indim, outdim} + for d, s in other.shape.items(): + if d not in self._schema._names and d not in exclude: + self_order.append(d) + other_order.append(d) + self_shape.append(s) + other_shape.append(s) + for d, s in self.shape.items(): + if d not in exclude: + self_order.append(d) + other_order.append(d) + self_shape.append(s) + other_shape.append(s) + return self_order, other_order, self_shape, other_shape + def _mask_broadcast_order(self, main): """ If broadcasting possible from self (mask) to main, outputs a shared order. diff --git a/namedtensor/test_core.py b/namedtensor/test_core.py index e8ccc4f..7bd02ab 100644 --- a/namedtensor/test_core.py +++ b/namedtensor/test_core.py @@ -284,7 +284,10 @@ def test_gather(): t = ntorch.tensor(torch.Tensor([[1, 2], [3, 4]]), ("a", "b")) index = ntorch.tensor(torch.LongTensor([[0, 0], [1, 0]]), ("a", "c")) - ntensor = ntorch.gather(t, "b", index, "c") + # Gather will move "b" and "c" to the front for t and index respectively + # so we must force the order in order to compare to the original + # torch.gather. + ntensor = ntorch.gather(t, "b", index, "c")._force_order(("a", "c")) assert (ntensor.values == base).all() assert ntensor.shape == OrderedDict([("a", 2), ("c", 2)]) diff --git a/namedtensor/torch_base.py b/namedtensor/torch_base.py index 8d13a25..3af0b4e 100644 --- a/namedtensor/torch_base.py +++ b/namedtensor/torch_base.py @@ -160,15 +160,20 @@ def unique(input, dim=None, names=("unique", "Indices"), **kwargs): @staticmethod def gather(input, dim, index, index_dim): + "Gathers elements using `index` from `input`." outdim = index_dim indim = dim - index_order = [ - (n if n != indim else outdim) for n in input._schema._names - ] - b1 = index._force_order(index_order) - dim = input._schema.get(indim) - return input._new( - input.values.gather(dim, b1.values), updates={indim: outdim} + input_order, index_order, input_shape, index_shape = ( + input._broadcast_order_shape(index, indim, outdim) + ) + input1 = input._force_order(input_order) + index1 = index._force_order(index_order) + dim = input1._schema.get(indim) + return input1._new( + input1.values.expand(input_shape).gather( + dim, index1.values.expand(index_shape) + ), + updates={indim: outdim} ) @staticmethod From 09ae131f41daa37980b6c98de7a713c1ec9feada Mon Sep 17 00:00:00 2001 From: Justin Chiu Date: Sun, 24 Feb 2019 20:40:27 -0500 Subject: [PATCH 2/2] fix grammar in comment --- namedtensor/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/namedtensor/core.py b/namedtensor/core.py index 959cf4f..e54bd34 100644 --- a/namedtensor/core.py +++ b/namedtensor/core.py @@ -209,7 +209,7 @@ def _broadcast_order(self, other): def _broadcast_order_shape(self, other, indim, outdim): """ - Outputs two orders (list) that works for self and other, + Outputs two orders (list) for self and other, as well as the shapes necessary to expand to a shared size. Assumes update from indim to outdim. Moves indim and outdim to the front to ensure the most spacing.