Skip to content

Commit de0216d

Browse files
committed
pytorch 1.3 support
1 parent bd3ae68 commit de0216d

File tree

17 files changed

+121
-70
lines changed

17 files changed

+121
-70
lines changed

cpu/compat.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#ifdef VERSION_GE_1_3
2+
#define DATA_PTR data_ptr
3+
#else
4+
#define DATA_PTR data
5+
#endif

cpu/fps.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,41 @@
11
#include <torch/extension.h>
22

3+
#include "compat.h"
34
#include "utils.h"
45

56
at::Tensor get_dist(at::Tensor x, ptrdiff_t index) {
67
return (x - x[index]).norm(2, 1);
78
}
89

910
at::Tensor fps(at::Tensor x, at::Tensor batch, float ratio, bool random) {
10-
auto batch_size = batch[-1].data<int64_t>()[0] + 1;
11+
auto batch_size = batch[-1].DATA_PTR<int64_t>()[0] + 1;
1112

1213
auto deg = degree(batch, batch_size);
1314
auto cum_deg = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
1415
auto k = (deg.toType(at::kFloat) * ratio).ceil().toType(at::kLong);
1516
auto cum_k = at::cat({at::zeros(1, k.options()), k.cumsum(0)}, 0);
1617

17-
auto out = at::empty(cum_k[-1].data<int64_t>()[0], batch.options());
18+
auto out = at::empty(cum_k[-1].DATA_PTR<int64_t>()[0], batch.options());
1819

19-
auto cum_deg_d = cum_deg.data<int64_t>();
20-
auto k_d = k.data<int64_t>();
21-
auto cum_k_d = cum_k.data<int64_t>();
22-
auto out_d = out.data<int64_t>();
20+
auto cum_deg_d = cum_deg.DATA_PTR<int64_t>();
21+
auto k_d = k.DATA_PTR<int64_t>();
22+
auto cum_k_d = cum_k.DATA_PTR<int64_t>();
23+
auto out_d = out.DATA_PTR<int64_t>();
2324

2425
for (ptrdiff_t b = 0; b < batch_size; b++) {
2526
auto index = at::range(cum_deg_d[b], cum_deg_d[b + 1] - 1, out.options());
2627
auto y = x.index_select(0, index);
2728

2829
ptrdiff_t start = 0;
2930
if (random) {
30-
start = at::randperm(y.size(0), batch.options()).data<int64_t>()[0];
31+
start = at::randperm(y.size(0), batch.options()).DATA_PTR<int64_t>()[0];
3132
}
3233

3334
out_d[cum_k_d[b]] = cum_deg_d[b] + start;
3435
auto dist = get_dist(y, start);
3536

3637
for (ptrdiff_t i = 1; i < k_d[b]; i++) {
37-
ptrdiff_t argmax = dist.argmax().data<int64_t>()[0];
38+
ptrdiff_t argmax = dist.argmax().DATA_PTR<int64_t>()[0];
3839
out_d[cum_k_d[b] + i] = cum_deg_d[b] + argmax;
3940
dist = at::min(dist, get_dist(y, argmax));
4041
}

cpu/graclus.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
#include <torch/extension.h>
22

3+
#include "compat.h"
34
#include "utils.h"
45

56
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
67
std::tie(row, col) = remove_self_loops(row, col);
78
std::tie(row, col) = rand(row, col);
89
std::tie(row, col) = to_csr(row, col, num_nodes);
9-
auto row_data = row.data<int64_t>(), col_data = col.data<int64_t>();
10+
auto row_data = row.DATA_PTR<int64_t>(), col_data = col.DATA_PTR<int64_t>();
1011

1112
auto perm = at::randperm(num_nodes, row.options());
12-
auto perm_data = perm.data<int64_t>();
13+
auto perm_data = perm.DATA_PTR<int64_t>();
1314

1415
auto cluster = at::full(num_nodes, -1, row.options());
15-
auto cluster_data = cluster.data<int64_t>();
16+
auto cluster_data = cluster.DATA_PTR<int64_t>();
1617

1718
for (int64_t i = 0; i < num_nodes; i++) {
1819
auto u = perm_data[i];
@@ -41,16 +42,16 @@ at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
4142
int64_t num_nodes) {
4243
std::tie(row, col, weight) = remove_self_loops(row, col, weight);
4344
std::tie(row, col, weight) = to_csr(row, col, weight, num_nodes);
44-
auto row_data = row.data<int64_t>(), col_data = col.data<int64_t>();
45+
auto row_data = row.DATA_PTR<int64_t>(), col_data = col.DATA_PTR<int64_t>();
4546

4647
auto perm = at::randperm(num_nodes, row.options());
47-
auto perm_data = perm.data<int64_t>();
48+
auto perm_data = perm.DATA_PTR<int64_t>();
4849

4950
auto cluster = at::full(num_nodes, -1, row.options());
50-
auto cluster_data = cluster.data<int64_t>();
51+
auto cluster_data = cluster.DATA_PTR<int64_t>();
5152

5253
AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "weighted_graclus", [&] {
53-
auto weight_data = weight.data<scalar_t>();
54+
auto weight_data = weight.DATA_PTR<scalar_t>();
5455

5556
for (int64_t i = 0; i < num_nodes; i++) {
5657
auto u = perm_data[i];

cpu/rw.cpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <torch/extension.h>
22

3+
#include "compat.h"
34
#include "utils.h"
45

56
at::Tensor rw(at::Tensor row, at::Tensor col, at::Tensor start,
@@ -12,12 +13,12 @@ at::Tensor rw(at::Tensor row, at::Tensor col, at::Tensor start,
1213
auto out =
1314
at::full({start.size(0), (int64_t)walk_length + 1}, -1, start.options());
1415

15-
auto deg_d = deg.data<int64_t>();
16-
auto cum_deg_d = cum_deg.data<int64_t>();
17-
auto col_d = col.data<int64_t>();
18-
auto start_d = start.data<int64_t>();
19-
auto rand_d = rand.data<float>();
20-
auto out_d = out.data<int64_t>();
16+
auto deg_d = deg.DATA_PTR<int64_t>();
17+
auto cum_deg_d = cum_deg.DATA_PTR<int64_t>();
18+
auto col_d = col.DATA_PTR<int64_t>();
19+
auto start_d = start.DATA_PTR<int64_t>();
20+
auto rand_d = rand.DATA_PTR<float>();
21+
auto out_d = out.DATA_PTR<int64_t>();
2122

2223
for (ptrdiff_t n = 0; n < start.size(0); n++) {
2324
int64_t cur = start_d[n];

cpu/sampler.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
#include <torch/extension.h>
22

3+
#include "compat.h"
4+
35
at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size,
46
float factor) {
57

6-
auto start_ptr = start.data<int64_t>();
7-
auto cumdeg_ptr = cumdeg.data<int64_t>();
8+
auto start_ptr = start.DATA_PTR<int64_t>();
9+
auto cumdeg_ptr = cumdeg.DATA_PTR<int64_t>();
810

911
std::vector<int64_t> e_ids;
1012
for (ptrdiff_t i = 0; i < start.size(0); i++) {
@@ -29,7 +31,7 @@ at::Tensor neighbor_sampler(at::Tensor start, at::Tensor cumdeg, size_t size,
2931
e_ids.insert(e_ids.end(), v.begin(), v.end());
3032
} else {
3133
auto sample = at::randperm(num_neighbors, start.options());
32-
auto sample_ptr = sample.data<int64_t>();
34+
auto sample_ptr = sample.DATA_PTR<int64_t>();
3335
for (size_t j = 0; j < size_i; j++) {
3436
e_ids.push_back(sample_ptr[j] + low);
3537
}

cuda/coloring.cuh

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
#include <ATen/ATen.h>
44

5+
#include "compat.cuh"
6+
57
#define THREADS 1024
68
#define BLOCKS(N) (N + THREADS - 1) / THREADS
79

@@ -30,8 +32,8 @@ int64_t colorize(at::Tensor cluster) {
3032
auto props = at::full(numel, BLUE_PROB, cluster.options().dtype(at::kFloat));
3133
auto bernoulli = props.bernoulli();
3234

33-
colorize_kernel<<<BLOCKS(numel), THREADS>>>(cluster.data<int64_t>(),
34-
bernoulli.data<float>(), numel);
35+
colorize_kernel<<<BLOCKS(numel), THREADS>>>(
36+
cluster.DATA_PTR<int64_t>(), bernoulli.DATA_PTR<float>(), numel);
3537

3638
int64_t out;
3739
cudaMemcpyFromSymbol(&out, done, sizeof(out), 0, cudaMemcpyDeviceToHost);

cuda/compat.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#ifdef VERSION_GE_1_3
2+
#define DATA_PTR data_ptr
3+
#else
4+
#define DATA_PTR data
5+
#endif

cuda/fps_kernel.cu

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <ATen/ATen.h>
22

33
#include "atomics.cuh"
4+
#include "compat.cuh"
45
#include "utils.cuh"
56

67
#define THREADS 1024
@@ -164,7 +165,7 @@ fps_kernel(const scalar_t *__restrict__ x, const int64_t *__restrict__ cum_deg,
164165
at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
165166
cudaSetDevice(x.get_device());
166167
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
167-
cudaMemcpy(batch_sizes, batch[-1].data<int64_t>(), sizeof(int64_t),
168+
cudaMemcpy(batch_sizes, batch[-1].DATA_PTR<int64_t>(), sizeof(int64_t),
168169
cudaMemcpyDeviceToHost);
169170
auto batch_size = batch_sizes[0] + 1;
170171

@@ -185,15 +186,15 @@ at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
185186
auto tmp_dist = at::empty(x.size(0), x.options());
186187

187188
auto k_sum = (int64_t *)malloc(sizeof(int64_t));
188-
cudaMemcpy(k_sum, cum_k[-1].data<int64_t>(), sizeof(int64_t),
189+
cudaMemcpy(k_sum, cum_k[-1].DATA_PTR<int64_t>(), sizeof(int64_t),
189190
cudaMemcpyDeviceToHost);
190191
auto out = at::empty(k_sum[0], k.options());
191192

192193
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "fps_kernel", [&] {
193-
FPS_KERNEL(x.size(1), x.data<scalar_t>(), cum_deg.data<int64_t>(),
194-
cum_k.data<int64_t>(), start.data<int64_t>(),
195-
dist.data<scalar_t>(), tmp_dist.data<scalar_t>(),
196-
out.data<int64_t>());
194+
FPS_KERNEL(x.size(1), x.DATA_PTR<scalar_t>(), cum_deg.DATA_PTR<int64_t>(),
195+
cum_k.DATA_PTR<int64_t>(), start.DATA_PTR<int64_t>(),
196+
dist.DATA_PTR<scalar_t>(), tmp_dist.DATA_PTR<scalar_t>(),
197+
out.DATA_PTR<int64_t>());
197198
});
198199

199200
return out;

cuda/grid_kernel.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#include <ATen/cuda/detail/IndexUtils.cuh>
33
#include <ATen/cuda/detail/TensorInfo.cuh>
44

5+
#include "compat.cuh"
6+
57
#define THREADS 1024
68
#define BLOCKS(N) (N + THREADS - 1) / THREADS
79

@@ -31,10 +33,10 @@ at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
3133

3234
AT_DISPATCH_ALL_TYPES(pos.scalar_type(), "grid_kernel", [&] {
3335
grid_kernel<scalar_t><<<BLOCKS(cluster.numel()), THREADS>>>(
34-
cluster.data<int64_t>(),
36+
cluster.DATA_PTR<int64_t>(),
3537
at::cuda::detail::getTensorInfo<scalar_t, int64_t>(pos),
36-
size.data<scalar_t>(), start.data<scalar_t>(), end.data<scalar_t>(),
37-
cluster.numel());
38+
size.DATA_PTR<scalar_t>(), start.DATA_PTR<scalar_t>(),
39+
end.DATA_PTR<scalar_t>(), cluster.numel());
3840
});
3941

4042
return cluster;

cuda/knn_kernel.cu

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <ATen/ATen.h>
22

3+
#include "compat.cuh"
34
#include "utils.cuh"
45

56
#define THREADS 1024
@@ -79,7 +80,7 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
7980
at::Tensor batch_y, bool cosine) {
8081
cudaSetDevice(x.get_device());
8182
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
82-
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
83+
cudaMemcpy(batch_sizes, batch_x[-1].DATA_PTR<int64_t>(), sizeof(int64_t),
8384
cudaMemcpyDeviceToHost);
8485
auto batch_size = batch_sizes[0] + 1;
8586

@@ -94,9 +95,10 @@ at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
9495

9596
AT_DISPATCH_FLOATING_TYPES(x.scalar_type(), "knn_kernel", [&] {
9697
knn_kernel<scalar_t><<<batch_size, THREADS>>>(
97-
x.data<scalar_t>(), y.data<scalar_t>(), batch_x.data<int64_t>(),
98-
batch_y.data<int64_t>(), dist.data<scalar_t>(), row.data<int64_t>(),
99-
col.data<int64_t>(), k, x.size(1), cosine);
98+
x.DATA_PTR<scalar_t>(), y.DATA_PTR<scalar_t>(),
99+
batch_x.DATA_PTR<int64_t>(), batch_y.DATA_PTR<int64_t>(),
100+
dist.DATA_PTR<scalar_t>(), row.DATA_PTR<int64_t>(),
101+
col.DATA_PTR<int64_t>(), k, x.size(1), cosine);
100102
});
101103

102104
auto mask = col != -1;

0 commit comments

Comments
 (0)