From be36308332bdad1cb734f27a05ce7ebd44007b45 Mon Sep 17 00:00:00 2001 From: Benson Ma Date: Mon, 18 Aug 2025 17:13:03 -0700 Subject: [PATCH] Migrate sparse ops kernels to `FBGEMM_LAUNCH_KERNEL`, pt 5B Reviewed By: cthi Differential Revision: D80432131 --- .../src/sparse_ops/sparse_permute_2d.cu | 88 +++++++++---------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu b/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu index f6a3669da3..64f8fa725f 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu @@ -105,7 +105,7 @@ permute_2D_sparse_data_cuda( permuted_lengths = at::empty({T, B}, lengths.options()); constexpr int32_t threads_1 = 256; - const auto blocks_1 = cuda_calc_xblock_count(B * T, threads_1); + const auto blocks_1 = cuda_calc_block_count(B * T, threads_1); AT_DISPATCH_INDEX_TYPES( lengths.scalar_type(), "permute_2D_lengths_kernel", [&] { FBGEMM_LAUNCH_KERNEL( @@ -134,7 +134,7 @@ permute_2D_sparse_data_cuda( constexpr int32_t BT_blocks = 32; dim3 threads_2(32, BT_blocks); - const auto blocks_2 = cuda_calc_xblock_count(B * T, BT_blocks); + const auto blocks_2 = cuda_calc_block_count(B * T, BT_blocks); permuted_indices = at::empty(permuted_indices_size, indices.options()); AT_DISPATCH_INDEX_TYPES( @@ -153,48 +153,48 @@ permute_2D_sparse_data_cuda( "permute_2D_data_kernel_3", [&] { using weights_t = scalar_t; - permute_2D_data_kernel< - true, - offsets_t, - indices_t, - weights_t> - <<>>( - permuted_indices_size, - T, - B, - indices_contig.data_ptr(), - weights_value_contig.data_ptr(), - permute_contig.data_ptr(), - input_offsets.data_ptr(), - output_offsets.data_ptr(), - permuted_indices.data_ptr(), - permuted_weights.data_ptr()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + FBGEMM_LAUNCH_KERNEL( + (permute_2D_data_kernel< + true, + offsets_t, + indices_t, + weights_t>), + blocks_2, + threads_2, + 0, + at::cuda::getCurrentCUDAStream(), + permuted_indices_size, + T, + B, + indices_contig.data_ptr(), + weights_value_contig.data_ptr(), + permute_contig.data_ptr(), + input_offsets.data_ptr(), + output_offsets.data_ptr(), + permuted_indices.data_ptr(), + permuted_weights.data_ptr()); }); // for each weights_t } else { - permute_2D_data_kernel< - false, - offsets_t, - indices_t, - std::nullptr_t> - <<>>( - permuted_indices_size, - T, - B, - indices_contig.data_ptr(), - nullptr, - permute_contig.data_ptr(), - input_offsets.data_ptr(), - output_offsets.data_ptr(), - permuted_indices.data_ptr(), - nullptr); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + FBGEMM_LAUNCH_KERNEL( + (permute_2D_data_kernel< + false, + offsets_t, + indices_t, + std::nullptr_t>), + blocks_2, + threads_2, + 0, + at::cuda::getCurrentCUDAStream(), + permuted_indices_size, + T, + B, + indices_contig.data_ptr(), + nullptr, + permute_contig.data_ptr(), + input_offsets.data_ptr(), + output_offsets.data_ptr(), + permuted_indices.data_ptr(), + nullptr); } }); // for each indices_t }); // for each offsets_t @@ -268,7 +268,7 @@ permute_sparse_features_cuda( constexpr int32_t threads_1 = 256; const auto blocks_1 = - cuda_calc_xblock_count(B * num_output_features, threads_1); + cuda_calc_block_count(B * num_output_features, threads_1); AT_DISPATCH_INDEX_TYPES( lengths.scalar_type(), "permute_2D_lengths_kernel", [&] { FBGEMM_LAUNCH_KERNEL( @@ -305,7 +305,7 @@ permute_sparse_features_cuda( constexpr int32_t BT_blocks = 32; dim3 threads_2(32, BT_blocks); const auto blocks_2 = - cuda_calc_xblock_count(B * num_output_features, BT_blocks); + cuda_calc_block_count(B * num_output_features, BT_blocks); permuted_indices = at::empty(permuted_lengths_sum, indices.options()); if (weights.has_value()) { const Tensor weights_value = weights.value();