Skip to content

Commit 0a03833

Browse files
committed
update doc
1 parent 21208fc commit 0a03833

File tree

5 files changed

+110
-72
lines changed

5 files changed

+110
-72
lines changed

README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,15 @@ tensor([0, 5, 3, 0, 1])
9090

9191
## FarthestPointSampling
9292

93-
A sampling algorithm, which iteratively samples the most distant point (in metric distance) with regard to the rest points.
93+
A sampling algorithm, which iteratively samples the most distant point with regard to the rest points.
9494

9595
```python
9696
import torch
9797
from torch_cluster import fps
9898

9999
x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
100100
batch = torch.tensor([0, 0, 0, 0])
101-
sample = fps(x, batch, ratio=0.5, random_start=False)
101+
index = fps(x, batch, ratio=0.5, random_start=False)
102102
```
103103

104104
```
@@ -108,7 +108,7 @@ tensor([0, 3])
108108

109109
## kNN-Graph
110110

111-
Computes graph edges to the nearest *k* points in metric space.
111+
Computes graph edges to the nearest *k* points.
112112

113113
```python
114114
import torch
@@ -127,7 +127,7 @@ tensor([[0, 0, 1, 1, 2, 2, 3, 3],
127127

128128
## Radius-Graph
129129

130-
Computes graph edges to all points within a given distance in metric space.
130+
Computes graph edges to all points within a given distance.
131131

132132
```python
133133
import torch
@@ -146,17 +146,17 @@ tensor([[0, 0, 1, 1, 2, 2, 3, 3],
146146

147147
## Nearest
148148

149-
Clusters points which are nearest to a given query point in metric space.
149+
Clusters points in *x* together which are nearest to a given query point in *y*.
150150

151151
```python
152152
import torch
153153
from torch_cluster import nearest
154154

155155
x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
156-
batch_x = torch.Tensor([0, 0, 0, 0])
157-
query_x = torch.Tensor([[-1, 0], [1, 0]])
158-
query_batch = torch.Tensor([0, 0])
159-
cluster = nearest(x, query_x, batch_x, query_batch)
156+
batch_x = torch.tensor([0, 0, 0, 0])
157+
y = torch.Tensor([[-1, 0], [1, 0]])
158+
batch_y = torch.tensor([0, 0])
159+
cluster = nearest(x, y, batch_x, batch_y)
160160
```
161161

162162
```

torch_cluster/fps.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,36 @@
22

33
if torch.cuda.is_available():
44
import fps_cuda
5-
""" """
65

76

87
def fps(x, batch=None, ratio=0.5, random_start=True):
9-
"""Iteratively samples the most distant point (in metric distance) with
10-
regard to the rest points.
8+
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
9+
Learning on Point Sets in a Metric Space"
10+
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
11+
most distant point with regard to the rest points.
1112
1213
Args:
13-
x (Tensor): D-dimensional point features.
14-
batch (LongTensor, optional): Vector that maps each point to its
15-
example identifier. If :obj:`None`, all points belong to the same
16-
example. If not :obj:`None`, points in the same example need to
17-
have contiguous memory layout and :obj:`batch` needs to be
18-
ascending. (default: :obj:`None`)
14+
x (Tensor): Node feature matrix
15+
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
16+
batch (LongTensor, optional): Batch vector
17+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
18+
node to a specific example. (default: :obj:`None`)
1919
ratio (float, optional): Sampling ratio. (default: :obj:`0.5`)
20-
random_start (bool, optional): Whether the starting node is
21-
sampled randomly. (default: :obj:`True`)
20+
random_start (bool, optional): If set to :obj:`False`, use the first
21+
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
2222
2323
:rtype: :class:`LongTensor`
2424
25-
Examples::
25+
.. testsetup::
26+
27+
import torch
28+
from torch_cluster import fps
29+
30+
.. testcode::
2631
2732
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
2833
>>> batch = torch.tensor([0, 0, 0, 0])
29-
>>> sample = fps(x, batch, ratio=0.5)
34+
>>> index = fps(x, batch, ratio=0.5)
3035
"""
3136

3237
if batch is None:

torch_cluster/knn.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,30 @@
66

77

88
def knn(x, y, k, batch_x=None, batch_y=None):
9-
"""Finds for each element in `y` the `k` nearest points in `x`.
9+
r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in
10+
:obj:`x`.
1011
1112
Args:
12-
x (Tensor): D-dimensional point features.
13-
y (Tensor): D-dimensional point features.
13+
x (Tensor): Node feature matrix
14+
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
15+
y (Tensor): Node feature matrix
16+
:math:`\mathbf{X} \in \mathbb{R}^{M \times F}`.
1417
k (int): The number of neighbors.
15-
batch_x (LongTensor, optional): Vector that maps each point to its
16-
example identifier. If :obj:`None`, all points belong to the same
17-
example. If not :obj:`None`, points in the same example need to
18-
have contiguous memory layout and :obj:`batch` needs to be
19-
ascending. (default: :obj:`None`)
20-
batch_y (LongTensor, optional): See `batch_x` (default: :obj:`None`)
18+
batch_x (LongTensor, optional): Batch vector
19+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
20+
node to a specific example. (default: :obj:`None`)
21+
batch_y (LongTensor, optional): Batch vector
22+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
23+
node to a specific example. (default: :obj:`None`)
2124
2225
:rtype: :class:`LongTensor`
2326
24-
Examples::
27+
.. testsetup::
28+
29+
import torch
30+
from torch_cluster import knn
31+
32+
.. testcode::
2533
2634
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
2735
>>> batch_x = torch.tensor([0, 0, 0, 0])
@@ -70,22 +78,26 @@ def knn(x, y, k, batch_x=None, batch_y=None):
7078

7179

7280
def knn_graph(x, k, batch=None, loop=False):
73-
"""Finds for each element in `x` the `k` nearest points.
81+
r"""Computes graph edges to the nearest :obj:`k` points.
7482
7583
Args:
76-
x (Tensor): D-dimensional point features.
84+
x (Tensor): Node feature matrix
85+
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
7786
k (int): The number of neighbors.
78-
batch (LongTensor, optional): Vector that maps each point to its
79-
example identifier. If :obj:`None`, all points belong to the same
80-
example. If not :obj:`None`, points in the same example need to
81-
have contiguous memory layout and :obj:`batch` needs to be
82-
ascending. (default: :obj:`None`)
87+
batch (LongTensor, optional): Batch vector
88+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
89+
node to a specific example. (default: :obj:`None`)
8390
loop (bool, optional): If :obj:`True`, the graph will contain
8491
self-loops. (default: :obj:`False`)
8592
8693
:rtype: :class:`LongTensor`
8794
88-
Examples::
95+
.. testsetup::
96+
97+
import torch
98+
from torch_cluster import knn_graph
99+
100+
.. testcode::
89101
90102
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
91103
>>> batch = torch.tensor([0, 0, 0, 0])

torch_cluster/nearest.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,32 @@
66

77

88
def nearest(x, y, batch_x=None, batch_y=None):
9-
"""Finds for each element in `x` its nearest point in `y`.
9+
"""Clusters points in :obj:`x` together which are nearest to a given query
10+
point in :obj:`y`.
1011
1112
Args:
12-
x (Tensor): D-dimensional point features.
13-
y (Tensor): D-dimensional point features.
14-
batch_x (LongTensor, optional): Vector that maps each point to its
15-
example identifier. If :obj:`None`, all points belong to the same
16-
example. If not :obj:`None`, points in the same example need to
17-
have contiguous memory layout and :obj:`batch` needs to be
18-
ascending. (default: :obj:`None`)
19-
batch_y (LongTensor, optional): See `batch_x` (default: :obj:`None`)
13+
x (Tensor): Node feature matrix
14+
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
15+
y (Tensor): Node feature matrix
16+
:math:`\mathbf{X} \in \mathbb{R}^{M \times F}`.
17+
batch_x (LongTensor, optional): Batch vector
18+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
19+
node to a specific example. (default: :obj:`None`)
20+
batch_y (LongTensor, optional): Batch vector
21+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
22+
node to a specific example. (default: :obj:`None`)
2023
21-
Examples::
24+
.. testsetup::
25+
26+
import torch
27+
from torch_cluster import nearest
28+
29+
.. testcode::
2230
2331
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
24-
>>> batch_x = torch.Tensor([0, 0, 0, 0])
32+
>>> batch_x = torch.tensor([0, 0, 0, 0])
2533
>>> y = torch.Tensor([[-1, 0], [1, 0]])
26-
>>> batch_x = torch.Tensor([0, 0])
34+
>>> batch_x = torch.tensor([0, 0])
2735
>>> cluster = nearest(x, y, batch_x, batch_y)
2836
"""
2937

torch_cluster/radius.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,33 @@
66

77

88
def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
9-
"""Finds for each element in `y` all points in `x` within distance `r`.
9+
r"""Finds for each element in :obj:`y` all points in :obj:`x` within
10+
distance :obj:`r`.
1011
1112
Args:
12-
x (Tensor): D-dimensional point features.
13-
y (Tensor): D-dimensional point features.
13+
x (Tensor): Node feature matrix
14+
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
15+
y (Tensor): Node feature matrix
16+
:math:`\mathbf{X} \in \mathbb{R}^{M \times F}`.
1417
r (float): The radius.
15-
batch_x (LongTensor, optional): Vector that maps each point to its
16-
example identifier. If :obj:`None`, all points belong to the same
17-
example. If not :obj:`None`, points in the same example need to
18-
have contiguous memory layout and :obj:`batch` needs to be
19-
ascending. (default: :obj:`None`)
20-
batch_y (LongTensor, optional): See `batch_x` (default: :obj:`None`)
18+
batch_x (LongTensor, optional): Batch vector
19+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
20+
node to a specific example. (default: :obj:`None`)
21+
batch_y (LongTensor, optional): Batch vector
22+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each
23+
node to a specific example. (default: :obj:`None`)
2124
max_num_neighbors (int, optional): The maximum number of neighbors to
22-
return for each element in `y`. (default: :obj:`32`)
25+
return for each element in :obj:`y`. (default: :obj:`32`)
2326
2427
:rtype: :class:`LongTensor`
2528
26-
Examples::
29+
.. testsetup::
30+
31+
import torch
32+
from torch_cluster import radius
33+
34+
.. testcode::
35+
2736
2837
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
2938
>>> batch_x = torch.tensor([0, 0, 0, 0])
@@ -63,24 +72,28 @@ def radius(x, y, r, batch_x=None, batch_y=None, max_num_neighbors=32):
6372

6473

6574
def radius_graph(x, r, batch=None, loop=False, max_num_neighbors=32):
66-
"""Finds for each element in `x` all points in `x` within distance `r`.
75+
r"""Computes graph edges to all points within a given distance.
6776
6877
Args:
69-
x (Tensor): D-dimensional point features.
78+
x (Tensor): Node feature matrix
79+
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
7080
r (float): The radius.
71-
batch (LongTensor, optional): Vector that maps each point to its
72-
example identifier. If :obj:`None`, all points belong to the same
73-
example. If not :obj:`None`, points in the same example need to
74-
have contiguous memory layout and :obj:`batch` needs to be
75-
ascending. (default: :obj:`None`)
81+
batch (LongTensor, optional): Batch vector
82+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
83+
node to a specific example. (default: :obj:`None`)
7684
loop (bool, optional): If :obj:`True`, the graph will contain
7785
self-loops. (default: :obj:`False`)
7886
max_num_neighbors (int, optional): The maximum number of neighbors to
79-
return for each element in `y`. (default: :obj:`32`)
87+
return for each element in :obj:`y`. (default: :obj:`32`)
8088
8189
:rtype: :class:`LongTensor`
8290
83-
Examples::
91+
.. testsetup::
92+
93+
import torch
94+
from torch_cluster import radius_graph
95+
96+
.. testcode::
8497
8598
>>> x = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
8699
>>> batch = torch.tensor([0, 0, 0, 0])

0 commit comments

Comments
 (0)