Skip to content

Commit d2cc316

Browse files
committed
graclus cpu
1 parent 0a559a4 commit d2cc316

File tree

5 files changed

+157
-45
lines changed

5 files changed

+157
-45
lines changed

cpu/graclus.cpp

Lines changed: 74 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,91 @@
11
#include <torch/torch.h>
22

3-
// #include "../include/degree.cpp"
4-
// #include "../include/loop.cpp"
5-
// #include "../include/perm.cpp"
3+
#include "utils.h"
4+
5+
#define ITERATE_NEIGHBORS(NODE, NAME, ROW, COL, ...) \
6+
{ \
7+
for (int64_t e = ROW[NODE]; e < ROW[NODE + 1]; e++) { \
8+
auto NAME = COL[e]; \
9+
__VA_ARGS__; \
10+
} \
11+
}
612

713
at::Tensor graclus(at::Tensor row, at::Tensor col, int64_t num_nodes) {
8-
// std::tie(row, col) = remove_self_loops(row, col);
9-
// std::tie(row, col) = randperm(row, col, num_nodes);
10-
// auto deg = degree(row, num_nodes, row.type().scalarType());
14+
std::tie(row, col) = remove_self_loops(row, col);
15+
std::tie(row, col) = rand(row, col);
16+
std::tie(row, col) = to_csr(row, col);
17+
auto row_data = row.data<int64_t>(), col_data = col.data<int64_t>();
18+
19+
auto perm = randperm(num_nodes);
20+
auto perm_data = perm.data<int64_t>();
1121

1222
auto cluster = at::full(num_nodes, -1, row.options());
23+
auto cluster_data = cluster.data<int64_t>();
24+
25+
for (int64_t i = 0; i < num_nodes; i++) {
26+
auto u = perm_data[i];
27+
28+
if (cluster_data[u] >= 0)
29+
continue;
30+
31+
cluster_data[u] = u;
1332

14-
// auto *row_data = row.data<int64_t>();
15-
// auto *col_data = col.data<int64_t>();
16-
// auto *deg_data = deg.data<int64_t>();
17-
// auto *cluster_data = cluster.data<int64_t>();
18-
19-
// int64_t e_idx = 0, d_idx, r, c;
20-
// while (e_idx < row.size(0)) {
21-
// r = row_data[e_idx];
22-
// if (cluster_data[r] < 0) {
23-
// cluster_data[r] = r;
24-
// for (d_idx = 0; d_idx < deg_data[r]; d_idx++) {
25-
// c = col_data[e_idx + d_idx];
26-
// if (cluster_data[c] < 0) {
27-
// cluster_data[r] = std::min(r, c);
28-
// cluster_data[c] = std::min(r, c);
29-
// break;
30-
// }
31-
// }
32-
// }
33-
// e_idx += deg_data[r];
34-
// }
33+
ITERATE_NEIGHBORS(u, v, row_data, col_data, {
34+
if (cluster_data[v] >= 0)
35+
continue;
36+
37+
cluster_data[u] = std::min(u, v);
38+
cluster_data[v] = std::min(u, v);
39+
break;
40+
});
41+
}
3542

3643
return cluster;
3744
}
3845

3946
at::Tensor weighted_graclus(at::Tensor row, at::Tensor col, at::Tensor weight,
4047
int64_t num_nodes) {
48+
std::tie(row, col) = remove_self_loops(row, col, weight);
49+
std::tie(row, col, weight) = to_csr(row, col, weight);
50+
auto row_data = row.data<int64_t>(), col_data = col.data<int64_t>();
51+
52+
auto perm = randperm(num_nodes);
53+
auto perm_data = perm.data<int64_t>();
54+
4155
auto cluster = at::full(num_nodes, -1, row.options());
56+
auto cluster_data = cluster.data<int64_t>();
57+
58+
AT_DISPATCH_ALL_TYPES(weight.type(), "weighted_graclus", [&] {
59+
auto weight_data = weight.data<scalar_t>();
60+
auto weight_data = weight.data<scalar_t>();
61+
62+
for (int64_t i = 0; i < num_nodes; i++) {
63+
auto u = perm_data[i];
64+
65+
if (cluster_data[u] >= 0)
66+
continue;
67+
68+
cluster_data[u] = u;
69+
70+
int64_t v_max;
71+
scalar_t w_max = 0;
72+
73+
ITERATE_NEIGHBORS(u, v, row_data, col_data, {
74+
if (cluster_data[v] >= 0)
75+
continue;
76+
77+
auto w = weight_data[e];
78+
if (w >= w_max) {
79+
v_max = v;
80+
w_max = w;
81+
}
82+
});
83+
84+
cluster_data[u] = std::min(u, v_max);
85+
cluster_data[v_max] = std::min(u, v_max);
86+
}
87+
});
88+
4289
return cluster;
4390
}
4491

