|
| 1 | +from os import path as osp |
| 2 | +from itertools import product |
| 3 | + |
| 4 | +import pytest |
| 5 | +import json |
| 6 | +import torch |
| 7 | +from torch_cluster import sparse_grid_cluster |
| 8 | + |
| 9 | +from .utils import tensors, Tensor |
| 10 | + |
| 11 | +f = open(osp.join(osp.dirname(__file__), 'sparse_grid.json'), 'r') |
| 12 | +data = json.load(f) |
| 13 | +f.close() |
| 14 | + |
| 15 | + |
| 16 | +@pytest.mark.parametrize('tensor,i', product(tensors, range(len(data)))) |
| 17 | +def test_sparse_grid_cluster_cpu(tensor, i): |
| 18 | + position = Tensor(tensor, data[i]['position']) |
| 19 | + size = torch.LongTensor(data[i]['size']) |
| 20 | + batch = data[i].get('batch') |
| 21 | + start = data[i].get('start') |
| 22 | + expected = torch.LongTensor(data[i]['expected']) |
| 23 | + |
| 24 | + if batch is None: |
| 25 | + output = sparse_grid_cluster(position, size, batch, start) |
| 26 | + assert output.tolist() == expected.tolist() |
| 27 | + else: |
| 28 | + batch = torch.LongTensor(batch) |
| 29 | + expected_batch = torch.LongTensor(data[i]['expected_batch']) |
| 30 | + output = sparse_grid_cluster(position, size, batch, start) |
| 31 | + assert output[0].tolist() == expected.tolist() |
| 32 | + assert output[1].tolist() == expected_batch.tolist() |
| 33 | + |
| 34 | + |
| 35 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason='no CUDA') |
| 36 | +@pytest.mark.parametrize('tensor,i', product(tensors, range(len(data)))) |
| 37 | +def test_sparse_grid_cluster_gpu(tensor, i): # pragma: no cover |
| 38 | + position = Tensor(tensor, data[i]['position']).cuda() |
| 39 | + size = torch.cuda.LongTensor(data[i]['size']) |
| 40 | + batch = data[i].get('batch') |
| 41 | + start = data[i].get('start') |
| 42 | + expected = torch.LongTensor(data[i]['expected']) |
| 43 | + |
| 44 | + if batch is None: |
| 45 | + output = sparse_grid_cluster(position, size, batch, start) |
| 46 | + assert output.cpu().tolist() == expected.tolist() |
| 47 | + else: |
| 48 | + batch = torch.cuda.LongTensor(batch) |
| 49 | + expected_batch = torch.LongTensor(data[i]['expected_batch']) |
| 50 | + output = sparse_grid_cluster(position, size, batch, start) |
| 51 | + assert output[0].cpu().tolist() == expected.tolist() |
| 52 | + assert output[1].cpu().tolist() == expected_batch.tolist() |
0 commit comments