diff --git a/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu b/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu index 2e8b89a823..b7be79d3af 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu @@ -25,6 +25,7 @@ __global__ __launch_bounds__(kMaxThreads) void permute_2D_data_kernel( int32_t B, const indices_t* __restrict__ indices, const weights_t* __restrict__ weights, + const int32_t weights_columns, const int32_t* __restrict__ permute, const offsets_t* __restrict__ input_offsets, const offsets_t* __restrict__ output_offsets, @@ -46,7 +47,10 @@ __global__ __launch_bounds__(kMaxThreads) void permute_2D_data_kernel( for (auto i = threadIdx.x; i < segment_length; i += blockDim.x) { permuted_indices[output_start + i] = indices[input_start + i]; if (has_weight) { - permuted_weights[output_start + i] = weights[input_start + i]; + for (auto w_col = 0; w_col < weights_columns; ++w_col) { + permuted_weights[(output_start + i) * weights_columns + w_col] = + weights[(input_start + i) * weights_columns + w_col]; + } } } } @@ -143,8 +147,16 @@ permute_2D_sparse_data_cuda( if (weights.has_value()) { const Tensor weights_value = weights.value(); const auto weights_value_contig = weights_value.contiguous(); - permuted_weights = - at::empty(permuted_indices_size, weights_value.options()); + int32_t weights_columns = 1; + if (weights_value.dense_dim() > 1) { + weights_columns = weights_value.size(1); + permuted_weights = at::empty( + {permuted_indices_size, weights_columns}, + weights_value.options()); + } else { + permuted_weights = + at::empty(permuted_indices_size, weights_value.options()); + } FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE( weights_value.scalar_type(), "permute_2D_data_kernel_3", @@ -164,6 +176,7 @@ permute_2D_sparse_data_cuda( B, indices_contig.data_ptr(), weights_value_contig.data_ptr(), + weights_columns, permute_contig.data_ptr(), input_offsets.data_ptr(), output_offsets.data_ptr(), @@ -186,6 +199,7 @@ permute_2D_sparse_data_cuda( B, indices_contig.data_ptr(), nullptr, + 1, permute_contig.data_ptr(), input_offsets.data_ptr(), output_offsets.data_ptr(), diff --git a/fbgemm_gpu/test/sparse/permute_sparse_features_test.py b/fbgemm_gpu/test/sparse/permute_sparse_features_test.py index 1f52268110..c8bec29f17 100644 --- a/fbgemm_gpu/test/sparse/permute_sparse_features_test.py +++ b/fbgemm_gpu/test/sparse/permute_sparse_features_test.py @@ -192,6 +192,39 @@ def test_permute_sparse_features_with_repeats( assert permuted_weights_cpu is None +class Permute2DSparseFeaturesTest(unittest.TestCase): + def test_permute_2D_sparse_data(self) -> None: + lengths = torch.tensor( + [[0, 0, 1], [0, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 1]], + dtype=torch.int32, + device="cuda", + ) + indices = torch.tensor([500, 1000, 1999], dtype=torch.int32, device="cuda") + permute = torch.tensor([0, 3, 1, 4, 2, 5], dtype=torch.int32, device="cuda") + weights = torch.rand((3, 64), device="cuda") + print(f"expected: {weights.dtype} {weights}") + ( + lengths_actual, + values_actual, + weights_actual, + ) = torch.ops.fbgemm.permute_2D_sparse_data( + permute, lengths, indices, weights, indices.numel() + ) + print(f"actual: {weights_actual.dtype} {weights_actual}") + self.assertTrue( + torch.equal( + lengths_actual, + torch.tensor( + [[0, 0, 1], [0, 0, 0], [0, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 1]], + dtype=torch.int32, + device="cuda", + ), + ) + ) + self.assertTrue(torch.equal(values_actual, indices)) + self.assertTrue(torch.equal(weights_actual, weights)) + + extend_test_class(PermuteSparseFeaturesTest) if __name__ == "__main__":