Skip to content

Migrate sparse ops kernels to FBGEMM_LAUNCH_KERNEL, pt 5B #4726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 44 additions & 44 deletions fbgemm_gpu/src/sparse_ops/sparse_permute_2d.cu
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ permute_2D_sparse_data_cuda(
permuted_lengths = at::empty({T, B}, lengths.options());

constexpr int32_t threads_1 = 256;
const auto blocks_1 = cuda_calc_xblock_count(B * T, threads_1);
const auto blocks_1 = cuda_calc_block_count(B * T, threads_1);
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "permute_2D_lengths_kernel", [&] {
FBGEMM_LAUNCH_KERNEL(
Expand Down Expand Up @@ -134,7 +134,7 @@ permute_2D_sparse_data_cuda(

constexpr int32_t BT_blocks = 32;
dim3 threads_2(32, BT_blocks);
const auto blocks_2 = cuda_calc_xblock_count(B * T, BT_blocks);
const auto blocks_2 = cuda_calc_block_count(B * T, BT_blocks);
permuted_indices = at::empty(permuted_indices_size, indices.options());

AT_DISPATCH_INDEX_TYPES(
Expand All @@ -153,48 +153,48 @@ permute_2D_sparse_data_cuda(
"permute_2D_data_kernel_3",
[&] {
using weights_t = scalar_t;
permute_2D_data_kernel<
true,
offsets_t,
indices_t,
weights_t>
<<<blocks_2,
threads_2,
0,
at::cuda::getCurrentCUDAStream()>>>(
permuted_indices_size,
T,
B,
indices_contig.data_ptr<indices_t>(),
weights_value_contig.data_ptr<weights_t>(),
permute_contig.data_ptr<int32_t>(),
input_offsets.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
permuted_indices.data_ptr<indices_t>(),
permuted_weights.data_ptr<weights_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
FBGEMM_LAUNCH_KERNEL(
(permute_2D_data_kernel<
true,
offsets_t,
indices_t,
weights_t>),
blocks_2,
threads_2,
0,
at::cuda::getCurrentCUDAStream(),
permuted_indices_size,
T,
B,
indices_contig.data_ptr<indices_t>(),
weights_value_contig.data_ptr<weights_t>(),
permute_contig.data_ptr<int32_t>(),
input_offsets.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
permuted_indices.data_ptr<indices_t>(),
permuted_weights.data_ptr<weights_t>());
}); // for each weights_t
} else {
permute_2D_data_kernel<
false,
offsets_t,
indices_t,
std::nullptr_t>
<<<blocks_2,
threads_2,
0,
at::cuda::getCurrentCUDAStream()>>>(
permuted_indices_size,
T,
B,
indices_contig.data_ptr<indices_t>(),
nullptr,
permute_contig.data_ptr<int32_t>(),
input_offsets.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
permuted_indices.data_ptr<indices_t>(),
nullptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
FBGEMM_LAUNCH_KERNEL(
(permute_2D_data_kernel<
false,
offsets_t,
indices_t,
std::nullptr_t>),
blocks_2,
threads_2,
0,
at::cuda::getCurrentCUDAStream(),
permuted_indices_size,
T,
B,
indices_contig.data_ptr<indices_t>(),
nullptr,
permute_contig.data_ptr<int32_t>(),
input_offsets.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
permuted_indices.data_ptr<indices_t>(),
nullptr);
}
}); // for each indices_t
}); // for each offsets_t
Expand Down Expand Up @@ -268,7 +268,7 @@ permute_sparse_features_cuda(

constexpr int32_t threads_1 = 256;
const auto blocks_1 =
cuda_calc_xblock_count(B * num_output_features, threads_1);
cuda_calc_block_count(B * num_output_features, threads_1);
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "permute_2D_lengths_kernel", [&] {
FBGEMM_LAUNCH_KERNEL(
Expand Down Expand Up @@ -305,7 +305,7 @@ permute_sparse_features_cuda(
constexpr int32_t BT_blocks = 32;
dim3 threads_2(32, BT_blocks);
const auto blocks_2 =
cuda_calc_xblock_count(B * num_output_features, BT_blocks);
cuda_calc_block_count(B * num_output_features, BT_blocks);
permuted_indices = at::empty(permuted_lengths_sum, indices.options());
if (weights.has_value()) {
const Tensor weights_value = weights.value();
Expand Down
Loading