Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions namedtensor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,32 @@ def _broadcast_order(self, other):
order.append(d)
return order

def _broadcast_order_shape(self, other, indim, outdim):
"""
Outputs two orders (list) for self and other,
as well as the shapes necessary to expand to a shared size.
Assumes update from indim to outdim.
Moves indim and outdim to the front to ensure the most spacing.
"""
self_order = [indim]
other_order = [outdim]
self_shape = [self.shape[indim]]
other_shape = [other.shape[outdim]]
exclude = {indim, outdim}
for d, s in other.shape.items():
if d not in self._schema._names and d not in exclude:
self_order.append(d)
other_order.append(d)
self_shape.append(s)
other_shape.append(s)
for d, s in self.shape.items():
if d not in exclude:
self_order.append(d)
other_order.append(d)
self_shape.append(s)
other_shape.append(s)
return self_order, other_order, self_shape, other_shape

def _mask_broadcast_order(self, main):
"""
If broadcasting possible from self (mask) to main, outputs a shared order.
Expand Down
5 changes: 4 additions & 1 deletion namedtensor/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,10 @@ def test_gather():

t = ntorch.tensor(torch.Tensor([[1, 2], [3, 4]]), ("a", "b"))
index = ntorch.tensor(torch.LongTensor([[0, 0], [1, 0]]), ("a", "c"))
ntensor = ntorch.gather(t, "b", index, "c")
# Gather will move "b" and "c" to the front for t and index respectively
# so we must force the order in order to compare to the original
# torch.gather.
ntensor = ntorch.gather(t, "b", index, "c")._force_order(("a", "c"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't a good unit test. It shouldn't call any _ functions.

Copy link
Contributor Author

@justinchiu justinchiu Feb 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High level question: Since we don't assume any ordering, is the right approach to try all permutations of the output ntensor and pass if any of them succeed (equal base)? Or should wer try to keep the underlying order the same as torch.* (although this may be unclear for ntorch.gather since broadcasting isn't defined in torch.gather).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the ntorch.equal function will now do this automatically. But either way isn't the function you want just .transpose? More importantly does this test prove to me that your change works?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thanks. I agree that the unit test doesn't test anything, but I wanted to ask about how to compare first. I can write a better test.

assert (ntensor.values == base).all()
assert ntensor.shape == OrderedDict([("a", 2), ("c", 2)])

Expand Down
19 changes: 12 additions & 7 deletions namedtensor/torch_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,20 @@ def unique(input, dim=None, names=("unique", "Indices"), **kwargs):

@staticmethod
def gather(input, dim, index, index_dim):
"Gathers elements using `index` from `input`."
outdim = index_dim
indim = dim
index_order = [
(n if n != indim else outdim) for n in input._schema._names
]
b1 = index._force_order(index_order)
dim = input._schema.get(indim)
return input._new(
input.values.gather(dim, b1.values), updates={indim: outdim}
input_order, index_order, input_shape, index_shape = (
input._broadcast_order_shape(index, indim, outdim)
)
input1 = input._force_order(input_order)
index1 = index._force_order(index_order)
dim = input1._schema.get(indim)
return input1._new(
input1.values.expand(input_shape).gather(
dim, index1.values.expand(index_shape)
),
updates={indim: outdim}
)

@staticmethod
Expand Down