Skip to content

Commit df2ed80

Browse files
committed
graclus done
1 parent dcd88f5 commit df2ed80

File tree

2 files changed

+40
-19
lines changed

2 files changed

+40
-19
lines changed

aten/cpu/cluster.cpp

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,54 @@
11
#include <torch/torch.h>
22

33

4-
inline std::tuple<at::Tensor, at::Tensor> randperm(at::Tensor row, at::Tensor col) {
5-
/* at::Tensor perm; */
6-
/* std::tie(row, perm) = row.sort(); */
7-
/* col = col.index_select(0, perm); */
8-
9-
/* TODO: randperm */
10-
/* TODO: randperm_sort_row */
11-
return { row, col };
4+
inline std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row, at::Tensor col) {
5+
auto mask = row != col;
6+
row = row.masked_select(mask);
7+
col = col.masked_select(mask);
8+
return {row, col};
9+
}
10+
11+
12+
inline std::tuple<at::Tensor, at::Tensor> randperm(at::Tensor row, at::Tensor col, int64_t num_nodes) {
13+
// Randomly reorder row and column indices.
14+
auto perm = at::randperm(torch::CPU(at::kLong), row.size(0));
15+
row = row.index_select(0, perm);
16+
col = col.index_select(0, perm);
17+
18+
// Randomly swap row values.
19+
auto node_rid = at::randperm(torch::CPU(at::kLong), num_nodes);
20+
row = node_rid.index_select(0, row);
21+
22+
// Sort row and column indices row-wise.
23+
std::tie(row, perm) = row.sort();
24+
col = col.index_select(0, perm);
25+
26+
// Revert row value swaps.
27+
row = std::get<1>(node_rid.sort()).index_select(0, row);
28+
29+
return {row, col};
1230
}
1331

1432

1533
inline at::Tensor degree(at::Tensor index, int64_t num_nodes) {
16-
auto zero = at::zeros(torch::CPU(at::kLong), { num_nodes });
34+
auto zero = at::zeros(torch::CPU(at::kLong), {num_nodes});
1735
return zero.scatter_add_(0, index, at::ones_like(index));
1836
}
1937

2038

2139
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
22-
std::tie(row, col) = randperm(row, col);
40+
std::tie(row, col) = remove_self_loops(row, col);
41+
std::tie(row, col) = randperm(row, col, num_nodes);
42+
2343
auto deg = degree(row, num_nodes);
24-
auto cluster = at::empty(torch::CPU(at::kLong), { num_nodes }).fill_(-1);
44+
auto cluster = at::empty(torch::CPU(at::kLong), {num_nodes}).fill_(-1);
2545

2646
auto *row_data = row.data<int64_t>();
2747
auto *col_data = col.data<int64_t>();
2848
auto *deg_data = deg.data<int64_t>();
2949
auto *cluster_data = cluster.data<int64_t>();
3050

31-
int64_t n_idx = 0, e_idx = 0, d_idx, r, c;
51+
int64_t e_idx = 0, d_idx, r, c;
3252
while (e_idx < row.size(0)) {
3353
r = row_data[e_idx];
3454
if (cluster_data[r] < 0) {
@@ -42,8 +62,7 @@ at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
4262
}
4363
}
4464
}
45-
e_idx += deg_data[n_idx];
46-
n_idx++;
65+
e_idx += deg_data[r];
4766
}
4867

4968
return cluster;
@@ -55,15 +74,15 @@ at::Tensor grid(at::Tensor pos, at::Tensor size, at::Tensor start, at::Tensor en
5574
start = start.toType(pos.type());
5675
end = end.toType(pos.type());
5776

58-
pos = pos - start.view({ 1, -1 });
77+
pos = pos - start.view({1, -1});
5978
auto num_voxels = ((end - start) / size).toType(at::kLong);
6079
num_voxels = (num_voxels + 1).cumsum(0);
6180
num_voxels -= num_voxels.data<int64_t>()[0];
6281
num_voxels.data<int64_t>()[0] = 1;
6382

64-
auto cluster = pos / size.view({ 1, -1 });
83+
auto cluster = pos / size.view({1, -1});
6584
cluster = cluster.toType(at::kLong);
66-
cluster *= num_voxels.view({ 1, -1 });
85+
cluster *= num_voxels.view({1, -1});
6786
cluster = cluster.sum(1);
6887

6988
return cluster;

aten/cpu/cluster.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,7 @@ def graclus_cluster(row, col, num_nodes):
2525
row = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3])
2626
col = torch.tensor([1, 2, 0, 2, 3, 0, 1, 3, 1, 2])
2727
print(row)
28-
29-
print(graclus_cluster(row, col, 4))
28+
print(col)
29+
print('-----------------')
30+
cluster = graclus_cluster(row, col, 4)
31+
print(cluster)

0 commit comments

Comments
 (0)