11#include < torch/torch.h>
22
33
4- inline std::tuple<at::Tensor, at::Tensor> randperm (at::Tensor row, at::Tensor col) {
5- /* at::Tensor perm; */
6- /* std::tie(row, perm) = row.sort(); */
7- /* col = col.index_select(0, perm); */
8-
9- /* TODO: randperm */
10- /* TODO: randperm_sort_row */
11- return { row, col };
4+ inline std::tuple<at::Tensor, at::Tensor> remove_self_loops (at::Tensor row, at::Tensor col) {
5+ auto mask = row != col;
6+ row = row.masked_select (mask);
7+ col = col.masked_select (mask);
8+ return {row, col};
9+ }
10+
11+
12+ inline std::tuple<at::Tensor, at::Tensor> randperm (at::Tensor row, at::Tensor col, int64_t num_nodes) {
13+ // Randomly reorder row and column indices.
14+ auto perm = at::randperm (torch::CPU (at::kLong ), row.size (0 ));
15+ row = row.index_select (0 , perm);
16+ col = col.index_select (0 , perm);
17+
18+ // Randomly swap row values.
19+ auto node_rid = at::randperm (torch::CPU (at::kLong ), num_nodes);
20+ row = node_rid.index_select (0 , row);
21+
22+ // Sort row and column indices row-wise.
23+ std::tie (row, perm) = row.sort ();
24+ col = col.index_select (0 , perm);
25+
26+ // Revert row value swaps.
27+ row = std::get<1 >(node_rid.sort ()).index_select (0 , row);
28+
29+ return {row, col};
1230}
1331
1432
1533inline at::Tensor degree (at::Tensor index, int64_t num_nodes) {
16- auto zero = at::zeros (torch::CPU (at::kLong ), { num_nodes });
34+ auto zero = at::zeros (torch::CPU (at::kLong ), {num_nodes});
1735 return zero.scatter_add_ (0 , index, at::ones_like (index));
1836}
1937
2038
2139at::Tensor graclus (at::Tensor row, at::Tensor col, int64_t num_nodes) {
22- std::tie (row, col) = randperm (row, col);
40+ std::tie (row, col) = remove_self_loops (row, col);
41+ std::tie (row, col) = randperm (row, col, num_nodes);
42+
2343 auto deg = degree (row, num_nodes);
24- auto cluster = at::empty (torch::CPU (at::kLong ), { num_nodes }).fill_ (-1 );
44+ auto cluster = at::empty (torch::CPU (at::kLong ), {num_nodes}).fill_ (-1 );
2545
2646 auto *row_data = row.data <int64_t >();
2747 auto *col_data = col.data <int64_t >();
2848 auto *deg_data = deg.data <int64_t >();
2949 auto *cluster_data = cluster.data <int64_t >();
3050
31- int64_t n_idx = 0 , e_idx = 0 , d_idx, r, c;
51+ int64_t e_idx = 0 , d_idx, r, c;
3252 while (e_idx < row.size (0 )) {
3353 r = row_data[e_idx];
3454 if (cluster_data[r] < 0 ) {
@@ -42,8 +62,7 @@ at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
4262 }
4363 }
4464 }
45- e_idx += deg_data[n_idx];
46- n_idx++;
65+ e_idx += deg_data[r];
4766 }
4867
4968 return cluster;
@@ -55,15 +74,15 @@ at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor en
5574 start = start.toType (pos.type ());
5675 end = end.toType (pos.type ());
5776
58- pos = pos - start.view ({ 1 , -1 });
77+ pos = pos - start.view ({1 , -1 });
5978 auto num_voxels = ((end - start) / size).toType (at::kLong );
6079 num_voxels = (num_voxels + 1 ).cumsum (0 );
6180 num_voxels -= num_voxels.data <int64_t >()[0 ];
6281 num_voxels.data <int64_t >()[0 ] = 1 ;
6382
64- auto cluster = pos / size.view ({ 1 , -1 });
83+ auto cluster = pos / size.view ({1 , -1 });
6584 cluster = cluster.toType (at::kLong );
66- cluster *= num_voxels.view ({ 1 , -1 });
85+ cluster *= num_voxels.view ({1 , -1 });
6786 cluster = cluster.sum (1 );
6887
6988 return cluster;
0 commit comments