Skip to content

Commit fefd2cb

Browse files
committed
add flow to knn call
1 parent 5221414 commit fefd2cb

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

test/test_knn.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,14 @@ def test_knn_graph(dtype, device):
4343
[+1, -1],
4444
], dtype, device)
4545

46-
row, col = knn_graph(x, k=2)
46+
row, col = knn_graph(x, k=2, flow='target_to_source')
4747
col = col.view(-1, 2).sort(dim=-1)[0].view(-1)
4848

4949
assert row.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
5050
assert col.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
51+
52+
row, col = knn_graph(x, k=2, flow='source_to_target')
53+
row = row.view(-1, 2).sort(dim=-1)[0].view(-1)
54+
55+
assert row.tolist() == [1, 3, 0, 2, 1, 3, 0, 2]
56+
assert col.tolist() == [0, 0, 1, 1, 2, 2, 3, 3]

torch_cluster/knn.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
7979
return torch.stack([row, col], dim=0)
8080

8181

82-
def knn_graph(x, k, batch=None, loop=False):
82+
def knn_graph(x, k, batch=None, loop=False, flow='source_to_target'):
8383
r"""Computes graph edges to the nearest :obj:`k` points.
8484
8585
Args:
@@ -91,6 +91,9 @@ def knn_graph(x, k, batch=None, loop=False):
9191
node to a specific example. (default: :obj:`None`)
9292
loop (bool, optional): If :obj:`True`, the graph will contain
9393
self-loops. (default: :obj:`False`)
94+
flow (string, optional): The flow direction when using in combination
95+
with message passing (:obj:`"source_to_target"` or
96+
:obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
9497
9598
:rtype: :class:`LongTensor`
9699
@@ -106,10 +109,10 @@ def knn_graph(x, k, batch=None, loop=False):
106109
>>> edge_index = knn_graph(x, k=2, batch=batch, loop=False)
107110
"""
108111

109-
edge_index = knn(x, x, k if loop else k + 1, batch, batch)
112+
assert flow in ['source_to_target', 'target_to_source']
113+
row, col = knn(x, x, k if loop else k + 1, batch, batch)
114+
row, col = (col, row) if flow == 'source_to_target' else (row, col)
110115
if not loop:
111-
row, col = edge_index
112116
mask = row != col
113117
row, col = row[mask], col[mask]
114-
edge_index = torch.stack([row, col], dim=0)
115-
return edge_index
118+
return torch.stack([row, col], dim=0)

0 commit comments

Comments
 (0)