Skip to content

support 2D weights in permute kernel #4723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ __global__ __launch_bounds__(kMaxThreads) void permute_2D_data_kernel(
int32_t B,
const indices_t* __restrict__ indices,
const weights_t* __restrict__ weights,
const int32_t weights_columns,
const int32_t* __restrict__ permute,
const offsets_t* __restrict__ input_offsets,
const offsets_t* __restrict__ output_offsets,
Expand All @@ -46,7 +47,10 @@ __global__ __launch_bounds__(kMaxThreads) void permute_2D_data_kernel(
for (auto i = threadIdx.x; i < segment_length; i += blockDim.x) {
permuted_indices[output_start + i] = indices[input_start + i];
if (has_weight) {
permuted_weights[output_start + i] = weights[input_start + i];
for (auto w_col = 0; w_col < weights_columns; ++w_col) {
permuted_weights[(output_start + i) * weights_columns + w_col] =
weights[(input_start + i) * weights_columns + w_col];
}
}
}
}
Expand Down Expand Up @@ -143,8 +147,16 @@ permute_2D_sparse_data_cuda(
if (weights.has_value()) {
const Tensor weights_value = weights.value();
const auto weights_value_contig = weights_value.contiguous();
permuted_weights =
at::empty(permuted_indices_size, weights_value.options());
int32_t weights_columns = 1;
if (weights_value.dense_dim() > 1) {
weights_columns = weights_value.size(1);
permuted_weights = at::empty(
{permuted_indices_size, weights_columns},
weights_value.options());
} else {
permuted_weights =
at::empty(permuted_indices_size, weights_value.options());
}
FBGEMM_DISPATCH_ALL_TYPES_AND_DOUBLE(
weights_value.scalar_type(),
"permute_2D_data_kernel_3",
Expand All @@ -164,6 +176,7 @@ permute_2D_sparse_data_cuda(
B,
indices_contig.data_ptr<indices_t>(),
weights_value_contig.data_ptr<weights_t>(),
weights_columns,
permute_contig.data_ptr<int32_t>(),
input_offsets.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
Expand All @@ -186,6 +199,7 @@ permute_2D_sparse_data_cuda(
B,
indices_contig.data_ptr<indices_t>(),
nullptr,
1,
permute_contig.data_ptr<int32_t>(),
input_offsets.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
Expand Down
33 changes: 33 additions & 0 deletions fbgemm_gpu/test/sparse/permute_sparse_features_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,39 @@ def test_permute_sparse_features_with_repeats(
assert permuted_weights_cpu is None


class Permute2DSparseFeaturesTest(unittest.TestCase):
def test_permute_2D_sparse_data(self) -> None:
lengths = torch.tensor(
[[0, 0, 1], [0, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 1]],
dtype=torch.int32,
device="cuda",
)
indices = torch.tensor([500, 1000, 1999], dtype=torch.int32, device="cuda")
permute = torch.tensor([0, 3, 1, 4, 2, 5], dtype=torch.int32, device="cuda")
weights = torch.rand((3, 64), device="cuda")
print(f"expected: {weights.dtype} {weights}")
(
lengths_actual,
values_actual,
weights_actual,
) = torch.ops.fbgemm.permute_2D_sparse_data(
permute, lengths, indices, weights, indices.numel()
)
print(f"actual: {weights_actual.dtype} {weights_actual}")
self.assertTrue(
torch.equal(
lengths_actual,
torch.tensor(
[[0, 0, 1], [0, 0, 0], [0, 1, 0], [0, 0, 0], [0, 0, 0], [0, 0, 1]],
dtype=torch.int32,
device="cuda",
),
)
)
self.assertTrue(torch.equal(values_actual, indices))
self.assertTrue(torch.equal(weights_actual, weights))


extend_test_class(PermuteSparseFeaturesTest)

if __name__ == "__main__":
Expand Down
Loading