Skip to content

Commit 6b75d23

Browse files
committed
grid cpu
1 parent 878e119 commit 6b75d23

File tree

4 files changed

+12
-15
lines changed

4 files changed

+12
-15
lines changed

cpu/grid.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,17 @@
22

33
at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start,
44
at::Tensor end) {
5-
size = size.toType(pos.type());
6-
start = start.toType(pos.type());
7-
end = end.toType(pos.type());
8-
95
pos = pos - start.view({1, -1});
10-
auto num_voxels = ((end - start) / size).toType(at::kLong);
11-
num_voxels = (num_voxels + 1).cumsum(0);
12-
num_voxels -= num_voxels.data<int64_t>()[0];
13-
num_voxels.data<int64_t>()[0] = 1;
146

15-
auto cluster = pos / size.view({1, -1});
16-
cluster = cluster.toType(at::kLong);
7+
auto num_voxels = ((end - start) / size).toType(at::kLong) + 1;
8+
num_voxels = num_voxels.cumprod(0);
9+
10+
num_voxels = at::cat({at::ones(1, num_voxels.options()), num_voxels}, 0);
11+
auto index = empty(size.size(0), num_voxels.options());
12+
arange_out(index, size.size(0));
13+
num_voxels = num_voxels.index_select(0, index);
14+
15+
auto cluster = (pos / size.view({1, -1})).toType(at::kLong);
1716
cluster *= num_voxels.view({1, -1});
1817
cluster = cluster.sum(1);
1918

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ description-file = README.md
55
test = pytest
66

77
[tool:pytest]
8-
addopts = --capture=no --cov
8+
addopts = --capture=no
File renamed without changes.

torch_cluster/grid.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ def grid_cluster(pos, size, start=None, end=None):
2828
start = pos.t().min(dim=1)[0] if start is None else start
2929
end = pos.t().max(dim=1)[0] if end is None else end
3030

31-
if pos.is_cuda:
32-
cluster = grid_cuda.grid(pos, size, start, end)
33-
else:
34-
cluster = grid_cpu.grid(pos, size, start, end)
31+
op = grid_cuda.grid if pos.is_cuda else grid_cpu.grid
32+
cluster = op(pos, size, start, end)
3533

3634
return cluster

0 commit comments

Comments
 (0)