diff --git a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp index 3688a4536ae89a..bd6738d37f4ac9 100644 --- a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp @@ -291,10 +291,9 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr REGISTER_PASS(manager, ConstantFolding) REGISTER_PASS(manager, SymbolicOptimizations) REGISTER_PASS(manager, ResolveNameCollisions, true); - // todo: enable after plugin support for MoE - // Remove pytestmark to enable e2e test: + // TODO: Remove pytestmark to enable e2e test: // tests/model_hub_tests/transformation_tests/test_moe_transformation.py - // REGISTER_PASS(manager, FuseMOE) + REGISTER_PASS(manager, FuseMOE) REGISTER_PASS(manager, VectorizedMOE2GEMMTransposeWeights) manager.run_passes(f); diff --git a/src/plugins/intel_gpu/include/intel_gpu/op/moe_3gemm_fused_compressed.hpp b/src/plugins/intel_gpu/include/intel_gpu/op/moe_3gemm_fused_compressed.hpp new file mode 100644 index 00000000000000..d24ea8ca0d8118 --- /dev/null +++ b/src/plugins/intel_gpu/include/intel_gpu/op/moe_3gemm_fused_compressed.hpp @@ -0,0 +1,46 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "intel_gpu/op/moe_compressed.hpp" + +namespace ov::intel_gpu::op { + +/// \brief MOE3GemmFusedCompressed that support compressed and fused MOE for GEMM3_SWIGLU. +class MOE3GemmFusedCompressed : public MOECompressed { +public: + OPENVINO_OP("MOE3GemmFusedCompressed", "gpu_opset", MOECompressed); + + MOE3GemmFusedCompressed() = default; + + /// \brief Constructs a MOE3GemmFusedCompressed operation with config only + /// \param args The input tensors, in the following order: + /// 0: hidden_states - input tensor with hidden representations + /// 1: routing_weights - [num_seq, num_experts] routing weights for all experts + /// 2: w0_weight - expert weights for first projection, + /// shape [num_experts, inter_size, group_num, group_size] + /// 3: w0_scale - expert scale for first projection for compressed experts, + /// shape [num_experts, inter_size, group_num, 1] + /// 4: w0_zp - expert zp for first projection for compressed experts, + /// shape [num_experts, inter_size, group_num, 1] + /// 5: w1_weight - expert weights for second projection, + /// shape [num_experts, inter_size, group_num, group_size] + /// 6: w1_scale - expert scale for second projection for compressed experts, + /// shape [num_experts, inter_size, group_num, 1] + /// 7: w1_zp - expert zp for second projection for compressed experts, + /// shape [num_experts, inter_size, group_num, 1] + /// 8: w2_weight - expert weights for final projection, + /// shape [num_experts, hidden_size, group_num, group_size] + /// 9: w2_scale - expert scale for final projection for compressed experts, + /// shape [num_experts, hidden_size, group_num, 1] + /// 10: w2_zp - expert zp for final projection for compressed experts, + /// shape [num_experts, hidden_size, group_num, 1] + /// \param config Configuration for the MOE 3GEMM SWIGLU fused operation + MOE3GemmFusedCompressed(const OutputVector& args, const MOECompressed::Config config); + + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; +}; + +} // namespace ov::intel_gpu::op diff --git a/src/plugins/intel_gpu/include/intel_gpu/op/moe_compressed.hpp b/src/plugins/intel_gpu/include/intel_gpu/op/moe_compressed.hpp index 5e08e3b52f5271..264c93dc441963 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/op/moe_compressed.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/op/moe_compressed.hpp @@ -15,6 +15,7 @@ class MOECompressed : public ov::op::internal::MOE { OPENVINO_OP("MOECompressed", "gpu_opset", ov::op::internal::MOE); MOECompressed() = default; + MOECompressed(const OutputVector& args) : MOE(args) {} struct Config : public MOE::Config { size_t hidden_size = 0; diff --git a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp index f36e22c707411f..d2f0d755a77bec 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp @@ -312,4 +312,5 @@ REGISTER_FACTORY(internal, PagedAttentionExtension); REGISTER_FACTORY(internal, LoraSubgraph); REGISTER_FACTORY(internal, LoraSubgraphFused); REGISTER_FACTORY(internal, VLSDPA); +REGISTER_FACTORY(internal, MOE3GemmFusedCompressed); REGISTER_FACTORY(internal, MOECompressed); diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_3gemm_fused_compressed.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_3gemm_fused_compressed.hpp new file mode 100644 index 00000000000000..80ee073a61c008 --- /dev/null +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_3gemm_fused_compressed.hpp @@ -0,0 +1,72 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include + +#include "intel_gpu/op/moe_3gemm_fused_compressed.hpp" +#include "intel_gpu/runtime/engine.hpp" +#include "primitive.hpp" + +namespace cldnn { +using MOE3GemmFusedCompressed = ov::intel_gpu::op::MOE3GemmFusedCompressed; + +/// @brief moe compressed primitive +/// @details Performs moe compressed +struct moe_3gemm_fused_compressed : public primitive_base { + CLDNN_DECLARE_PRIMITIVE(moe_3gemm_fused_compressed) + + moe_3gemm_fused_compressed() : primitive_base("", {}) {} + + // @brief Constructs moe primitive / layer. + // + // @param id An identifier of new primitive. + // @param inputs A list of Input primitive ids (inputs). + // 0: hidden_states - input tensor with hidden representations + // 1: routing_weights - [num_seq, num_experts] routing weights for all experts + // 2: w0_weight - expert weights for first projection, + // shape [num_experts, inter_size, group_num, group_size] + // 3: w0_scale - expert scale for first projection for compressed experts, + // shape [num_experts, inter_size, group_num, 1] + // 4: w0_zp - expert zp for first projection for compressed experts, + // shape [num_experts, inter_size, group_num, 1] + // 5: w1_weight - expert weights for second projection, + // shape [num_experts, inter_size, group_num, group_size] + // 6: w1_scale - expert scale for second projection for compressed experts, + // shape [num_experts, inter_size, group_num, 1] + // 7: w1_zp - expert zp for second projection for compressed experts, + // shape [num_experts, inter_size, group_num, 1] + // 8: w2_weight - expert weights for final projection, + // shape [num_experts, hidden_size, group_num, group_size] + // 9: w2_scale - expert scale for final projection for compressed experts, + // shape [num_experts, hidden_size, group_num, 1] + // 10: w2_zp - expert zp for final projection for compressed experts, + // + moe_3gemm_fused_compressed(const primitive_id& id, const std::vector& inputs, const MOE3GemmFusedCompressed::Config& config) + : primitive_base(id, inputs, 1, {optional_data_type()}), + _config(config) {} + + MOE3GemmFusedCompressed::Config _config; + + bool operator==(const primitive& rhs) const override { + if (!compare_common_params(rhs)) + return false; + + auto rhs_casted = downcast(rhs); + + return std::memcmp(&_config, &rhs_casted._config, sizeof(_config)) == 0; + } + + void save(BinaryOutputBuffer& ob) const override { + primitive_base::save(ob); + ob << make_data(&_config, sizeof(_config)); + } + + void load(BinaryInputBuffer& ib) override { + primitive_base::load(ib); + ib >> make_data(&_config, sizeof(_config)); + } +}; + +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_3gemm_swiglu_opt.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_3gemm_swiglu_opt.cpp new file mode 100644 index 00000000000000..a68f18353965cd --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_3gemm_swiglu_opt.cpp @@ -0,0 +1,1137 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "moe_3gemm_swiglu_opt.hpp" + +#ifdef ENABLE_ONEDNN_FOR_GPU +# include +# include +# include +# include +# include +# include +# include + +# include "../primitive_ocl_base.hpp" +# include "../utils/kernel_generator.hpp" +# include "common_utils/jitter.hpp" +# include "debug_helper.hpp" +# include "intel_gpu/graph/kernel_impl_params.hpp" +# include "intel_gpu/primitives/moe_3gemm_fused_compressed.hpp" +# include "intel_gpu/runtime/lru_cache.hpp" +# include "intel_gpu/runtime/stream.hpp" +# include "intel_gpu/runtime/utils.hpp" +# include "moe_3gemm_fused_inst.h" +# include "ocl_v2/utils/fused_ops_jitter.hpp" +# include "ocl_v2/utils/jitter.hpp" +# include "primitive_inst.h" + +namespace ov::intel_gpu::ocl { + +namespace { + +using namespace ov::intel_gpu::ocl; + +dnnl::memory::data_type convert_data_type(cldnn::data_types dt) { + switch (dt) { + case cldnn::data_types::f32: + return dnnl::memory::data_type::f32; + case cldnn::data_types::f16: + return dnnl::memory::data_type::f16; + case cldnn::data_types::i8: + return dnnl::memory::data_type::s8; + case cldnn::data_types::u8: + return dnnl::memory::data_type::u8; + case cldnn::data_types::i32: + return dnnl::memory::data_type::s32; + case cldnn::data_types::i4: + return dnnl::memory::data_type::s4; + case cldnn::data_types::u4: + return dnnl::memory::data_type::u4; + default: + throw std::invalid_argument("[clDNN] Unsupported conversion from cldnn to onednn type"); + } +} + +struct onednn_matmul { + dnnl::matmul m_prim; + dnnl::memory::desc m_wei_md; + dnnl::memory::data_type m_w_type; + dnnl::memory::data_type m_a_type; // activation dtype + dnnl::memory::dim m_K; + dnnl::memory::dim m_N; + dnnl::memory::dim m_M; + dnnl::memory::dim m_K_groups; + + dnnl::primitive_attr attr; + dnnl::post_ops postops; + + onednn_matmul(dnnl::memory::data_type act_dtype, dnnl::memory::data_type weight_dtype, int batch_size, int ic, int oc, int ic_group_size = -1) { + m_a_type = act_dtype; + m_w_type = weight_dtype; + m_K_groups = 0; + m_K = ic; + m_N = oc; + m_M = DNNL_RUNTIME_DIM_VAL; + if (batch_size > 0) { + // jit-gemm kernel only support static batch size + m_M = batch_size; + } + if (ic_group_size >= 0) { + w_scale(ic_group_size).w_zp(ic_group_size).fpmath_f16(); + } + } + + onednn_matmul& w_scale(int k_group_size) { + if (k_group_size <= 0) { + m_K_groups = 1; + // per-OC, no grouping in K dimension + attr.set_scales(DNNL_ARG_WEIGHTS, (0 << 0) + (1 << 1), {1}, dnnl::memory::data_type::f16); + } else { + OPENVINO_ASSERT((k_group_size % 32) == 0); + OPENVINO_ASSERT((m_K % k_group_size) == 0); + m_K_groups = m_K / k_group_size; + attr.set_scales(DNNL_ARG_WEIGHTS, (1 << 0) + (1 << 1), {k_group_size, 1}, dnnl::memory::data_type::f16); + } + return *this; + } + + onednn_matmul& w_zp(int k_group_size) { + if (k_group_size <= 0) { + OPENVINO_ASSERT(m_K_groups == 1); + attr.set_zero_points(DNNL_ARG_WEIGHTS, (0 << 0) + (1 << 1), {1}, m_w_type); + } else { + OPENVINO_ASSERT((m_K % k_group_size) == 0); + m_K_groups = (m_K / k_group_size); + attr.set_zero_points(DNNL_ARG_WEIGHTS, (1 << 0) + (1 << 1), {k_group_size, 1}, m_w_type); + } + return *this; + } + + onednn_matmul& fpmath_f16() { + attr.set_fpmath_mode(dnnl::fpmath_mode::f16, true); + return *this; + } + onednn_matmul& post_op_silu() { + float alpha = 1.0f; + float beta = 0.0f; + postops.append_eltwise(dnnl::algorithm::eltwise_swish, alpha, beta); + return *this; + } + onednn_matmul& post_op_bin_mul(bool per_oc = true) { + dnnl::memory::dim batch_size = m_M; + if (batch_size == DNNL_RUNTIME_DIM_VAL) + batch_size = 1024 * 1024; // big enough fake static batch + + dnnl::memory::desc bin_mul_md = dnnl::memory::desc(dnnl::memory::dims({batch_size, per_oc ? m_N : 1}), m_a_type, dnnl::memory::format_tag::ab); + postops.append_binary(dnnl::algorithm::binary_mul, bin_mul_md); + return *this; + } + + onednn_matmul& post_op_sum(float scale = 1.f, int32_t zero_point = 0) { + postops.append_sum(scale, zero_point, dnnl::memory::data_type::undef); + return *this; + } + + void create(dnnl::engine eng) { + if (postops.len() > 0) { + attr.set_post_ops(postops); + } + + dnnl::memory::desc src_md = dnnl::memory::desc(dnnl::memory::dims({m_M, m_K}), m_a_type, dnnl::memory::format_tag::ab); + dnnl::memory::desc dst_md = dnnl::memory::desc(dnnl::memory::dims({m_M, m_N}), m_a_type, dnnl::memory::format_tag::ab); + + // use fixed weight-layout to prevent shape-dependent weight-layout changes + dnnl::memory::desc wei_md = dnnl::memory::desc(dnnl::memory::dims({m_K, m_N}), m_w_type, dnnl::memory::format_tag::ba); + + // Create primitive descriptor. + auto matmul_pd = dnnl::matmul::primitive_desc(eng, src_md, wei_md, dst_md, attr); + + // Pre-packed weights stored as int8_t + m_wei_md = matmul_pd.weights_desc(); + + // Create the primitive. + m_prim = dnnl::matmul(matmul_pd); + } + + // this creator is for predefined matmul primitive types + enum class type { + none, + with_bin_mul, + with_bin_mul_per_row, + with_bin_mul_per_row_sum, + with_silu, + with_silu_bin_mul, + }; + int bin_post_id = -1; + bool bin_per_row = false; + onednn_matmul(dnnl::engine eng, + dnnl::memory::data_type act_dtype, + dnnl::memory::data_type weight_dtype, + int batch, + int ic, + int oc, + int ic_group_size, + type t) + : onednn_matmul(act_dtype, weight_dtype, batch, ic, oc, ic_group_size) { + if (t == type::with_bin_mul) { + bin_post_id = 0; + post_op_bin_mul(true); + } + if (t == type::with_bin_mul_per_row) { + bin_post_id = 0; + bin_per_row = true; + post_op_bin_mul(false); + } + if (t == type::with_bin_mul_per_row_sum) { + bin_post_id = 0; + bin_per_row = true; + post_op_bin_mul(false); + post_op_sum(); + } + if (t == type::with_silu) + post_op_silu(); + if (t == type::with_silu_bin_mul) { + bin_post_id = 1; + post_op_silu(); + post_op_bin_mul(true); + } + + create(eng); + } +}; + +// all jit-based/performance-aware function should be a functor/callable because: +// - it needs to hold reference to kernel (to save build time & resources) +// - it needs to do other compile time preparation work and hold the relevant +// runtime-data-struct (to make runtime faster) +// to optimize compile-time-workload itself, the functor instance itself should be +// cached with compile-time parameter as the key. +// +// because it's a functor, which supposed to have no states, so cache-factory should +// always return shared_ptr to constant object, so it won't behave differently when being +// called by different caller, and this also ensure it's multi-threading safe since it +// won't modify it's content. +// +template +class tuple_hasher { +private: + typedef std::tuple Tuple; + template + size_t hash(Tuple& value) const { + return 0; + } + template + size_t hash(Tuple& value) const { + constexpr int Index = N - sizeof...(TTail) - 1; + return std::hash()(std::get(value)) ^ hash(value); + } + +public: + size_t operator()(Tuple value) const { + auto hv = hash(value); + return hv; + } +}; + +// create const object with internal cache with constructor-args as the key +// this helps reduces construction time overhead, and perfectly suitable +// for caching functor/callable. +template +std::shared_ptr make_cacheable(dnnl::engine eng, CArgs... cargs) { + std::shared_ptr sptr; + auto key = std::make_tuple(cargs...); + static std::unordered_map, tuple_hasher> cache; + static std::mutex mutex; + std::lock_guard guard(mutex); + auto it = cache.find(key); + if (it != cache.end()) { + auto& wptr = it->second; + sptr = wptr.lock(); + if (!sptr) { + sptr = std::make_shared(eng, cargs...); + // ECOUT("make_cacheable re-constructed: ", typeid(T).name(), "(", cargs..., ")"); + wptr = sptr; + } + } else { + sptr = std::make_shared(eng, cargs...); + // ECOUT("make_cacheable constructed: ", typeid(T).name(), "(", cargs..., ")"); + cache.emplace(std::make_pair(key, std::weak_ptr(sptr))); + } + return sptr; +} + +struct onednn_linear { + std::shared_ptr mm; + dnnl::memory weight; + dnnl::memory scale; + dnnl::memory zp; + dnnl::matmul m_prim; + dnnl::memory::dim m_K; + dnnl::memory::dim m_N; + dnnl::memory::dim m_batch; + dnnl::memory::data_type m_a_type; + int bin_post_id; + + static onednn_linear create(dnnl::engine eng, + dnnl::memory::data_type act_dtype, + dnnl::memory::data_type weight_dtype, + int batch, + int ic, + int oc, + int ic_group_size, + onednn_matmul::type t, + dnnl::memory weight, // external weight + dnnl::memory scale, + dnnl::memory zp) { + OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("onednn_linear::create()")); + auto mm = make_cacheable(eng, act_dtype, weight_dtype, batch, ic, oc, ic_group_size, t); + onednn_linear linear; + linear.mm = mm; + linear.bin_post_id = mm->bin_post_id; + linear.m_prim = mm->m_prim; + linear.m_K = mm->m_K; + linear.m_N = mm->m_N; + linear.m_batch = batch; + linear.m_a_type = mm->m_a_type; + linear.weight = weight; + + if (scale) { + // https://uxlfoundation.github.io/oneDNN/page_weights_decompression_matmul_cpp.html + // Quantization Group size for scales. Must be divisible by 32. + linear.scale = scale; + if (zp) { + linear.zp = zp; + } + } + return linear; + } + + void forward(dnnl::stream& stream, int m, dnnl::memory src_mem, dnnl::memory dst_mem, dnnl::memory bin_mem) { + OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("onednn_linear::forward()")); + dnnl::memory::dim M = m; + + if (!(m_batch == 0 || m_batch == M)) { + OPENVINO_THROW("onednn_linear::forward(): invalid batch size m_batch=", m_batch, " M=", M); + } + + std::unordered_map args; + args.insert({DNNL_ARG_SRC, src_mem}); + args.insert({DNNL_ARG_WEIGHTS, weight}); + // args.insert({DNNL_ARG_BIAS, bias_mem}); + args.insert({DNNL_ARG_DST, dst_mem}); + + if (scale) { + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, scale}); + } + if (zp) { + args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, zp}); + } + if (bin_mem) { + args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(bin_post_id) | DNNL_ARG_SRC_1, bin_mem}); + } + m_prim.execute(stream, args); + } +}; + +class MoE3GemmSwigluSoftMaxTopK : public KernelGenerator { +public: + MoE3GemmSwigluSoftMaxTopK() : KernelGenerator("moe_3gemm_swiglu_fuse", "softmax_topk") {} + +protected: + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + auto desc = params.typed_desc(); + jit.make("SOFTMAX_TOPK_ENABLE", 1); + jit.make("TOP_K", desc->_config.top_k); + jit.make("VALUE_NUM", desc->_config.num_expert); + jit.make("MOE_DTYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); + jit.make("MOE_DTYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); + return jit; + } + + [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{nullptr}; + } +}; + +class MoE3GemmSwigluGather : public KernelGenerator { +public: + MoE3GemmSwigluGather() : KernelGenerator("moe_3gemm_swiglu_fuse", "gather") {} + +protected: + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + auto desc = params.typed_desc(); + auto& engine = params.prog->get_engine(); + const auto& info = engine.get_device_info(); + jit.make("GATHER_ENABLE", 1); + jit.make("HIDDEN_SIZE", desc->_config.hidden_size); + jit.make("MOE_DTYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); + jit.make("MOE_DTYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); + jit.make("SUBGROUP_SIZE", info.arch >= gpu_arch::xe2 ? 32 : 16); + return jit; + } + + [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{nullptr}; + } +}; + +class MoE3GemmSwigluScatter : public KernelGenerator { +public: + MoE3GemmSwigluScatter() : KernelGenerator("moe_3gemm_swiglu_fuse", "index_add") {} + +protected: + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + auto desc = params.typed_desc(); + jit.make("SCATTER_ENABLE", 1); + jit.make("HIDDEN_SIZE", desc->_config.hidden_size); + jit.make("MOE_DTYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); + jit.make("MOE_DTYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); + return jit; + } + + [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{nullptr}; + } +}; + +// Performance tuning parameters +# define N_BLOCK 4 +# define SUBGROUP_NUM 8 + +static void add_common_consts(const RuntimeParams& params, JitConstants& jit) { + auto desc = params.typed_desc(); + auto& engine = params.prog->get_engine(); + const auto& info = engine.get_device_info(); + jit.make("MAX_TOPK", desc->_config.top_k); + jit.make("EXPERT_NUM", desc->_config.num_expert); + jit.make("HIDDEN_SIZE", desc->_config.hidden_size); + jit.make("INTERMEDIATE_SIZE", desc->_config.inter_size); + jit.make("N_BLOCK", N_BLOCK); + jit.make("SUBGROUP_SIZE", info.arch >= gpu_arch::xe2 ? 32 : 16); + jit.make("SUBGROUP_NUM", SUBGROUP_NUM); + jit.make("GROUP_SIZE", desc->_config.group_size); + jit.make("MOE_DTYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); + jit.make("MOE_DTYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); +} + +class MoE3GemmSwigluMLPGateUp : public KernelGenerator { +public: + MoE3GemmSwigluMLPGateUp() : KernelGenerator("moe_3gemm_swiglu_mlp", "gate_up") {} + +protected: + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + auto desc = params.typed_desc(); + add_common_consts(params, jit); + jit.make("GATE_UP_ENABLE", 1); + return jit; + } + + [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{nullptr}; + } +}; + +class MoE3GemmSwigluMLPDown : public KernelGenerator { +public: + MoE3GemmSwigluMLPDown() : KernelGenerator("moe_3gemm_swiglu_mlp", "down") {} + +protected: + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + auto desc = params.typed_desc(); + add_common_consts(params, jit); + jit.make("DOWN_ENABLE", 1); + return jit; + } + + [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{nullptr}; + } +}; + +class MoE3GemmSwigluMLPReduce : public KernelGenerator { +public: + MoE3GemmSwigluMLPReduce() : KernelGenerator("moe_3gemm_swiglu_mlp", "reduce") {} + +protected: + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + auto desc = params.typed_desc(); + add_common_consts(params, jit); + jit.make("REDUCE_ENABLE", 1); + return jit; + } + + [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{nullptr}; + } +}; + +dnnl::memory convert2dnnl(const memory::ptr& ptr, const std::vector& dim, dnnl::memory::format_tag tag, int64_t offset = 0) { + OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("convert2dnnl")); + return ptr->get_onednn_memory(dnnl::memory::desc(dnnl::memory::dims(dim), convert_data_type(ptr->get_layout().data_type), tag), offset); +} + +class moe_3gemm_swiglu_opt_impl : public PrimitiveImplOCL { +public: + DECLARE_OBJECT_TYPE_SERIALIZATION(ov::intel_gpu::ocl::MoE3GemmSwigluImpl) + Stage::Ptr softmax_topk = make_stage(); + Stage::Ptr gather = make_stage(); + Stage::Ptr scatter = make_stage(); + Stage::Ptr mlp_gate_up = make_stage(); + Stage::Ptr mlp_down = make_stage(); + Stage::Ptr mlp_reduce = make_stage(); + + struct dnnl_weights { + dnnl::memory weight; + dnnl::memory scale; + dnnl::memory zp; + int ic, oc, ic_group_size; + }; + + // expert_mask result in cpu side + struct expert_mask_cpu { + std::vector pred_flag; + // shape: [expert_num, batch_no] + std::vector> batch; + // shape: [expert_num, topk_no] + std::vector> topk; + }; + + // store expert_mask for gpu kernel + struct expert_mask_gpu { + memory::ptr batch; + memory::ptr topk; + }; + + struct moe_fusion_weights_base_addr { + memory::ptr weight[3]; // gate/up/down weights, experts fusion + memory::ptr scale[3]; + memory::ptr zp[3]; + memory::ptr bias[3]; + } moe_fusion_weights; + + struct scratch_buffers { + // softmax+topk + memory::ptr topk_id; + memory::ptr topk_weights; + + // fast single batch: scratch.up = up(x) * silu(gate(x)) + // scratch.y = down(scratch.up) * routing_weights + memory::ptr up; + memory::ptr y; + // onednn: scratch.x, scratch.routing_weights = gather(x, ...) + // scratch.up = up(scratch.x) + // scratch.gate = gate(scratch.x) * scratch.up + // scratch.y = down(scratch.gate) * routing_weights + memory::ptr x; + memory::ptr routing_weights; + memory::ptr gate; + // buffers for batch and topk from cpu, each expert has one + std::vector expert_masks; + + moe_fusion_weights_base_addr moe_fusion_wei_addr; + memory::ptr input_routing_weights; + memory::ptr input_router_topk_idx; + }; + + std::vector> _dnnl_weights; + int _hidden_size; + int _intermediate_size; + int _group_size; + + moe_3gemm_swiglu_opt_impl() : PrimitiveImplOCL(moe_3gemm_swiglu_opt::get_type_info_static()) {} + moe_3gemm_swiglu_opt_impl(const program_node& node, const RuntimeParams& params) : moe_3gemm_swiglu_opt_impl() { + init(node.as().get_primitive()); + + add_stage(softmax_topk, params); + add_stage(gather, params); + add_stage(scatter, params); + add_stage(mlp_gate_up, params); + add_stage(mlp_down, params); + add_stage(mlp_reduce, params); + } + + void init(const std::shared_ptr& cur_moe) { + _hidden_size = static_cast(cur_moe->_config.hidden_size); + _intermediate_size = static_cast(cur_moe->_config.inter_size); + _group_size = static_cast(cur_moe->_config.group_size); + } + + void init_dnnl_weights(const std::shared_ptr& cur_moe, + cldnn::engine& engine, + const struct moe_fusion_weights_base_addr& moe_fusion_wei_addr) { + if (_dnnl_weights.size() == cur_moe->_config.num_expert) + return; + init(cur_moe); + + _dnnl_weights.resize(cur_moe->_config.num_expert); + for (size_t j = 0; j < cur_moe->_config.num_expert; j++) { + auto& dnnl_weights = _dnnl_weights[j]; + dnnl_weights.resize(3); + dnnl_weights[0].ic = _hidden_size; + dnnl_weights[0].ic_group_size = _group_size; + dnnl_weights[0].oc = _intermediate_size; + dnnl_weights[1].ic = _hidden_size; + dnnl_weights[1].ic_group_size = _group_size; + dnnl_weights[1].oc = _intermediate_size; + dnnl_weights[2].ic = _intermediate_size; + dnnl_weights[2].ic_group_size = _group_size; + dnnl_weights[2].oc = _hidden_size; + for (int i = 0; i < 3; i++) { + // weight shape: [ic, oc], type: u4 + int64_t wei_offset = j * dnnl_weights[i].ic * dnnl_weights[i].oc / 2; + dnnl_weights[i].weight = + convert2dnnl(moe_fusion_wei_addr.weight[i], {dnnl_weights[i].ic, dnnl_weights[i].oc}, dnnl::memory::format_tag::ba, wei_offset); + + // scale shape: [ic / ic_group_size, oc], type: f16 + int64_t scale_offset = j * dnnl_weights[i].ic * dnnl_weights[i].oc / dnnl_weights[i].ic_group_size * 2; + dnnl_weights[i].scale = convert2dnnl(moe_fusion_wei_addr.scale[i], + {dnnl_weights[i].ic / dnnl_weights[i].ic_group_size, dnnl_weights[i].oc}, + dnnl::memory::format_tag::ab, + scale_offset); + + // zp shape: [ic / ic_group_size, oc], type: u4 + int64_t zp_offset = j * dnnl_weights[i].ic * dnnl_weights[i].oc / dnnl_weights[i].ic_group_size / 2; + dnnl_weights[i].zp = convert2dnnl(moe_fusion_wei_addr.zp[i], + {dnnl_weights[i].ic / dnnl_weights[i].ic_group_size, dnnl_weights[i].oc}, + dnnl::memory::format_tag::ab, + zp_offset); + } + } + } + + void load(BinaryInputBuffer& ib) override { + PrimitiveImplOCL::load(ib); + const kernel_impl_params* impl_params = reinterpret_cast(ib.getKernelImplParams()); + init(impl_params->typed_desc()); + } + + [[nodiscard]] std::unique_ptr clone() const override { + auto cur_moe = make_deep_copy(this); + cur_moe->_dnnl_weights = _dnnl_weights; + cur_moe->_hidden_size = _hidden_size; + cur_moe->_intermediate_size = _intermediate_size; + cur_moe->_group_size = _group_size; + return cur_moe; + } + + std::vector get_internal_buffer_descs(const kernel_impl_params& params) const override { + auto cur_moe = params.typed_desc(); + const auto& config = cur_moe->_config; + int max_topk = static_cast(config.top_k); + int expert_num = static_cast(config.num_expert); + + auto hidden_states_layout = params.input_layouts[0]; + auto batch = static_cast(hidden_states_layout.get_shape()[0]); + auto data_type = hidden_states_layout.data_type; + + std::vector internal_buffers; + // softmax+topk + layout layout_topk_id(ov::PartialShape{batch, max_topk}, data_types::u32, cldnn::format::bfyx); + layout layout_topk_weights(ov::PartialShape{batch, max_topk}, data_type, cldnn::format::bfyx); + internal_buffers.emplace_back(layout_topk_id, true); // 0: topk_id + internal_buffers.emplace_back(layout_topk_weights, true); // 1: topk_weights + // fast single batch: scratch.up = up(x) * silu(gate(x)); scratch.y = down(scratch.up) * weight[expert_no] + layout layout_gateup_out(ov::PartialShape{batch, static_cast(config.inter_size)}, data_type, cldnn::format::bfyx); + layout layout_down_out(ov::PartialShape{batch, static_cast(config.hidden_size)}, data_type, cldnn::format::bfyx); + internal_buffers.emplace_back(layout_gateup_out, true); // 2: up + internal_buffers.emplace_back(layout_down_out, true); // 3: y + // onednn: scratch.x, scratch.routing_weights = gather(x, ...) + // scratch.up = up(scratch.x) + // scratch.gate = gate(scratch.x) * scratch.up + // scratch.y = down(scratch.gate) * routing_weights + internal_buffers.emplace_back(layout_down_out, true); // 4: x, scratch.x has same layout with down output + layout routing_layout(ov::PartialShape{batch * max_topk}, data_type, cldnn::format::bfyx); + internal_buffers.emplace_back(routing_layout, true); // 5: routing_weights + internal_buffers.emplace_back(layout_gateup_out, true); // 6: gate, scratch.gate has same layout with up + // expert masks for gpu + layout index_layout(ov::PartialShape{expert_num, batch}, ov::element::i32, cldnn::format::bfyx); + internal_buffers.emplace_back(index_layout, true); // 7: batch + internal_buffers.emplace_back(index_layout, true); // 8: topk + + return internal_buffers; + } + + void prepare_internal_buffers(typed_primitive_inst& instance, scratch_buffers& scratch, size_t batch) { + const auto& intermediates_memories = instance.get_intermediates_memories(); + auto& engine = instance.get_network().get_engine(); + scratch.topk_id = intermediates_memories[0]; + scratch.topk_weights = intermediates_memories[1]; + scratch.up = intermediates_memories[2]; + scratch.y = intermediates_memories[3]; + if (batch > 1) { + scratch.x = intermediates_memories[4]; + scratch.routing_weights = intermediates_memories[5]; + scratch.gate = intermediates_memories[6]; + const auto& config = instance.get_typed_desc()->_config; + int expert_num = static_cast(config.num_expert); + scratch.expert_masks.resize(expert_num); + for (int i = 0; i < expert_num; i++) { + auto mask_layout = cldnn::layout({static_cast(batch)}, cldnn::data_types::i32, cldnn::format::get_default_format(1)); + scratch.expert_masks[i].batch = engine.create_subbuffer(*intermediates_memories[7], mask_layout, i * batch * sizeof(int32_t)); + scratch.expert_masks[i].topk = engine.create_subbuffer(*intermediates_memories[8], mask_layout, i * batch * sizeof(int32_t)); + } + } + + // gate + scratch.moe_fusion_wei_addr.weight[0] = instance.input_memory_ptr(static_cast(MOEInputIndex::WEIGHT_0)); + scratch.moe_fusion_wei_addr.scale[0] = instance.input_memory_ptr(static_cast(MOEInputIndex::SCALE_0)); + scratch.moe_fusion_wei_addr.zp[0] = instance.input_memory_ptr(static_cast(MOEInputIndex::ZP_0)); + + // up + scratch.moe_fusion_wei_addr.weight[1] = instance.input_memory_ptr(static_cast(MOEInputIndex::WEIGHT_1)); + scratch.moe_fusion_wei_addr.scale[1] = instance.input_memory_ptr(static_cast(MOEInputIndex::SCALE_1)); + scratch.moe_fusion_wei_addr.zp[1] = instance.input_memory_ptr(static_cast(MOEInputIndex::ZP_1)); + + // down + scratch.moe_fusion_wei_addr.weight[2] = instance.input_memory_ptr(static_cast(MOEInputIndex::WEIGHT_2)); + scratch.moe_fusion_wei_addr.scale[2] = instance.input_memory_ptr(static_cast(MOEInputIndex::SCALE_2)); + scratch.moe_fusion_wei_addr.zp[2] = instance.input_memory_ptr(static_cast(MOEInputIndex::ZP_2)); + } + + void get_expert_mask_from_gpu(const MOE3GemmFusedCompressed::Config& config, memory::ptr mem, stream& stream, expert_mask_cpu& expert_mask) { + // shape: [batch, topk] + auto layout = mem->get_layout(); + const auto& shape = layout.get_shape(); + + int max_expert_num = static_cast(config.num_expert); + int max_topk = static_cast(config.top_k); + int max_tokens = static_cast(shape[0]); + + expert_mask.pred_flag.resize(max_expert_num, 0); + expert_mask.batch.resize(max_expert_num, {}); + expert_mask.topk.resize(max_expert_num, {}); + + if (layout.data_padding) { + OPENVINO_THROW("get_expert_mask_from_memory not support padding"); + } + + std::vector buf(max_topk * max_tokens); + mem->copy_to(stream, buf.data(), 0, 0, buf.size() * sizeof(int32_t), true); + + for (int b = 0; b < max_tokens; b++) { + auto* tok_p = &buf[b * max_topk]; + for (int t = 0; t < max_topk; t++) { + auto expert_no = tok_p[t]; + if (expert_no >= max_expert_num) { + OPENVINO_THROW("expert_no ", expert_no, " exceed max_expert_num ", max_expert_num); + } + + expert_mask.batch[expert_no].push_back(b); + expert_mask.topk[expert_no].push_back(t + b * max_topk); + expert_mask.pred_flag[expert_no] = 1; + } + } + { + // check if the result is ok + int count = 0; + for (int no = 0; no < max_expert_num; no++) { + count += static_cast(expert_mask.batch[no].size()); + } + OPENVINO_ASSERT(count == max_topk * max_tokens, + "With max_expert_num=", + max_expert_num, + ",max_topk=", + max_topk, + ",max_tokens=", + max_tokens, + " should have ", + max_topk * max_tokens, + " tokens, but current is ", + count, + ". layout=", + layout); + } + } + + void copy_expert_mask_to_gpu(stream& stream, const expert_mask_cpu& expert_mask, size_t expert_no, expert_mask_gpu& expert_mask_mem) { + auto size = expert_mask.batch[expert_no].size() * sizeof(int); + + { + mem_lock lock_data{expert_mask_mem.batch, stream}; + memcpy(lock_data.data(), expert_mask.batch[expert_no].data(), size); + } + { + mem_lock lock_data{expert_mask_mem.topk, stream}; + memcpy(lock_data.data(), expert_mask.topk[expert_no].data(), size); + } + } + + cldnn::event::ptr execute_stage(const std::vector& events, + cldnn::primitive_inst& instance, + Stage& stage, + std::vector inputs, + std::vector outputs, + const std::vector& global, + const std::vector& local, + bool needs_completion_event = false) const { + OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("moe_3gemm_swiglu_opt_impl::execute_stage")); + cldnn::stream& stream = instance.get_network().get_stream(); + cldnn::kernel_arguments_data args; + cldnn::kernel_arguments_desc desc; + for (uint32_t i = 0; i < inputs.size(); i++) { + desc.arguments.push_back({ArgumentDescriptor::Types::INPUT, i}); + args.inputs.push_back(inputs[i]); + } + + for (uint32_t i = 0; i < outputs.size(); i++) { + desc.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, i}); + args.outputs.push_back(outputs[i]); + } + + stream.set_arguments(*stage.kernel, desc, args); + desc.workGroups.global = global; + desc.workGroups.local = local; + + return stream.enqueue_kernel(*stage.kernel, desc, {}, events, needs_completion_event); + } + + auto get_input_info(typed_primitive_inst& instance, int idx) { + auto mem = instance.input_memory_ptr(idx); + auto dep = instance.dependencies()[idx]; + auto layout = dep.first->get_impl_params()->get_output_layout(dep.second); + return std::make_tuple(mem, layout); + } + + cldnn::event::ptr exec_single_batch(typed_primitive_inst& instance, scratch_buffers& scratch) { + auto cur_moe = instance.get_typed_desc(); + int max_topk = static_cast(cur_moe->_config.top_k); + + auto final_hidden_states_mem_ptr = instance.output_memory_ptr(0); + auto batch_mem_ptr = scratch.topk_id; + auto [hidden_states_mem_ptr, hidden_states_layout] = get_input_info(instance, static_cast(MOEInputIndex::HIDDEN_STATES)); + auto routing_mem_ptr = scratch.topk_weights; + + _hidden_size = static_cast(cur_moe->_config.hidden_size); + _intermediate_size = static_cast(cur_moe->_config.inter_size); + + const size_t subgroup_size = instance.get_impl_params()->get_device_info().arch >= gpu_arch::xe2 ? 32 : 16; + const size_t max_work_group_size = instance.get_impl_params()->get_device_info().max_work_group_size; + + // gate + const auto& mlp_gate_wei_mem = scratch.moe_fusion_wei_addr.weight[0]; + const auto& mlp_gate_scale_mem = scratch.moe_fusion_wei_addr.scale[0]; + const auto& mlp_gate_zp_mem = scratch.moe_fusion_wei_addr.zp[0]; + + // up + const auto& mlp_up_wei_mem = scratch.moe_fusion_wei_addr.weight[1]; + const auto& mlp_up_scale_mem = scratch.moe_fusion_wei_addr.scale[1]; + const auto& mlp_up_zp_mem = scratch.moe_fusion_wei_addr.zp[1]; + + // down + const auto& mlp_down_wei_mem = scratch.moe_fusion_wei_addr.weight[2]; + const auto& mlp_down_scale_mem = scratch.moe_fusion_wei_addr.scale[2]; + const auto& mlp_down_zp_mem = scratch.moe_fusion_wei_addr.zp[2]; + event::ptr ret; + + { + // scratch.up = up(x) * silu(gate(x)) + execute_stage( + {}, + instance, + *mlp_gate_up, + {batch_mem_ptr, mlp_gate_wei_mem, mlp_gate_scale_mem, mlp_gate_zp_mem, mlp_up_wei_mem, mlp_up_scale_mem, mlp_up_zp_mem, hidden_states_mem_ptr}, + {scratch.up}, + {static_cast(max_topk), subgroup_size, static_cast(_intermediate_size / N_BLOCK)}, + {1, subgroup_size, SUBGROUP_NUM}); + + // scratch.y = down(scratch.up) * weight[expert_no] + execute_stage({}, + instance, + *mlp_down, + {batch_mem_ptr, mlp_down_wei_mem, mlp_down_scale_mem, mlp_down_zp_mem, scratch.up, routing_mem_ptr}, + {scratch.y}, + {static_cast(max_topk), subgroup_size, static_cast(_hidden_size / N_BLOCK)}, + {1, subgroup_size, SUBGROUP_NUM}); + + // final = sum(scratch.y) + ret = execute_stage({}, + instance, + *mlp_reduce, + {scratch.y}, + {final_hidden_states_mem_ptr}, + {static_cast(1), static_cast(_hidden_size)}, + {1, std::min(max_work_group_size, size_t{1024})}, + instance.needs_completion_event()); + } + return ret; + } + + struct onednn_kernel { + onednn_linear up; + onednn_linear gate; + onednn_linear down; + }; + struct PairHash { + template + size_t operator()(const std::pair& p) const { + // Combine hash values of the pair elements + return std::hash()(p.first) ^ std::hash()(p.second); + } + }; + + using lru_cache_hash = LruCache, std::shared_ptr, PairHash>; + lru_cache_hash _kernels = lru_cache_hash(1024); + onednn_kernel& get_kernel(int n_token, int expert_no, typed_primitive_inst& instance) { + auto key = std::make_pair(n_token, expert_no); + if (_kernels.has(key)) { + return *_kernels.get(key); + } + + auto& cur_net = instance.get_network(); + auto& stream = cur_net.get_stream(); + auto& dnn_stream = stream.get_onednn_stream(); + auto hidden_states_layout_dt = convert_data_type(instance.input_memory_ptr(static_cast(MOEInputIndex::HIDDEN_STATES))->get_layout().data_type); + + auto& dnnl_weights = _dnnl_weights[expert_no]; + auto kernel = std::make_shared(); + + // gate + auto gate_weight_layout_dt = convert_data_type(instance.input_memory_ptr(static_cast(MOEInputIndex::WEIGHT_0))->get_layout().data_type); + kernel->gate = onednn_linear::create(dnn_stream.get_engine(), + hidden_states_layout_dt, + gate_weight_layout_dt, + n_token, + dnnl_weights[0].ic, + dnnl_weights[0].oc, + dnnl_weights[0].ic_group_size, + onednn_matmul::type::with_silu_bin_mul, + dnnl_weights[0].weight, + dnnl_weights[0].scale, + dnnl_weights[0].zp); + + // up + auto up_weight_layout_dt = convert_data_type(instance.input_memory_ptr(static_cast(MOEInputIndex::WEIGHT_1))->get_layout().data_type); + kernel->up = onednn_linear::create(dnn_stream.get_engine(), + hidden_states_layout_dt, + up_weight_layout_dt, + n_token, + dnnl_weights[1].ic, + dnnl_weights[1].oc, + dnnl_weights[1].ic_group_size, + onednn_matmul::type::none, + dnnl_weights[1].weight, + dnnl_weights[1].scale, + dnnl_weights[1].zp); + + // down + auto down_weight_layout_dt = convert_data_type(instance.input_memory_ptr(static_cast(MOEInputIndex::WEIGHT_2))->get_layout().data_type); + kernel->down = onednn_linear::create(dnn_stream.get_engine(), + hidden_states_layout_dt, + down_weight_layout_dt, + n_token, + dnnl_weights[2].ic, + dnnl_weights[2].oc, + dnnl_weights[2].ic_group_size, + onednn_matmul::type::with_bin_mul_per_row, + dnnl_weights[2].weight, + dnnl_weights[2].scale, + dnnl_weights[2].zp); + _kernels.add(key, kernel); + return *_kernels.get(key); + } + + // inputs 0 is hidden_states, inputs 1 is router_logits[num_tokens, NUM_EXPERTS=128] + // extra step Softmax_TopK is fused to give topk-id & router_weights + // + // scratch.topk_id, scratch.full_router_weights = Softmax_TopK(router_logits) + // + // generate expert_mask from topk-id + // expert_mask.batch[i][j] : j'th token index for i'th expert + // expert_mask.topk[i][j] : topk-output offset for j'th token for i'th expert, used to get weights + // expert_mask.pred_flag[i]: bool, if expert i can be skipped + // + // + // scratch.x, scratch.routing_weights = gather(hidden_states, scratch.full_router_weights, expert_mask.batch, expert_mask.topk) + // scratch.y = MLP(scratch.x, .gate/up/down) * scratch.routing_weights + // scatter(final_hidden, scratch.y, expert_mask.batch) + // + cldnn::event::ptr execute(const std::vector& events, cldnn::primitive_inst& ins) override { + OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("moe_3gemm_swiglu_opt_impl::execute")); + auto& instance = reinterpret_cast&>(ins); + auto cur_moe = instance.get_typed_desc(); + const auto& config = cur_moe->_config; + int max_topk = static_cast(config.top_k); + auto& cur_net = instance.get_network(); + auto& stream = cur_net.get_stream(); + + auto [hidden_states_mem_ptr, hidden_states_layout] = get_input_info(instance, static_cast(MOEInputIndex::HIDDEN_STATES)); + auto batch = static_cast(hidden_states_layout.get_shape()[0]); + + scratch_buffers scratch; + prepare_internal_buffers(instance, scratch, batch); + + // softmax+topk + auto lws_size = cur_moe->_config.num_expert; + auto topk_event = execute_stage(events, + instance, + *softmax_topk, + {instance.input_memory_ptr(static_cast(MOEInputIndex::ROUTING_WEIGHTS))}, + {scratch.topk_id, scratch.topk_weights}, + {static_cast(batch), lws_size}, + {1, lws_size}); + + // Single batch is a special case, we don't need to do gather/scatter, + // and we can apply optimal kernels against memory bound to improve performance. + // It is very important for MoE's second token performance. + if (batch == 1) { + return exec_single_batch(instance, scratch); + } + + auto& engine = instance.get_network().get_engine(); + init_dnnl_weights(cur_moe, engine, scratch.moe_fusion_wei_addr); + auto final_hidden_states_mem_ptr = instance.output_memory_ptr(0); + auto final_hidden_states_layout = instance.get_output_layout(0); + + // onednn path will accumulate to the output + final_hidden_states_mem_ptr->fill(stream, false); + + // Wait for topk is ready + topk_event->wait(); + // [batch, max_topk] + auto topk_id_mem = scratch.topk_id; + + expert_mask_cpu expert_mask; + get_expert_mask_from_gpu(config, topk_id_mem, stream, expert_mask); + + auto& dnn_stream = stream.get_onednn_stream(); + cldnn::event::ptr result_event; + + auto routing_mem_ptr = scratch.topk_weights; + auto get_best_lws = [](size_t hidden_size) { + const size_t candidate[] = {128, 64, 32, 16, 8}; + for (size_t i = 0; i < sizeof(candidate) / sizeof(size_t); i++) { + if (hidden_size % candidate[i] == 0) { + return candidate[i]; + } + } + OPENVINO_THROW("hidden_size=", hidden_size, " is not divisible by any of ", sizeof(candidate) / sizeof(size_t), " candidates"); + }; + lws_size = get_best_lws(_hidden_size); + + if (batch <= 1) { + OPENVINO_THROW("batch size should be > 1 for this path!"); + } + for (size_t expert_no = 0; expert_no < config.num_expert; expert_no++) { + if (expert_no >= expert_mask.pred_flag.size()) { + OPENVINO_THROW("expert_no=", expert_no, " is out of bounds"); + } + auto can_skip_subgraph = !expert_mask.pred_flag[expert_no]; + if (can_skip_subgraph) { + continue; + } + auto& dnnl_weights = _dnnl_weights[expert_no]; + + // expert_mask + expert_mask_gpu& expert_mask_mem = scratch.expert_masks[expert_no]; + copy_expert_mask_to_gpu(stream, expert_mask, expert_no, expert_mask_mem); + + auto n_token = static_cast(expert_mask.batch[expert_no].size()); + onednn_kernel& kernel = get_kernel(n_token, static_cast(expert_no), instance); + + // gather + execute_stage(events, + instance, + *gather, + {hidden_states_mem_ptr, routing_mem_ptr, expert_mask_mem.batch, expert_mask_mem.topk}, + {scratch.x, scratch.routing_weights}, + {static_cast(n_token), static_cast(_hidden_size)}, + {1, lws_size}); + + // up + kernel.up.forward(dnn_stream, + n_token, + convert2dnnl(scratch.x, {static_cast(n_token), dnnl_weights[1].ic}, dnnl::memory::format_tag::ab), + convert2dnnl(scratch.up, {static_cast(n_token), _intermediate_size}, dnnl::memory::format_tag::ab), + dnnl::memory()); + + // gate + kernel.gate.forward(dnn_stream, + n_token, + convert2dnnl(scratch.x, {static_cast(n_token), dnnl_weights[0].ic}, dnnl::memory::format_tag::ab), + convert2dnnl(scratch.gate, {static_cast(n_token), _intermediate_size}, dnnl::memory::format_tag::ab), + convert2dnnl(scratch.up, {static_cast(n_token), _intermediate_size}, dnnl::memory::format_tag::ab)); + + // down + kernel.down.forward(dnn_stream, + n_token, + convert2dnnl(scratch.gate, {static_cast(n_token), _intermediate_size}, dnnl::memory::format_tag::ab), + convert2dnnl(scratch.y, {static_cast(n_token), _hidden_size}, dnnl::memory::format_tag::ab), + convert2dnnl(scratch.routing_weights, {n_token * max_topk}, dnnl::memory::format_tag::a)); + // index_add + result_event = execute_stage(events, + instance, + *scatter, + {scratch.y, expert_mask_mem.batch}, + {final_hidden_states_mem_ptr}, + {static_cast(n_token), static_cast(_hidden_size)}, + {1, lws_size}, + instance.needs_completion_event()); + } + + return result_event; + } +}; + +} // namespace + +std::unique_ptr moe_3gemm_swiglu_opt::create_impl(const program_node& node, const RuntimeParams& params) const { + assert(node.is_type()); + return std::make_unique(node, params); +} + +} // namespace ov::intel_gpu::ocl + +BIND_BINARY_BUFFER_WITH_TYPE(cldnn::moe_3gemm_fused_compressed) +BIND_BINARY_BUFFER_WITH_TYPE(ov::intel_gpu::ocl::moe_3gemm_swiglu_opt_impl) + +#else + +namespace ov::intel_gpu::ocl { + +std::unique_ptr moe_3gemm_swiglu_opt::create_impl(const program_node& node, const RuntimeParams& params) const { + OPENVINO_THROW("moe_3gemm_swiglu_opt depends on onednn."); +} + +} // namespace ov::intel_gpu::ocl + +#endif diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_3gemm_swiglu_opt.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_3gemm_swiglu_opt.hpp new file mode 100644 index 00000000000000..8b3fa3d8c9c548 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe/moe_3gemm_swiglu_opt.hpp @@ -0,0 +1,90 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "intel_gpu/primitives/activation.hpp" +#include "intel_gpu/primitives/eltwise.hpp" +#include "program_node.h" +#include "registry/implementation_manager.hpp" + +using namespace cldnn; // TODO: Remove once namespaces are aligned +namespace ov::intel_gpu::ocl { + +// mlp_gate: 0 +// mlp_up: 1 +// mlp_down: 2 +enum class MOEInputIndex : uint8_t { + HIDDEN_STATES = 0, + ROUTING_WEIGHTS = 1, + WEIGHT_0 = 2, + SCALE_0 = 3, + ZP_0 = 4, + WEIGHT_1 = 5, + SCALE_1 = 6, + ZP_1 = 7, + WEIGHT_2 = 8, + SCALE_2 = 9, + ZP_2 = 10 +}; + +struct moe_3gemm_swiglu_opt : public ImplementationManager { + OV_GPU_PRIMITIVE_IMPL("ocl::moe::moe_3gemm_swiglu_opt") + explicit moe_3gemm_swiglu_opt(shape_types shape_type, ValidateFunc vf = nullptr) : ImplementationManager(impl_types::ocl, shape_type, std::move(vf)) {} + [[nodiscard]] std::unique_ptr create_impl(const program_node& node, const RuntimeParams& params) const override; + [[nodiscard]] bool validate_impl(const program_node& node) const override { + static constexpr std::array supported_fmts = { + format::bfyx, + }; + + // TODO(MOE): support more precision + static constexpr std::array supported_types = { + ov::element::f16, + }; + + const auto& in0_layout = node.get_input_layout(static_cast(MOEInputIndex::HIDDEN_STATES)); + const auto& out_layout = node.get_output_layout(0); + if (!one_of(in0_layout.format, supported_fmts) || !one_of(out_layout.format, supported_fmts)) { + return false; + } + + if (!one_of(in0_layout.data_type, supported_types) || !one_of(out_layout.data_type, supported_types)) { + return false; + } + + // Only support weight: u4 + static constexpr std::array supported_wei_type = { + ov::element::u4, + }; + const auto& wei_layout = node.get_input_layout(static_cast(MOEInputIndex::WEIGHT_0)); + if (!one_of(wei_layout.data_type, supported_wei_type)) { + return false; + } + + // Only support scale: f16 + static constexpr std::array supported_scale_type = { + ov::element::f16, + }; + const auto& scale_layout = node.get_input_layout(static_cast(MOEInputIndex::SCALE_0)); + if (!one_of(scale_layout.data_type, supported_scale_type)) { + return false; + } + + // Only support zp: u4 + static constexpr std::array supported_zp_type = { + ov::element::u4, + }; + const auto& zp_layout = node.get_input_layout(static_cast(MOEInputIndex::ZP_0)); + if (!one_of(zp_layout.data_type, supported_zp_type)) { + return false; + } + + return true; + } +}; + +} // namespace ov::intel_gpu::ocl diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_3gemm_swiglu_fuse.cl b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_3gemm_swiglu_fuse.cl new file mode 100644 index 00000000000000..b321525ec1b7ff --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_3gemm_swiglu_fuse.cl @@ -0,0 +1,132 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#if SOFTMAX_TOPK_ENABLE + +KERNEL(softmax_topk)( + const __global MOE_DTYPE* input, // [input_batch, sort_in_num] + __global uint* output_index, // [input_batch, TOP_K] + __global MOE_DTYPE* output // [input_batch, TOP_K] +) { + // gws [batch, sort_in_num] + const uint batch = (uint)get_global_id(0); + const uint sort_index = (uint)get_global_id(1); + const uint sort_cnt = (uint)get_global_size(1); + + input += batch * sort_cnt + sort_index; + + uint sort_position = 0; + + __local MOE_DTYPE local_input[VALUE_NUM]; + __local MOE_DTYPE local_output[TOP_K]; + __local uint local_index[TOP_K]; + + MOE_DTYPE in_value = as_half(intel_sub_group_block_read_us((const __global ushort*)(input))); + local_input[sort_index] = in_value; + barrier(CLK_LOCAL_MEM_FENCE); + + __attribute__((opencl_unroll_hint(8))) + for(uint i = 0; i < sort_index; i++) { + MOE_DTYPE value = local_input[i]; + if(value >= in_value) { + sort_position++; + } + } + + __attribute__((opencl_unroll_hint(8))) + for(uint i = sort_index; i < sort_cnt; i++) { + MOE_DTYPE value = local_input[i]; + if(value > in_value) { + sort_position++; + } + } + if (sort_position < TOP_K) { + local_output[sort_position] = in_value; + local_index[sort_position] = sort_index; + } + barrier(CLK_LOCAL_MEM_FENCE); + + if(sort_position == 0) { + float softmax_total = 1.0; + MOE_DTYPE max_v = local_output[0]; + local_output[0] = 1; + for(uint i = 1; i < TOP_K; i++) { + local_output[i] = native_exp(local_output[i] - max_v); + softmax_total += local_output[i]; + } + output_index += batch * TOP_K; + output += batch * TOP_K; + + for(uint i = 0; i < TOP_K; i++) { + output[i] = local_output[i]/softmax_total; + output_index[i] = local_index[i]; + } + } +} + +#elif GATHER_ENABLE +__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) +KERNEL (gather_2d_ref)( + const __global MOE_DTYPE* src_tok, // input tokens [total_token, hidden_size] - hidden_states_mem_ptr + const __global MOE_DTYPE* src_rweight, // topk_weights [total_token, topk_experts] + __global int * tok_index, // token index [expert_idx][] = [actual_token_num] - expert_mask_mem.batch + __global int * top_index, // topk index [expert_idx][] = [actual_token_num] - expert_mask_mem.topk + __global MOE_DTYPE* dst_tok, // output tokens [batch_size, hidden_size] - scratch.x + __global MOE_DTYPE* dst_rweight) { // output topk_weights [batch_size] - scratch.routing_weights + + int k = get_global_id(0); // token_idx + int off = get_global_id(1); // hidden_size offset + int tok_idx = tok_index[k]; + + src_tok += tok_idx * HIDDEN_SIZE; + dst_tok += k * HIDDEN_SIZE; + + if (off >= HIDDEN_SIZE) { + printf("Warning off >= HIDDEN_SIZE: k = %d, off = %d, HIDDEN_SIZE = %d\n", k, off, HIDDEN_SIZE); + return; + } + + #if MOE_DTYPE_SIZE == 2 + ushort value = intel_sub_group_block_read_us((const __global ushort *)(src_tok + off)); + intel_sub_group_block_write_us((__global ushort *)(dst_tok + off), value); + #elif MOE_DTYPE_SIZE == 4 + uint value = intel_sub_group_block_read((const __global uint *)(src_tok + off)); + intel_sub_group_block_write((__global uint *)(dst_tok + off), value); + #else + dst_tok[off] = src_tok[off]; + #endif + + if (off == 0) { + int top_idx = top_index[k]; + dst_rweight[k] = src_rweight[top_idx]; + } +} + +#elif SCATTER_ENABLE +KERNEL (index_add_)(const __global MOE_DTYPE* src_tok, + __global int * tok_index, + __global MOE_DTYPE* dst_tok) { + + int k = get_global_id(0); + int off = get_global_id(1); + int tok_idx = tok_index[k]; + + src_tok += k * HIDDEN_SIZE; + dst_tok += tok_idx * HIDDEN_SIZE; + + #if MOE_DTYPE_SIZE == 2 + half src_value = as_half(intel_sub_group_block_read_us((const __global ushort *)(src_tok + off))); + half dst_value = as_half(intel_sub_group_block_read_us((const __global ushort *)(dst_tok + off))); + half value = dst_value + src_value; + intel_sub_group_block_write_us((__global ushort *)(dst_tok + off), as_ushort(value)); + #elif MOE_DTYPE_SIZE == 4 + float src_value = as_float(intel_sub_group_block_read((const __global uint *)(src_tok + off))); + float dst_value = as_float(intel_sub_group_block_read((const __global uint *)(dst_tok + off))); + float value = dst_value + src_value; + intel_sub_group_block_write_us((__global ushort *)(dst_tok + off), as_uint(value)); + #else + dst_tok[off] += src_tok[off]; + #endif +} +#endif diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_3gemm_swiglu_mlp.cl b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_3gemm_swiglu_mlp.cl new file mode 100644 index 00000000000000..c2a5c0be808b66 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_3gemm_swiglu_mlp.cl @@ -0,0 +1,319 @@ + +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#define unroll_for __attribute__((opencl_unroll_hint)) for + +#if GATE_UP_ENABLE +inline void gemv_n2x(const __global uchar* weight, + __global half* scales, + __global uchar* zps, + const __global half* x, + __global half* y, int N, int K, + half* x2, + float* xg_sum, + const bool silu) { + int num_sg = get_num_sub_groups(); + int id_sg = get_sub_group_id(); + int id_local = get_sub_group_local_id(); + + int n_start = get_global_id(2) * N_BLOCK; + int n_end = n_start + N_BLOCK; + unroll_for (int n = n_start; n < n_end; n+=2) { + const __global uchar* B = weight + n * K / 2; + float sum_all0 = 0; + float sum_all1 = 0; +#if SZ_LAYOUT == 0 + __global half* S = scales + n; + __global uchar* Z = zps + n / 2; + unroll_for (int gk = 0; gk < K / GROUP_SIZE; gk++, S += N, Z += N / 2) { + half s0 = S[0]; + half s1 = S[1]; + ushort z = Z[0]; + half z_hf0 = convert_half(z & 0xf); + half z_hf1 = convert_half(z >> 4); +#else + __global half* S = scales + n*K/GROUP_SIZE; + __global uchar* Z = zps + n*K/GROUP_SIZE/2; + + half scale_values = as_half(intel_sub_group_block_read_us((const __global ushort*)S)); + uchar zp_values = intel_sub_group_block_read_uc((const __global uchar*)Z); + half zp_even = convert_half(zp_values & 0xF); + half zp_odd = convert_half(zp_values >> 4); + unroll_for (int gk = 0; gk < K / GROUP_SIZE; gk++) { + half s0 = sub_group_broadcast(scale_values, 2*gk + 0); + half s1 = sub_group_broadcast(scale_values, 2*gk + 1); + half z_hf0 = sub_group_broadcast(zp_even, gk); + half z_hf1 = sub_group_broadcast(zp_odd, gk); +#endif + +#if SUBGROUP_SIZE == 32 + half2 sum0; + half2 sum1; + half4 a = as_half4(intel_sub_group_block_read_us4((const __local ushort*)x2 + gk*GROUP_SIZE)); + uchar2 b = intel_sub_group_block_read_uc2((const __global uchar*)B + gk*GROUP_SIZE/2); + uchar2 b2 = intel_sub_group_block_read_uc2((const __global uchar*)(B + (K/2) + gk*GROUP_SIZE/2)); + + sum0.s0 = fma(a.s0, (convert_half(b.s0 & 0x0F)), 0); + sum0.s1 = fma(a.s1, (convert_half(b.s1 & 0x0F)), 0); + sum0.s0 = fma(a.s2, (convert_half(b.s0 >> 4)), sum0.s0); + sum0.s1 = fma(a.s3, (convert_half(b.s1 >> 4)), sum0.s1); + + sum1.s0 = fma(a.s0, (convert_half(b2.s0 & 0x0F)), 0); + sum1.s1 = fma(a.s1, (convert_half(b2.s1 & 0x0F)), 0); + sum1.s0 = fma(a.s2, (convert_half(b2.s0 >> 4)), sum1.s0); + sum1.s1 = fma(a.s3, (convert_half(b2.s1 >> 4)), sum1.s1); + + sum_all0 += (sum0[0] + sum0[1] - xg_sum[gk] * z_hf0) * s0; + sum_all1 += (sum1[0] + sum1[1] - xg_sum[gk] * z_hf1) * s1; +#else + half4 sum0; + half4 sum1; + half8 a = as_half8(intel_sub_group_block_read_us8((const __local ushort*)x2 + gk*GROUP_SIZE)); + uchar4 b = intel_sub_group_block_read_uc4((const __global uchar*)B + gk*GROUP_SIZE/2); + uchar4 b2 = intel_sub_group_block_read_uc4((const __global uchar*)(B + (K/2) + gk*GROUP_SIZE/2)); + + sum0.s0 = fma(a.s0, (convert_half(b.s0 & 0x0F)), 0); + sum0.s1 = fma(a.s1, (convert_half(b.s1 & 0x0F)), 0); + sum0.s2 = fma(a.s2, (convert_half(b.s2 & 0x0F)), 0); + sum0.s3 = fma(a.s3, (convert_half(b.s3 & 0x0F)), 0); + + sum0.s0 = fma(a.s4, (convert_half(b.s0 >> 4)), sum0.s0); + sum0.s1 = fma(a.s5, (convert_half(b.s1 >> 4)), sum0.s1); + sum0.s2 = fma(a.s6, (convert_half(b.s2 >> 4)), sum0.s2); + sum0.s3 = fma(a.s7, (convert_half(b.s3 >> 4)), sum0.s3); + + sum1.s0 = fma(a.s0, (convert_half(b2.s0 & 0x0F)), 0); + sum1.s1 = fma(a.s1, (convert_half(b2.s1 & 0x0F)), 0); + sum1.s2 = fma(a.s2, (convert_half(b2.s2 & 0x0F)), 0); + sum1.s3 = fma(a.s3, (convert_half(b2.s3 & 0x0F)), 0); + + sum1.s0 = fma(a.s4, (convert_half(b2.s0 >> 4)), sum1.s0); + sum1.s1 = fma(a.s5, (convert_half(b2.s1 >> 4)), sum1.s1); + sum1.s2 = fma(a.s6, (convert_half(b2.s2 >> 4)), sum1.s2); + sum1.s3 = fma(a.s7, (convert_half(b2.s3 >> 4)), sum1.s3); + + sum_all0 += (sum0[0] + sum0[1] + sum0[2] + sum0[3] - xg_sum[gk] * z_hf0) * s0; + sum_all1 += (sum1[0] + sum1[1] + sum1[2] + sum1[3] - xg_sum[gk] * z_hf1) * s1; +#endif + } + + sum_all0 = sub_group_reduce_add(sum_all0); + sum_all1 = sub_group_reduce_add(sum_all1); + if (id_local == 0) { + if (silu) { + y[n] *= sum_all0 / (1 + exp(-sum_all0)); + y[n+1] *= sum_all1 / (1 + exp(-sum_all1)); + } else { + y[n] = sum_all0; + y[n+1] = sum_all1; + } + } + } +} + +__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) +KERNEL (mlp_gate_up)( + const __global int* expert_list, + const __global uchar* gate_weight_addr, + const __global uchar* gate_scale_addr, + const __global uchar* gate_zp_addr, + const __global uchar* up_weight_addr, + const __global uchar* up_scale_addr, + const __global uchar* up_zp_addr, + __global MOE_DTYPE* x, // [1, HIDDEN_SIZE] + __global MOE_DTYPE* y) { // [MAX_TOPK, INTERMEDIATE_SIZE] + // global: [expert, SUBGROUP_SIZE, N//N_BLOCK],[1, SUBGROUP_SIZE, SUBGROUP_NUM] + int expert_no = get_global_id(0); + y += expert_no * INTERMEDIATE_SIZE; + + const int expert_wei_size = INTERMEDIATE_SIZE * HIDDEN_SIZE / 2; + const int expert_scale_size = INTERMEDIATE_SIZE * HIDDEN_SIZE * 2 / GROUP_SIZE; + const int expert_zp_size = INTERMEDIATE_SIZE * HIDDEN_SIZE / 2 / GROUP_SIZE; + int expert_id = expert_list[expert_no]; + + // gate, [HIDDEN_SIZE, INTERMEDIATE_SIZE] + __global uchar* gate_weight = (__global uchar*)(gate_weight_addr + expert_id * expert_wei_size); + __global half* gate_scale = (__global half*)(gate_scale_addr + expert_id * expert_scale_size); + __global uchar* gate_zp = (__global uchar*)(gate_zp_addr + expert_id * expert_zp_size); + + // up, [HIDDEN_SIZE, INTERMEDIATE_SIZE] + __global uchar* up_weight = (__global uchar*)(up_weight_addr + expert_id * expert_wei_size); + __global half* up_scale = (__global half*)(up_scale_addr + expert_id * expert_scale_size); + __global uchar* up_zp = (__global uchar*)(up_zp_addr + expert_id * expert_zp_size); + + __local half x2[HIDDEN_SIZE]; + __local float xg_sum[HIDDEN_SIZE/32]; + //# interleaving x into x2 + int id_sg = get_sub_group_id(); + int num_sg = get_num_sub_groups(); + int id_local = get_sub_group_local_id(); + half * px = x + id_sg*GROUP_SIZE; + half * px2 = x2 + id_sg*GROUP_SIZE; + unroll_for(int i = id_sg; i < HIDDEN_SIZE/GROUP_SIZE; i += num_sg, px += num_sg*GROUP_SIZE, px2 += num_sg*GROUP_SIZE) { + //# quantization group + float x_group_sum = 0; + unroll_for(int j = id_local; j < GROUP_SIZE/2; j += SUBGROUP_SIZE) { + half even = px[2*j + 0]; + half odd = px[2*j + 1]; + px2[j] = even; + px2[j + GROUP_SIZE/2] = odd; + x_group_sum += even + odd; + } + x_group_sum = sub_group_reduce_add(x_group_sum); + if (id_local == 0) { + xg_sum[i] = x_group_sum / SUBGROUP_SIZE; + } + } + barrier(CLK_LOCAL_MEM_FENCE); + + gemv_n2x(up_weight, up_scale, up_zp, x, y, INTERMEDIATE_SIZE, HIDDEN_SIZE, x2, xg_sum, false); + gemv_n2x(gate_weight, gate_scale, gate_zp, x, y, INTERMEDIATE_SIZE, HIDDEN_SIZE, x2, xg_sum, true); +} + +#elif DOWN_ENABLE +__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) +KERNEL (mlp_down)( + const __global int* expert_list, + const __global uchar* down_weight_addr, + const __global uchar* down_scale_addr, + const __global uchar* down_zp_addr, + const __global MOE_DTYPE* x, // [MAX_TOPK, INTERMEDIATE_SIZE] + __global MOE_DTYPE* routing_weights, // [MAX_TOPK] + __global MOE_DTYPE* y) { // [MAX_TOPK, HIDDEN_SIZE] + // global: [expert, SUBGROUP_SIZE, N//N_BLOCK],[1, SUBGROUP_SIZE, SUBGROUP_NUM] + int expert_no = get_global_id(0); + x += expert_no * INTERMEDIATE_SIZE; + y += expert_no * HIDDEN_SIZE; + + const int expert_wei_size = INTERMEDIATE_SIZE * HIDDEN_SIZE / 2; + const int expert_scale_size = INTERMEDIATE_SIZE * HIDDEN_SIZE * 2 / GROUP_SIZE; + const int expert_zp_size = INTERMEDIATE_SIZE * HIDDEN_SIZE / 2 / GROUP_SIZE; + int expert_id = expert_list[expert_no]; + + // down, [INTERMEDIATE_SIZE, HIDDEN_SIZE] + __global uchar* weight = (__global uchar*)(down_weight_addr + expert_id * expert_wei_size); + __global half* scales = (__global half*)(down_scale_addr + expert_id * expert_scale_size); + __global uchar* zps = (__global uchar*)(down_zp_addr + expert_id * expert_zp_size); + + int N = HIDDEN_SIZE; + int K = INTERMEDIATE_SIZE; + int num_sg = get_num_sub_groups(); + int id_sg = get_sub_group_id(); + int id_local = get_sub_group_local_id(); + + __local half x2[INTERMEDIATE_SIZE]; + __local float xg_sum[INTERMEDIATE_SIZE/32]; + + //# interleaving x into x2 + __global half * px = x + id_sg*GROUP_SIZE; + __local half * px2 = x2 + id_sg*GROUP_SIZE; + unroll_for(int i = id_sg; i < INTERMEDIATE_SIZE/GROUP_SIZE; i += num_sg, px += num_sg*GROUP_SIZE, px2 += num_sg*GROUP_SIZE) { + //# quantization group + float x_group_sum = 0; + unroll_for(int j = id_local; j < GROUP_SIZE/2; j += SUBGROUP_SIZE) { + half even = px[2*j + 0]; + half odd = px[2*j + 1]; + px2[j] = even; + px2[j + GROUP_SIZE/2] = odd; + x_group_sum += even + odd; + } + x_group_sum = sub_group_reduce_add(x_group_sum); + if (id_local == 0) { + xg_sum[i] = x_group_sum / SUBGROUP_SIZE; + } + } + barrier(CLK_LOCAL_MEM_FENCE); + + int n_start = get_global_id(2) * N_BLOCK; + int n_end = n_start + N_BLOCK; + + unroll_for (int n = n_start; n < n_end; n+=2) { + const __global uchar* B = weight + n * K / 2; + __global half* S = scales + n; + __global uchar* Z = zps + n / 2; + float sum_all0 = 0; + float sum_all1 = 0; + unroll_for (int gk = 0; gk < K / GROUP_SIZE; gk++, S += N, Z += N / 2) { + half s0 = S[0]; + half s1 = S[1]; + ushort z = Z[0]; + half z_hf0 = convert_half(z & 0xf); + half z_hf1 = convert_half(z >> 4); + +#if SUBGROUP_SIZE == 32 + half2 sum0; + half2 sum1; + half4 a = as_half4(intel_sub_group_block_read_us4((const __local ushort*)x2 + gk*GROUP_SIZE)); + uchar2 b = intel_sub_group_block_read_uc2((const __global uchar*)B + gk*GROUP_SIZE/2); + uchar2 b2 = intel_sub_group_block_read_uc2((const __global uchar*)(B + (K/2) + gk*GROUP_SIZE/2)); + + sum0.s0 = fma(a.s0, (convert_half(b.s0 & 0x0F)), 0); + sum0.s1 = fma(a.s1, (convert_half(b.s1 & 0x0F)), 0); + sum0.s0 = fma(a.s2, (convert_half(b.s0 >> 4)), sum0.s0); + sum0.s1 = fma(a.s3, (convert_half(b.s1 >> 4)), sum0.s1); + + sum1.s0 = fma(a.s0, (convert_half(b2.s0 & 0x0F)), 0); + sum1.s1 = fma(a.s1, (convert_half(b2.s1 & 0x0F)), 0); + sum1.s0 = fma(a.s2, (convert_half(b2.s0 >> 4)), sum1.s0); + sum1.s1 = fma(a.s3, (convert_half(b2.s1 >> 4)), sum1.s1); + + sum_all0 += (sum0[0] + sum0[1] - xg_sum[gk] * z_hf0) * s0; + sum_all1 += (sum1[0] + sum1[1] - xg_sum[gk] * z_hf1) * s1; +#else + half4 sum0; + half4 sum1; + half8 a = as_half8(intel_sub_group_block_read_us8((const __local ushort*)x2 + gk*GROUP_SIZE)); + uchar4 b = intel_sub_group_block_read_uc4((const __global uchar*)B + gk*GROUP_SIZE/2); + uchar4 b2 = intel_sub_group_block_read_uc4((const __global uchar*)(B + (K/2) + gk*GROUP_SIZE/2)); + + sum0.s0 = fma(a.s0, (convert_half(b.s0 & 0x0F)), 0); + sum0.s1 = fma(a.s1, (convert_half(b.s1 & 0x0F)), 0); + sum0.s2 = fma(a.s2, (convert_half(b.s2 & 0x0F)), 0); + sum0.s3 = fma(a.s3, (convert_half(b.s3 & 0x0F)), 0); + + sum0.s0 = fma(a.s4, (convert_half(b.s0 >> 4)), sum0.s0); + sum0.s1 = fma(a.s5, (convert_half(b.s1 >> 4)), sum0.s1); + sum0.s2 = fma(a.s6, (convert_half(b.s2 >> 4)), sum0.s2); + sum0.s3 = fma(a.s7, (convert_half(b.s3 >> 4)), sum0.s3); + + sum1.s0 = fma(a.s0, (convert_half(b2.s0 & 0x0F)), 0); + sum1.s1 = fma(a.s1, (convert_half(b2.s1 & 0x0F)), 0); + sum1.s2 = fma(a.s2, (convert_half(b2.s2 & 0x0F)), 0); + sum1.s3 = fma(a.s3, (convert_half(b2.s3 & 0x0F)), 0); + + sum1.s0 = fma(a.s4, (convert_half(b2.s0 >> 4)), sum1.s0); + sum1.s1 = fma(a.s5, (convert_half(b2.s1 >> 4)), sum1.s1); + sum1.s2 = fma(a.s6, (convert_half(b2.s2 >> 4)), sum1.s2); + sum1.s3 = fma(a.s7, (convert_half(b2.s3 >> 4)), sum1.s3); + + sum_all0 += (sum0[0] + sum0[1] + sum0[2] + sum0[3] - xg_sum[gk] * z_hf0) * s0; + sum_all1 += (sum1[0] + sum1[1] + sum1[2] + sum1[3] - xg_sum[gk] * z_hf1) * s1; +#endif + } + sum_all0 = sub_group_reduce_add(sum_all0); + sum_all1 = sub_group_reduce_add(sum_all1); + if (id_local == 0) { + y[n] = sum_all0 * routing_weights[expert_no]; + y[n+1] = sum_all1 * routing_weights[expert_no]; + } + } +} + +#else +__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) +KERNEL (mlp_reduce)(const __global MOE_DTYPE* x, // [MAX_TOPK, HIDDEN_SIZE] + __global MOE_DTYPE* y) { // [1, HIDDEN_SIZE] + int n = get_global_id(1); + half sum[MAX_TOPK] = {0}; + __attribute__((opencl_unroll_hint(MAX_TOPK))) + for (int i = 0; i < MAX_TOPK; i++) { + sum[i] = as_half(intel_sub_group_block_read_us((const __global ushort*)(x + i*HIDDEN_SIZE + n))); + } + for (int i = 1; i < MAX_TOPK; i++) { + sum[0] += sum[i]; + } + intel_sub_group_block_write_us((__global ushort*)(y + n), as_ushort(sum[0])); +} +#endif diff --git a/src/plugins/intel_gpu/src/graph/include/moe_3gemm_fused_inst.h b/src/plugins/intel_gpu/src/graph/include/moe_3gemm_fused_inst.h new file mode 100644 index 00000000000000..f3cf3e1fc4a7cf --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/include/moe_3gemm_fused_inst.h @@ -0,0 +1,52 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include + +#include "intel_gpu/primitives/moe_3gemm_fused_compressed.hpp" +#include "primitive_inst.h" + +namespace cldnn { +namespace details {} + +template <> +struct typed_program_node : public typed_program_node_base { +private: + using parent = typed_program_node_base; + +public: + using parent::parent; + + typed_program_node(std::shared_ptr prim, program& prog) : parent(prim, prog) {} + + using parent::get_kernel_impl_params; + std::unique_ptr get_kernel_impl_params(const std::vector& in_layouts, const std::vector& out_layouts) const override { + auto params = parent::get_kernel_impl_params(in_layouts, out_layouts); + + return params; + } +}; + +using moe_node = typed_program_node; + +template <> +class typed_primitive_inst : public typed_primitive_inst_base { + using parent = typed_primitive_inst_base; + using parent::parent; + using primitive_inst::update_output_memory; + +public: + template + static std::vector calc_output_layouts(const moe_node& /*node*/, const kernel_impl_params& impl_param); + static layout calc_output_layout(const moe_node& /* node */, const kernel_impl_params& impl_param); + static std::string to_string(const moe_node& node); + typed_primitive_inst(network& network, const moe_node& node); +}; + +using moe_inst = typed_primitive_inst; +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/moe_3gemm_fused.cpp b/src/plugins/intel_gpu/src/graph/moe_3gemm_fused.cpp new file mode 100644 index 00000000000000..dd177b67129c76 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/moe_3gemm_fused.cpp @@ -0,0 +1,54 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "intel_gpu/runtime/error_handler.hpp" +#include "json_object.h" +#include "moe_3gemm_fused_inst.h" +#include "openvino/core/except.hpp" +#include "openvino/core/parallel.hpp" +#include "primitive_type_base.h" +#include "program_node.h" + +namespace cldnn { +GPU_DEFINE_PRIMITIVE_TYPE_ID(moe_3gemm_fused_compressed) + +/* + Calc_output_layout method is called only when output layout is invalidated. + It means, that it is called when: + 1) It has never been called. + 2) Dependency has changed output layout. + In this both cases, we need to recalc branch_true and branch_false. + !* We can be sure, that this method was called AT LEAST once during graph compilation.*! +*/ +layout moe_inst::calc_output_layout(const moe_node& /* node */, const kernel_impl_params& impl_param) { + return impl_param.input_layouts[0]; +} + +template +std::vector moe_inst::calc_output_layouts(const moe_node& /* node */, const kernel_impl_params& impl_param) { + return {impl_param.input_layouts[0]}; +} + +template std::vector moe_inst::calc_output_layouts(const moe_node& node, const kernel_impl_params& impl_param); + +std::string moe_inst::to_string(const moe_node& node) { + auto desc = node.get_primitive(); + auto node_info = node.desc_to_json(); + json_composite moe_info; + + node_info->add("moe info", moe_info); + + std::stringstream primitive_description; + node_info->dump(primitive_description); + return primitive_description.str(); +} + +/* +moe primitive is reusing memory with the input. +*/ +moe_inst::typed_primitive_inst(network& network, const moe_node& node) : parent(network, node) {} + +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/registry/moe_3gemm_swiglu_impls.cpp b/src/plugins/intel_gpu/src/graph/registry/moe_3gemm_swiglu_impls.cpp new file mode 100644 index 00000000000000..30027f667a7d1b --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/registry/moe_3gemm_swiglu_impls.cpp @@ -0,0 +1,23 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "intel_gpu/primitives/moe_3gemm_fused_compressed.hpp" +#include "primitive_inst.h" +#include "registry.hpp" + +#if OV_GPU_WITH_OCL +# include "impls/ocl_v2/moe/moe_3gemm_swiglu_opt.hpp" +#endif + +namespace ov::intel_gpu { + +using namespace cldnn; + +const std::vector>& Registry::get_implementations() { + static const std::vector> impls = {OV_GPU_CREATE_INSTANCE_OCL(ocl::moe_3gemm_swiglu_opt, shape_types::any)}; + + return impls; +} + +} // namespace ov::intel_gpu diff --git a/src/plugins/intel_gpu/src/graph/registry/registry.hpp b/src/plugins/intel_gpu/src/graph/registry/registry.hpp index 910c3a07292db4..d703831f6b0c1e 100644 --- a/src/plugins/intel_gpu/src/graph/registry/registry.hpp +++ b/src/plugins/intel_gpu/src/graph/registry/registry.hpp @@ -166,6 +166,7 @@ REGISTER_IMPLS(strided_slice); REGISTER_IMPLS(tile); REGISTER_IMPLS(col2im); REGISTER_IMPLS(vl_sdpa); +REGISTER_IMPLS(moe_3gemm_fused_compressed); REGISTER_IMPLS(moe_mask_gen); REGISTER_IMPLS(moe_mask_gen_reshape); REGISTER_IMPLS(moe_gemm); diff --git a/src/plugins/intel_gpu/src/plugin/ops/moe.cpp b/src/plugins/intel_gpu/src/plugin/ops/moe.cpp index 5c4cd75c7a46ba..f3fd6f46cb240a 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/moe.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/moe.cpp @@ -5,7 +5,10 @@ #include "openvino/op/moe.hpp" #include "intel_gpu/op/moe_compressed.hpp" #include "intel_gpu/plugin/program_builder.hpp" - +#include "intel_gpu/op/moe_3gemm_fused_compressed.hpp" +#include "intel_gpu/plugin/common_utils.hpp" +#include "intel_gpu/plugin/program_builder.hpp" +#include "intel_gpu/primitives/moe_3gemm_fused_compressed.hpp" #include "intel_gpu/primitives/moe_gemm.hpp" #include "intel_gpu/primitives/moe_mask_gen.hpp" #include @@ -18,14 +21,46 @@ namespace ov { namespace op { namespace internal { +using MOE3GemmFusedCompressed = ov::intel_gpu::op::MOE3GemmFusedCompressed; using MOECompressed = ov::intel_gpu::op::MOECompressed; } // namespace internal } // namespace op } // namespace ov - namespace ov::intel_gpu { using namespace cldnn; + +static void CreateMOE3GemmFusedCompressedOp(ProgramBuilder& p, const std::shared_ptr& op) { + auto inputs = p.GetInputInfo(op); + const auto& config = op->get_config(); + /// 0: hidden_states - input tensor with hidden representations + /// 1: routing_weights - [num_seq, num_experts] routing weights for all experts + /// 2: w0_weight - expert weights for first projection, + /// shape [num_experts, inter_size, group_num, group_size] + /// 3: w0_scale - expert scale for first projection for compressed experts, + /// shape [num_experts, inter_size, group_num, 1] + /// 4: w0_zp - expert zp for first projection for compressed experts, + /// shape [num_experts, inter_size, group_num, 1] + /// 5: w1_weight - expert weights for second projection, + /// shape [num_experts, inter_size, group_num, group_size] + /// 6: w1_scale - expert scale for second projection for compressed experts, + /// shape [num_experts, inter_size, group_num, 1] + /// 7: w1_zp - expert zp for second projection for compressed experts, + /// shape [num_experts, inter_size, group_num, 1] + /// 8: w2_weight - expert weights for final projection, + /// shape [num_experts, hidden_size, group_num, group_size] + /// 9: w2_scale - expert scale for final projection for compressed experts, + /// shape [num_experts, hidden_size, group_num, 1] + /// 10: w2_zp - expert zp for final projection for compressed experts, + /// shape [num_experts, hidden_size, group_num, 1] + validate_inputs_count(op, {11}); + + const std::string layerName = layer_type_name_ID(op); + const cldnn::moe_3gemm_fused_compressed moe(layerName, inputs, config); + + p.add_primitive(*op, moe); +} + static void CreateMOECompressedOp(ProgramBuilder& p, const std::shared_ptr& op) { auto inputs = p.GetInputInfo(op); auto& config = op->get_config(); @@ -35,6 +70,30 @@ static void CreateMOECompressedOp(ProgramBuilder& p, const std::shared_ptr