@@ -79,7 +79,7 @@ def knn(x, y, k, batch_x=None, batch_y=None):
7979 return torch .stack ([row , col ], dim = 0 )
8080
8181
82- def knn_graph (x , k , batch = None , loop = False ):
82+ def knn_graph (x , k , batch = None , loop = False , flow = 'source_to_target' ):
8383 r"""Computes graph edges to the nearest :obj:`k` points.
8484
8585 Args:
@@ -91,6 +91,9 @@ def knn_graph(x, k, batch=None, loop=False):
9191 node to a specific example. (default: :obj:`None`)
9292 loop (bool, optional): If :obj:`True`, the graph will contain
9393 self-loops. (default: :obj:`False`)
94+ flow (string, optional): The flow direction when using in combination
95+ with message passing (:obj:`"source_to_target"` or
96+ :obj:`"target_to_source"`). (default: :obj:`"source_to_target"`)
9497
9598 :rtype: :class:`LongTensor`
9699
@@ -106,10 +109,10 @@ def knn_graph(x, k, batch=None, loop=False):
106109 >>> edge_index = knn_graph(x, k=2, batch=batch, loop=False)
107110 """
108111
109- edge_index = knn (x , x , k if loop else k + 1 , batch , batch )
112+ assert flow in ['source_to_target' , 'target_to_source' ]
113+ row , col = knn (x , x , k if loop else k + 1 , batch , batch )
114+ row , col = (col , row ) if flow == 'source_to_target' else (row , col )
110115 if not loop :
111- row , col = edge_index
112116 mask = row != col
113117 row , col = row [mask ], col [mask ]
114- edge_index = torch .stack ([row , col ], dim = 0 )
115- return edge_index
118+ return torch .stack ([row , col ], dim = 0 )
0 commit comments