cpu/utils.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
#pragma once
2+
3+
#include <torch/torch.h>
4+
5+
std::tuple<at::Tensor, at::Tensor> remove_self_loops(at::Tensor row,
6+
at::Tensor col) {
7+
auto mask = row != col;
8+
return make_tuple(row.masked_select(mask), col.masked_select(mask));
9+
}
10+
11+
std::tuple<at::Tensor, at::Tensor, at::Tensor>
12+
remove_self_loops(at::Tensor row, at::Tensor col, at::Tensor weight) {
13+
auto mask = row != col;
14+
return make_tuple(row.masked_select(mask), col.masked_select(mask),
15+
weight.masked_select(mask));
16+
}
17+
18+
at::Tensor randperm(int64_t n) {
19+
auto out = at::empty(n, torch::CPU(at::kLong));
20+
at::randperm_out(out, n);
21+
return out;
22+
}
23+
24+
std::tuple<at::Tensor, at::Tensor> rand(at::Tensor row, at::Tensor col) {
25+
auto perm = randperm(row.size(0));
26+
return make_tuple(row.index_select(perm), col.index_select(perm));
27+
}
28+
29+
std::tuple<at::Tensor, at::Tensor> sort_by_row(at::Tensor row, at::Tensor col) {
30+
Tensor perm;
31+
tie(row, perm) = row.sort();
32+
col = col.index_select(0, perm);
33+
return stack({row, col}, 0);
34+
}
35+
36+
inline Tensor degree(Tensor row, int64_t num_nodes) {
37+
auto zero = zeros(num_nodes, row.type());
38+
auto one = ones(row.size(0), row.type());
39+
return zero.scatter_add_(0, row, one);
40+
}
41+
42+
inline tuple<Tensor, Tensor> to_csr(Tensor index, int64_t num_nodes) {
43+
index = sort_by_row(index);
44+
auto row = degree(index[0], num_nodes).cumsum(0);
45+
row = cat({zeros(1, row.type()), row}, 0); // Prepend zero.
46+
return make_tuple(row, index[1]);
47+
}
48+
49+
// std::tie(row, col) = randperm(row, col);
50+
// std::tie(row, col) = to_csr(row, col);

test/graclus.py renamed to test/test_graclus.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@
1515
'weight': [1, 2, 1, 3, 2, 2, 3, 1, 2, 1],
1616
}]
1717

18+
devices = [torch.device('cpu')]
19+
dtypes = [torch.float]
20+
tests = [tests[0]]
1821

19-
def assert_correct_graclus(row, col, cluster):
22+
23+
def assert_correct(row, col, cluster):
2024
row, col, cluster = row.to('cpu'), col.to('cpu'), cluster.to('cpu')
2125
n = cluster.size(0)
2226

@@ -47,4 +51,5 @@ def test_graclus_cluster(test, dtype, device):
4751
weight = tensor(test.get('weight'), dtype, device)
4852

4953
cluster = graclus_cluster(row, col, weight)
50-
assert_correct_graclus(row, col, cluster)
54+
print(cluster)
55+
# assert_correct(row, col, cluster)

torch_cluster/graclus.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1-
from .utils.loop import remove_self_loops
2-
from .utils.perm import randperm, sort_row, randperm_sort_row
3-
from .utils.ffi import graclus
1+
# from .utils.loop import remove_self_loops
2+
# from .utils.perm import randperm, sort_row, randperm_sort_row
3+
# from .utils.ffi import graclus
4+
5+
import torch
6+
import graclus_cpu
7+
8+
if torch.cuda.is_available():
9+
import graclus_cuda
410

511

612
def graclus_cluster(row, col, weight=None, num_nodes=None):
@@ -15,22 +21,26 @@ def graclus_cluster(row, col, weight=None, num_nodes=None):
1521
1622
Examples::
1723
18-
>>> row = torch.LongTensor([0, 1, 1, 2])
19-
>>> col = torch.LongTensor([1, 0, 2, 1])
24+
>>> row = torch.tensor([0, 1, 1, 2])
25+
>>> col = torch.tensor([1, 0, 2, 1])
2026
>>> weight = torch.Tensor([1, 1, 1, 1])
2127
>>> cluster = graclus_cluster(row, col, weight)
2228
"""
2329

24-
num_nodes = row.max().item() + 1 if num_nodes is None else num_nodes
30+
if num_nodes is None:
31+
num_nodes = max(row.max().item(), col.max().item()) + 1
2532

26-
if row.is_cuda:
27-
row, col = sort_row(row, col)
28-
else:
29-
row, col = randperm(row, col)
30-
row, col = randperm_sort_row(row, col, num_nodes)
33+
op = graclus_cuda if row.is_cuda else graclus_cpu
3134

32-
row, col = remove_self_loops(row, col)
33-
cluster = row.new_empty((num_nodes, ))
34-
graclus(cluster, row, col, weight)
35+
if weight is None:
36+
cluster = op.graclus(row, col, num_nodes)
37+
else:
38+
cluster = op.weighted_graclus(row, col, weight, num_nodes)
3539

3640
return cluster
41+
42+
# if row.is_cuda:
43+
# row, col = sort_row(row, col)
44+
# else:
45+
# row, col = randperm(row, col)
46+
# row, col = randperm_sort_row(row, col, num_nodes)

torch_cluster/grid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def grid_cluster(pos, size, start=None, end=None):
2828
start = pos.t().min(dim=1)[0] if start is None else start
2929
end = pos.t().max(dim=1)[0] if end is None else end
3030

31-
op = grid_cuda.grid if pos.is_cuda else grid_cpu.grid
32-
cluster = op(pos, size, start, end)
31+
op = grid_cuda if pos.is_cuda else grid_cpu
32+
cluster = op.grid(pos, size, start, end)
3333

3434
return cluster

0 commit comments

Comments
 (0)