|
1 | 1 | #include <torch/extension.h> |
2 | 2 |
|
| 3 | +#include "compat.h" |
3 | 4 | #include "utils.h" |
4 | 5 |
|
5 | 6 | at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) { |
6 | 7 | std::tie(row, col) = remove_self_loops(row, col); |
7 | 8 | std::tie(row, col) = rand(row, col); |
8 | 9 | std::tie(row, col) = to_csr(row, col, num_nodes); |
9 | | - auto row_data = row.data<int64_t>(), col_data = col.data<int64_t>(); |
| 10 | + auto row_data = row.DATA_PTR<int64_t>(), col_data = col.DATA_PTR<int64_t>(); |
10 | 11 |
|
11 | 12 | auto perm = at::randperm(num_nodes, row.options()); |
12 | | - auto perm_data = perm.data<int64_t>(); |
| 13 | + auto perm_data = perm.DATA_PTR<int64_t>(); |
13 | 14 |
|
14 | 15 | auto cluster = at::full(num_nodes, -1, row.options()); |
15 | | - auto cluster_data = cluster.data<int64_t>(); |
| 16 | + auto cluster_data = cluster.DATA_PTR<int64_t>(); |
16 | 17 |
|
17 | 18 | for (int64_t i = 0; i < num_nodes; i++) { |
18 | 19 | auto u = perm_data[i]; |
@@ -41,16 +42,16 @@ at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight, |
41 | 42 | int64_t num_nodes) { |
42 | 43 | std::tie(row, col, weight) = remove_self_loops(row, col, weight); |
43 | 44 | std::tie(row, col, weight) = to_csr(row, col, weight, num_nodes); |
44 | | - auto row_data = row.data<int64_t>(), col_data = col.data<int64_t>(); |
| 45 | + auto row_data = row.DATA_PTR<int64_t>(), col_data = col.DATA_PTR<int64_t>(); |
45 | 46 |
|
46 | 47 | auto perm = at::randperm(num_nodes, row.options()); |
47 | | - auto perm_data = perm.data<int64_t>(); |
| 48 | + auto perm_data = perm.DATA_PTR<int64_t>(); |
48 | 49 |
|
49 | 50 | auto cluster = at::full(num_nodes, -1, row.options()); |
50 | | - auto cluster_data = cluster.data<int64_t>(); |
| 51 | + auto cluster_data = cluster.DATA_PTR<int64_t>(); |
51 | 52 |
|
52 | 53 | AT_DISPATCH_ALL_TYPES(weight.scalar_type(), "weighted_graclus", [&] { |
53 | | - auto weight_data = weight.data<scalar_t>(); |
| 54 | + auto weight_data = weight.DATA_PTR<scalar_t>(); |
54 | 55 |
|
55 | 56 | for (int64_t i = 0; i < num_nodes; i++) { |
56 | 57 | auto u = perm_data[i]; |
|
0 commit comments