Skip to content

Commit 65374fb

Browse files
committed
multi gpu update
1 parent 3a4c67c commit 65374fb

File tree

9 files changed

+10
-3
lines changed

9 files changed

+10
-3
lines changed

cuda/fps_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ fps_kernel(const scalar_t *__restrict__ x, const int64_t *__restrict__ cum_deg,
162162
}()
163163

164164
at::Tensor fps_cuda(at::Tensor x, at::Tensor batch, float ratio, bool random) {
165+
cudaSetDevice(x.get_device());
165166
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
166167
cudaMemcpy(batch_sizes, batch[-1].data<int64_t>(), sizeof(int64_t),
167168
cudaMemcpyDeviceToHost);

cuda/graclus_kernel.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "utils.cuh"
77

88
at::Tensor graclus_cuda(at::Tensor row, at::Tensor col, int64_t num_nodes) {
9+
cudaSetDevice(row.get_device());
910
std::tie(row, col) = remove_self_loops(row, col);
1011
std::tie(row, col) = rand(row, col);
1112
std::tie(row, col) = to_csr(row, col, num_nodes);
@@ -23,6 +24,7 @@ at::Tensor graclus_cuda(at::Tensor row, at::Tensor col, int64_t num_nodes) {
2324

2425
at::Tensor weighted_graclus_cuda(at::Tensor row, at::Tensor col,
2526
at::Tensor weight, int64_t num_nodes) {
27+
cudaSetDevice(row.get_device());
2628
std::tie(row, col, weight) = remove_self_loops(row, col, weight);
2729
std::tie(row, col, weight) = to_csr(row, col, weight, num_nodes);
2830

cuda/grid_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ __global__ void grid_kernel(int64_t *cluster,
2626

2727
at::Tensor grid_cuda(at::Tensor pos, at::Tensor size, at::Tensor start,
2828
at::Tensor end) {
29+
cudaSetDevice(pos.get_device());
2930
auto cluster = at::empty(pos.size(0), pos.options().dtype(at::kLong));
3031

3132
AT_DISPATCH_ALL_TYPES(pos.type(), "grid_kernel", [&] {

cuda/knn_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ knn_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
5252

5353
at::Tensor knn_cuda(at::Tensor x, at::Tensor y, size_t k, at::Tensor batch_x,
5454
at::Tensor batch_y) {
55+
cudaSetDevice(x.get_device());
5556
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
5657
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
5758
cudaMemcpyDeviceToHost);

cuda/nearest_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ __global__ void nearest_kernel(const scalar_t *__restrict__ x,
6060

6161
at::Tensor nearest_cuda(at::Tensor x, at::Tensor y, at::Tensor batch_x,
6262
at::Tensor batch_y) {
63+
cudaSetDevice(x.get_device());
6364
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
6465
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
6566
cudaMemcpyDeviceToHost);

cuda/radius_kernel.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ radius_kernel(const scalar_t *__restrict__ x, const scalar_t *__restrict__ y,
4848
at::Tensor radius_cuda(at::Tensor x, at::Tensor y, float radius,
4949
at::Tensor batch_x, at::Tensor batch_y,
5050
size_t max_num_neighbors) {
51+
cudaSetDevice(x.get_device());
5152
auto batch_sizes = (int64_t *)malloc(sizeof(int64_t));
5253
cudaMemcpy(batch_sizes, batch_x[-1].data<int64_t>(), sizeof(int64_t),
5354
cudaMemcpyDeviceToHost);

cuda/rw_kernel.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ __global__ void uniform_rw_kernel(
2727

2828
at::Tensor rw_cuda(at::Tensor row, at::Tensor col, at::Tensor start,
2929
size_t walk_length, float p, float q, size_t num_nodes) {
30-
30+
cudaSetDevice(row.get_device());
3131
auto deg = degree(row, num_nodes);
3232
row = at::cat({at::zeros(1, deg.options()), deg.cumsum(0)}, 0);
3333

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
['cuda/rw.cpp', 'cuda/rw_kernel.cu']),
2828
]
2929

30-
__version__ = '1.2.3'
30+
__version__ = '1.2.4'
3131
url = 'https://github.com/rusty1s/pytorch_cluster'
3232

3333
install_requires = ['scipy']

torch_cluster/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .radius import radius, radius_graph
77
from .rw import random_walk
88

9-
__version__ = '1.2.3'
9+
__version__ = '1.2.4'
1010

1111
__all__ = [
1212
'graclus_cluster',

0 commit comments

Comments
 (0)