diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 03b4004..6d604c3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1,6 +1,6 @@ # ATen XPU sources -file(GLOB device_cpp "sycl/*.cpp" "sycl/*.sycl") +file(GLOB device_cpp "sycl/*.cpp" "sycl/*.sycl" "sycl/kernels/moe/*.cpp") file(GLOB host_cpp "./*.cpp" "./*.cc") list(APPEND ATen_XPU_CPP_SRCS ${host_cpp}) diff --git a/src/sycl/TripleOps.cpp b/src/sycl/kernels/moe/activations.cpp similarity index 98% rename from src/sycl/TripleOps.cpp rename to src/sycl/kernels/moe/activations.cpp index 3549f6b..ba81d6e 100644 --- a/src/sycl/TripleOps.cpp +++ b/src/sycl/kernels/moe/activations.cpp @@ -10,9 +10,9 @@ #include #include -#include "MemoryAccess.h" -#include "SYCLHelpers.h" -#include "Utils.h" +#include "../../MemoryAccess.h" +#include "../../SYCLHelpers.h" +#include "../../Utils.h" #define DPCPP_CONSTANT __attribute__((opencl_constant)) @@ -232,3 +232,4 @@ void gelu_and_mul(at::Tensor& out, at::Tensor& input) { } return; } + diff --git a/src/sycl/kernels/moe/prepare_inputs.cpp b/src/sycl/kernels/moe/prepare_inputs.cpp new file mode 100644 index 0000000..4c46595 --- /dev/null +++ b/src/sycl/kernels/moe/prepare_inputs.cpp @@ -0,0 +1,518 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../../MemoryAccess.h" +#include "../../SYCLHelpers.h" +#include "../../Utils.h" + +constexpr int64_t THREADS_PER_EXPERT = 512; +constexpr int block_size = 128; + +struct compute_problem_sizes_sycl_K { + compute_problem_sizes_sycl_K( + const int* topk_ids, + int32_t* problem_sizes1, + int32_t* problem_sizes2, + int32_t* atomic_buffer, + const uint32_t num_experts, + const uint32_t topk_length, + const uint32_t n, + const uint32_t k, + const uint32_t max_tokens_per_expert) + : topk_ids_(topk_ids), + problem_sizes1_(problem_sizes1), + problem_sizes2_(problem_sizes2), + atomic_buffer_(atomic_buffer), + num_experts_(num_experts), + topk_length_(topk_length), + n_(n), + k_(k), + max_tokens_per_expert_(max_tokens_per_expert) {} + + void operator()(sycl::nd_item<1> item) const { + int thread_id = item.get_local_linear_id(); + if (thread_id < topk_length_) { + int expert_id = item.get_group(0); + + int occurrences = 0; + for (int i = thread_id; i < topk_length_; i += max_tokens_per_expert_) { + occurrences += (topk_ids_[thread_id] == expert_id); + } + + sycl::atomic_ref< + int32_t, + sycl::memory_order::relaxed, + sycl::memory_scope::work_group, + sycl::access::address_space::generic_space + > atomic_counter(atomic_buffer_[expert_id]); + + atomic_counter.fetch_add(occurrences); + + item.barrier(sycl::access::fence_space::local_space); + + if (thread_id == 0) { + int final_occurrences = atomic_buffer_[expert_id]; + problem_sizes1_[expert_id * 3] = final_occurrences; + problem_sizes1_[expert_id * 3 + 1] = static_cast(2 * n_); + problem_sizes1_[expert_id * 3 + 2] = static_cast(k_); + problem_sizes2_[expert_id * 3] = final_occurrences; + problem_sizes2_[expert_id * 3 + 1] = static_cast(k_); + problem_sizes2_[expert_id * 3 + 2] = static_cast(n_); + } + } + } + + const int* topk_ids_; + int32_t* problem_sizes1_; + int32_t* problem_sizes2_; + int32_t* atomic_buffer_; + const uint32_t num_experts_; + const uint32_t topk_length_; + const uint32_t n_; + const uint32_t k_; + const uint32_t max_tokens_per_expert_; +}; + + +void compute_problem_sizes_sycl( + const int* topk_ids, + int32_t* problem_sizes1, + int32_t* problem_sizes2, + int32_t* atomic_buffer, + const uint32_t num_experts, + const uint32_t topk_length, + const uint32_t n, + const uint32_t k) { + + auto stream = at::xpu::getCurrentXPUStream(); + auto queue = stream.queue(); + + using Kernel = compute_problem_sizes_sycl_K; + + auto dev_id = dpcppGetDeviceIdOfCurrentQueue(); + uint32_t max_wg_size = dpcppMaxWorkGroupSize(dev_id); + uint32_t max_tokens_per_expert = static_cast(sycl::min(max_wg_size, topk_length)); + + sycl::range<1> global_range{ num_experts * max_tokens_per_expert }; + sycl::range<1> local_range{ max_tokens_per_expert }; + + Kernel task(topk_ids, problem_sizes1, problem_sizes2, atomic_buffer, num_experts, topk_length, n, k, max_tokens_per_expert); + + sycl_kernel_submit(global_range, local_range, queue, task); + return; +} + + +struct compute_expert_offsets_sycl_k { + compute_expert_offsets_sycl_k( + const int32_t* problem_sizes1, + int32_t* expert_offsets, + int32_t* atomic_buffer, + const uint32_t num_experts) + : problem_sizes1_(problem_sizes1), + expert_offsets_(expert_offsets), + atomic_buffer_(atomic_buffer), + num_experts_(num_experts) {} + + void operator()(sycl::nd_item<1> item) const { + uint32_t tot_offset = 0; + expert_offsets_[0] = 0; + for (int i = 0; i < num_experts_; ++i) { + atomic_buffer_[i] = tot_offset; + tot_offset += problem_sizes1_[i * 3]; + expert_offsets_[i + 1] = tot_offset; + } + } + + const int32_t* problem_sizes1_; + int32_t* expert_offsets_; + int32_t* atomic_buffer_; + const uint32_t num_experts_; +}; + +void compute_expert_offsets_sycl( + const int32_t* problem_sizes1, + int32_t* expert_offsets, + int32_t* atomic_buffer, + const uint32_t num_experts) { + + auto stream = at::xpu::getCurrentXPUStream(); + auto queue = stream.queue(); + + using Kernel = compute_expert_offsets_sycl_k; + + Kernel task(problem_sizes1, expert_offsets, atomic_buffer, num_experts); + + sycl_kernel_submit(1, 1, queue, task); + return; +} + +struct compute_expert_blockscale_offsets_sycl_K { + compute_expert_blockscale_offsets_sycl_K( + const int32_t* problem_sizes1, + int32_t* expert_offsets, + int32_t* blockscale_offsets, + int32_t* atomic_buffer, + const uint32_t num_experts) + : problem_sizes1_(problem_sizes1), + expert_offsets_(expert_offsets), + blockscale_offsets_(blockscale_offsets), + atomic_buffer_(atomic_buffer), + num_experts_(num_experts) {} + + void operator()(sycl::nd_item<1> item) const { + int32_t tot_offset = 0; + int32_t tot_rounded_offset = 0; + expert_offsets_[0] = 0; + blockscale_offsets_[0] = 0; + for (int i = 0; i < num_experts_; ++i) { + atomic_buffer_[i] = tot_offset; + int num_tokens = problem_sizes1_[i * 3]; + int rounded_num_tokens = (num_tokens + (block_size - 1)) / block_size * block_size; + tot_offset += num_tokens; + tot_rounded_offset += rounded_num_tokens; + expert_offsets_[i + 1] = tot_offset; + blockscale_offsets_[i + 1] = tot_rounded_offset; + } + } + + const int32_t* problem_sizes1_; + int32_t* expert_offsets_; + int32_t* blockscale_offsets_; + int32_t* atomic_buffer_; + const uint32_t num_experts_; +}; + +void compute_expert_blockscale_offsets_sycl( + const int32_t* problem_sizes1, + int32_t* expert_offsets, + int32_t* blockscale_offsets, + int32_t* atomic_buffer, + const int64_t num_experts) { + + auto stream = at::xpu::getCurrentXPUStream(); + auto queue = stream.queue(); + + using Kernel = compute_expert_blockscale_offsets_sycl_K; + + Kernel task(problem_sizes1, expert_offsets, blockscale_offsets, atomic_buffer, num_experts); + + sycl_kernel_submit(1, 1, queue, task); + return; +} + + +struct compute_arg_sorts_sycl_K { + compute_arg_sorts_sycl_K( + const int32_t* topk_ids, + int32_t* input_permutation, + int32_t* output_permutation, + int32_t* atomic_buffer, + const uint32_t topk_length, + const uint32_t topk, + const uint32_t num_experts, + const uint32_t max_tokens_per_expert) + : topk_ids_(topk_ids), + input_permutation_(input_permutation), + output_permutation_(output_permutation), + atomic_buffer_(atomic_buffer), + topk_length_(topk_length), + topk_(topk), + num_experts_(num_experts), + max_tokens_per_expert_(max_tokens_per_expert) {} + + void operator()(sycl::nd_item<1> item) const { + int expert_id = item.get_group(0); + + sycl::atomic_ref< + int32_t, + sycl::memory_order::relaxed, + sycl::memory_scope::work_group, + sycl::access::address_space::generic_space + > atomic_counter(atomic_buffer_[expert_id]); + + for (int i = item.get_local_id(0); i < topk_length_; i += max_tokens_per_expert_) { + if (topk_ids_[i] == expert_id) { + int start = atomic_counter.fetch_add(1); + input_permutation_[start] = i / topk_; + output_permutation_[i] = start; + } + } + } + + const int32_t* topk_ids_; + int32_t* input_permutation_; + int32_t* output_permutation_; + int32_t* atomic_buffer_; + const uint32_t topk_length_; + const uint32_t topk_; + const uint32_t num_experts_; + const uint32_t max_tokens_per_expert_; +}; + +void compute_arg_sorts_sycl( + const int32_t* topk_ids, + int32_t* input_permutation, + int32_t* output_permutation, + int32_t* atomic_buffer, + const uint32_t topk_length, + const uint32_t topk, + const uint32_t num_experts) { + + + auto stream = at::xpu::getCurrentXPUStream(); + auto queue = stream.queue(); + + using Kernel = compute_arg_sorts_sycl_K; + + auto dev_id = dpcppGetDeviceIdOfCurrentQueue(); + uint32_t max_wg_size = dpcppMaxWorkGroupSize(dev_id); + uint32_t max_tokens_per_expert = static_cast(sycl::min(max_wg_size, topk_length)); + + sycl::range<1> global_range{ num_experts * max_tokens_per_expert }; + sycl::range<1> local_range{ max_tokens_per_expert }; + + Kernel task(topk_ids, input_permutation, output_permutation, atomic_buffer, topk_length, topk, num_experts, max_tokens_per_expert); + + sycl_kernel_submit(global_range, local_range, queue, task); + return; + +} + +void prepare_moe_input( + const torch::Tensor& topk_ids, + torch::Tensor& expert_offsets, + const std::optional& blockscale_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, + torch::Tensor& output_permutation, + const int64_t num_experts, + const int64_t n, + const int64_t k) { + TORCH_CHECK(topk_ids.dtype() == torch::kInt32); + + auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); + torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); + + compute_problem_sizes_sycl( + static_cast(topk_ids.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(atomic_buffer.data_ptr()), + num_experts, + topk_ids.numel(), + n, + k); + + if (blockscale_offsets.has_value()) { + compute_expert_blockscale_offsets_sycl( + static_cast(problem_sizes1.data_ptr()), + static_cast(expert_offsets.data_ptr()), + static_cast(blockscale_offsets.value().data_ptr()), + static_cast(atomic_buffer.data_ptr()), + num_experts); + } else { + compute_expert_offsets_sycl( + static_cast(problem_sizes1.data_ptr()), + static_cast(expert_offsets.data_ptr()), + static_cast(atomic_buffer.data_ptr()), + num_experts); + } + + compute_arg_sorts_sycl( + static_cast(topk_ids.data_ptr()), + static_cast(input_permutation.data_ptr()), + static_cast(output_permutation.data_ptr()), + static_cast(atomic_buffer.data_ptr()), + topk_ids.numel(), + topk_ids.size(1), + num_experts); + + return; +} + +template +struct ShuffleRows { + ShuffleRows( + const T* input, + const int32_t* dst2src_map, + T* output, + int64_t num_src_rows, + int64_t num_dest_rows, + int64_t num_cols) + : input_(input), + dst2src_map_(dst2src_map), + output_(output), + num_src_rows_(num_src_rows), + num_dest_rows_(num_dest_rows), + num_cols_(num_cols) {} + + void operator()(sycl::nd_item<1> item) const { + int gid = item.get_global_linear_id(); + int tid = item.get_local_linear_id(); + // Leave it to compiler for simd sub-group + if (gid < num_dest_rows_ * num_cols_) { + int64_t dest_token_idx = item.get_group(0); + int64_t const source_token_idx = dst2src_map_[dest_token_idx]; + + auto source_val = input_[source_token_idx * num_cols_ + tid]; + output_[dest_token_idx * num_cols_ + tid] = source_val; + } + } + const T* input_; + const int32_t* dst2src_map_; + T* output_; + int64_t num_src_rows_; + int64_t num_dest_rows_; + int64_t num_cols_; +}; + +template +void shuffle_rows_kernel_impl(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) { + auto input = reinterpret_cast(input_tensor.data_ptr()); + auto dst2srcmap = reinterpret_cast(dst2src_map.data_ptr()); + auto output = reinterpret_cast(output_tensor.data_ptr()); + + int64_t num_src_rows = input_tensor.size(0); + unsigned long num_dest_rows = output_tensor.size(0); + unsigned long num_cols = input_tensor.size(1); + + auto stream = at::xpu::getCurrentXPUStream(); + auto queue = stream.queue(); + + using Kernel = ShuffleRows; + sycl::range<1> global_range{ num_dest_rows * num_cols }; + sycl::range<1> local_range{ num_cols }; + + Kernel task(input, dst2srcmap, output, num_src_rows, num_dest_rows, num_cols); + + sycl_kernel_submit(global_range, local_range, queue, task); + return; + +} + +void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) { + TORCH_CHECK( + input_tensor.scalar_type() == output_tensor.scalar_type(), + "Input and output tensors must have the same data type"); + SYCL_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, at::ScalarType::Half, input_tensor.scalar_type(), "shuffle_rows_kernel_impl", [&]() { + shuffle_rows_kernel_impl(input_tensor, dst2src_map, output_tensor); + }); + return; +} + +template +struct ApplyShuffleMulSum { + ApplyShuffleMulSum( + const T* input, + T* output, + const int32_t* dst2src_map, + const T1* factors, + const int64_t topk, + const int64_t hidden_dim) + : input_(input), + output_(output), + dst2src_map_(dst2src_map), + factors_(factors), + topk_(topk), + hidden_dim_(hidden_dim) {} + + void operator()(sycl::nd_item<1> item) const { + + int out_tkn_id = item.get_group(0); + float sum_val = 0; + + for (int k = 0; k < topk_; ++k) { + int src_perm_offset = out_tkn_id * topk_ + k; + int src_index = dst2src_map_[src_perm_offset]; + T src_val = input_[src_index * hidden_dim_ + item.get_local_id(0)]; + T1 weight = 0; + if (factors_ != nullptr) { + weight = factors_[out_tkn_id * topk_ + k]; + } + sum_val += weight * src_val; + } + output_[out_tkn_id * hidden_dim_ + item.get_local_id(0)] = sum_val; + } + const T* input_; + const int32_t* dst2src_map_; + T* output_; + const T1* factors_; + const unsigned long topk_; + const unsigned long hidden_dim_; +}; + +template +void apply_shuffle_mul_sum_impl( + const T* input, + T* output, + const int32_t* dst2src_map, + const T1* factors, + const unsigned long out_tkns, + const unsigned long out_hidden_dims, + const int topk) { + + auto stream = at::xpu::getCurrentXPUStream(); + auto queue = stream.queue(); + + using Kernel = ApplyShuffleMulSum; + sycl::range<1> global_range{ out_tkns * out_hidden_dims }; + sycl::range<1> local_range{ out_hidden_dims }; + + Kernel task(input, output, dst2src_map, factors, topk, out_hidden_dims); + + sycl_kernel_submit(global_range, local_range, queue, task); + return; + +} + +void apply_shuffle_mul_sum( + const torch::Tensor& input, + torch::Tensor& output, + const torch::Tensor& permutation, + const std::optional& factors) { + int m = output.size(0); + int topk = int(permutation.size(0) / m); + SYCL_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, at::ScalarType::Half, input.scalar_type(), "apply_shuffle_mul_sum", [&]() { + using input_t = scalar_t; + if (factors.has_value()) { + SYCL_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, at::ScalarType::Half, factors.value().scalar_type(), "factors dispatch", [&]() { + using factors_t = scalar_t; + apply_shuffle_mul_sum_impl( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(permutation.data_ptr()), + reinterpret_cast(factors->data_ptr()), + output.size(0), + output.size(1), + topk + ); + }); + } else { + apply_shuffle_mul_sum_impl( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(permutation.data_ptr()), + nullptr, + output.size(0), + output.size(1), + topk + ); + } + }); + return; +} diff --git a/src/sycl/kernels/moe/topk_softmax.cpp b/src/sycl/kernels/moe/topk_softmax.cpp new file mode 100644 index 0000000..2de00a0 --- /dev/null +++ b/src/sycl/kernels/moe/topk_softmax.cpp @@ -0,0 +1,267 @@ +#include +#include +#include + +#include + +#include "../../SYCLHelpers.h" +#include "../../Utils.h" + +namespace at::native::xpu { + +namespace TopKSoftmaxImpl { + +template +struct FusedTopkSoftmax { + static constexpr int sub_group_size = 32; + static constexpr int max_group_size = 1024; + static constexpr int malloc_per_item = 8; + static constexpr float kNegInfinity = -std::numeric_limits::infinity(); + + FusedTopkSoftmax( + float* topk_weights, + int* topk_ids, + const T* gating_output, + const bool renormalize, + const int tokens, + const int experts, + const int top_k) + : topk_weights(topk_weights), + topk_ids(topk_ids), + gating_output(gating_output), + renormalize(renormalize), + tokens(tokens), + experts(experts), + top_k(top_k) {} + + static inline sycl::nd_range<3> get_nd_range(const int tokens, const int experts) { + int calc_per_item = div_up(experts, sub_group_size); + int group_size = div_up(experts, calc_per_item); + group_size = group_size < sub_group_size ? sub_group_size : group_size; + group_size = group_size < max_group_size ? group_size : max_group_size; + int sub_groups_per_group = div_up(group_size, sub_group_size); + group_size = sub_groups_per_group * sub_group_size; + int global_size = div_up(tokens, sub_groups_per_group); + + sycl::range<3> local(1, 1, group_size); + sycl::range<3> global(1, 1, global_size); + return sycl::nd_range<3>(global * local, local); + } + + static inline T Sigmoid(T x) { + float sycl_x = static_cast(x); + float result = 1.0f / (1.0f + sycl::exp(-sycl_x)); + return static_cast(result); + } + + [[sycl::reqd_sub_group_size(sub_group_size)]] void operator()(sycl::nd_item<3> item) const { + int group_id = item.get_group_linear_id(); + int local_range = item.get_local_range(2); + int sub_groups_per_group = local_range / sub_group_size; + int calc_per_item = (experts + sub_group_size - 1) / sub_group_size; + + sycl::sub_group sg = item.get_sub_group(); + int sg_id = sg.get_group_id(); + int sg_local_id = sg.get_local_id(); + + int tid = group_id * sub_groups_per_group + sg_id; + + if (tid >= tokens) { + return; // Out of bounds + } + + T local_elems[malloc_per_item]; + int local_idx[malloc_per_item]; + + int start_offset = sg_local_id * calc_per_item; + int local_num = calc_per_item; + + if (start_offset + local_num >= experts) { + local_num = experts - start_offset; + if (local_num < 0) { + local_num = 0; // No elements to process + } + } + + for (int e = 0; e < calc_per_item; ++e) { + local_elems[e] = kNegInfinity; + local_idx[e] = -1; + } + + for (int e = 0; e < local_num; ++e) { + local_elems[e] = gating_output[tid * experts + start_offset + e]; + local_idx[e] = start_offset + e; + } + + // Perform top-k selection + T topk_weights_local[malloc_per_item]; + int topk_ids_local[malloc_per_item]; + + for (int k = 0; k < top_k; ++k) { + T k_max = kNegInfinity; + int k_max_idx = -1; + int remove_ix = -1; + for (int e = 0; e < calc_per_item; ++e) { + T my_val = local_elems[e]; + int my_idx = local_idx[e]; + for (int offset = sub_group_size / 2; offset > 0; offset /= 2) { + T other_val = sycl::permute_group_by_xor(sg, my_val, offset); + int other_idx = sycl::permute_group_by_xor(sg, my_idx, offset); + if (other_val > my_val || (other_val == my_val && other_idx < my_idx)) { + my_val = other_val; + my_idx = other_idx; + } + } + if (my_val > k_max || (my_val == k_max && my_idx < k_max_idx)) { + k_max = my_val; + k_max_idx = my_idx; + + if (k_max_idx == local_idx[e]) { + remove_ix = e; // Mark this index for removal + } else + remove_ix = -1; + } + } + topk_weights_local[k] = k_max; + topk_ids_local[k] = k_max_idx; + if (remove_ix != -1) { + // Reset the score to avoid re-selection + local_elems[remove_ix] = kNegInfinity; + local_idx[remove_ix] = -1; + remove_ix = -1; + } + } + + float max_score = topk_weights_local[0]; + float sum_exp = 0; + + for (int i = 0; i < top_k; ++i) { + float score = topk_weights_local[i]; + sum_exp += sycl::exp(score - max_score); + } + + for (int e = 0; e < calc_per_item; ++e) { + float score = local_elems[e]; + float my_val = sycl::exp(score - max_score); + for (int offset = sub_group_size / 2; offset > 0; offset /= 2) { + float other_val = sycl::permute_group_by_xor(sg, my_val, offset); + my_val += other_val; + } + sum_exp += my_val; + } + + for (int i = 0; i < top_k; ++i) { + float score = topk_weights_local[i]; + topk_weights_local[i] = sycl::exp(score - max_score) / sum_exp; + } + + if (renormalize) { + // Renormalize the top-k weights + float sum = 0; + for (int i = 0; i < top_k; ++i) { + sum += topk_weights_local[i]; + } + if (sum > 0) { + for (int i = 0; i < top_k; ++i) { + topk_weights_local[i] /= sum; + } + } + } + + if (sg_local_id == 0) { + int offset = tid * top_k; + for (int i = 0; i < top_k; ++i) { + topk_weights[offset + i] = topk_weights_local[i]; + if (topk_ids_local[i] < 0 || topk_ids_local[i] >= experts) { + // Ensure valid index + topk_ids[offset + i] = 0; + continue; + } + topk_ids[offset + i] = topk_ids_local[i]; + } + } + } + float* topk_weights; + int* topk_ids; + const T* gating_output; + const bool renormalize; + const int tokens; + const int experts; + const int top_k; +}; + +template +void launch_fused_topk_softmax( + sycl::queue& queue, + const T* gating_output, + float* topk_weights, + int* topk_indices, + const bool renormalize, + const int top_k, + const int num_tokens, + const int num_experts) { + using Kernel = FusedTopkSoftmax; + auto range = Kernel::get_nd_range(num_tokens, num_experts); + + auto global_range = range.get_global_range(); + auto local_range = range.get_local_range(); + + Kernel task(topk_weights, topk_indices, gating_output, renormalize, num_tokens, num_experts, top_k); + + sycl_kernel_submit(global_range, local_range, queue, task); + return; +} + +template +void fused_topk_softmax( + const T* gating_output, + float* topk_weights, + int* topk_indices, + const bool renormalize, + const int num_tokens, + const int num_experts, + const int topk) { + auto stream = at::xpu::getCurrentXPUStream(); + auto queue = stream.queue(); + + launch_fused_topk_softmax( + queue, gating_output, topk_weights, topk_indices, renormalize, topk, num_tokens, num_experts); +}; +}; // namespace TopKSoftmaxImpl + +/** + * @brief Perform topk after softmax on gating_output. + * @param topk_weights The topk_weights tensor of shape [n_tokens, n_topk]. + * @param topk_indices The topk_indices tensor of shape [n_tokens, n_topk]. + * @param gating_output The gating output tensor of shape [n_tokens, n_experts]. + * @param renormalize The renormalize bool whether the topk_weights needs to be renormalized. + * @return void. + */ +void topk_softmax(at::Tensor& topk_weights, at::Tensor& topk_indices, at::Tensor& gating_output, bool renormalize) { + auto shape = gating_output.sizes().vec(); + TORCH_CHECK(shape.size() == 2, "gating_output must be 2D tensor, but got ", shape.size(), "D"); + int64_t n_tokens = shape[0]; + int64_t n_experts = shape[1]; + + TORCH_CHECK(n_experts <= 128, "n_experts only support up to 128, but got ", n_experts); + + TORCH_CHECK(topk_weights.scalar_type() == at::kFloat, "topk_weights should be Float"); + TORCH_CHECK(topk_indices.scalar_type() == at::kInt, "topk_indices should be Int"); + + constexpr int64_t alignment = 8; + int64_t n_experts_aligned = div_up(n_experts, alignment) * alignment; // align to 8 + + int64_t n_topk = topk_weights.size(1); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(gating_output.scalar_type(), "fused_topk_softmax_kernel", [&]() { + TopKSoftmaxImpl::fused_topk_softmax( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + renormalize, + n_tokens, + n_experts, + n_topk); + }); +} +} // namespace at::native::xpu diff --git a/src/torch_extension_sycl.cc b/src/torch_extension_sycl.cc index e826b35..aab34c6 100644 --- a/src/torch_extension_sycl.cc +++ b/src/torch_extension_sycl.cc @@ -54,6 +54,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { m.impl("rotary_embedding", torch::kXPU, &at::native::xpu::rotary_embedding); m.def( + "prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor? blockscale_offsets, Tensor problem_sizes1," + " Tensor problem_sizes2, Tensor input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> " + "()"); + m.impl("prepare_moe_input", torch::kXPU, &prepare_moe_input); + m.def("shuffle_rows(Tensor input, Tensor dst2src_map, Tensor output) -> ()"); + m.impl("shuffle_rows", torch::kXPU, &shuffle_rows); + m.def("apply_shuffle_mul_sum(Tensor input, Tensor output, Tensor permutation, Tensor? factors) -> ()"); + m.impl("apply_shuffle_mul_sum", torch::kXPU, &apply_shuffle_mul_sum); "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " "experts_ids, Tensor! num_tokens_post_pad, Tensor! cumsum_buffer, bool " "pad_sorted_token_ids) -> ()"); diff --git a/tests/test_moe_prepare_input.py b/tests/test_moe_prepare_input.py new file mode 100644 index 0000000..3fce79b --- /dev/null +++ b/tests/test_moe_prepare_input.py @@ -0,0 +1,96 @@ +import itertools + +import pytest +import torch +from sgl_kernel import prepare_moe_input, shuffle_rows, apply_shuffle_mul_sum + +@pytest.mark.parametrize("num_tokens", [5, 16, 128]) +@pytest.mark.parametrize("num_experts", [4, 8, 32]) +@pytest.mark.parametrize("top_k", [2]) +@pytest.mark.parametrize("hidden_dims", [16, 32, 64]) +def test_prepare_input_moe(num_tokens, num_experts, top_k, hidden_dims): + torch.manual_seed(41) + # Generate unique token + def generate_unique_topk_ids(tokens, top_k, num_experts): + topk_ids = torch.empty((tokens, top_k), dtype=torch.int32) + #avoid duplicate tokens + for T in range(tokens): + topk_ids[T] = torch.randperm(num_experts, dtype=torch.int32)[:top_k] + return topk_ids + + def prepare_input_moe_ref(topk_ids, expert_offsets, blockscale_offsets, problem_sizes1, problem_sizes2, input_permutation, output_permutation, num_experts, hidden_dim, top_k): + tokens, top_k = topk_ids.shape + expert_cnt = torch.zeros(num_experts, dtype=torch.int32) + for e in range(num_experts): + expert_cnt[e] = (topk_ids == e).sum() + + for e in range(num_experts): + r = expert_cnt[e].item() + c = hidden_dim + problem_sizes1[e * 3 + 0] = r + problem_sizes1[e * 3 + 1] = c * 2 + problem_sizes1[e * 3 + 2] = top_k + problem_sizes2[e * 3 + 0] = r + problem_sizes2[e * 3 + 1] = top_k + problem_sizes2[e * 3 + 2] = c + + # compute offsets + atomic_buffer = torch.zeros(num_experts, dtype=torch.int32) + tot_offset = 0 + expert_offsets[0] = 0 + for i in range(num_experts): + atomic_buffer[i] = tot_offset + tot_offset += problem_sizes1[i * 3].item() + expert_offsets[i + 1] = tot_offset + + # compute input/output permutes + num_tokens = topk_ids.size(0) + flat_topk = topk_ids.flatten() + topk_length = num_tokens * top_k + + for i in range(topk_length): + expert_id = int(flat_topk[i]) + start = int(atomic_buffer[expert_id].item()) + atomic_buffer[expert_id] += 1 + + input_permutation[start] = i // top_k + output_permutation[i] = start + + #routing that generate unique tokens + topk_ids = generate_unique_topk_ids(num_tokens, top_k, num_experts) + expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int32) + problem_sizes1 = torch.zeros(num_experts * 3, dtype=torch.int32) + problem_sizes2 = torch.zeros(num_experts * 3, dtype=torch.int32) + + flat_topk = topk_ids.flatten() + input_permutation = torch.empty_like(flat_topk) + output_permutation = torch.empty_like(flat_topk) + blocksclae_offset = None + + device = "xpu" + topk_ids_xpu = topk_ids.clone().to(device) + expert_offsets_xpu = expert_offsets.clone().to(device) + problem_sizes1_xpu = problem_sizes1.clone().to(device) + problem_sizes2_xpu = problem_sizes2.clone().to(device) + input_permutation_xpu = torch.empty_like(flat_topk).to(device) + output_permutation_xpu = torch.empty_like(flat_topk).to(device) + # generate reference permutations on cpu + prepare_input_moe_ref(topk_ids, expert_offsets, blocksclae_offset, problem_sizes1, problem_sizes2, input_permutation, output_permutation, num_experts, hidden_dims, top_k) + # prepare moe inputs on xpu + prepare_moe_input(topk_ids_xpu, expert_offsets_xpu, problem_sizes1_xpu, problem_sizes2_xpu, input_permutation_xpu, output_permutation_xpu, num_experts, hidden_dims, top_k,blocksclae_offset) + # validate expert offsets + torch.testing.assert_close(expert_offsets, expert_offsets_xpu.to("cpu")) + torch.testing.assert_close(problem_sizes1, problem_sizes1_xpu.to("cpu")) + torch.testing.assert_close(problem_sizes2, problem_sizes2_xpu.to("cpu")) + + input_tensor = torch.randn(num_tokens , hidden_dims, dtype=torch.float32) + input_tensor_xpu = input_tensor.clone().to(device) + output_tensor_xpu = shuffle_rows(input_tensor_xpu, input_permutation_xpu, (num_tokens * top_k, hidden_dims)) + + input_merge_xpu = torch.empty((num_tokens, hidden_dims), dtype=torch.float32, device=device) + # apply weights + factors = torch.ones(top_k * num_tokens, dtype=torch.float32, device=device).fill_(0.5) + apply_shuffle_mul_sum(output_tensor_xpu, input_merge_xpu, output_permutation_xpu, factors) + # of smae order as in input + torch.testing.assert_allclose(input_merge_xpu.to("cpu"), input_tensor) +