Skip to content

Commit d5c8a4d

Browse files
committed
100% codecov
1 parent 79b935c commit d5c8a4d

File tree

3 files changed

+19
-2
lines changed

3 files changed

+19
-2
lines changed

test/dense_grid.json

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@
3737
"batch": [0, 0, 0, 0, 0, 1, 1],
3838
"expected": [0, 5, 1, 0, 2, 6, 9],
3939
"expected_C": 6
40+
},
41+
{
42+
"name": "Batch with start/end parameter",
43+
"position": [[0, 0], [11, 9], [2, 8], [2, 2], [8, 3], [1, 1], [6, 6]],
44+
"size": [5, 5],
45+
"batch": [0, 0, 0, 0, 0, 1, 1],
46+
"start": 0,
47+
"end": 20,
48+
"expected": [0, 9, 1, 0, 4, 16, 21],
49+
"expected_C": 16
4050
}
4151
]
4252

test/sparse_grid.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,5 +31,11 @@
3131
"batch": [0, 0, 0, 0, 0, 1, 1],
3232
"expected": [0, 3, 1, 0, 2, 4, 5],
3333
"expected_batch": [0, 0, 0, 0, 1, 1]
34+
},
35+
{
36+
"name": "Position tensor",
37+
"position": [[[0, 0], [9, 9]], [[0, 0], [9, 0]]],
38+
"size": [5, 5],
39+
"expected": [[0, 2], [0, 1]]
3440
}
3541
]

torch_cluster/functions/grid.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@ def _fixed_cluster_size(position, size, batch=None, end=None):
4848
if end is None:
4949
return _minimal_cluster_size(position, size)
5050

51-
eps = 0.000001 # Model [start, end).
51+
eps = 0.000001 # Simulate [start, end) interval.
5252
if batch is None:
5353
cluster_size = ((end / size).float() - eps).long() + 1
5454
else:
5555
cluster_size = ((end / size[1:]).float() - eps).long() + 1
56-
cluster_size = torch.cat([batch.max() + 1, cluster_size], dim=0)
56+
max_batch = cluster_size.new(1).fill_(batch.max() + 1)
57+
cluster_size = torch.cat([max_batch, cluster_size], dim=0)
5758

5859
return cluster_size
5960

0 commit comments

Comments
 (0)