Skip to content

Commit d143629

Browse files
committed
added cpu extension
1 parent b1db864 commit d143629

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

cpu/fps.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include <torch/extension.h>
2+
3+
#include "utils.h"
4+
5+
at::Tensor get_dist(at::Tensor x, ptrdiff_t index) {
6+
return (x - x[index]).norm(2, 1);
7+
}
8+
9+
at::Tensor fps(at::Tensor x, at::Tensor batch, float ratio, bool random) {
10+
auto batch_size = batch[-1].data<int64_t>()[0] + 1;
11+
12+
auto deg = degree(batch, batch_size);
13+
auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
14+
auto k = (deg.toType(at::kFloat) * ratio).ceil().toType(at::kLong);
15+
auto cum_k = at::cat({at::zeros(1, k.options()), k.cumsum(0)}, 0);
16+
17+
auto out = at::empty(cum_k[-1].data<int64_t>()[0], batch.options());
18+
19+
auto cum_deg_d = cum_deg.data<int64_t>();
20+
auto k_d = k.data<int64_t>();
21+
auto cum_k_d = cum_k.data<int64_t>();
22+
auto out_d = out.data<int64_t>();
23+
24+
for (ptrdiff_t b = 0; b < batch_size; b++) {
25+
auto index = at::range(cum_deg_d[b], cum_deg_d[b + 1] - 1, out.options());
26+
auto y = x.index_select(0, index);
27+
28+
ptrdiff_t start = 0;
29+
if (random) {
30+
start = at::randperm(y.size(0), batch.options()).data<int64_t>()[0];
31+
}
32+
33+
out_d[cum_k_d[b]] = cum_deg_d[b] + start;
34+
auto dist = get_dist(y, start);
35+
36+
for (ptrdiff_t i = 1; i < k_d[b]; i++) {
37+
ptrdiff_t argmax = dist.argmax().data<int64_t>()[0];
38+
out_d[cum_k_d[b] + i] = cum_deg_d[b] + argmax;
39+
dist = at::min(dist, get_dist(y, argmax));
40+
}
41+
}
42+
43+
return out;
44+
}
45+
46+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
47+
m.def("fps", &fps, "Farthest Point Sampling (CPU)");
48+
}

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
ext_modules = [
66
CppExtension('graclus_cpu', ['cpu/graclus.cpp']),
77
CppExtension('grid_cpu', ['cpu/grid.cpp']),
8+
CppExtension('fps_cpu', ['cpu/fps.cpp']),
89
]
910
cmdclass = {'build_ext': torch.utils.cpp_extension.BuildExtension}
1011

0 commit comments

Comments
 (0)