11#include < torch/torch.h>
22
3-
4- inline std::tuple<at::Tensor, at::Tensor> remove_self_loops (at::Tensor row, at::Tensor col) {
3+ inline std::tuple<at::Tensor, at::Tensor> remove_self_loops (at::Tensor row,
4+ at::Tensor col) {
55 auto mask = row != col;
66 row = row.masked_select (mask);
77 col = col.masked_select (mask);
88 return {row, col};
99}
1010
11-
12- inline std::tuple<at::Tensor, at::Tensor> randperm (at::Tensor row, at::Tensor col, int64_t num_nodes) {
11+ inline std::tuple<at::Tensor, at::Tensor>
12+ randperm (at::Tensor row, at::Tensor col, int64_t num_nodes) {
1313 // Randomly reorder row and column indices.
1414 auto perm = at::randperm (torch::CPU (at::kLong ), row.size (0 ));
1515 row = row.index_select (0 , perm);
@@ -29,13 +29,11 @@ inline std::tuple<at::Tensor, at::Tensor> randperm(at::Tensor row, at::Tensor co
2929 return {row, col};
3030}
3131
32-
3332inline at::Tensor degree (at::Tensor index, int64_t num_nodes) {
3433 auto zero = at::zeros (torch::CPU (at::kLong ), {num_nodes});
3534 return zero.scatter_add_ (0 , index, at::ones_like (index));
3635}
3736
38-
3937at::Tensor graclus (at::Tensor row, at::Tensor col, int64_t num_nodes) {
4038 std::tie (row, col) = remove_self_loops (row, col);
4139 std::tie (row, col) = randperm (row, col, num_nodes);
@@ -68,8 +66,8 @@ at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
6866 return cluster;
6967}
7068
71-
72- at::Tensor grid (at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor end) {
69+ at::Tensor grid (at::Tensor pos, at::Tensor size, at::Tensor start,
70+ at::Tensor end) {
7371 size = size.toType (pos.type ());
7472 start = start.toType (pos.type ());
7573 end = end.toType (pos.type ());
@@ -88,7 +86,6 @@ at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor en
8886 return cluster;
8987}
9088
91-
9289PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
9390 m.def (" graclus" , &graclus, " Graclus (CPU)" );
9491 m.def (" grid" , &grid, " Grid (CPU)" );
0 commit comments