From 6f13594d4e1e483c53108a4c66941e0293c3ec61 Mon Sep 17 00:00:00 2001 From: "River, Li" Date: Mon, 20 Oct 2025 10:15:43 +0800 Subject: [PATCH 01/13] qwen3 moe_compressed primitive_impl --- .../intel_gpu/plugin/primitives_list.hpp | 1 + .../intel_gpu/primitives/moe_compressed.hpp | 54 + .../src/graph/impls/ocl_v2/moe_mlp.cl | 317 +++++ .../src/graph/impls/ocl_v2/moe_opt.cl | 127 ++ .../src/graph/impls/ocl_v2/moe_opt.cpp | 1132 +++++++++++++++++ .../src/graph/impls/ocl_v2/moe_opt.hpp | 63 + .../intel_gpu/src/graph/include/moe_inst.h | 52 + src/plugins/intel_gpu/src/graph/moe.cpp | 55 + .../src/graph/registry/moe_impls.cpp | 26 + .../intel_gpu/src/graph/registry/registry.hpp | 1 + src/plugins/intel_gpu/src/plugin/ops/moe.cpp | 27 + 11 files changed, 1855 insertions(+) create mode 100644 src/plugins/intel_gpu/include/intel_gpu/primitives/moe_compressed.hpp create mode 100644 src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_mlp.cl create mode 100644 src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cl create mode 100644 src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp create mode 100644 src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.hpp create mode 100644 src/plugins/intel_gpu/src/graph/include/moe_inst.h create mode 100644 src/plugins/intel_gpu/src/graph/moe.cpp create mode 100644 src/plugins/intel_gpu/src/graph/registry/moe_impls.cpp create mode 100644 src/plugins/intel_gpu/src/plugin/ops/moe.cpp diff --git a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp index 37734352b9f698..ff8ed815e94d45 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp @@ -311,3 +311,4 @@ REGISTER_FACTORY(internal, PagedAttentionExtension); REGISTER_FACTORY(internal, LoraSubgraph); REGISTER_FACTORY(internal, LoraSubgraphFused); REGISTER_FACTORY(internal, VLSDPA); +REGISTER_FACTORY(internal, MOECompressed); diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_compressed.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_compressed.hpp new file mode 100644 index 00000000000000..92b96fb9fc7da2 --- /dev/null +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_compressed.hpp @@ -0,0 +1,54 @@ + // Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once +#include "intel_gpu/runtime/engine.hpp" +#include "primitive.hpp" +#include "ov_ops/moe_compressed.hpp" +#include + +namespace cldnn { +using MOECompressed = ov::op::internal::MOECompressed; + +/// @brief moe compressed primitive +/// @details Performs moe compressed +struct moe_compressed : public primitive_base { + CLDNN_DECLARE_PRIMITIVE(moe_compressed) + + moe_compressed() : primitive_base("", {}) {} + + /// @brief Constructs moe primitive / layer. + /// + /// @param id An identifier of new primitive. + /// @param inputs A list of Input primitive ids (inputs). + moe_compressed(const primitive_id& id, + const std::vector& inputs, + const MOE::Config& config) + : primitive_base(id, inputs, 15, {optional_data_type()}), + _config(config) { + } + + MOECompressed::Config _config; + + bool operator==(const primitive& rhs) const override { + if (!compare_common_params(rhs)) + return false; + + auto rhs_casted = downcast(rhs); + + return std::memcmp(&_config, &rhs_casted._config, sizeof(_config)) == 0; + } + + void save(BinaryOutputBuffer& ob) const override { + primitive_base::save(ob); + ob << make_data(&_config, sizeof(_config)); + } + + void load(BinaryInputBuffer& ib) override { + primitive_base::load(ib); + ib >> make_data(&_config, sizeof(_config)); + } +}; + +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_mlp.cl b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_mlp.cl new file mode 100644 index 00000000000000..f5cd1297ff9e0e --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_mlp.cl @@ -0,0 +1,317 @@ + +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#if GATE_UP_ENABLE +inline void gemv_n2x(const __global uchar* weight, + __global half* scales, + __global uchar* zps, + const __global half* x, + __global half* y, int N, int K, + half* x2, + float* xg_sum, + const bool silu) { + int num_sg = get_num_sub_groups(); + int id_sg = get_sub_group_id(); + int id_local = get_sub_group_local_id(); + + //# interleaving x into x2 + half * px = x + id_sg*GROUP_SIZE; + half * px2 = x2 + id_sg*GROUP_SIZE; + for(int i = id_sg; i < HIDDEN_SIZE/GROUP_SIZE; i += num_sg, px += num_sg*GROUP_SIZE, px2 += num_sg*GROUP_SIZE) { + //# quantization group + float x_group_sum = 0; + for(int j = id_local; j < GROUP_SIZE/2; j += SUBGROUP_SIZE) { + half even = px[2*j + 0]; + half odd = px[2*j + 1]; + px2[j] = even; + px2[j + GROUP_SIZE/2] = odd; + x_group_sum += even + odd; + } + x_group_sum = sub_group_reduce_add(x_group_sum); + if (id_local == 0) { + xg_sum[i] = x_group_sum / SUBGROUP_SIZE; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + int n_start = get_global_id(2) * N_BLOCK; + int n_end = n_start + N_BLOCK; + + for (int n = n_start; n < n_end; n+=2) { + const __global uchar* B = weight + n * K / 2; + float sum_all0 = 0; + float sum_all1 = 0; +#if SZ_LAYOUT == 0 + __global half* S = scales + n; + __global uchar* Z = zps + n / 2; + for (int gk = 0; gk < K / GROUP_SIZE; gk++, S += N, Z += N / 2) { + half s0 = S[0]; + half s1 = S[1]; + ushort z = Z[0]; + half z_hf0 = convert_half(z & 0xf); + half z_hf1 = convert_half(z >> 4); +#else + __global half* S = scales + n*K/GROUP_SIZE; + __global uchar* Z = zps + n*K/GROUP_SIZE/2; + + half scale_values = as_half(intel_sub_group_block_read_us((const __global ushort*)S)); + uchar zp_values = intel_sub_group_block_read_uc((const __global uchar*)Z); + half zp_even = convert_half(zp_values & 0xF); + half zp_odd = convert_half(zp_values >> 4); + + for (int gk = 0; gk < K / GROUP_SIZE; gk++) { + half s0 = sub_group_broadcast(scale_values, 2*gk + 0); + half s1 = sub_group_broadcast(scale_values, 2*gk + 1); + half z_hf0 = sub_group_broadcast(zp_even, gk); + half z_hf1 = sub_group_broadcast(zp_odd, gk); +#endif + +#if SUBGROUP_SIZE == 32 + half2 sum0; + half2 sum1; + half4 a = as_half4(intel_sub_group_block_read_us4((const __local ushort*)x2 + gk*GROUP_SIZE)); + uchar2 b = intel_sub_group_block_read_uc2((const __global uchar*)B + gk*GROUP_SIZE/2); + uchar2 b2 = intel_sub_group_block_read_uc2((const __global uchar*)(B + (K/2) + gk*GROUP_SIZE/2)); + + sum0.s0 = fma(a.s0, (convert_half(b.s0 & 0x0F)), 0); + sum0.s1 = fma(a.s1, (convert_half(b.s1 & 0x0F)), 0); + sum0.s0 = fma(a.s2, (convert_half(b.s0 >> 4)), sum0.s0); + sum0.s1 = fma(a.s3, (convert_half(b.s1 >> 4)), sum0.s1); + + sum1.s0 = fma(a.s0, (convert_half(b2.s0 & 0x0F)), 0); + sum1.s1 = fma(a.s1, (convert_half(b2.s1 & 0x0F)), 0); + sum1.s0 = fma(a.s2, (convert_half(b2.s0 >> 4)), sum1.s0); + sum1.s1 = fma(a.s3, (convert_half(b2.s1 >> 4)), sum1.s1); + + sum_all0 += (sum0[0] + sum0[1] - xg_sum[gk] * z_hf0) * s0; + sum_all1 += (sum1[0] + sum1[1] - xg_sum[gk] * z_hf1) * s1; +#else + half4 sum0; + half4 sum1; + half8 a = as_half8(intel_sub_group_block_read_us8((const __local ushort*)x2 + gk*GROUP_SIZE)); + uchar4 b = intel_sub_group_block_read_uc4((const __global uchar*)B + gk*GROUP_SIZE/2); + uchar4 b2 = intel_sub_group_block_read_uc4((const __global uchar*)(B + (K/2) + gk*GROUP_SIZE/2)); + + sum0.s0 = fma(a.s0, (convert_half(b.s0 & 0x0F)), 0); + sum0.s1 = fma(a.s1, (convert_half(b.s1 & 0x0F)), 0); + sum0.s2 = fma(a.s2, (convert_half(b.s2 & 0x0F)), 0); + sum0.s3 = fma(a.s3, (convert_half(b.s3 & 0x0F)), 0); + + sum0.s0 = fma(a.s4, (convert_half(b.s0 >> 4)), sum0.s0); + sum0.s1 = fma(a.s5, (convert_half(b.s1 >> 4)), sum0.s1); + sum0.s2 = fma(a.s6, (convert_half(b.s2 >> 4)), sum0.s2); + sum0.s3 = fma(a.s7, (convert_half(b.s3 >> 4)), sum0.s3); + + sum1.s0 = fma(a.s0, (convert_half(b2.s0 & 0x0F)), 0); + sum1.s1 = fma(a.s1, (convert_half(b2.s1 & 0x0F)), 0); + sum1.s2 = fma(a.s2, (convert_half(b2.s2 & 0x0F)), 0); + sum1.s3 = fma(a.s3, (convert_half(b2.s3 & 0x0F)), 0); + + sum1.s0 = fma(a.s4, (convert_half(b2.s0 >> 4)), sum1.s0); + sum1.s1 = fma(a.s5, (convert_half(b2.s1 >> 4)), sum1.s1); + sum1.s2 = fma(a.s6, (convert_half(b2.s2 >> 4)), sum1.s2); + sum1.s3 = fma(a.s7, (convert_half(b2.s3 >> 4)), sum1.s3); + + sum_all0 += (sum0[0] + sum0[1] + sum0[2] + sum0[3] - xg_sum[gk] * z_hf0) * s0; + sum_all1 += (sum1[0] + sum1[1] + sum1[2] + sum1[3] - xg_sum[gk] * z_hf1) * s1; +#endif + } + + sum_all0 = sub_group_reduce_add(sum_all0); + sum_all1 = sub_group_reduce_add(sum_all1); + if (id_local == 0) { + if (silu) { + y[n] *= sum_all0 / (1 + exp(-sum_all0)); + y[n+1] *= sum_all1 / (1 + exp(-sum_all1)); + } else { + y[n] = sum_all0; + y[n+1] = sum_all1; + } + } + } +} + +__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) +__kernel void mlp_gate_up( + const __global int* expert_list, + const __global uchar* gate_weight_addr, + const __global uchar* gate_scale_addr, + const __global uchar* gate_zp_addr, + const __global uchar* up_weight_addr, + const __global uchar* up_scale_addr, + const __global uchar* up_zp_addr, + __global TYPE* x, // [1, HIDDEN_SIZE] + __global TYPE* y) { // [MAX_TOPK, INTERMEDIATE_SIZE] + // global: [expert, SUBGROUP_SIZE, N//N_BLOCK],[1, SUBGROUP_SIZE, SUBGROUP_NUM] + int expert_no = get_global_id(0); + y += expert_no * INTERMEDIATE_SIZE; + + const int expert_wei_size = INTERMEDIATE_SIZE * HIDDEN_SIZE / 2; + const int expert_scale_size = INTERMEDIATE_SIZE * HIDDEN_SIZE * 2 / GROUP_SIZE; + const int expert_zp_size = INTERMEDIATE_SIZE * HIDDEN_SIZE / 2 / GROUP_SIZE; + int expert_id = expert_list[expert_no]; + + // gate, [HIDDEN_SIZE, INTERMEDIATE_SIZE] + __global uchar* gate_weight = (__global uchar*)(gate_weight_addr + expert_id * expert_wei_size); + __global half* gate_scale = (__global half*)(gate_scale_addr + expert_id * expert_scale_size); + __global uchar* gate_zp = (__global uchar*)(gate_zp_addr + expert_id * expert_zp_size); + + // up, [HIDDEN_SIZE, INTERMEDIATE_SIZE] + __global uchar* up_weight = (__global uchar*)(up_weight_addr + expert_id * expert_wei_size); + __global half* up_scale = (__global half*)(up_scale_addr + expert_id * expert_scale_size); + __global uchar* up_zp = (__global uchar*)(up_zp_addr + expert_id * expert_zp_size); + + __local half x2[HIDDEN_SIZE]; + __local float xg_sum[HIDDEN_SIZE/32]; + gemv_n2x(up_weight, up_scale, up_zp, x, y, INTERMEDIATE_SIZE, HIDDEN_SIZE, x2, xg_sum, false); + gemv_n2x(gate_weight, gate_scale, gate_zp, x, y, INTERMEDIATE_SIZE, HIDDEN_SIZE, x2, xg_sum, true); +} + +#elif DOWN_ENABLE +__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) +__kernel void mlp_down( + const __global int* expert_list, + const __global uchar* down_weight_addr, + const __global uchar* down_scale_addr, + const __global uchar* down_zp_addr, + const __global TYPE* x, // [MAX_TOPK, INTERMEDIATE_SIZE] + __global TYPE* routing_weights, // [MAX_TOPK] + __global TYPE* y) { // [MAX_TOPK, HIDDEN_SIZE] + // global: [expert, SUBGROUP_SIZE, N//N_BLOCK],[1, SUBGROUP_SIZE, SUBGROUP_NUM] + int expert_no = get_global_id(0); + x += expert_no * INTERMEDIATE_SIZE; + y += expert_no * HIDDEN_SIZE; + + const int expert_wei_size = INTERMEDIATE_SIZE * HIDDEN_SIZE / 2; + const int expert_scale_size = INTERMEDIATE_SIZE * HIDDEN_SIZE * 2 / GROUP_SIZE; + const int expert_zp_size = INTERMEDIATE_SIZE * HIDDEN_SIZE / 2 / GROUP_SIZE; + int expert_id = expert_list[expert_no]; + + // down, [INTERMEDIATE_SIZE, HIDDEN_SIZE] + __global uchar* weight = (__global uchar*)(down_weight_addr + expert_id * expert_wei_size); + __global half* scales = (__global half*)(down_scale_addr + expert_id * expert_scale_size); + __global uchar* zps = (__global uchar*)(down_zp_addr + expert_id * expert_zp_size); + + int N = HIDDEN_SIZE; + int K = INTERMEDIATE_SIZE; + int num_sg = get_num_sub_groups(); + int id_sg = get_sub_group_id(); + int id_local = get_sub_group_local_id(); + + __local half x2[INTERMEDIATE_SIZE]; + __local float xg_sum[INTERMEDIATE_SIZE/32]; + + //# interleaving x into x2 + __global half * px = x + id_sg*GROUP_SIZE; + __local half * px2 = x2 + id_sg*GROUP_SIZE; + for(int i = id_sg; i < INTERMEDIATE_SIZE/GROUP_SIZE; i += num_sg, px += num_sg*GROUP_SIZE, px2 += num_sg*GROUP_SIZE) { + //# quantization group + float x_group_sum = 0; + for(int j = id_local; j < GROUP_SIZE/2; j += SUBGROUP_SIZE) { + half even = px[2*j + 0]; + half odd = px[2*j + 1]; + px2[j] = even; + px2[j + GROUP_SIZE/2] = odd; + x_group_sum += even + odd; + } + x_group_sum = sub_group_reduce_add(x_group_sum); + if (id_local == 0) { + xg_sum[i] = x_group_sum / SUBGROUP_SIZE; + } + } + barrier(CLK_LOCAL_MEM_FENCE); + + int n_start = get_global_id(2) * N_BLOCK; + int n_end = n_start + N_BLOCK; + + for (int n = n_start; n < n_end; n+=2) { + const __global uchar* B = weight + n * K / 2; + __global half* S = scales + n; + __global uchar* Z = zps + n / 2; + float sum_all0 = 0; + float sum_all1 = 0; + for (int gk = 0; gk < K / GROUP_SIZE; gk++, S += N, Z += N / 2) { + half s0 = S[0]; + half s1 = S[1]; + ushort z = Z[0]; + half z_hf0 = convert_half(z & 0xf); + half z_hf1 = convert_half(z >> 4); + +#if SUBGROUP_SIZE == 32 + half2 sum0; + half2 sum1; + half4 a = as_half4(intel_sub_group_block_read_us4((const __local ushort*)x2 + gk*GROUP_SIZE)); + uchar2 b = intel_sub_group_block_read_uc2((const __global uchar*)B + gk*GROUP_SIZE/2); + uchar2 b2 = intel_sub_group_block_read_uc2((const __global uchar*)(B + (K/2) + gk*GROUP_SIZE/2)); + + sum0.s0 = fma(a.s0, (convert_half(b.s0 & 0x0F)), 0); + sum0.s1 = fma(a.s1, (convert_half(b.s1 & 0x0F)), 0); + sum0.s0 = fma(a.s2, (convert_half(b.s0 >> 4)), sum0.s0); + sum0.s1 = fma(a.s3, (convert_half(b.s1 >> 4)), sum0.s1); + + sum1.s0 = fma(a.s0, (convert_half(b2.s0 & 0x0F)), 0); + sum1.s1 = fma(a.s1, (convert_half(b2.s1 & 0x0F)), 0); + sum1.s0 = fma(a.s2, (convert_half(b2.s0 >> 4)), sum1.s0); + sum1.s1 = fma(a.s3, (convert_half(b2.s1 >> 4)), sum1.s1); + + sum_all0 += (sum0[0] + sum0[1] - xg_sum[gk] * z_hf0) * s0; + sum_all1 += (sum1[0] + sum1[1] - xg_sum[gk] * z_hf1) * s1; +#else + half4 sum0; + half4 sum1; + half8 a = as_half8(intel_sub_group_block_read_us8((const __local ushort*)x2 + gk*GROUP_SIZE)); + uchar4 b = intel_sub_group_block_read_uc4((const __global uchar*)B + gk*GROUP_SIZE/2); + uchar4 b2 = intel_sub_group_block_read_uc4((const __global uchar*)(B + (K/2) + gk*GROUP_SIZE/2)); + + sum0.s0 = fma(a.s0, (convert_half(b.s0 & 0x0F)), 0); + sum0.s1 = fma(a.s1, (convert_half(b.s1 & 0x0F)), 0); + sum0.s2 = fma(a.s2, (convert_half(b.s2 & 0x0F)), 0); + sum0.s3 = fma(a.s3, (convert_half(b.s3 & 0x0F)), 0); + + sum0.s0 = fma(a.s4, (convert_half(b.s0 >> 4)), sum0.s0); + sum0.s1 = fma(a.s5, (convert_half(b.s1 >> 4)), sum0.s1); + sum0.s2 = fma(a.s6, (convert_half(b.s2 >> 4)), sum0.s2); + sum0.s3 = fma(a.s7, (convert_half(b.s3 >> 4)), sum0.s3); + + sum1.s0 = fma(a.s0, (convert_half(b2.s0 & 0x0F)), 0); + sum1.s1 = fma(a.s1, (convert_half(b2.s1 & 0x0F)), 0); + sum1.s2 = fma(a.s2, (convert_half(b2.s2 & 0x0F)), 0); + sum1.s3 = fma(a.s3, (convert_half(b2.s3 & 0x0F)), 0); + + sum1.s0 = fma(a.s4, (convert_half(b2.s0 >> 4)), sum1.s0); + sum1.s1 = fma(a.s5, (convert_half(b2.s1 >> 4)), sum1.s1); + sum1.s2 = fma(a.s6, (convert_half(b2.s2 >> 4)), sum1.s2); + sum1.s3 = fma(a.s7, (convert_half(b2.s3 >> 4)), sum1.s3); + + sum_all0 += (sum0[0] + sum0[1] + sum0[2] + sum0[3] - xg_sum[gk] * z_hf0) * s0; + sum_all1 += (sum1[0] + sum1[1] + sum1[2] + sum1[3] - xg_sum[gk] * z_hf1) * s1; +#endif + } + sum_all0 = sub_group_reduce_add(sum_all0); + sum_all1 = sub_group_reduce_add(sum_all1); + if (id_local == 0) { + y[n] = sum_all0 * routing_weights[expert_no]; + y[n+1] = sum_all1 * routing_weights[expert_no]; + } + } +} + +#else +__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) +__kernel void mlp_reduce(const __global TYPE* x, // [MAX_TOPK, HIDDEN_SIZE] + __global TYPE* y) { // [1, HIDDEN_SIZE] + int n = get_global_id(1); + half sum[MAX_TOPK] = {0}; + __attribute__((opencl_unroll_hint(MAX_TOPK))) + for (int i = 0; i < MAX_TOPK; i++) { + sum[i] = as_half(intel_sub_group_block_read_us((const __global ushort*)(x + i*HIDDEN_SIZE + n))); + } + for (int i = 1; i < MAX_TOPK; i++) { + sum[0] += sum[i]; + } + intel_sub_group_block_write_us((const __global ushort*)(y + n), as_ushort(sum[0])); +} +#endif diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cl b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cl new file mode 100644 index 00000000000000..3f7837aae82e88 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cl @@ -0,0 +1,127 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#if SOFTMAX_TOPK_ENABLE + +__kernel void softmax_topk( + const __global TYPE* input, // [input_batch, sort_in_num] + __global uint* output_index, // [input_batch, TOP_K] + __global TYPE* output // [input_batch, TOP_K] +) { + // gws [batch, sort_in_num] + const uint batch = (uint)get_global_id(0); + const uint sort_index = (uint)get_global_id(1); + const uint sort_cnt = (uint)get_global_size(1); + + input += batch * sort_cnt + sort_index; + + uint sort_position = 0; + + __local TYPE local_input[VALUE_NUM]; + __local TYPE local_output[TOP_K]; + __local uint local_index[TOP_K]; + + TYPE in_value = as_half(intel_sub_group_block_read_us((const __global ushort*)(input))); + local_input[sort_index] = in_value; + barrier(CLK_LOCAL_MEM_FENCE); + + __attribute__((opencl_unroll_hint(8))) + for(uint i = 0; i < sort_index; i++) { + TYPE value = local_input[i]; + if(value >= in_value) { + sort_position++; + } + } + + __attribute__((opencl_unroll_hint(8))) + for(uint i = sort_index; i < sort_cnt; i++) { + TYPE value = local_input[i]; + if(value > in_value) { + sort_position++; + } + } + if (sort_position < TOP_K) { + local_output[sort_position] = in_value; + local_index[sort_position] = sort_index; + } + barrier(CLK_LOCAL_MEM_FENCE); + + if(sort_position == 0) { + float softmax_total = 1.0; + TYPE max_v = local_output[0]; + local_output[0] = 1; + for(uint i = 1; i < TOP_K; i++) { + local_output[i] = native_exp(local_output[i] - max_v); + softmax_total += local_output[i]; + } + output_index += batch * TOP_K; + output += batch * TOP_K; + + for(uint i = 0; i < TOP_K; i++) { + output[i] = local_output[i]/softmax_total; + output_index[i] = local_index[i]; + } + } +} + +#elif GATHER_ENABLE +__kernel void gather_2d_ref( + const __global TYPE* src_tok, + const __global TYPE* src_rweight, + __global int * tok_index, + __global int * top_index, + __global TYPE* dst_tok, + __global TYPE* dst_rweight) { + + int k = get_global_id(0); + int off = get_global_id(1); + int tok_idx = tok_index[k]; + + src_tok += tok_idx * HIDDEN_SIZE; + dst_tok += k * HIDDEN_SIZE; + + #if TYPE_SIZE == 2 + ushort value = intel_sub_group_block_read_us((const __global ushort *)(src_tok + off)); + intel_sub_group_block_write_us((__global ushort *)(dst_tok + off), value); + #elif TYPE_SIZE == 4 + uint value = intel_sub_group_block_read((const __global uint *)(src_tok + off)); + intel_sub_group_block_write((__global uint *)(dst_tok + off), value); + #else + dst_tok[off] = src_tok[off]; + #endif + + if (off == 0) { + int top_idx = top_index[k]; + dst_rweight[k] = src_rweight[top_idx]; + } +} + +#elif SCATTER_ENABLE + +__kernel void index_add_(const __global TYPE* src_tok, + __global int * tok_index, + __global TYPE* dst_tok) { + + int k = get_global_id(0); + int off = get_global_id(1); + int tok_idx = tok_index[k]; + + src_tok += k * HIDDEN_SIZE; + dst_tok += tok_idx * HIDDEN_SIZE; + + #if TYPE_SIZE == 2 + half src_value = as_half(intel_sub_group_block_read_us((const __global ushort *)(src_tok + off))); + half dst_value = as_half(intel_sub_group_block_read_us((const __global ushort *)(dst_tok + off))); + half value = dst_value + src_value; + intel_sub_group_block_write_us((__global ushort *)(dst_tok + off), as_ushort(value)); + #elif TYPE_SIZE == 4 + float src_value = as_float(intel_sub_group_block_read((const __global uint *)(src_tok + off))); + float dst_value = as_float(intel_sub_group_block_read((const __global uint *)(dst_tok + off))); + float value = dst_value + src_value; + intel_sub_group_block_write_us((__global ushort *)(dst_tok + off), as_uint(value)); + #else + dst_tok[off] += src_tok[off]; + #endif +} +#endif diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp new file mode 100644 index 00000000000000..615f6ff4dde284 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp @@ -0,0 +1,1132 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "moe_opt.hpp" + +#define ENABLE_ONEDNN_FOR_GPU +#ifdef ENABLE_ONEDNN_FOR_GPU +# include +# include +# include +# include +# include +# include +# include + +# include "cm/utils/kernel_generator.hpp" +# include "common_utils/jitter.hpp" +# include "debug_helper.hpp" +# include "intel_gpu/graph/kernel_impl_params.hpp" +# include "intel_gpu/primitives/moe.hpp" +# include "intel_gpu/runtime/lru_cache.hpp" +# include "intel_gpu/runtime/stream.hpp" +# include "intel_gpu/runtime/utils.hpp" +# include "moe_inst.h" +# include "ocl_v2/utils/fused_ops_jitter.hpp" +# include "ocl_v2/utils/jitter.hpp" +# include "primitive_inst.h" +# include "primitive_ocl_base.hpp" +# include "utils/kernel_generator.hpp" + +namespace ov::intel_gpu::ocl { + +namespace { + +using namespace ov::intel_gpu::ocl; + +dnnl::memory::data_type convert_data_type(cldnn::data_types dt) { + switch (dt) { + case cldnn::data_types::f32: + return dnnl::memory::data_type::f32; + case cldnn::data_types::f16: + return dnnl::memory::data_type::f16; + case cldnn::data_types::i8: + return dnnl::memory::data_type::s8; + case cldnn::data_types::u8: + return dnnl::memory::data_type::u8; + case cldnn::data_types::i32: + return dnnl::memory::data_type::s32; + case cldnn::data_types::i4: + return dnnl::memory::data_type::s4; + case cldnn::data_types::u4: + return dnnl::memory::data_type::u4; + default: + throw std::invalid_argument("[clDNN] Unsupported conversion from cldnn to onednn type"); + } +} + +struct onednn_matmul { + dnnl::matmul m_prim; + dnnl::memory::desc m_wei_md; + dnnl::memory::data_type m_w_type; + dnnl::memory::data_type m_a_type; // activation dtype + dnnl::memory::dim m_K; + dnnl::memory::dim m_N; + dnnl::memory::dim m_M; + dnnl::memory::dim m_K_groups; + + dnnl::primitive_attr attr; + dnnl::post_ops postops; + + onednn_matmul(dnnl::memory::data_type act_dtype, dnnl::memory::data_type weight_dtype, int batch_size, int ic, int oc, int ic_group_size = -1) { + m_a_type = act_dtype; + m_w_type = weight_dtype; + m_K_groups = 0; + m_K = ic; + m_N = oc; + m_M = DNNL_RUNTIME_DIM_VAL; + if (batch_size > 0) { + // jit-gemm kernel only support static batch size + m_M = batch_size; + } + if (ic_group_size >= 0) { + w_scale(ic_group_size).w_zp(ic_group_size).fpmath_f16(); + } + } + + onednn_matmul& w_scale(int k_group_size) { + if (k_group_size <= 0) { + m_K_groups = 1; + // per-OC, no grouping in K dimension + attr.set_scales(DNNL_ARG_WEIGHTS, (0 << 0) + (1 << 1), {1}, dnnl::memory::data_type::f16); + } else { + OPENVINO_ASSERT((k_group_size % 32) == 0); + OPENVINO_ASSERT((m_K % k_group_size) == 0); + m_K_groups = m_K / k_group_size; + attr.set_scales(DNNL_ARG_WEIGHTS, (1 << 0) + (1 << 1), {k_group_size, 1}, dnnl::memory::data_type::f16); + } + return *this; + } + + onednn_matmul& w_zp(int k_group_size) { + if (k_group_size <= 0) { + OPENVINO_ASSERT(m_K_groups == 1); + attr.set_zero_points(DNNL_ARG_WEIGHTS, (0 << 0) + (1 << 1), {1}, m_w_type); + } else { + OPENVINO_ASSERT(m_K_groups = (m_K / k_group_size)); + attr.set_zero_points(DNNL_ARG_WEIGHTS, (1 << 0) + (1 << 1), {k_group_size, 1}, m_w_type); + } + return *this; + } + + onednn_matmul& fpmath_f16() { + attr.set_fpmath_mode(dnnl::fpmath_mode::f16, true); + return *this; + } + onednn_matmul& post_op_silu() { + float alpha = 1.0f; + float beta = 0.0f; + postops.append_eltwise(dnnl::algorithm::eltwise_swish, alpha, beta); + return *this; + } + onednn_matmul& post_op_bin_mul(bool per_oc = true) { + dnnl::memory::dim batch_size = m_M; + if (batch_size == DNNL_RUNTIME_DIM_VAL) + batch_size = 1024 * 1024; // big enough fake static batch + + dnnl::memory::desc bin_mul_md = dnnl::memory::desc(dnnl::memory::dims({batch_size, per_oc ? m_N : 1}), m_a_type, dnnl::memory::format_tag::ab); + postops.append_binary(dnnl::algorithm::binary_mul, bin_mul_md); + return *this; + } + + onednn_matmul& post_op_sum(float scale = 1.f, int32_t zero_point = 0) { + postops.append_sum(scale, zero_point, dnnl::memory::data_type::undef); + return *this; + } + + void create(dnnl::engine eng) { + if (postops.len() > 0) { + attr.set_post_ops(postops); + } + + dnnl::memory::desc src_md = dnnl::memory::desc(dnnl::memory::dims({m_M, m_K}), m_a_type, dnnl::memory::format_tag::ab); + dnnl::memory::desc dst_md = dnnl::memory::desc(dnnl::memory::dims({m_M, m_N}), m_a_type, dnnl::memory::format_tag::ab); + + // use fixed weight-layout to prevent shape-dependent weight-layout changes + dnnl::memory::desc wei_md = dnnl::memory::desc(dnnl::memory::dims({m_K, m_N}), m_w_type, dnnl::memory::format_tag::ba); + + // Create primitive descriptor. + auto matmul_pd = dnnl::matmul::primitive_desc(eng, src_md, wei_md, dst_md, attr); + + // Pre-packed weights stored as int8_t + m_wei_md = matmul_pd.weights_desc(); + + // Create the primitive. + m_prim = dnnl::matmul(matmul_pd); + } + + // this creator is for predefined matmul primitive types + enum class type { + none, + with_bin_mul, + with_bin_mul_per_row, + with_bin_mul_per_row_sum, + with_silu, + with_silu_bin_mul, + }; + int bin_post_id = -1; + bool bin_per_row = false; + onednn_matmul(dnnl::engine eng, + dnnl::memory::data_type act_dtype, + dnnl::memory::data_type weight_dtype, + int batch, + int ic, + int oc, + int ic_group_size, + type t) + : onednn_matmul(act_dtype, weight_dtype, batch, ic, oc, ic_group_size) { + if (t == type::with_bin_mul) { + bin_post_id = 0; + post_op_bin_mul(true); + } + if (t == type::with_bin_mul_per_row) { + bin_post_id = 0; + bin_per_row = true; + post_op_bin_mul(false); + } + if (t == type::with_bin_mul_per_row_sum) { + bin_post_id = 0; + bin_per_row = true; + post_op_bin_mul(false); + post_op_sum(); + } + if (t == type::with_silu) + post_op_silu(); + if (t == type::with_silu_bin_mul) { + bin_post_id = 1; + post_op_silu(); + post_op_bin_mul(true); + } + + create(eng); + } +}; + +// all jit-based/performance-aware function should be a functor/callable because: +// - it needs to hold reference to kernel (to save build time & resources) +// - it needs to do other compile time preparation work and hold the relevant +// runtime-data-struct (to make runtime faster) +// to optimize compile-time-workload itself, the functor instance itself should be +// cached with compile-time parameter as the key. +// +// because it's a functor, which supposed to have no states, so cache-factory should +// always return shared_ptr to constant object, so it won't behave differently when being +// called by different caller, and this also ensure it's multi-threading safe since it +// won't modify it's content. +// +template +class tuple_hasher { +private: + typedef std::tuple Tuple; + template + size_t hash(Tuple& value) const { + return 0; + } + template + size_t hash(Tuple& value) const { + constexpr int Index = N - sizeof...(TTail) - 1; + return std::hash()(std::get(value)) ^ hash(value); + } + +public: + size_t operator()(Tuple value) const { + auto hv = hash(value); + return hv; + } +}; + +// create const object with internal cache with constructor-args as the key +// this helps reduces construction time overhead, and perfectly suitable +// for caching functor/callable. +template +std::shared_ptr make_cacheable(dnnl::engine eng, CArgs... cargs) { + std::shared_ptr sptr; + auto key = std::make_tuple(cargs...); + static std::unordered_map, tuple_hasher> cache; + static std::mutex mutex; + std::lock_guard guard(mutex); + auto it = cache.find(key); + if (it != cache.end()) { + auto& wptr = it->second; + sptr = wptr.lock(); + if (!sptr) { + sptr = std::make_shared(eng, cargs...); + // ECOUT("make_cacheable re-constructed: ", typeid(T).name(), "(", cargs..., ")"); + wptr = sptr; + } + } else { + sptr = std::make_shared(eng, cargs...); + // ECOUT("make_cacheable constructed: ", typeid(T).name(), "(", cargs..., ")"); + cache.emplace(std::make_pair(key, std::weak_ptr(sptr))); + } + return sptr; +} + +struct onednn_linear { + std::shared_ptr mm; + dnnl::memory weight; + dnnl::memory scale; + dnnl::memory zp; + dnnl::matmul m_prim; + dnnl::memory::dim m_K; + dnnl::memory::dim m_N; + dnnl::memory::dim m_batch; + dnnl::memory::data_type m_a_type; + int bin_post_id; + + static onednn_linear create(dnnl::engine eng, + dnnl::memory::data_type act_dtype, + dnnl::memory::data_type weight_dtype, + int batch, + int ic, + int oc, + int ic_group_size, + onednn_matmul::type t, + dnnl::memory weight, // external weight + dnnl::memory scale, + dnnl::memory zp) { + OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("onednn_linear::create()")); + auto mm = make_cacheable(eng, act_dtype, weight_dtype, batch, ic, oc, ic_group_size, t); + onednn_linear linear; + linear.mm = mm; + linear.bin_post_id = mm->bin_post_id; + linear.m_prim = mm->m_prim; + linear.m_K = mm->m_K; + linear.m_N = mm->m_N; + linear.m_batch = batch; + linear.m_a_type = mm->m_a_type; + linear.weight = weight; + + if (scale) { + // https://uxlfoundation.github.io/oneDNN/page_weights_decompression_matmul_cpp.html + // Quantization Group size for scales. Must be divisible by 32. + auto wei_scale_md = dnnl::memory::desc(dnnl::memory::dims({mm->m_K_groups, mm->m_N}), dnnl::memory::data_type::f16, dnnl::memory::format_tag::ab); + linear.scale = scale; // dnnl::ocl_interop::make_memory(wei_scale_md, linear.m_engine, dnnl::ocl_interop::memory_kind::usm, scale); + if (zp) { + auto wei_zp_md = dnnl::memory::desc(dnnl::memory::dims({mm->m_K_groups, mm->m_N}), mm->m_w_type, dnnl::memory::format_tag::ab); + linear.zp = zp; // dnnl::ocl_interop::make_memory(wei_zp_md, linear.m_engine, dnnl::ocl_interop::memory_kind::usm, zp); + } + } + return linear; + } + + void forward(dnnl::stream& stream, int m, dnnl::memory src_mem, dnnl::memory dst_mem, dnnl::memory bin_mem) { + OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("onednn_linear::forward()")); + dnnl::memory::dim M = m; + + OPENVINO_ASSERT(m_batch == 0 || m_batch == M, "m_batch=", m_batch, " M=", M); + + dnnl::memory::desc rt_src_md = dnnl::memory::desc(dnnl::memory::dims({M, m_K}), m_a_type, dnnl::memory::format_tag::ab); + dnnl::memory::desc rt_dst_md = dnnl::memory::desc(dnnl::memory::dims({M, m_N}), m_a_type, dnnl::memory::format_tag::ab); + dnnl::memory::desc rt_bin_md; + if (mm->bin_per_row) { + rt_bin_md = dnnl::memory::desc(dnnl::memory::dims({M, 1}), m_a_type, dnnl::memory::format_tag::ab); + } else { + rt_bin_md = dnnl::memory::desc(dnnl::memory::dims({M, m_N}), m_a_type, dnnl::memory::format_tag::ab); + } + + std::unordered_map args; + args.insert({DNNL_ARG_SRC, src_mem}); + args.insert({DNNL_ARG_WEIGHTS, weight}); + // args.insert({DNNL_ARG_BIAS, bias_mem}); + args.insert({DNNL_ARG_DST, dst_mem}); + + if (scale) { + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, scale}); + } + if (zp) { + args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS, zp}); + } + if (bin_mem) { + // auto bin_mem = dnnl::ocl_interop::make_memory(rt_bin_md, m_engine, dnnl::ocl_interop::memory_kind::usm, (void *)(bin_input)); + args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(bin_post_id) | DNNL_ARG_SRC_1, bin_mem}); + } + m_prim.execute(stream, args); + } +}; + +class MOEOptSoftMaxTopK : public KernelGenerator { +public: + MOEOptSoftMaxTopK() : KernelGenerator("moe_opt", "softmax_topk") {} + +protected: + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + auto desc = params.typed_desc(); + jit.make("SOFTMAX_TOPK_ENABLE", 1); + jit.make("TOP_K", desc->_config.topk); + jit.make("VALUE_NUM", desc->_config.expert_num); + jit.make("TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); + jit.make("TYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); + return jit; + } + + [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{nullptr}; + } +}; + +class MOEOptGather : public KernelGenerator { +public: + MOEOptGather() : KernelGenerator("moe_opt", "gather") {} + +protected: + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + auto desc = params.typed_desc(); + jit.make("GATHER_ENABLE", 1); + jit.make("HIDDEN_SIZE", desc->_config.hidden_size); + jit.make("TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); + jit.make("TYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); + return jit; + } + + [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{nullptr}; + } +}; + +class MOEOptScatter : public KernelGenerator { +public: + MOEOptScatter() : KernelGenerator("moe_opt", "index_add") {} + +protected: + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + auto desc = params.typed_desc(); + jit.make("SCATTER_ENABLE", 1); + jit.make("HIDDEN_SIZE", desc->_config.hidden_size); + jit.make("TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); + jit.make("TYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); + return jit; + } + + [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{nullptr}; + } +}; + +# define N_BLOCK 4 +# define SUBGROUP_NUM 8 + +static void add_common_consts(const RuntimeParams& params, JitConstants& jit) { + auto desc = params.typed_desc(); + auto& engine = params.prog->get_engine(); + const auto& info = engine.get_device_info(); + jit.make("MAX_TOPK", desc->_config.topk); + jit.make("EXPERT_NUM", desc->_config.expert_num); + jit.make("HIDDEN_SIZE", desc->_config.hidden_size); + jit.make("INTERMEDIATE_SIZE", desc->_config.intermediate_size); + jit.make("N_BLOCK", N_BLOCK); + jit.make("SUBGROUP_SIZE", info.arch >= gpu_arch::xe2 ? 32 : 16); + jit.make("SUBGROUP_NUM", SUBGROUP_NUM); + jit.make("GROUP_SIZE", desc->_config.group_size); + jit.make("TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); + jit.make("TYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); +} + +class MOEOptMLPGateUp : public KernelGenerator { +public: + MOEOptMLPGateUp() : KernelGenerator("moe_mlp", "gate_up") {} + +protected: + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + auto desc = params.typed_desc(); + add_common_consts(params, jit); + jit.make("GATE_UP_ENABLE", 1); + return jit; + } + + [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{nullptr}; + } +}; + +class MOEOptMLPDown : public KernelGenerator { +public: + MOEOptMLPDown() : KernelGenerator("moe_mlp", "down") {} + +protected: + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + auto desc = params.typed_desc(); + add_common_consts(params, jit); + jit.make("DOWN_ENABLE", 1); + return jit; + } + + [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{nullptr}; + } +}; + +class MOEOptMLPReduce : public KernelGenerator { +public: + MOEOptMLPReduce() : KernelGenerator("moe_mlp", "reduce") {} + +protected: + [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { + auto jit = KernelGenerator::get_jit_constants(params); + auto desc = params.typed_desc(); + add_common_consts(params, jit); + jit.make("REDUCE_ENABLE", 1); + return jit; + } + + [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { + Arguments args; + return args; + } + + [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { + return DispatchDataFunc{nullptr}; + } +}; + +dnnl::memory convert2dnnl(const memory::ptr& ptr, const std::vector& dim, dnnl::memory::format_tag tag, int offset = 0) { + OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("convert2dnnl")); + return ptr->get_onednn_memory(dnnl::memory::desc(dnnl::memory::dims(dim), convert_data_type(ptr->get_layout().data_type), tag), offset); +} + +class MOEOptImpl : public PrimitiveImplOCL { +public: + DECLARE_OBJECT_TYPE_SERIALIZATION(ov::intel_gpu::ocl::MOEOptImpl) + Stage::Ptr softmax_topk = make_stage(); + Stage::Ptr gather = make_stage(); + Stage::Ptr scatter = make_stage(); + Stage::Ptr mlp_gate_up = make_stage(); + Stage::Ptr mlp_down = make_stage(); + Stage::Ptr mlp_reduce = make_stage(); + + struct dnnl_weights { + dnnl::memory weight; + dnnl::memory scale; + dnnl::memory zp; + int ic, oc, ic_group_size; + }; + + // expert_mask result in cpu side + struct expert_mask_cpu { + std::vector pred_flag; + // shape: [expert_num, batch_no] + std::vector> batch; + // shape: [expert_num, topk_no] + std::vector> topk; + }; + + // store expert_mask for gpu kernel + struct expert_mask_gpu { + memory::ptr batch; + memory::ptr topk; + }; + + struct moe_fusion_weights_base_addr { + memory::ptr weight[3]; // gate/up/down weights, experts fusion + memory::ptr scale[3]; + memory::ptr zp[3]; + memory::ptr bias[3]; + } moe_fusion_weights; + + struct scratch_buffers { + // softmax+topk + memory::ptr topk_id; + memory::ptr topk_weights; + + // fast single batch: scratch.up = up(x) * silu(gate(x)) + // scratch.y = down(scratch.up) * routing_weights + memory::ptr up; + memory::ptr y; + // onednn: scratch.x, scratch.routing_weights = gather(x, ...) + // scratch.up = up(scratch.x) + // scratch.gate = gate(scratch.x) * scratch.up + // scratch.y = down(scratch.gate) * routing_weights + memory::ptr x; + memory::ptr routing_weights; + memory::ptr gate; + // buffers for batch and topk from cpu, each expert has one + std::vector expert_masks; + + moe_fusion_weights_base_addr moe_fusion_wei_addr; + }; + + std::vector> _dnnl_weights; + int _hidden_size; + int _intermediate_size; + int _group_size; + + MOEOptImpl() : PrimitiveImplOCL(MOEOpt::get_type_info_static()) {} + MOEOptImpl(const program_node& node, const RuntimeParams& params) : MOEOptImpl() { + init(node.as().get_primitive()); + + add_stage(softmax_topk, params); + add_stage(gather, params); + add_stage(scatter, params); + add_stage(mlp_gate_up, params); + add_stage(mlp_down, params); + add_stage(mlp_reduce, params); + } + + void init(const std::shared_ptr& cur_moe) { + _hidden_size = static_cast(cur_moe->_config.hidden_size); + _intermediate_size = static_cast(cur_moe->_config.intermediate_size); + _group_size = static_cast(cur_moe->_config.group_size); + } + + void init_dnnl_weights(const std::shared_ptr& cur_moe, + const struct moe_fusion_weights_base_addr& moe_fusion_wei_addr) { + if(_dnnl_weights.size() == cur_moe->_config.expert_num) + return; + init(cur_moe); + + _dnnl_weights.resize(cur_moe->_config.expert_num); + for (size_t j = 0; j < cur_moe->_config.expert_num; j++) { + // const auto& mlp_params = moe_mlp_params[j]; + auto& dnnl_weights = _dnnl_weights[j]; + dnnl_weights.resize(3); + dnnl_weights[0].ic = _hidden_size; + dnnl_weights[0].ic_group_size = _group_size; + dnnl_weights[0].oc = _intermediate_size; + dnnl_weights[1].ic = _hidden_size; + dnnl_weights[1].ic_group_size = _group_size; + dnnl_weights[1].oc = _intermediate_size; + dnnl_weights[2].ic = _intermediate_size; + dnnl_weights[2].ic_group_size = _group_size; + dnnl_weights[2].oc = _hidden_size; + for (int i = 0; i < 3; i++) { + if (mlp_params.param[i].scale) { + // scale shape: [ic / ic_group_size, oc], type: f16 + dnnl_weights[i].scale = convert2dnnl(moe_fusion_wei_addr.scale[i] + j * dnnl_weights[i].ic * dnnl_weights[i].oc / dnnl_weights[i].ic_group_size * 2, + {dnnl_weights[i].ic / dnnl_weights[i].ic_group_size, dnnl_weights[i].oc}, + dnnl::memory::format_tag::ab); + } + if (mlp_params.param[i].zp) { + // zp shape: [ic / ic_group_size, oc], type: u4 + dnnl_weights[i].zp = convert2dnnl(moe_fusion_wei_addr.zp[i] + j * dnnl_weights[i].ic * dnnl_weights[i].oc / dnnl_weights[i].ic_group_size / 2, + {dnnl_weights[i].ic / dnnl_weights[i].ic_group_size, dnnl_weights[i].oc}, + dnnl::memory::format_tag::ab); + } + if (mlp_params.param[i].weight) { + // weight shape: [oc, ic], type: u4 + dnnl_weights[i].weight = convert2dnnl(moe_fusion_wei_addr.weight[i] + j * dnnl_weights[i].ic * dnnl_weights[i].oc / 2, + {dnnl_weights[i].ic, dnnl_weights[i].oc}, + dnnl::memory::format_tag::ba); + } + } + } + } + + void load(BinaryInputBuffer& ib) override { + PrimitiveImplOCL::load(ib); + const kernel_impl_params* impl_params = reinterpret_cast(ib.getKernelImplParams()); + init(impl_params->typed_desc()); + } + + [[nodiscard]] std::unique_ptr clone() const override { + auto cur_moe = make_deep_copy(this); + cur_moe->_dnnl_weights = _dnnl_weights; + cur_moe->_hidden_size = _hidden_size; + cur_moe->_intermediate_size = _intermediate_size; + cur_moe->_group_size = _group_size; + return cur_moe; + } + + std::vector get_internal_buffer_descs(const kernel_impl_params& params) const override { + auto cur_moe = params.typed_desc(); + const auto& config = cur_moe->_config; + int max_topk = static_cast(config.topk); + int expert_num = static_cast(config.expert_num); + + auto hidden_states_layout = params.input_layouts[0]; + auto batch = static_cast(hidden_states_layout.get_shape()[0]); + auto data_type = hidden_states_layout.data_type; + + std::vector internal_buffers; + // softmax+topk + layout layout_topk_id(ov::PartialShape{batch, max_topk}, data_types::u32, cldnn::format::bfyx); + layout layout_topk_weights(ov::PartialShape{batch, max_topk}, data_type, cldnn::format::bfyx); + internal_buffers.emplace_back(layout_topk_id, true); // 0: topk_id + internal_buffers.emplace_back(layout_topk_weights, true); // 1: topk_weights + // fast single batch: scratch.up = up(x) * silu(gate(x)); scratch.y = down(scratch.up) * weight[expert_no] + layout layout_gateup_out(ov::PartialShape{batch, static_cast(config.intermediate_size)}, data_type, cldnn::format::bfyx); + layout layout_down_out(ov::PartialShape{batch, static_cast(config.hidden_size)}, data_type, cldnn::format::bfyx); + internal_buffers.emplace_back(layout_gateup_out, true); // 2: up + internal_buffers.emplace_back(layout_down_out, true); // 3: y + // onednn: scratch.x, scratch.routing_weights = gather(x, ...) + // scratch.up = up(scratch.x) + // scratch.gate = gate(scratch.x) * scratch.up + // scratch.y = down(scratch.gate) * routing_weights + internal_buffers.emplace_back(layout_down_out, true); // 4: x, scratch.x has same layout with down output + layout routing_layout(ov::PartialShape{batch * max_topk}, data_type, cldnn::format::bfyx); + internal_buffers.emplace_back(layout_down_out, true); // 5: routing_weights + internal_buffers.emplace_back(layout_gateup_out, true); // 6: gate, scratch.gate has same layout with up + // expert masks for gpu + layout index_layout(ov::PartialShape{batch}, ov::element::i32, cldnn::format::bfyx); + for (int i = 0; i < expert_num; i++) { + internal_buffers.emplace_back(index_layout, true); // batch + internal_buffers.emplace_back(index_layout, true); // topk + } + + return internal_buffers; + } + + void prepare_internal_buffers(typed_primitive_inst& instance, scratch_buffers& scratch, bool is_single_batch) { + const auto& intermediates_memories = instance.get_intermediates_memories(); + scratch.topk_id = intermediates_memories[0]; + scratch.topk_weights = intermediates_memories[1]; + scratch.up = intermediates_memories[2]; + scratch.y = intermediates_memories[3]; + if (!is_single_batch) { + scratch.x = intermediates_memories[4]; + scratch.routing_weights = intermediates_memories[5]; + scratch.gate = intermediates_memories[6]; + const auto& config = instance.get_typed_desc()->_config; + int expert_num = static_cast(config.expert_num); + scratch.expert_masks.resize(expert_num); + for (int i = 0; i < expert_num; i++) { + scratch.expert_masks[i].batch = intermediates_memories[7 + 2 * i + 0]; + scratch.expert_masks[i].topk = intermediates_memories[7 + 2 * i + 1]; + } + } + + // gate + moe_fusion_wei_addr.weight[0] = instance.input_memory_ptr(MOEInputIndex::WEIGHT_0); + moe_fusion_wei_addr.scale[0] = instance.input_memory_ptr(MOEInputIndex::SCALE_0); + moe_fusion_wei_addr.zp[0] = instance.input_memory_ptr(MOEInputIndex::ZP_0); + + // up + moe_fusion_wei_addr.weight[1] = instance.input_memory_ptr(MOEInputIndex::WEIGHT_1); + moe_fusion_wei_addr.scale[1] = instance.input_memory_ptr(MOEInputIndex::SCALE_1); + moe_fusion_wei_addr.zp[1] = instance.input_memory_ptr(MOEInputIndex::ZP_1); + + // down + moe_fusion_wei_addr.weight[2] = instance.input_memory_ptr(MOEInputIndex::WEIGHT_2); + moe_fusion_wei_addr.scale[2] = instance.input_memory_ptr(MOEInputIndex::SCALE_2); + moe_fusion_wei_addr.zp[2] = instance.input_memory_ptr(MOEInputIndex::ZP_2); + } + + void get_expert_mask_from_gpu(const MOE::Config& config, memory::ptr mem, stream& stream, expert_mask_cpu& expert_mask) { + // shape: [batch, topk] + auto layout = mem->get_layout(); + const auto& shape = layout.get_shape(); + + int max_expert_num = static_cast(config.expert_num), max_topk = static_cast(config.topk), max_tokens = static_cast(shape[0]); + + expert_mask.pred_flag.resize(max_expert_num, 0); + expert_mask.batch.resize(max_expert_num, {}); + expert_mask.topk.resize(max_expert_num, {}); + + OPENVINO_ASSERT(!layout.data_padding, "get_expert_mask_from_memory not support padding"); + + std::vector buf(max_topk * max_tokens); + mem->copy_to(stream, buf.data(), 0, 0, buf.size() * sizeof(int32_t), true); + + for (int b = 0; b < max_tokens; b++) { + auto* tok_p = &buf[b * max_topk]; + for (int t = 0; t < max_topk; t++) { + auto expert_no = tok_p[t]; + OPENVINO_ASSERT(expert_no < max_expert_num); + expert_mask.batch[expert_no].push_back(b); + expert_mask.topk[expert_no].push_back(t + b * max_topk); + expert_mask.pred_flag[expert_no] = 1; + } + } + { + // check if the result is ok + int count = 0; + for (int no = 0; no < max_expert_num; no++) { + count += static_cast(expert_mask.batch[no].size()); + } + OPENVINO_ASSERT(count == max_topk * max_tokens, + "With max_expert_num=", + max_expert_num, + ",max_topk=", + max_topk, + ",max_tokens=", + max_tokens, + " should have ", + max_topk * max_tokens, + " tokens, but current is ", + count, + ". layout=", + layout); + } + } + + void copy_expert_mask_to_gpu(stream& stream, const expert_mask_cpu& expert_mask, size_t expert_no, expert_mask_gpu& expert_mask_mem) { + auto size = expert_mask.batch[expert_no].size() * sizeof(int); + + { + mem_lock lock_data{expert_mask_mem.batch, stream}; + memcpy(lock_data.data(), expert_mask.batch[expert_no].data(), size); + } + { + mem_lock lock_data{expert_mask_mem.topk, stream}; + memcpy(lock_data.data(), expert_mask.topk[expert_no].data(), size); + } + } + + cldnn::event::ptr execute_stage(const std::vector& events, + cldnn::primitive_inst& instance, + Stage& stage, + std::vector inputs, + std::vector outputs, + const std::vector& global, + const std::vector& local, + bool needs_completion_event = false) const { + OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("MOEOptImpl::execute_stage")); + cldnn::stream& stream = instance.get_network().get_stream(); + cldnn::kernel_arguments_data args; + cldnn::kernel_arguments_desc desc; + for (uint32_t i = 0; i < inputs.size(); i++) { + desc.arguments.push_back({ArgumentDescriptor::Types::INPUT, i}); + args.inputs.push_back(inputs[i]); + } + + for (uint32_t i = 0; i < outputs.size(); i++) { + desc.arguments.push_back({ArgumentDescriptor::Types::OUTPUT, i}); + args.outputs.push_back(outputs[i]); + } + + stream.set_arguments(*stage.kernel, desc, args); + desc.workGroups.global = global; + desc.workGroups.local = local; + + return stream.enqueue_kernel(*stage.kernel, desc, {}, events, needs_completion_event); + } + + auto get_input_info(typed_primitive_inst& instance, int idx) { + auto mem = instance.input_memory_ptr(idx); + auto dep = instance.dependencies()[idx]; + auto layout = dep.first->get_impl_params()->get_output_layout(dep.second); + return std::make_tuple(mem, layout); + } + + cldnn::event::ptr exec_single_batch(typed_primitive_inst& instance, scratch_buffers& scratch) { + auto cur_moe = instance.get_typed_desc(); + int max_topk = static_cast(cur_moe->_config.topk); + + auto final_hidden_states_mem_ptr = instance.output_memory_ptr(0); + auto batch_mem_ptr = scratch.topk_id; + auto [hidden_states_mem_ptr, hidden_states_layout] = get_input_info(instance, MOEInputIndex::HIDDEN_STATES); + auto routing_mem_ptr = scratch.topk_weights; + + _hidden_size = static_cast(cur_moe->_config.hidden_size); + _intermediate_size = static_cast(cur_moe->_config.intermediate_size); + + const size_t subgroup_size = instance.get_impl_params()->get_device_info().arch >= gpu_arch::xe2 ? 32 : 16; + const size_t max_work_group_size = instance.get_impl_params()->get_device_info().max_work_group_size; + + // gate + const auto mlp_gate_wei_mem = scratch.moe_fusion_wei_addr.weight[0]; + const auto mlp_gate_scale_mem = scratch.moe_fusion_wei_addr.scale[0]; + const auto mlp_gate_zp_mem = scratch.moe_fusion_wei_addr.zp[0]; + + // up + const auto mlp_up_wei_mem = scratch.moe_fusion_wei_addr.weight[1]; + const auto mlp_up_scale_mem = scratch.moe_fusion_wei_addr.scale[1]; + const auto mlp_up_zp_mem = scratch.moe_fusion_wei_addr.zp[1]; + + // down + const auto mlp_down_wei_mem = scratch.moe_fusion_wei_addr.weight[2]; + const auto mlp_down_scale_mem = scratch.moe_fusion_wei_addr.scale[2]; + const auto mlp_down_zp_mem = scratch.moe_fusion_wei_addr.zp[2]; + event::ptr ret; + + { + // scratch.up = up(x) * silu(gate(x)) + execute_stage({}, + instance, + *mlp_gate_up, + {batch_mem_ptr, mlp_gate_wei_mem, mlp_gate_scale_mem, mlp_gate_zp_mem, mlp_up_wei_mem, mlp_up_scale_mem, mlp_up_zp_mem, hidden_states_mem_ptr}, + {scratch.up}, + {static_cast(max_topk), subgroup_size, static_cast(_intermediate_size / N_BLOCK)}, + {1, subgroup_size, SUBGROUP_NUM}); + + // scratch.y = down(scratch.up) * weight[expert_no] + execute_stage({}, + instance, + *mlp_down, + {batch_mem_ptr, mlp_down_wei_mem, mlp_down_scale_mem, mlp_down_zp_mem, scratch.up, routing_mem_ptr}, + {scratch.y}, + {static_cast(max_topk), subgroup_size, static_cast(_hidden_size / N_BLOCK)}, + {1, subgroup_size, SUBGROUP_NUM}); + + // final = sum(scratch.y) + ret = execute_stage({}, + instance, + *mlp_reduce, + {scratch.y}, + {final_hidden_states_mem_ptr}, + {static_cast(1), static_cast(_hidden_size)}, + {1, std::min(max_work_group_size, size_t{1024})}, + instance.needs_completion_event()); + } + return ret; + } + + struct onednn_kernel { + onednn_linear up; + onednn_linear gate; + onednn_linear down; + }; + struct PairHash { + template + size_t operator()(const std::pair& p) const { + // Combine hash values of the pair elements + return std::hash()(p.first) ^ std::hash()(p.second); + } + }; + + using lru_cache_hash = LruCache, std::shared_ptr, PairHash>; + lru_cache_hash _kernels = lru_cache_hash(1024); + onednn_kernel& get_kernel(int n_token, int expert_no, typed_primitive_inst& instance) { + auto key = std::make_pair(n_token, expert_no); + if (_kernels.has(key)) { + return *_kernels.get(key); + } + + auto& cur_net = instance.get_network(); + auto& stream = cur_net.get_stream(); + // auto cur_moe = instance.get_typed_desc(); + // const auto& moe_mlp_params = cur_moe->_mlp_params; + // const auto& mlp_params = moe_mlp_params[expert_no]; + auto& dnn_stream = stream.get_onednn_stream(); + auto hidden_states_layout_dt = convert_data_type(instance.input_memory_ptr(MOEInputIndex::HIDDEN_STATES)->get_layout().data_type); + + auto& dnnl_weights = _dnnl_weights[expert_no]; + auto kernel = std::make_shared(); + + // gate + auto gate_weight_layout_dt = convert_data_type(instance.input_memory_ptr(MOEInputIndex::WEIGHT_0)->get_layout().data_type); + kernel->gate = onednn_linear::create(dnn_stream.get_engine(), + hidden_states_layout_dt, + gate_weight_layout_dt, + n_token, + dnnl_weights[0].ic, + dnnl_weights[0].oc, + dnnl_weights[0].ic_group_size, + onednn_matmul::type::with_silu_bin_mul, + dnnl_weights[0].weight, + dnnl_weights[0].scale, + dnnl_weights[0].zp); + + // up + auto up_weight_layout_dt = convert_data_type(instance.input_memory_ptr(MOEInputIndex::WEIGHT_1)->get_layout().data_type); + kernel->up = onednn_linear::create(dnn_stream.get_engine(), + hidden_states_layout_dt, + up_weight_layout_dt, + n_token, + dnnl_weights[1].ic, + dnnl_weights[1].oc, + dnnl_weights[1].ic_group_size, + onednn_matmul::type::none, + dnnl_weights[1].weight, + dnnl_weights[1].scale, + dnnl_weights[1].zp); + + // down + auto down_weight_layout_dt = convert_data_type(instance.input_memory_ptr(MOEInputIndex::WEIGHT_2)->get_layout().data_type); + kernel->down = onednn_linear::create(dnn_stream.get_engine(), + hidden_states_layout_dt, + down_weight_layout_dt, + n_token, + dnnl_weights[2].ic, + dnnl_weights[2].oc, + dnnl_weights[2].ic_group_size, + onednn_matmul::type::with_bin_mul_per_row, + dnnl_weights[2].weight, + dnnl_weights[2].scale, + dnnl_weights[2].zp); + _kernels.add(key, kernel); + return *_kernels.get(key); + } + + // inputs 0 is hidden_states, inputs 1 is router_logits[num_tokens, NUM_EXPERTS=128] + // extra step Softmax_TopK is fused to give topk-id & router_weights + // + // scratch.topk_id, scratch.full_router_weights = Softmax_TopK(router_logits) + // + // generate expert_mask from topk-id + // expert_mask.batch[i][j] : j'th token index for i'th expert + // expert_mask.topk[i][j] : topk-output offset for j'th token for i'th expert, used to get weights + // expert_mask.pred_flag[i]: bool, if expert i can be skipped + // + // + // scratch.x, scratch.routing_weights = gather(hidden_states, scratch.full_router_weights, expert_mask.batch, expert_mask.topk) + // scratch.y = MLP(scratch.x, .gate/up/down) * scratch.routing_weights + // scatter(final_hidden, scratch.y, expert_mask.batch) + // + cldnn::event::ptr execute(const std::vector& events, cldnn::primitive_inst& ins) override { + OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("MOEOptImpl::execute")); + auto& instance = reinterpret_cast&>(ins); + auto cur_moe = instance.get_typed_desc(); + const auto& config = cur_moe->_config; + int max_topk = static_cast(config.topk); + auto& cur_net = instance.get_network(); + auto& stream = cur_net.get_stream(); + + auto [hidden_states_mem_ptr, hidden_states_layout] = get_input_info(instance, MOEInputIndex::HIDDEN_STATES); + auto batch = static_cast(hidden_states_layout.get_shape()[0]); + + scratch_buffers scratch; + prepare_internal_buffers(instance, scratch, batch == 1); + + // softmax+topk + auto lws_size = cur_moe->_config.expert_num; + auto topk_event = execute_stage(events, + instance, + *softmax_topk, + {instance.input_memory_ptr(MOEInputIndex::ROUTING_WEIGHTS)}, + {scratch.topk_id, scratch.topk_weights}, + {static_cast(batch), lws_size}, + {1, lws_size}); + + // Single batch is a special case, we don't need to do gather/scatter, + // and we can apply optimal kernels against memory bound to improve performance. + // It is very important for MoE's second token performance. + if (batch == 1) { + return exec_single_batch(instance, scratch); + } + + init_dnnl_weights(cur_moe, scratch.moe_fusion_wei_addr); + auto final_hidden_states_mem_ptr = instance.output_memory_ptr(0); + auto final_hidden_states_layout = instance.get_output_layout(0); + + // onednn path will accumulate to the output + final_hidden_states_mem_ptr->fill(stream, false); + + // Wait for topk is ready + topk_event->wait(); + // [batch, max_topk] + auto topk_id_mem = scratch.topk_id; + expert_mask_cpu expert_mask; + get_expert_mask_from_gpu(config, topk_id_mem, stream, expert_mask); + + auto& dnn_stream = stream.get_onednn_stream(); + cldnn::event::ptr result_event; + + auto routing_mem_ptr = scratch.topk_weights; + auto get_best_lws = [](size_t hidden_size) { + const size_t candidate[] = {128, 64, 32, 16, 8}; + for (size_t i = 0; i < sizeof(candidate) / sizeof(size_t); i++) { + if (hidden_size % candidate[i] == 0) { + return candidate[i]; + } + } + OPENVINO_ASSERT(false, "hidden_size=", hidden_size, " is not divisible by any of ", sizeof(candidate) / sizeof(size_t), " candidates"); + }; + lws_size = get_best_lws(_hidden_size); + + OPENVINO_ASSERT(batch != 1, "batch size shouldn't be 1 for this path!"); + for (size_t expert_no = 0; expert_no < config.expert_num; expert_no++) { + OPENVINO_ASSERT(expert_no < expert_mask.pred_flag.size()); + auto can_skip_subgraph = !expert_mask.pred_flag[expert_no]; + if (can_skip_subgraph) { + continue; + } + auto& dnnl_weights = _dnnl_weights[expert_no]; + + // expert_mask + expert_mask_gpu& expert_mask_mem = scratch.expert_masks[expert_no]; + copy_expert_mask_to_gpu(stream, expert_mask, expert_no, expert_mask_mem); + + auto n_token = static_cast(expert_mask.batch[expert_no].size()); + onednn_kernel& kernel = get_kernel(n_token, static_cast(expert_no), instance); + memory::ptr& x = scratch.x; + + // gather + execute_stage(events, + instance, + *gather, + {hidden_states_mem_ptr, routing_mem_ptr, expert_mask_mem.batch, expert_mask_mem.topk}, + {x, scratch.routing_weights}, + {static_cast(n_token), static_cast(_hidden_size)}, + {1, lws_size}); + + // up + kernel.up.forward(dnn_stream, + n_token, + convert2dnnl(x, {static_cast(n_token), dnnl_weights[1].ic}, dnnl::memory::format_tag::ab), + convert2dnnl(scratch.up, {static_cast(n_token), _intermediate_size}, dnnl::memory::format_tag::ab), + dnnl::memory()); + + // gate + kernel.gate.forward(dnn_stream, + n_token, + convert2dnnl(x, {static_cast(n_token), dnnl_weights[0].ic}, dnnl::memory::format_tag::ab), + convert2dnnl(scratch.gate, {static_cast(n_token), _intermediate_size}, dnnl::memory::format_tag::ab), + convert2dnnl(scratch.up, {static_cast(n_token), _intermediate_size}, dnnl::memory::format_tag::ab)); + + // down + kernel.down.forward(dnn_stream, + n_token, + convert2dnnl(scratch.gate, {static_cast(n_token), _intermediate_size}, dnnl::memory::format_tag::ab), + convert2dnnl(scratch.y, {static_cast(n_token), _hidden_size}, dnnl::memory::format_tag::ab), + convert2dnnl(scratch.routing_weights, {n_token * max_topk}, dnnl::memory::format_tag::a)); + // index_add + result_event = execute_stage(events, + instance, + *scatter, + {scratch.y, expert_mask_mem.batch}, + {final_hidden_states_mem_ptr}, + {static_cast(n_token), static_cast(_hidden_size)}, + {1, lws_size}, + instance.needs_completion_event()); + } + + return result_event; + } +}; + +} // namespace + +std::unique_ptr MOEOpt::create_impl(const program_node& node, const RuntimeParams& params) const { + assert(node.is_type()); + return std::make_unique(node, params); +} + +} // namespace ov::intel_gpu::ocl + +BIND_BINARY_BUFFER_WITH_TYPE(cldnn::moe) +BIND_BINARY_BUFFER_WITH_TYPE(ov::intel_gpu::ocl::MOEOptImpl) + +#else + +namespace ov::intel_gpu::ocl { + +std::unique_ptr MOEOpt::create_impl(const program_node& node, const RuntimeParams& params) const { + OPENVINO_THROW("MOEOpt depends on onednn."); +} + +} // namespace ov::intel_gpu::ocl + +#endif diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.hpp new file mode 100644 index 00000000000000..da9b8fbe582e6d --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.hpp @@ -0,0 +1,63 @@ +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "intel_gpu/primitives/activation.hpp" +#include "intel_gpu/primitives/eltwise.hpp" +#include "program_node.h" +#include "registry/implementation_manager.hpp" + +using namespace cldnn; // TODO: Remove once namespaces are aligned +namespace ov::intel_gpu::ocl { + +//TODO: need confirm, gate is 1st matmul or up is 1st matmul? +// mlp_gate: 0 +// mlp_up: 1 +// mlp_down: 2 +enum class MOEInputIndex : uint8_t { HIDDEN_STATES = 0, ROUTING_WEIGHTS = 1, ROUTER_TOPK_OUTPUT_INDICES = 2, + WEIGHT_0 = 3, SCALE_0 = 4, ZP_0 = 5, WEIGHT_1 = 6, SCALE_1 = 7, ZP_1 = 8, WEIGHT_2 = 9, + SCALE_2 = 10, ZP_2 = 11}; + +struct MOEOpt : public ImplementationManager { + OV_GPU_PRIMITIVE_IMPL("ocl::moe::opt") + explicit MOEOpt(shape_types shape_type, ValidateFunc vf = nullptr) : ImplementationManager(impl_types::ocl, shape_type, std::move(vf)) {} + [[nodiscard]] std::unique_ptr create_impl(const program_node& node, const RuntimeParams& params) const override; + [[nodiscard]] bool validate_impl(const program_node& node) const override { + static constexpr std::array supported_fmts = { + format::bfyx, + }; + + // TODO(MOE): support more precision + static constexpr std::array supported_types = { + ov::element::f16, + }; + + const auto& in0_layout = node.get_input_layout(MOEInputIndex::HIDDEN_STATES); + const auto& out_layout = node.get_output_layout(0); + if (!one_of(in0_layout.format, supported_fmts) || !one_of(out_layout.format, supported_fmts)) { + return false; + } + + if (!one_of(in0_layout.data_type, supported_types) || !one_of(out_layout.data_type, supported_types)) { + return false; + } + + // Only support u4 weights for now + static constexpr std::array supported_wei_type = { + ov::element::u4, + }; + const auto& wei_layout = node.get_input_layout(static_cast(MOEInputIndex::WEIGHT_0)); + if (!one_of(wei_layout.data_type, supported_wei_type)) { + return false; + } + + return true; + } +}; + +} // namespace ov::intel_gpu::ocl diff --git a/src/plugins/intel_gpu/src/graph/include/moe_inst.h b/src/plugins/intel_gpu/src/graph/include/moe_inst.h new file mode 100644 index 00000000000000..4af2e33678222c --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/include/moe_inst.h @@ -0,0 +1,52 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "intel_gpu/primitives/moe_compressed.hpp" +#include "primitive_inst.h" + +#include +#include +#include + +namespace cldnn { +namespace details {} + +template <> +struct typed_program_node : public typed_program_node_base { +private: + using parent = typed_program_node_base; + +public: + using parent::parent; + + typed_program_node(std::shared_ptr prim, program& prog) : parent(prim, prog) {} + + using parent::get_kernel_impl_params; + std::unique_ptr get_kernel_impl_params(const std::vector& in_layouts, const std::vector& out_layouts) const override { + auto params = parent::get_kernel_impl_params(in_layouts, out_layouts); + + return params; + } +}; + +using moe_node = typed_program_node; + +template <> +class typed_primitive_inst : public typed_primitive_inst_base { + using parent = typed_primitive_inst_base; + using parent::parent; + using primitive_inst::update_output_memory; + +public: + template + static std::vector calc_output_layouts(moe_node const& /*node*/, kernel_impl_params const& impl_param); + static layout calc_output_layout(moe_node const& /* node */, kernel_impl_params const& impl_param); + static std::string to_string(moe_node const& node); + typed_primitive_inst(network& network, moe_node const& node); +}; + +using moe_inst = typed_primitive_inst; +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/moe.cpp b/src/plugins/intel_gpu/src/graph/moe.cpp new file mode 100644 index 00000000000000..3d36b1dbf30864 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/moe.cpp @@ -0,0 +1,55 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "moe_inst.h" +#include "openvino/core/except.hpp" +#include "program_node.h" +#include "intel_gpu/runtime/error_handler.hpp" +#include "json_object.h" +#include "primitive_type_base.h" +#include "openvino/core/parallel.hpp" +#include + +namespace cldnn { +GPU_DEFINE_PRIMITIVE_TYPE_ID(moe) + +/* + Calc_output_layout method is called only when output layout is invalidated. + It means, that it is called when: + 1) It has never been called. + 2) Dependency has changed output layout. + In this both cases, we need to recalc branch_true and branch_false. + !* We can be sure, that this method was called AT LEAST once during graph compilation.*! +*/ +layout moe_inst::calc_output_layout(moe_node const& /* node */, kernel_impl_params const& impl_param) { + return impl_param.input_layouts[0]; +} + +template +std::vector moe_inst::calc_output_layouts(moe_node const& /* node */, kernel_impl_params const& impl_param) { + return {impl_param.input_layouts[0]}; +} + +template std::vector moe_inst::calc_output_layouts(moe_node const& node, const kernel_impl_params& impl_param); + +std::string moe_inst::to_string(moe_node const& node) { + auto desc = node.get_primitive(); + auto node_info = node.desc_to_json(); + json_composite moe_info; + + node_info->add("moe info", moe_info); + + std::stringstream primitive_description; + node_info->dump(primitive_description); + return primitive_description.str(); +} + +/* +moe primitive is reusing memory with the input. +*/ +moe_inst::typed_primitive_inst(network& network, moe_node const& node) + : parent(network, node) { +} + +} // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/registry/moe_impls.cpp b/src/plugins/intel_gpu/src/graph/registry/moe_impls.cpp new file mode 100644 index 00000000000000..9f150a77f83545 --- /dev/null +++ b/src/plugins/intel_gpu/src/graph/registry/moe_impls.cpp @@ -0,0 +1,26 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "intel_gpu/primitives/moe_compressed.hpp" +#include "registry.hpp" +#include "primitive_inst.h" + +#if OV_GPU_WITH_OCL + #include "impls/ocl_v2/moe_opt.hpp" +#endif + + +namespace ov::intel_gpu { + +using namespace cldnn; + +const std::vector>& Registry::get_implementations() { + static const std::vector> impls = { + OV_GPU_CREATE_INSTANCE_OCL(ocl::MOEOpt, shape_types::any) + }; + + return impls; +} + +} // namespace ov::intel_gpu diff --git a/src/plugins/intel_gpu/src/graph/registry/registry.hpp b/src/plugins/intel_gpu/src/graph/registry/registry.hpp index 341503a2322125..a6b8ed625625ba 100644 --- a/src/plugins/intel_gpu/src/graph/registry/registry.hpp +++ b/src/plugins/intel_gpu/src/graph/registry/registry.hpp @@ -166,6 +166,7 @@ REGISTER_IMPLS(strided_slice); REGISTER_IMPLS(tile); REGISTER_IMPLS(col2im); REGISTER_IMPLS(vl_sdpa); +REGISTER_IMPLS(moe_compressed); REGISTER_DEFAULT_IMPLS(assign, CPU_S, CPU_D); REGISTER_DEFAULT_IMPLS(read_value, CPU_S, CPU_D); diff --git a/src/plugins/intel_gpu/src/plugin/ops/moe.cpp b/src/plugins/intel_gpu/src/plugin/ops/moe.cpp new file mode 100644 index 00000000000000..ddc76d44d8c0bf --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/ops/moe.cpp @@ -0,0 +1,27 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "intel_gpu/plugin/program_builder.hpp" +#include "intel_gpu/plugin/common_utils.hpp" +#include + + +namespace ov::intel_gpu { + +static void CreateMOEOp(ProgramBuilder& p, const std::shared_ptr& op) { + auto inputs = p.GetInputInfo(op); + const auto& config = op->get_config(); + OPENVINO_ASSERT(inputs.size() == 11, "Inputs count of MOE should be 11"); + + const std::string layerName = layer_type_name_ID(op); + auto& engine = p.get_engine(); + + const cldnn::moe_compressed moe(layerName, inputs, config); + + p.add_primitive(*op, moe); +} + +REGISTER_FACTORY_IMPL(internal, MOECompressed); + +} // namespace ov::intel_gpu From b4f78f6ffbb8a5458287c15ca6ac3065b1ffee72 Mon Sep 17 00:00:00 2001 From: chenhu-wang Date: Mon, 20 Oct 2025 08:46:25 +0800 Subject: [PATCH 02/13] MOECompressed internal op --- .../include/ov_ops/moe_compressed.hpp | 68 +++++++++++++++++++ .../src/ov_ops/moe_compressed.cpp | 46 +++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 src/common/transformations/include/ov_ops/moe_compressed.hpp create mode 100644 src/common/transformations/src/ov_ops/moe_compressed.cpp diff --git a/src/common/transformations/include/ov_ops/moe_compressed.hpp b/src/common/transformations/include/ov_ops/moe_compressed.hpp new file mode 100644 index 00000000000000..69fba3e3fe465c --- /dev/null +++ b/src/common/transformations/include/ov_ops/moe_compressed.hpp @@ -0,0 +1,68 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/core/node.hpp" +#include "openvino/op/op.hpp" +// #include "openvino/op/moe.hpp" +#include "transformations_visibility.hpp" + +namespace ov { +namespace op { +namespace internal { + +/// \brief MOECompressed experts that support compressed weights for GEMM3_SWIGLU MOE. +class TRANSFORMATIONS_API MOECompressed : public ov::op::Op { +public: + OPENVINO_OP("MOECompressed", "ie_internal_opset"); + + MOECompressed() = default; + + struct Config { + ov::element::Type out_type = ov::element::dynamic; // fp16 + size_t group_size = 0; + }; + + /// \brief Constructs a MOECompressed operation with config only + /// \param args The input tensors, in the following order: + /// 0: hidden_states - input tensor with hidden representations + /// 1: routing_weights - [num_experts, ...] normalized weights for selected experts + /// (input to final multiplication) + /// 2: router_topk_output_indices - [..., topk] indices of selected top-k experts + /// 3: w0_weight - expert weights for first projection, + /// shape [num_experts, inter_size, group_num, group_size] + /// 4: w0_scale - expert scale for first projection for compressed experts, + /// shape [num_experts, inter_size, group_num, 1] + /// 5: w0_zp - expert zp for first projection for compressed experts, + /// shape [num_experts, inter_size, group_num, 1] + /// 6: w1_weight - expert weights for second projection, + /// shape [num_experts, inter_size, hidden_size] + /// 7: w1_scale - expert scale for second projection for compressed experts, + /// shape [num_experts, inter_size, group_num, 1] + /// 8: w1_zp - expert zp for second projection for compressed experts, + /// shape [num_experts, inter_size, group_num, 1] + /// 9: w2_weight - expert weights for final projection, + /// shape [num_experts, hidden_size, inter_size] + /// 10: w2_scale - expert scale for final projection for compressed experts, + /// shape [num_experts, inter_size, group_num, 1] + /// 11: w2_zp - expert zp for final projection for compressed experts, + /// shape [num_experts, inter_size, group_num, 1] + /// \param config Configuration for the MOE operation + MOECompressed(const OutputVector& args, const Config& config); + + const Config& get_config() const; + void set_config(const Config& config); + + bool visit_attributes(AttributeVisitor& visitor) override; + void validate_and_infer_types() override; + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + +private: + Config m_config; +}; + +} // namespace internal +} // namespace op +} // namespace ov diff --git a/src/common/transformations/src/ov_ops/moe_compressed.cpp b/src/common/transformations/src/ov_ops/moe_compressed.cpp new file mode 100644 index 00000000000000..acc3be917276b7 --- /dev/null +++ b/src/common/transformations/src/ov_ops/moe_compressed.cpp @@ -0,0 +1,46 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "ov_ops/moe_compressed.hpp" + +#include "itt.hpp" + +namespace ov { +namespace op { +namespace internal { + +MOECompressed::MOECompressed(const OutputVector& args, const Config& config) : Op(args), m_config(config) { + constructor_validate_and_infer_types(); +} + +const MOECompressed::Config& MOECompressed::get_config() const { + return m_config; +} + +void MOECompressed::set_config(const Config& config) { + m_config = config; +} + +std::shared_ptr MOECompressed::clone_with_new_inputs(const ov::OutputVector& new_args) const { + check_new_args_count(this, new_args); + + return std::make_shared(new_args, m_config); +} + +void MOECompressed::validate_and_infer_types() { + auto output_type = m_config.out_type == ov::element::dynamic ? get_input_element_type(0) : m_config.out_type; + + set_output_type(0, output_type, get_input_partial_shape(0)); +} + +bool MOECompressed::visit_attributes(ov::AttributeVisitor& visitor) { + visitor.on_attribute("out_type", m_config.out_type); + visitor.on_attribute("group_size", m_config.group_size); + + return true; +} + +} // namespace internal +} // namespace op +} // namespace ov From 758babdcf43e429570d6de41f0586d23780a272c Mon Sep 17 00:00:00 2001 From: "River, Li" Date: Mon, 20 Oct 2025 10:23:43 +0800 Subject: [PATCH 03/13] update --- src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp | 1 - src/plugins/intel_gpu/src/plugin/ops/moe.cpp | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp index 615f6ff4dde284..63573696118318 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp @@ -4,7 +4,6 @@ #include "moe_opt.hpp" -#define ENABLE_ONEDNN_FOR_GPU #ifdef ENABLE_ONEDNN_FOR_GPU # include # include diff --git a/src/plugins/intel_gpu/src/plugin/ops/moe.cpp b/src/plugins/intel_gpu/src/plugin/ops/moe.cpp index ddc76d44d8c0bf..184b6e2837c512 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/moe.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/moe.cpp @@ -12,7 +12,7 @@ namespace ov::intel_gpu { static void CreateMOEOp(ProgramBuilder& p, const std::shared_ptr& op) { auto inputs = p.GetInputInfo(op); const auto& config = op->get_config(); - OPENVINO_ASSERT(inputs.size() == 11, "Inputs count of MOE should be 11"); + OPENVINO_ASSERT(inputs.size() == 12, "Inputs count of MOE should be 12"); const std::string layerName = layer_type_name_ID(op); auto& engine = p.get_engine(); From 57749a4377a63c0bb1f64832a97a8b73651f94b8 Mon Sep 17 00:00:00 2001 From: "River, Li" Date: Mon, 20 Oct 2025 14:16:21 +0800 Subject: [PATCH 04/13] Update weight management --- .../include/ov_ops/moe_compressed.hpp | 6 +- .../intel_gpu/primitives/moe_compressed.hpp | 20 +- .../src/graph/impls/ocl_v2/moe_opt.cpp | 217 +++++++++--------- .../src/graph/impls/ocl_v2/moe_opt.hpp | 27 ++- .../intel_gpu/src/graph/include/moe_inst.h | 18 +- src/plugins/intel_gpu/src/graph/moe.cpp | 27 ++- .../src/graph/registry/moe_impls.cpp | 13 +- src/plugins/intel_gpu/src/plugin/ops/moe.cpp | 9 +- 8 files changed, 178 insertions(+), 159 deletions(-) diff --git a/src/common/transformations/include/ov_ops/moe_compressed.hpp b/src/common/transformations/include/ov_ops/moe_compressed.hpp index 69fba3e3fe465c..0d16619ec4e813 100644 --- a/src/common/transformations/include/ov_ops/moe_compressed.hpp +++ b/src/common/transformations/include/ov_ops/moe_compressed.hpp @@ -22,7 +22,11 @@ class TRANSFORMATIONS_API MOECompressed : public ov::op::Op { struct Config { ov::element::Type out_type = ov::element::dynamic; // fp16 - size_t group_size = 0; + size_t group_size{}; + size_t hidden_size{}; + size_t inter_size{}; + size_t num_experts{}; + size_t topk{}; }; /// \brief Constructs a MOECompressed operation with config only diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_compressed.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_compressed.hpp index 92b96fb9fc7da2..730dd9f2faaaa5 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_compressed.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_compressed.hpp @@ -1,12 +1,13 @@ - // Copyright (C) 2025 Intel Corporation +// Copyright (C) 2025 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #pragma once +#include + #include "intel_gpu/runtime/engine.hpp" -#include "primitive.hpp" #include "ov_ops/moe_compressed.hpp" -#include +#include "primitive.hpp" namespace cldnn { using MOECompressed = ov::op::internal::MOECompressed; @@ -22,12 +23,9 @@ struct moe_compressed : public primitive_base { /// /// @param id An identifier of new primitive. /// @param inputs A list of Input primitive ids (inputs). - moe_compressed(const primitive_id& id, - const std::vector& inputs, - const MOE::Config& config) + moe_compressed(const primitive_id& id, const std::vector& inputs, const MOECompressed::Config& config) : primitive_base(id, inputs, 15, {optional_data_type()}), - _config(config) { - } + _config(config) {} MOECompressed::Config _config; @@ -35,18 +33,18 @@ struct moe_compressed : public primitive_base { if (!compare_common_params(rhs)) return false; - auto rhs_casted = downcast(rhs); + auto rhs_casted = downcast(rhs); return std::memcmp(&_config, &rhs_casted._config, sizeof(_config)) == 0; } void save(BinaryOutputBuffer& ob) const override { - primitive_base::save(ob); + primitive_base::save(ob); ob << make_data(&_config, sizeof(_config)); } void load(BinaryInputBuffer& ib) override { - primitive_base::load(ib); + primitive_base::load(ib); ib >> make_data(&_config, sizeof(_config)); } }; diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp index 63573696118318..502eb2b17f109f 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp @@ -4,6 +4,7 @@ #include "moe_opt.hpp" +// #define ENABLE_ONEDNN_FOR_GPU #ifdef ENABLE_ONEDNN_FOR_GPU # include # include @@ -17,7 +18,7 @@ # include "common_utils/jitter.hpp" # include "debug_helper.hpp" # include "intel_gpu/graph/kernel_impl_params.hpp" -# include "intel_gpu/primitives/moe.hpp" +# include "intel_gpu/primitives/moe_compressed.hpp" # include "intel_gpu/runtime/lru_cache.hpp" # include "intel_gpu/runtime/stream.hpp" # include "intel_gpu/runtime/utils.hpp" @@ -103,7 +104,8 @@ struct onednn_matmul { OPENVINO_ASSERT(m_K_groups == 1); attr.set_zero_points(DNNL_ARG_WEIGHTS, (0 << 0) + (1 << 1), {1}, m_w_type); } else { - OPENVINO_ASSERT(m_K_groups = (m_K / k_group_size)); + OPENVINO_ASSERT(m_K / k_group_size); + m_K_groups = (m_K / k_group_size); attr.set_zero_points(DNNL_ARG_WEIGHTS, (1 << 0) + (1 << 1), {k_group_size, 1}, m_w_type); } return *this; @@ -352,10 +354,10 @@ class MOEOptSoftMaxTopK : public KernelGenerator { protected: [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { auto jit = KernelGenerator::get_jit_constants(params); - auto desc = params.typed_desc(); + auto desc = params.typed_desc(); jit.make("SOFTMAX_TOPK_ENABLE", 1); jit.make("TOP_K", desc->_config.topk); - jit.make("VALUE_NUM", desc->_config.expert_num); + jit.make("VALUE_NUM", desc->_config.num_experts); jit.make("TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); jit.make("TYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); return jit; @@ -379,7 +381,7 @@ class MOEOptGather : public KernelGenerator { protected: [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { auto jit = KernelGenerator::get_jit_constants(params); - auto desc = params.typed_desc(); + auto desc = params.typed_desc(); jit.make("GATHER_ENABLE", 1); jit.make("HIDDEN_SIZE", desc->_config.hidden_size); jit.make("TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); @@ -405,7 +407,7 @@ class MOEOptScatter : public KernelGenerator { protected: [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { auto jit = KernelGenerator::get_jit_constants(params); - auto desc = params.typed_desc(); + auto desc = params.typed_desc(); jit.make("SCATTER_ENABLE", 1); jit.make("HIDDEN_SIZE", desc->_config.hidden_size); jit.make("TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); @@ -427,13 +429,13 @@ class MOEOptScatter : public KernelGenerator { # define SUBGROUP_NUM 8 static void add_common_consts(const RuntimeParams& params, JitConstants& jit) { - auto desc = params.typed_desc(); + auto desc = params.typed_desc(); auto& engine = params.prog->get_engine(); const auto& info = engine.get_device_info(); jit.make("MAX_TOPK", desc->_config.topk); - jit.make("EXPERT_NUM", desc->_config.expert_num); + jit.make("EXPERT_NUM", desc->_config.num_experts); jit.make("HIDDEN_SIZE", desc->_config.hidden_size); - jit.make("INTERMEDIATE_SIZE", desc->_config.intermediate_size); + jit.make("INTERMEDIATE_SIZE", desc->_config.inter_size); jit.make("N_BLOCK", N_BLOCK); jit.make("SUBGROUP_SIZE", info.arch >= gpu_arch::xe2 ? 32 : 16); jit.make("SUBGROUP_NUM", SUBGROUP_NUM); @@ -449,7 +451,7 @@ class MOEOptMLPGateUp : public KernelGenerator { protected: [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { auto jit = KernelGenerator::get_jit_constants(params); - auto desc = params.typed_desc(); + auto desc = params.typed_desc(); add_common_consts(params, jit); jit.make("GATE_UP_ENABLE", 1); return jit; @@ -472,7 +474,7 @@ class MOEOptMLPDown : public KernelGenerator { protected: [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { auto jit = KernelGenerator::get_jit_constants(params); - auto desc = params.typed_desc(); + auto desc = params.typed_desc(); add_common_consts(params, jit); jit.make("DOWN_ENABLE", 1); return jit; @@ -495,7 +497,7 @@ class MOEOptMLPReduce : public KernelGenerator { protected: [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { auto jit = KernelGenerator::get_jit_constants(params); - auto desc = params.typed_desc(); + auto desc = params.typed_desc(); add_common_consts(params, jit); jit.make("REDUCE_ENABLE", 1); return jit; @@ -549,7 +551,7 @@ class MOEOptImpl : public PrimitiveImplOCL { }; struct moe_fusion_weights_base_addr { - memory::ptr weight[3]; // gate/up/down weights, experts fusion + memory::ptr weight[3]; // gate/up/down weights, experts fusion memory::ptr scale[3]; memory::ptr zp[3]; memory::ptr bias[3]; @@ -584,7 +586,7 @@ class MOEOptImpl : public PrimitiveImplOCL { MOEOptImpl() : PrimitiveImplOCL(MOEOpt::get_type_info_static()) {} MOEOptImpl(const program_node& node, const RuntimeParams& params) : MOEOptImpl() { - init(node.as().get_primitive()); + init(node.as().get_primitive()); add_stage(softmax_topk, params); add_stage(gather, params); @@ -594,21 +596,21 @@ class MOEOptImpl : public PrimitiveImplOCL { add_stage(mlp_reduce, params); } - void init(const std::shared_ptr& cur_moe) { + void init(const std::shared_ptr& cur_moe) { _hidden_size = static_cast(cur_moe->_config.hidden_size); - _intermediate_size = static_cast(cur_moe->_config.intermediate_size); + _intermediate_size = static_cast(cur_moe->_config.inter_size); _group_size = static_cast(cur_moe->_config.group_size); } - void init_dnnl_weights(const std::shared_ptr& cur_moe, - const struct moe_fusion_weights_base_addr& moe_fusion_wei_addr) { - if(_dnnl_weights.size() == cur_moe->_config.expert_num) + void init_dnnl_weights(const std::shared_ptr& cur_moe, + cldnn::engine& engine, + const struct moe_fusion_weights_base_addr& moe_fusion_wei_addr) { + if (_dnnl_weights.size() == cur_moe->_config.num_experts) return; init(cur_moe); - _dnnl_weights.resize(cur_moe->_config.expert_num); - for (size_t j = 0; j < cur_moe->_config.expert_num; j++) { - // const auto& mlp_params = moe_mlp_params[j]; + _dnnl_weights.resize(cur_moe->_config.num_experts); + for (size_t j = 0; j < cur_moe->_config.num_experts; j++) { auto& dnnl_weights = _dnnl_weights[j]; dnnl_weights.resize(3); dnnl_weights[0].ic = _hidden_size; @@ -621,24 +623,29 @@ class MOEOptImpl : public PrimitiveImplOCL { dnnl_weights[2].ic_group_size = _group_size; dnnl_weights[2].oc = _hidden_size; for (int i = 0; i < 3; i++) { - if (mlp_params.param[i].scale) { - // scale shape: [ic / ic_group_size, oc], type: f16 - dnnl_weights[i].scale = convert2dnnl(moe_fusion_wei_addr.scale[i] + j * dnnl_weights[i].ic * dnnl_weights[i].oc / dnnl_weights[i].ic_group_size * 2, - {dnnl_weights[i].ic / dnnl_weights[i].ic_group_size, dnnl_weights[i].oc}, - dnnl::memory::format_tag::ab); - } - if (mlp_params.param[i].zp) { - // zp shape: [ic / ic_group_size, oc], type: u4 - dnnl_weights[i].zp = convert2dnnl(moe_fusion_wei_addr.zp[i] + j * dnnl_weights[i].ic * dnnl_weights[i].oc / dnnl_weights[i].ic_group_size / 2, - {dnnl_weights[i].ic / dnnl_weights[i].ic_group_size, dnnl_weights[i].oc}, - dnnl::memory::format_tag::ab); - } - if (mlp_params.param[i].weight) { - // weight shape: [oc, ic], type: u4 - dnnl_weights[i].weight = convert2dnnl(moe_fusion_wei_addr.weight[i] + j * dnnl_weights[i].ic * dnnl_weights[i].oc / 2, - {dnnl_weights[i].ic, dnnl_weights[i].oc}, - dnnl::memory::format_tag::ba); - } + // weight shape: [ic, oc], type: u4 + ov::Shape wei_shape = {static_cast(dnnl_weights[i].ic), static_cast(dnnl_weights[i].oc)}; + auto wei_layout = cldnn::layout(wei_shape, cldnn::data_types::u4, cldnn::format::get_default_format(wei_shape.size())); + auto wei_mem = engine.create_subbuffer(*moe_fusion_wei_addr.weight[i], wei_layout, j * dnnl_weights[i].ic * dnnl_weights[i].oc / 2); + dnnl_weights[i].weight = convert2dnnl(wei_mem, {dnnl_weights[i].ic, dnnl_weights[i].oc}, dnnl::memory::format_tag::ba); + + // scale shape: [ic / ic_group_size, oc], type: f16 + ov::Shape scale_shape = {static_cast(dnnl_weights[i].ic / dnnl_weights[i].ic_group_size), static_cast(dnnl_weights[i].oc)}; + auto scale_layout = cldnn::layout(scale_shape, cldnn::data_types::f16, cldnn::format::get_default_format(scale_shape.size())); + auto scale_mem = engine.create_subbuffer(*moe_fusion_wei_addr.scale[i], + scale_layout, + j * dnnl_weights[i].ic * dnnl_weights[i].oc / dnnl_weights[i].ic_group_size * 2); + dnnl_weights[i].scale = + convert2dnnl(scale_mem, {dnnl_weights[i].ic / dnnl_weights[i].ic_group_size, dnnl_weights[i].oc}, dnnl::memory::format_tag::ab); + + // zp shape: [ic / ic_group_size, oc], type: u4 + ov::Shape zp_shape = {static_cast(dnnl_weights[i].ic / dnnl_weights[i].ic_group_size), static_cast(dnnl_weights[i].oc)}; + auto zp_layout = cldnn::layout(zp_shape, cldnn::data_types::u4, cldnn::format::get_default_format(zp_shape.size())); + auto zp_mem = engine.create_subbuffer(*moe_fusion_wei_addr.zp[i], + zp_layout, + j * dnnl_weights[i].ic * dnnl_weights[i].oc / dnnl_weights[i].ic_group_size / 2); + dnnl_weights[i].zp = + convert2dnnl(zp_mem, {dnnl_weights[i].ic / dnnl_weights[i].ic_group_size, dnnl_weights[i].oc}, dnnl::memory::format_tag::ab); } } } @@ -646,7 +653,7 @@ class MOEOptImpl : public PrimitiveImplOCL { void load(BinaryInputBuffer& ib) override { PrimitiveImplOCL::load(ib); const kernel_impl_params* impl_params = reinterpret_cast(ib.getKernelImplParams()); - init(impl_params->typed_desc()); + init(impl_params->typed_desc()); } [[nodiscard]] std::unique_ptr clone() const override { @@ -659,10 +666,10 @@ class MOEOptImpl : public PrimitiveImplOCL { } std::vector get_internal_buffer_descs(const kernel_impl_params& params) const override { - auto cur_moe = params.typed_desc(); + auto cur_moe = params.typed_desc(); const auto& config = cur_moe->_config; int max_topk = static_cast(config.topk); - int expert_num = static_cast(config.expert_num); + int expert_num = static_cast(config.num_experts); auto hidden_states_layout = params.input_layouts[0]; auto batch = static_cast(hidden_states_layout.get_shape()[0]); @@ -675,7 +682,7 @@ class MOEOptImpl : public PrimitiveImplOCL { internal_buffers.emplace_back(layout_topk_id, true); // 0: topk_id internal_buffers.emplace_back(layout_topk_weights, true); // 1: topk_weights // fast single batch: scratch.up = up(x) * silu(gate(x)); scratch.y = down(scratch.up) * weight[expert_no] - layout layout_gateup_out(ov::PartialShape{batch, static_cast(config.intermediate_size)}, data_type, cldnn::format::bfyx); + layout layout_gateup_out(ov::PartialShape{batch, static_cast(config.inter_size)}, data_type, cldnn::format::bfyx); layout layout_down_out(ov::PartialShape{batch, static_cast(config.hidden_size)}, data_type, cldnn::format::bfyx); internal_buffers.emplace_back(layout_gateup_out, true); // 2: up internal_buffers.emplace_back(layout_down_out, true); // 3: y @@ -697,7 +704,7 @@ class MOEOptImpl : public PrimitiveImplOCL { return internal_buffers; } - void prepare_internal_buffers(typed_primitive_inst& instance, scratch_buffers& scratch, bool is_single_batch) { + void prepare_internal_buffers(typed_primitive_inst& instance, scratch_buffers& scratch, bool is_single_batch) { const auto& intermediates_memories = instance.get_intermediates_memories(); scratch.topk_id = intermediates_memories[0]; scratch.topk_weights = intermediates_memories[1]; @@ -707,8 +714,8 @@ class MOEOptImpl : public PrimitiveImplOCL { scratch.x = intermediates_memories[4]; scratch.routing_weights = intermediates_memories[5]; scratch.gate = intermediates_memories[6]; - const auto& config = instance.get_typed_desc()->_config; - int expert_num = static_cast(config.expert_num); + const auto& config = instance.get_typed_desc()->_config; + int expert_num = static_cast(config.num_experts); scratch.expert_masks.resize(expert_num); for (int i = 0; i < expert_num; i++) { scratch.expert_masks[i].batch = intermediates_memories[7 + 2 * i + 0]; @@ -717,27 +724,27 @@ class MOEOptImpl : public PrimitiveImplOCL { } // gate - moe_fusion_wei_addr.weight[0] = instance.input_memory_ptr(MOEInputIndex::WEIGHT_0); - moe_fusion_wei_addr.scale[0] = instance.input_memory_ptr(MOEInputIndex::SCALE_0); - moe_fusion_wei_addr.zp[0] = instance.input_memory_ptr(MOEInputIndex::ZP_0); + scratch.moe_fusion_wei_addr.weight[0] = instance.input_memory_ptr(static_cast(MOEInputIndex::WEIGHT_0)); + scratch.moe_fusion_wei_addr.scale[0] = instance.input_memory_ptr(static_cast(MOEInputIndex::SCALE_0)); + scratch.moe_fusion_wei_addr.zp[0] = instance.input_memory_ptr(static_cast(MOEInputIndex::ZP_0)); // up - moe_fusion_wei_addr.weight[1] = instance.input_memory_ptr(MOEInputIndex::WEIGHT_1); - moe_fusion_wei_addr.scale[1] = instance.input_memory_ptr(MOEInputIndex::SCALE_1); - moe_fusion_wei_addr.zp[1] = instance.input_memory_ptr(MOEInputIndex::ZP_1); + scratch.moe_fusion_wei_addr.weight[1] = instance.input_memory_ptr(static_cast(MOEInputIndex::WEIGHT_1)); + scratch.moe_fusion_wei_addr.scale[1] = instance.input_memory_ptr(static_cast(MOEInputIndex::SCALE_1)); + scratch.moe_fusion_wei_addr.zp[1] = instance.input_memory_ptr(static_cast(MOEInputIndex::ZP_1)); // down - moe_fusion_wei_addr.weight[2] = instance.input_memory_ptr(MOEInputIndex::WEIGHT_2); - moe_fusion_wei_addr.scale[2] = instance.input_memory_ptr(MOEInputIndex::SCALE_2); - moe_fusion_wei_addr.zp[2] = instance.input_memory_ptr(MOEInputIndex::ZP_2); + scratch.moe_fusion_wei_addr.weight[2] = instance.input_memory_ptr(static_cast(MOEInputIndex::WEIGHT_2)); + scratch.moe_fusion_wei_addr.scale[2] = instance.input_memory_ptr(static_cast(MOEInputIndex::SCALE_2)); + scratch.moe_fusion_wei_addr.zp[2] = instance.input_memory_ptr(static_cast(MOEInputIndex::ZP_2)); } - void get_expert_mask_from_gpu(const MOE::Config& config, memory::ptr mem, stream& stream, expert_mask_cpu& expert_mask) { + void get_expert_mask_from_gpu(const MOECompressed::Config& config, memory::ptr mem, stream& stream, expert_mask_cpu& expert_mask) { // shape: [batch, topk] auto layout = mem->get_layout(); const auto& shape = layout.get_shape(); - int max_expert_num = static_cast(config.expert_num), max_topk = static_cast(config.topk), max_tokens = static_cast(shape[0]); + int max_expert_num = static_cast(config.num_experts), max_topk = static_cast(config.topk), max_tokens = static_cast(shape[0]); expert_mask.pred_flag.resize(max_expert_num, 0); expert_mask.batch.resize(max_expert_num, {}); @@ -822,53 +829,54 @@ class MOEOptImpl : public PrimitiveImplOCL { return stream.enqueue_kernel(*stage.kernel, desc, {}, events, needs_completion_event); } - auto get_input_info(typed_primitive_inst& instance, int idx) { + auto get_input_info(typed_primitive_inst& instance, int idx) { auto mem = instance.input_memory_ptr(idx); auto dep = instance.dependencies()[idx]; auto layout = dep.first->get_impl_params()->get_output_layout(dep.second); return std::make_tuple(mem, layout); } - cldnn::event::ptr exec_single_batch(typed_primitive_inst& instance, scratch_buffers& scratch) { - auto cur_moe = instance.get_typed_desc(); + cldnn::event::ptr exec_single_batch(typed_primitive_inst& instance, scratch_buffers& scratch) { + auto cur_moe = instance.get_typed_desc(); int max_topk = static_cast(cur_moe->_config.topk); auto final_hidden_states_mem_ptr = instance.output_memory_ptr(0); auto batch_mem_ptr = scratch.topk_id; - auto [hidden_states_mem_ptr, hidden_states_layout] = get_input_info(instance, MOEInputIndex::HIDDEN_STATES); + auto [hidden_states_mem_ptr, hidden_states_layout] = get_input_info(instance, static_cast(MOEInputIndex::HIDDEN_STATES)); auto routing_mem_ptr = scratch.topk_weights; _hidden_size = static_cast(cur_moe->_config.hidden_size); - _intermediate_size = static_cast(cur_moe->_config.intermediate_size); + _intermediate_size = static_cast(cur_moe->_config.inter_size); const size_t subgroup_size = instance.get_impl_params()->get_device_info().arch >= gpu_arch::xe2 ? 32 : 16; const size_t max_work_group_size = instance.get_impl_params()->get_device_info().max_work_group_size; // gate - const auto mlp_gate_wei_mem = scratch.moe_fusion_wei_addr.weight[0]; - const auto mlp_gate_scale_mem = scratch.moe_fusion_wei_addr.scale[0]; - const auto mlp_gate_zp_mem = scratch.moe_fusion_wei_addr.zp[0]; + const auto& mlp_gate_wei_mem = scratch.moe_fusion_wei_addr.weight[0]; + const auto& mlp_gate_scale_mem = scratch.moe_fusion_wei_addr.scale[0]; + const auto& mlp_gate_zp_mem = scratch.moe_fusion_wei_addr.zp[0]; // up - const auto mlp_up_wei_mem = scratch.moe_fusion_wei_addr.weight[1]; - const auto mlp_up_scale_mem = scratch.moe_fusion_wei_addr.scale[1]; - const auto mlp_up_zp_mem = scratch.moe_fusion_wei_addr.zp[1]; + const auto& mlp_up_wei_mem = scratch.moe_fusion_wei_addr.weight[1]; + const auto& mlp_up_scale_mem = scratch.moe_fusion_wei_addr.scale[1]; + const auto& mlp_up_zp_mem = scratch.moe_fusion_wei_addr.zp[1]; // down - const auto mlp_down_wei_mem = scratch.moe_fusion_wei_addr.weight[2]; - const auto mlp_down_scale_mem = scratch.moe_fusion_wei_addr.scale[2]; - const auto mlp_down_zp_mem = scratch.moe_fusion_wei_addr.zp[2]; + const auto& mlp_down_wei_mem = scratch.moe_fusion_wei_addr.weight[2]; + const auto& mlp_down_scale_mem = scratch.moe_fusion_wei_addr.scale[2]; + const auto& mlp_down_zp_mem = scratch.moe_fusion_wei_addr.zp[2]; event::ptr ret; { // scratch.up = up(x) * silu(gate(x)) - execute_stage({}, - instance, - *mlp_gate_up, - {batch_mem_ptr, mlp_gate_wei_mem, mlp_gate_scale_mem, mlp_gate_zp_mem, mlp_up_wei_mem, mlp_up_scale_mem, mlp_up_zp_mem, hidden_states_mem_ptr}, - {scratch.up}, - {static_cast(max_topk), subgroup_size, static_cast(_intermediate_size / N_BLOCK)}, - {1, subgroup_size, SUBGROUP_NUM}); + execute_stage( + {}, + instance, + *mlp_gate_up, + {batch_mem_ptr, mlp_gate_wei_mem, mlp_gate_scale_mem, mlp_gate_zp_mem, mlp_up_wei_mem, mlp_up_scale_mem, mlp_up_zp_mem, hidden_states_mem_ptr}, + {scratch.up}, + {static_cast(max_topk), subgroup_size, static_cast(_intermediate_size / N_BLOCK)}, + {1, subgroup_size, SUBGROUP_NUM}); // scratch.y = down(scratch.up) * weight[expert_no] execute_stage({}, @@ -907,7 +915,7 @@ class MOEOptImpl : public PrimitiveImplOCL { using lru_cache_hash = LruCache, std::shared_ptr, PairHash>; lru_cache_hash _kernels = lru_cache_hash(1024); - onednn_kernel& get_kernel(int n_token, int expert_no, typed_primitive_inst& instance) { + onednn_kernel& get_kernel(int n_token, int expert_no, typed_primitive_inst& instance) { auto key = std::make_pair(n_token, expert_no); if (_kernels.has(key)) { return *_kernels.get(key); @@ -919,27 +927,27 @@ class MOEOptImpl : public PrimitiveImplOCL { // const auto& moe_mlp_params = cur_moe->_mlp_params; // const auto& mlp_params = moe_mlp_params[expert_no]; auto& dnn_stream = stream.get_onednn_stream(); - auto hidden_states_layout_dt = convert_data_type(instance.input_memory_ptr(MOEInputIndex::HIDDEN_STATES)->get_layout().data_type); - + auto hidden_states_layout_dt = convert_data_type(instance.input_memory_ptr(static_cast(MOEInputIndex::HIDDEN_STATES))->get_layout().data_type); + auto& dnnl_weights = _dnnl_weights[expert_no]; auto kernel = std::make_shared(); // gate - auto gate_weight_layout_dt = convert_data_type(instance.input_memory_ptr(MOEInputIndex::WEIGHT_0)->get_layout().data_type); + auto gate_weight_layout_dt = convert_data_type(instance.input_memory_ptr(static_cast(MOEInputIndex::WEIGHT_0))->get_layout().data_type); kernel->gate = onednn_linear::create(dnn_stream.get_engine(), - hidden_states_layout_dt, - gate_weight_layout_dt, - n_token, - dnnl_weights[0].ic, - dnnl_weights[0].oc, - dnnl_weights[0].ic_group_size, - onednn_matmul::type::with_silu_bin_mul, - dnnl_weights[0].weight, - dnnl_weights[0].scale, - dnnl_weights[0].zp); + hidden_states_layout_dt, + gate_weight_layout_dt, + n_token, + dnnl_weights[0].ic, + dnnl_weights[0].oc, + dnnl_weights[0].ic_group_size, + onednn_matmul::type::with_silu_bin_mul, + dnnl_weights[0].weight, + dnnl_weights[0].scale, + dnnl_weights[0].zp); // up - auto up_weight_layout_dt = convert_data_type(instance.input_memory_ptr(MOEInputIndex::WEIGHT_1)->get_layout().data_type); + auto up_weight_layout_dt = convert_data_type(instance.input_memory_ptr(static_cast(MOEInputIndex::WEIGHT_1))->get_layout().data_type); kernel->up = onednn_linear::create(dnn_stream.get_engine(), hidden_states_layout_dt, up_weight_layout_dt, @@ -953,7 +961,7 @@ class MOEOptImpl : public PrimitiveImplOCL { dnnl_weights[1].zp); // down - auto down_weight_layout_dt = convert_data_type(instance.input_memory_ptr(MOEInputIndex::WEIGHT_2)->get_layout().data_type); + auto down_weight_layout_dt = convert_data_type(instance.input_memory_ptr(static_cast(MOEInputIndex::WEIGHT_2))->get_layout().data_type); kernel->down = onednn_linear::create(dnn_stream.get_engine(), hidden_states_layout_dt, down_weight_layout_dt, @@ -986,25 +994,25 @@ class MOEOptImpl : public PrimitiveImplOCL { // cldnn::event::ptr execute(const std::vector& events, cldnn::primitive_inst& ins) override { OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, openvino::itt::handle("MOEOptImpl::execute")); - auto& instance = reinterpret_cast&>(ins); - auto cur_moe = instance.get_typed_desc(); + auto& instance = reinterpret_cast&>(ins); + auto cur_moe = instance.get_typed_desc(); const auto& config = cur_moe->_config; int max_topk = static_cast(config.topk); auto& cur_net = instance.get_network(); auto& stream = cur_net.get_stream(); - auto [hidden_states_mem_ptr, hidden_states_layout] = get_input_info(instance, MOEInputIndex::HIDDEN_STATES); + auto [hidden_states_mem_ptr, hidden_states_layout] = get_input_info(instance, static_cast(MOEInputIndex::HIDDEN_STATES)); auto batch = static_cast(hidden_states_layout.get_shape()[0]); scratch_buffers scratch; prepare_internal_buffers(instance, scratch, batch == 1); // softmax+topk - auto lws_size = cur_moe->_config.expert_num; + auto lws_size = cur_moe->_config.num_experts; auto topk_event = execute_stage(events, instance, *softmax_topk, - {instance.input_memory_ptr(MOEInputIndex::ROUTING_WEIGHTS)}, + {instance.input_memory_ptr(static_cast(MOEInputIndex::ROUTING_WEIGHTS))}, {scratch.topk_id, scratch.topk_weights}, {static_cast(batch), lws_size}, {1, lws_size}); @@ -1016,7 +1024,8 @@ class MOEOptImpl : public PrimitiveImplOCL { return exec_single_batch(instance, scratch); } - init_dnnl_weights(cur_moe, scratch.moe_fusion_wei_addr); + auto& engine = instance.get_network().get_engine(); + init_dnnl_weights(cur_moe, engine, scratch.moe_fusion_wei_addr); auto final_hidden_states_mem_ptr = instance.output_memory_ptr(0); auto final_hidden_states_layout = instance.get_output_layout(0); @@ -1046,7 +1055,7 @@ class MOEOptImpl : public PrimitiveImplOCL { lws_size = get_best_lws(_hidden_size); OPENVINO_ASSERT(batch != 1, "batch size shouldn't be 1 for this path!"); - for (size_t expert_no = 0; expert_no < config.expert_num; expert_no++) { + for (size_t expert_no = 0; expert_no < config.num_experts; expert_no++) { OPENVINO_ASSERT(expert_no < expert_mask.pred_flag.size()); auto can_skip_subgraph = !expert_mask.pred_flag[expert_no]; if (can_skip_subgraph) { @@ -1109,13 +1118,13 @@ class MOEOptImpl : public PrimitiveImplOCL { } // namespace std::unique_ptr MOEOpt::create_impl(const program_node& node, const RuntimeParams& params) const { - assert(node.is_type()); + assert(node.is_type()); return std::make_unique(node, params); } } // namespace ov::intel_gpu::ocl -BIND_BINARY_BUFFER_WITH_TYPE(cldnn::moe) +BIND_BINARY_BUFFER_WITH_TYPE(cldnn::moe_compressed) BIND_BINARY_BUFFER_WITH_TYPE(ov::intel_gpu::ocl::MOEOptImpl) #else diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.hpp index da9b8fbe582e6d..4f5c971854fedf 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.hpp @@ -15,13 +15,24 @@ using namespace cldnn; // TODO: Remove once namespaces are aligned namespace ov::intel_gpu::ocl { -//TODO: need confirm, gate is 1st matmul or up is 1st matmul? -// mlp_gate: 0 -// mlp_up: 1 -// mlp_down: 2 -enum class MOEInputIndex : uint8_t { HIDDEN_STATES = 0, ROUTING_WEIGHTS = 1, ROUTER_TOPK_OUTPUT_INDICES = 2, - WEIGHT_0 = 3, SCALE_0 = 4, ZP_0 = 5, WEIGHT_1 = 6, SCALE_1 = 7, ZP_1 = 8, WEIGHT_2 = 9, - SCALE_2 = 10, ZP_2 = 11}; +// TODO: need confirm, gate is 1st matmul or up is 1st matmul? +// mlp_gate: 0 +// mlp_up: 1 +// mlp_down: 2 +enum class MOEInputIndex : uint8_t { + HIDDEN_STATES = 0, + ROUTING_WEIGHTS = 1, + ROUTER_TOPK_OUTPUT_INDICES = 2, + WEIGHT_0 = 3, + SCALE_0 = 4, + ZP_0 = 5, + WEIGHT_1 = 6, + SCALE_1 = 7, + ZP_1 = 8, + WEIGHT_2 = 9, + SCALE_2 = 10, + ZP_2 = 11 +}; struct MOEOpt : public ImplementationManager { OV_GPU_PRIMITIVE_IMPL("ocl::moe::opt") @@ -37,7 +48,7 @@ struct MOEOpt : public ImplementationManager { ov::element::f16, }; - const auto& in0_layout = node.get_input_layout(MOEInputIndex::HIDDEN_STATES); + const auto& in0_layout = node.get_input_layout(static_cast(MOEInputIndex::HIDDEN_STATES)); const auto& out_layout = node.get_output_layout(0); if (!one_of(in0_layout.format, supported_fmts) || !one_of(out_layout.format, supported_fmts)) { return false; diff --git a/src/plugins/intel_gpu/src/graph/include/moe_inst.h b/src/plugins/intel_gpu/src/graph/include/moe_inst.h index 4af2e33678222c..a57ab01cf09d70 100644 --- a/src/plugins/intel_gpu/src/graph/include/moe_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/moe_inst.h @@ -4,13 +4,13 @@ #pragma once -#include "intel_gpu/primitives/moe_compressed.hpp" -#include "primitive_inst.h" - -#include #include +#include #include +#include "intel_gpu/primitives/moe_compressed.hpp" +#include "primitive_inst.h" + namespace cldnn { namespace details {} @@ -42,11 +42,11 @@ class typed_primitive_inst : public typed_primitive_inst_base - static std::vector calc_output_layouts(moe_node const& /*node*/, kernel_impl_params const& impl_param); - static layout calc_output_layout(moe_node const& /* node */, kernel_impl_params const& impl_param); - static std::string to_string(moe_node const& node); - typed_primitive_inst(network& network, moe_node const& node); + static std::vector calc_output_layouts(const moe_node& /*node*/, const kernel_impl_params& impl_param); + static layout calc_output_layout(const moe_node& /* node */, const kernel_impl_params& impl_param); + static std::string to_string(const moe_node& node); + typed_primitive_inst(network& network, const moe_node& node); }; -using moe_inst = typed_primitive_inst; +using moe_inst = typed_primitive_inst; } // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/moe.cpp b/src/plugins/intel_gpu/src/graph/moe.cpp index 3d36b1dbf30864..b19434320ba74a 100644 --- a/src/plugins/intel_gpu/src/graph/moe.cpp +++ b/src/plugins/intel_gpu/src/graph/moe.cpp @@ -2,17 +2,18 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "moe_inst.h" -#include "openvino/core/except.hpp" -#include "program_node.h" +#include + #include "intel_gpu/runtime/error_handler.hpp" #include "json_object.h" -#include "primitive_type_base.h" +#include "moe_inst.h" +#include "openvino/core/except.hpp" #include "openvino/core/parallel.hpp" -#include +#include "primitive_type_base.h" +#include "program_node.h" namespace cldnn { -GPU_DEFINE_PRIMITIVE_TYPE_ID(moe) +GPU_DEFINE_PRIMITIVE_TYPE_ID(moe_compressed) /* Calc_output_layout method is called only when output layout is invalidated. @@ -22,18 +23,18 @@ GPU_DEFINE_PRIMITIVE_TYPE_ID(moe) In this both cases, we need to recalc branch_true and branch_false. !* We can be sure, that this method was called AT LEAST once during graph compilation.*! */ -layout moe_inst::calc_output_layout(moe_node const& /* node */, kernel_impl_params const& impl_param) { +layout moe_inst::calc_output_layout(const moe_node& /* node */, const kernel_impl_params& impl_param) { return impl_param.input_layouts[0]; } -template -std::vector moe_inst::calc_output_layouts(moe_node const& /* node */, kernel_impl_params const& impl_param) { +template +std::vector moe_inst::calc_output_layouts(const moe_node& /* node */, const kernel_impl_params& impl_param) { return {impl_param.input_layouts[0]}; } -template std::vector moe_inst::calc_output_layouts(moe_node const& node, const kernel_impl_params& impl_param); +template std::vector moe_inst::calc_output_layouts(const moe_node& node, const kernel_impl_params& impl_param); -std::string moe_inst::to_string(moe_node const& node) { +std::string moe_inst::to_string(const moe_node& node) { auto desc = node.get_primitive(); auto node_info = node.desc_to_json(); json_composite moe_info; @@ -48,8 +49,6 @@ std::string moe_inst::to_string(moe_node const& node) { /* moe primitive is reusing memory with the input. */ -moe_inst::typed_primitive_inst(network& network, moe_node const& node) - : parent(network, node) { -} +moe_inst::typed_primitive_inst(network& network, const moe_node& node) : parent(network, node) {} } // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/registry/moe_impls.cpp b/src/plugins/intel_gpu/src/graph/registry/moe_impls.cpp index 9f150a77f83545..10079dc53a0fbd 100644 --- a/src/plugins/intel_gpu/src/graph/registry/moe_impls.cpp +++ b/src/plugins/intel_gpu/src/graph/registry/moe_impls.cpp @@ -3,24 +3,21 @@ // #include "intel_gpu/primitives/moe_compressed.hpp" -#include "registry.hpp" #include "primitive_inst.h" +#include "registry.hpp" #if OV_GPU_WITH_OCL - #include "impls/ocl_v2/moe_opt.hpp" +# include "impls/ocl_v2/moe_opt.hpp" #endif - namespace ov::intel_gpu { using namespace cldnn; -const std::vector>& Registry::get_implementations() { - static const std::vector> impls = { - OV_GPU_CREATE_INSTANCE_OCL(ocl::MOEOpt, shape_types::any) - }; +const std::vector>& Registry::get_implementations() { + static const std::vector> impls = {OV_GPU_CREATE_INSTANCE_OCL(ocl::MOEOpt, shape_types::any)}; return impls; } -} // namespace ov::intel_gpu +} // namespace ov::intel_gpu diff --git a/src/plugins/intel_gpu/src/plugin/ops/moe.cpp b/src/plugins/intel_gpu/src/plugin/ops/moe.cpp index 184b6e2837c512..49a26521866ca0 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/moe.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/moe.cpp @@ -2,20 +2,21 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "intel_gpu/plugin/program_builder.hpp" -#include "intel_gpu/plugin/common_utils.hpp" #include +#include "intel_gpu/plugin/common_utils.hpp" +#include "intel_gpu/plugin/program_builder.hpp" +#include "intel_gpu/primitives/moe_compressed.hpp" namespace ov::intel_gpu { -static void CreateMOEOp(ProgramBuilder& p, const std::shared_ptr& op) { +static void CreateMOECompressedOp(ProgramBuilder& p, const std::shared_ptr& op) { auto inputs = p.GetInputInfo(op); const auto& config = op->get_config(); OPENVINO_ASSERT(inputs.size() == 12, "Inputs count of MOE should be 12"); const std::string layerName = layer_type_name_ID(op); - auto& engine = p.get_engine(); + // auto& engine = p.get_engine(); const cldnn::moe_compressed moe(layerName, inputs, config); From 402bcb49f7652fd0f12502ef255c8d42fd5f9463 Mon Sep 17 00:00:00 2001 From: chenhu-wang Date: Mon, 20 Oct 2025 08:46:25 +0800 Subject: [PATCH 05/13] MOECompressed internal op move to gpu internal --- .../include/intel_gpu/op}/moe_compressed.hpp | 19 ++++++------------- .../intel_gpu/plugin/primitives_list.hpp | 2 +- .../intel_gpu/primitives/moe_compressed.hpp | 4 ++-- src/plugins/intel_gpu/src/plugin/ops/moe.cpp | 14 +++++++++++--- .../transformations/op}/moe_compressed.cpp | 12 +++--------- 5 files changed, 23 insertions(+), 28 deletions(-) rename src/{common/transformations/include/ov_ops => plugins/intel_gpu/include/intel_gpu/op}/moe_compressed.hpp (84%) rename src/{common/transformations/src/ov_ops => plugins/intel_gpu/src/plugin/transformations/op}/moe_compressed.cpp (86%) diff --git a/src/common/transformations/include/ov_ops/moe_compressed.hpp b/src/plugins/intel_gpu/include/intel_gpu/op/moe_compressed.hpp similarity index 84% rename from src/common/transformations/include/ov_ops/moe_compressed.hpp rename to src/plugins/intel_gpu/include/intel_gpu/op/moe_compressed.hpp index 0d16619ec4e813..f24921ab685c9e 100644 --- a/src/common/transformations/include/ov_ops/moe_compressed.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/op/moe_compressed.hpp @@ -4,19 +4,14 @@ #pragma once -#include "openvino/core/node.hpp" #include "openvino/op/op.hpp" -// #include "openvino/op/moe.hpp" -#include "transformations_visibility.hpp" -namespace ov { -namespace op { -namespace internal { +namespace ov::intel_gpu::op { /// \brief MOECompressed experts that support compressed weights for GEMM3_SWIGLU MOE. -class TRANSFORMATIONS_API MOECompressed : public ov::op::Op { +class MOECompressed : public ov::op::Op { public: - OPENVINO_OP("MOECompressed", "ie_internal_opset"); + OPENVINO_OP("MOECompressed", "gpu_opset"); MOECompressed() = default; @@ -50,9 +45,9 @@ class TRANSFORMATIONS_API MOECompressed : public ov::op::Op { /// 9: w2_weight - expert weights for final projection, /// shape [num_experts, hidden_size, inter_size] /// 10: w2_scale - expert scale for final projection for compressed experts, - /// shape [num_experts, inter_size, group_num, 1] + /// shape [num_experts, hidden_size, group_num, 1] /// 11: w2_zp - expert zp for final projection for compressed experts, - /// shape [num_experts, inter_size, group_num, 1] + /// shape [num_experts, hidden_size, group_num, 1] /// \param config Configuration for the MOE operation MOECompressed(const OutputVector& args, const Config& config); @@ -67,6 +62,4 @@ class TRANSFORMATIONS_API MOECompressed : public ov::op::Op { Config m_config; }; -} // namespace internal -} // namespace op -} // namespace ov +} // namespace ov::intel_gpu::op diff --git a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp index ff8ed815e94d45..80b5338d0cef0f 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp @@ -311,4 +311,4 @@ REGISTER_FACTORY(internal, PagedAttentionExtension); REGISTER_FACTORY(internal, LoraSubgraph); REGISTER_FACTORY(internal, LoraSubgraphFused); REGISTER_FACTORY(internal, VLSDPA); -REGISTER_FACTORY(internal, MOECompressed); +REGISTER_FACTORY(internal, MOECompressed); \ No newline at end of file diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_compressed.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_compressed.hpp index 730dd9f2faaaa5..ca0fd496117bd2 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_compressed.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_compressed.hpp @@ -6,11 +6,11 @@ #include #include "intel_gpu/runtime/engine.hpp" -#include "ov_ops/moe_compressed.hpp" +#include "intel_gpu/op/moe_compressed.hpp" #include "primitive.hpp" namespace cldnn { -using MOECompressed = ov::op::internal::MOECompressed; +using MOECompressed = ov::intel_gpu::op::MOECompressed; /// @brief moe compressed primitive /// @details Performs moe compressed diff --git a/src/plugins/intel_gpu/src/plugin/ops/moe.cpp b/src/plugins/intel_gpu/src/plugin/ops/moe.cpp index 49a26521866ca0..2253a82d0e28ef 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/moe.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/moe.cpp @@ -2,15 +2,23 @@ // SPDX-License-Identifier: Apache-2.0 // -#include - +#include "intel_gpu/op/moe_compressed.hpp" #include "intel_gpu/plugin/common_utils.hpp" #include "intel_gpu/plugin/program_builder.hpp" #include "intel_gpu/primitives/moe_compressed.hpp" + +namespace ov { +namespace op { +namespace internal { +using MOECompressed = ov::intel_gpu::op::MOECompressed; +} // namespace internal +} // namespace op +} // namespace ov + namespace ov::intel_gpu { -static void CreateMOECompressedOp(ProgramBuilder& p, const std::shared_ptr& op) { +static void CreateMOECompressedOp(ProgramBuilder& p, const std::shared_ptr& op) { auto inputs = p.GetInputInfo(op); const auto& config = op->get_config(); OPENVINO_ASSERT(inputs.size() == 12, "Inputs count of MOE should be 12"); diff --git a/src/common/transformations/src/ov_ops/moe_compressed.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/moe_compressed.cpp similarity index 86% rename from src/common/transformations/src/ov_ops/moe_compressed.cpp rename to src/plugins/intel_gpu/src/plugin/transformations/op/moe_compressed.cpp index acc3be917276b7..7898260b4677d4 100644 --- a/src/common/transformations/src/ov_ops/moe_compressed.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/moe_compressed.cpp @@ -2,13 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 // -#include "ov_ops/moe_compressed.hpp" +#include "intel_gpu/op/moe_compressed.hpp" -#include "itt.hpp" - -namespace ov { -namespace op { -namespace internal { +namespace ov::intel_gpu::op { MOECompressed::MOECompressed(const OutputVector& args, const Config& config) : Op(args), m_config(config) { constructor_validate_and_infer_types(); @@ -41,6 +37,4 @@ bool MOECompressed::visit_attributes(ov::AttributeVisitor& visitor) { return true; } -} // namespace internal -} // namespace op -} // namespace ov +} // namespace ov::intel_gpu::op From 8da94cd45956d69f10d5cc4fbb00d84bc64f9eaa Mon Sep 17 00:00:00 2001 From: "River, Li" Date: Tue, 21 Oct 2025 08:57:28 +0800 Subject: [PATCH 06/13] Remove softmax_top out of moe primitive implement --- .../intel_gpu/primitives/moe_compressed.hpp | 2 +- .../src/graph/impls/ocl_v2/moe_opt.cpp | 139 ++++++++++-------- src/plugins/intel_gpu/src/plugin/ops/moe.cpp | 2 +- 3 files changed, 78 insertions(+), 65 deletions(-) diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_compressed.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_compressed.hpp index ca0fd496117bd2..8a7a960a046c40 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_compressed.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/moe_compressed.hpp @@ -24,7 +24,7 @@ struct moe_compressed : public primitive_base { /// @param id An identifier of new primitive. /// @param inputs A list of Input primitive ids (inputs). moe_compressed(const primitive_id& id, const std::vector& inputs, const MOECompressed::Config& config) - : primitive_base(id, inputs, 15, {optional_data_type()}), + : primitive_base(id, inputs, 1, {optional_data_type()}), _config(config) {} MOECompressed::Config _config; diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp index 502eb2b17f109f..dda9364535db3c 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp @@ -347,32 +347,32 @@ struct onednn_linear { } }; -class MOEOptSoftMaxTopK : public KernelGenerator { -public: - MOEOptSoftMaxTopK() : KernelGenerator("moe_opt", "softmax_topk") {} - -protected: - [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { - auto jit = KernelGenerator::get_jit_constants(params); - auto desc = params.typed_desc(); - jit.make("SOFTMAX_TOPK_ENABLE", 1); - jit.make("TOP_K", desc->_config.topk); - jit.make("VALUE_NUM", desc->_config.num_experts); - jit.make("TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); - jit.make("TYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); - return jit; - } - - [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { - Arguments args; - - return args; - } - - [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { - return DispatchDataFunc{nullptr}; - } -}; +// class MOEOptSoftMaxTopK : public KernelGenerator { +// public: +// MOEOptSoftMaxTopK() : KernelGenerator("moe_opt", "softmax_topk") {} + +// protected: +// [[nodiscard]] JitConstants get_jit_constants(const RuntimeParams& params) const override { +// auto jit = KernelGenerator::get_jit_constants(params); +// auto desc = params.typed_desc(); +// jit.make("SOFTMAX_TOPK_ENABLE", 1); +// jit.make("TOP_K", desc->_config.topk); +// jit.make("VALUE_NUM", desc->_config.num_experts); +// jit.make("TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); +// jit.make("TYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); +// return jit; +// } + +// [[nodiscard]] Arguments get_arguments_desc(const RuntimeParams& params) const override { +// Arguments args; + +// return args; +// } + +// [[nodiscard]] DispatchDataFunc get_dispatch_data_func() const override { +// return DispatchDataFunc{nullptr}; +// } +// }; class MOEOptGather : public KernelGenerator { public: @@ -521,7 +521,7 @@ dnnl::memory convert2dnnl(const memory::ptr& ptr, const std::vector& di class MOEOptImpl : public PrimitiveImplOCL { public: DECLARE_OBJECT_TYPE_SERIALIZATION(ov::intel_gpu::ocl::MOEOptImpl) - Stage::Ptr softmax_topk = make_stage(); + // Stage::Ptr softmax_topk = make_stage(); Stage::Ptr gather = make_stage(); Stage::Ptr scatter = make_stage(); Stage::Ptr mlp_gate_up = make_stage(); @@ -559,8 +559,8 @@ class MOEOptImpl : public PrimitiveImplOCL { struct scratch_buffers { // softmax+topk - memory::ptr topk_id; - memory::ptr topk_weights; + // memory::ptr topk_id; + // memory::ptr topk_weights; // fast single batch: scratch.up = up(x) * silu(gate(x)) // scratch.y = down(scratch.up) * routing_weights @@ -577,6 +577,8 @@ class MOEOptImpl : public PrimitiveImplOCL { std::vector expert_masks; moe_fusion_weights_base_addr moe_fusion_wei_addr; + memory::ptr input_routing_weights; + memory::ptr input_router_topk_idx; }; std::vector> _dnnl_weights; @@ -588,7 +590,7 @@ class MOEOptImpl : public PrimitiveImplOCL { MOEOptImpl(const program_node& node, const RuntimeParams& params) : MOEOptImpl() { init(node.as().get_primitive()); - add_stage(softmax_topk, params); + // add_stage(softmax_topk, params); add_stage(gather, params); add_stage(scatter, params); add_stage(mlp_gate_up, params); @@ -677,23 +679,23 @@ class MOEOptImpl : public PrimitiveImplOCL { std::vector internal_buffers; // softmax+topk - layout layout_topk_id(ov::PartialShape{batch, max_topk}, data_types::u32, cldnn::format::bfyx); - layout layout_topk_weights(ov::PartialShape{batch, max_topk}, data_type, cldnn::format::bfyx); - internal_buffers.emplace_back(layout_topk_id, true); // 0: topk_id - internal_buffers.emplace_back(layout_topk_weights, true); // 1: topk_weights + // layout layout_topk_id(ov::PartialShape{batch, max_topk}, data_types::u32, cldnn::format::bfyx); + // layout layout_topk_weights(ov::PartialShape{batch, max_topk}, data_type, cldnn::format::bfyx); + // internal_buffers.emplace_back(layout_topk_id, true); // 0: topk_id + // internal_buffers.emplace_back(layout_topk_weights, true); // 1: topk_weights // fast single batch: scratch.up = up(x) * silu(gate(x)); scratch.y = down(scratch.up) * weight[expert_no] layout layout_gateup_out(ov::PartialShape{batch, static_cast(config.inter_size)}, data_type, cldnn::format::bfyx); layout layout_down_out(ov::PartialShape{batch, static_cast(config.hidden_size)}, data_type, cldnn::format::bfyx); - internal_buffers.emplace_back(layout_gateup_out, true); // 2: up - internal_buffers.emplace_back(layout_down_out, true); // 3: y + internal_buffers.emplace_back(layout_gateup_out, true); // 0: up + internal_buffers.emplace_back(layout_down_out, true); // 1: y // onednn: scratch.x, scratch.routing_weights = gather(x, ...) // scratch.up = up(scratch.x) // scratch.gate = gate(scratch.x) * scratch.up // scratch.y = down(scratch.gate) * routing_weights - internal_buffers.emplace_back(layout_down_out, true); // 4: x, scratch.x has same layout with down output + internal_buffers.emplace_back(layout_down_out, true); // 2: x, scratch.x has same layout with down output layout routing_layout(ov::PartialShape{batch * max_topk}, data_type, cldnn::format::bfyx); - internal_buffers.emplace_back(layout_down_out, true); // 5: routing_weights - internal_buffers.emplace_back(layout_gateup_out, true); // 6: gate, scratch.gate has same layout with up + internal_buffers.emplace_back(layout_down_out, true); // 3: routing_weights + internal_buffers.emplace_back(layout_gateup_out, true); // 4: gate, scratch.gate has same layout with up // expert masks for gpu layout index_layout(ov::PartialShape{batch}, ov::element::i32, cldnn::format::bfyx); for (int i = 0; i < expert_num; i++) { @@ -706,23 +708,26 @@ class MOEOptImpl : public PrimitiveImplOCL { void prepare_internal_buffers(typed_primitive_inst& instance, scratch_buffers& scratch, bool is_single_batch) { const auto& intermediates_memories = instance.get_intermediates_memories(); - scratch.topk_id = intermediates_memories[0]; - scratch.topk_weights = intermediates_memories[1]; - scratch.up = intermediates_memories[2]; - scratch.y = intermediates_memories[3]; + // scratch.topk_id = intermediates_memories[0]; + // scratch.topk_weights = intermediates_memories[1]; + scratch.up = intermediates_memories[0]; + scratch.y = intermediates_memories[1]; if (!is_single_batch) { - scratch.x = intermediates_memories[4]; - scratch.routing_weights = intermediates_memories[5]; - scratch.gate = intermediates_memories[6]; + scratch.x = intermediates_memories[2]; + scratch.routing_weights = intermediates_memories[3]; + scratch.gate = intermediates_memories[4]; const auto& config = instance.get_typed_desc()->_config; int expert_num = static_cast(config.num_experts); scratch.expert_masks.resize(expert_num); for (int i = 0; i < expert_num; i++) { - scratch.expert_masks[i].batch = intermediates_memories[7 + 2 * i + 0]; - scratch.expert_masks[i].topk = intermediates_memories[7 + 2 * i + 1]; + scratch.expert_masks[i].batch = intermediates_memories[5 + 2 * i + 0]; + scratch.expert_masks[i].topk = intermediates_memories[5 + 2 * i + 1]; } } + scratch.input_routing_weights = instance.input_memory_ptr(static_cast(MOEInputIndex::ROUTING_WEIGHTS)); + scratch.input_router_topk_idx = instance.input_memory_ptr(static_cast(MOEInputIndex::ROUTER_TOPK_OUTPUT_INDICES)); + // gate scratch.moe_fusion_wei_addr.weight[0] = instance.input_memory_ptr(static_cast(MOEInputIndex::WEIGHT_0)); scratch.moe_fusion_wei_addr.scale[0] = instance.input_memory_ptr(static_cast(MOEInputIndex::SCALE_0)); @@ -744,7 +749,9 @@ class MOEOptImpl : public PrimitiveImplOCL { auto layout = mem->get_layout(); const auto& shape = layout.get_shape(); - int max_expert_num = static_cast(config.num_experts), max_topk = static_cast(config.topk), max_tokens = static_cast(shape[0]); + int max_expert_num = static_cast(config.num_experts); + int max_topk = static_cast(config.topk); + int max_tokens = static_cast(shape[0]); expert_mask.pred_flag.resize(max_expert_num, 0); expert_mask.batch.resize(max_expert_num, {}); @@ -841,9 +848,11 @@ class MOEOptImpl : public PrimitiveImplOCL { int max_topk = static_cast(cur_moe->_config.topk); auto final_hidden_states_mem_ptr = instance.output_memory_ptr(0); - auto batch_mem_ptr = scratch.topk_id; + // auto batch_mem_ptr = scratch.topk_id; + auto batch_mem_ptr = scratch.input_router_topk_idx; auto [hidden_states_mem_ptr, hidden_states_layout] = get_input_info(instance, static_cast(MOEInputIndex::HIDDEN_STATES)); - auto routing_mem_ptr = scratch.topk_weights; + // auto routing_mem_ptr = scratch.topk_weights; + auto routing_mem_ptr = scratch.input_routing_weights; _hidden_size = static_cast(cur_moe->_config.hidden_size); _intermediate_size = static_cast(cur_moe->_config.inter_size); @@ -1008,14 +1017,14 @@ class MOEOptImpl : public PrimitiveImplOCL { prepare_internal_buffers(instance, scratch, batch == 1); // softmax+topk - auto lws_size = cur_moe->_config.num_experts; - auto topk_event = execute_stage(events, - instance, - *softmax_topk, - {instance.input_memory_ptr(static_cast(MOEInputIndex::ROUTING_WEIGHTS))}, - {scratch.topk_id, scratch.topk_weights}, - {static_cast(batch), lws_size}, - {1, lws_size}); + // auto lws_size = cur_moe->_config.num_experts; + // auto topk_event = execute_stage(events, + // instance, + // *softmax_topk, + // {instance.input_memory_ptr(static_cast(MOEInputIndex::ROUTING_WEIGHTS))}, + // {scratch.topk_id, scratch.topk_weights}, + // {static_cast(batch), lws_size}, + // {1, lws_size}); // Single batch is a special case, we don't need to do gather/scatter, // and we can apply optimal kernels against memory bound to improve performance. @@ -1033,16 +1042,20 @@ class MOEOptImpl : public PrimitiveImplOCL { final_hidden_states_mem_ptr->fill(stream, false); // Wait for topk is ready - topk_event->wait(); + // topk_event->wait(); // [batch, max_topk] - auto topk_id_mem = scratch.topk_id; + // auto topk_id_mem = scratch.topk_id; + auto topk_id_mem = scratch.input_router_topk_idx; + + expert_mask_cpu expert_mask; get_expert_mask_from_gpu(config, topk_id_mem, stream, expert_mask); auto& dnn_stream = stream.get_onednn_stream(); cldnn::event::ptr result_event; - auto routing_mem_ptr = scratch.topk_weights; + // auto routing_mem_ptr = scratch.topk_weights; + auto routing_mem_ptr = scratch.input_routing_weights; auto get_best_lws = [](size_t hidden_size) { const size_t candidate[] = {128, 64, 32, 16, 8}; for (size_t i = 0; i < sizeof(candidate) / sizeof(size_t); i++) { @@ -1052,7 +1065,7 @@ class MOEOptImpl : public PrimitiveImplOCL { } OPENVINO_ASSERT(false, "hidden_size=", hidden_size, " is not divisible by any of ", sizeof(candidate) / sizeof(size_t), " candidates"); }; - lws_size = get_best_lws(_hidden_size); + auto lws_size = get_best_lws(_hidden_size); OPENVINO_ASSERT(batch != 1, "batch size shouldn't be 1 for this path!"); for (size_t expert_no = 0; expert_no < config.num_experts; expert_no++) { diff --git a/src/plugins/intel_gpu/src/plugin/ops/moe.cpp b/src/plugins/intel_gpu/src/plugin/ops/moe.cpp index 2253a82d0e28ef..f1b1b5bcf2a3c6 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/moe.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/moe.cpp @@ -21,7 +21,7 @@ namespace ov::intel_gpu { static void CreateMOECompressedOp(ProgramBuilder& p, const std::shared_ptr& op) { auto inputs = p.GetInputInfo(op); const auto& config = op->get_config(); - OPENVINO_ASSERT(inputs.size() == 12, "Inputs count of MOE should be 12"); + OPENVINO_ASSERT(inputs.size() == 12, "Inputs count of MOECompressed should be 12"); const std::string layerName = layer_type_name_ID(op); // auto& engine = p.get_engine(); From d41574f2af2b14598ed030abc258465524b0b206 Mon Sep 17 00:00:00 2001 From: chenhu-wang Date: Mon, 20 Oct 2025 16:35:44 +0800 Subject: [PATCH 07/13] MOE to MOECompressed --- .../include/intel_gpu/op/moe_compressed.hpp | 14 +- .../convert_moe_to_compressed.cpp | 154 ++++++++++++++++++ .../convert_moe_to_compressed.hpp | 17 ++ 3 files changed, 178 insertions(+), 7 deletions(-) create mode 100644 src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp create mode 100644 src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.hpp diff --git a/src/plugins/intel_gpu/include/intel_gpu/op/moe_compressed.hpp b/src/plugins/intel_gpu/include/intel_gpu/op/moe_compressed.hpp index f24921ab685c9e..07eededcdaf06e 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/op/moe_compressed.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/op/moe_compressed.hpp @@ -16,12 +16,12 @@ class MOECompressed : public ov::op::Op { MOECompressed() = default; struct Config { + size_t hidden_size = 0; + size_t inter_size = 0; + size_t num_expert = 0; + size_t top_k = 0; + size_t group_size = 0; ov::element::Type out_type = ov::element::dynamic; // fp16 - size_t group_size{}; - size_t hidden_size{}; - size_t inter_size{}; - size_t num_experts{}; - size_t topk{}; }; /// \brief Constructs a MOECompressed operation with config only @@ -37,13 +37,13 @@ class MOECompressed : public ov::op::Op { /// 5: w0_zp - expert zp for first projection for compressed experts, /// shape [num_experts, inter_size, group_num, 1] /// 6: w1_weight - expert weights for second projection, - /// shape [num_experts, inter_size, hidden_size] + /// shape [num_experts, inter_size, group_num, group_size] /// 7: w1_scale - expert scale for second projection for compressed experts, /// shape [num_experts, inter_size, group_num, 1] /// 8: w1_zp - expert zp for second projection for compressed experts, /// shape [num_experts, inter_size, group_num, 1] /// 9: w2_weight - expert weights for final projection, - /// shape [num_experts, hidden_size, inter_size] + /// shape [num_experts, hidden_size, group_num, group_size] /// 10: w2_scale - expert scale for final projection for compressed experts, /// shape [num_experts, hidden_size, group_num, 1] /// 11: w2_zp - expert zp for final projection for compressed experts, diff --git a/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp b/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp new file mode 100644 index 00000000000000..b254c7bc067bcd --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp @@ -0,0 +1,154 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "convert_moe_to_compressed.hpp" + +#include + +#include "intel_gpu/op/moe_compressed.hpp" +#include "openvino/core/graph_util.hpp" +#include "openvino/core/rt_info.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/moe.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/rt_info/keep_const_precision.hpp" +#include "transformations/utils/utils.hpp" + +namespace ov::intel_gpu { +using namespace ov::pass::pattern; + +ConvertMOEToMOECompressed::ConvertMOEToMOECompressed() { + auto reshape_ungroup = [](const ov::Output& output) { + auto in_ps = output.get_node()->get_input_partial_shape(0); + auto out_ps = output.get_node()->get_output_partial_shape(0); + return in_ps.rank().is_static() && out_ps.rank().is_static() && (in_ps.size() == 4 && out_ps.size() == 3); + }; + // first proj + auto compressed_weights_m_0 = wrap_type(type_matches(ov::element::u4)); + auto zp_m_0 = wrap_type(type_matches(ov::element::u4)); + + auto weight_convert_const_m_0 = wrap_type({compressed_weights_m_0}, type_matches(ov::element::f16)); + auto zp_convert_const_m_0 = wrap_type({zp_m_0}, type_matches(ov::element::f16)); + auto sub_m_0 = wrap_type({weight_convert_const_m_0, zp_convert_const_m_0}); + + auto scale_m_0 = wrap_type(type_matches(ov::element::f16)); + auto mul_m_0 = wrap_type({sub_m_0, scale_m_0}); + + auto reshape_const_m_0 = wrap_type(); + auto reshape_m_0 = wrap_type({mul_m_0, reshape_const_m_0}, reshape_ungroup); + + auto convert_m_0 = wrap_type({reshape_m_0}, type_matches(ov::element::f32)); + + // second proj + auto compressed_weights_m_1 = wrap_type(type_matches(ov::element::u4)); + auto zp_m_1 = wrap_type(type_matches(ov::element::u4)); + + auto weight_convert_const_m_1 = wrap_type({compressed_weights_m_1}, type_matches(ov::element::f16)); + auto zp_convert_const_m_1 = wrap_type({zp_m_1}, type_matches(ov::element::f16)); + auto sub_m_1 = wrap_type({weight_convert_const_m_1, zp_convert_const_m_1}); + + auto scale_m_1 = wrap_type(type_matches(ov::element::f16)); + auto mul_m_1 = wrap_type({sub_m_1, scale_m_1}); + + auto reshape_const_m_1 = wrap_type(); + auto reshape_m_1 = wrap_type({mul_m_1, reshape_const_m_1}, reshape_ungroup); + + auto convert_m_1 = wrap_type({reshape_m_1}, type_matches(ov::element::f32)); + + // third proj + auto compressed_weights_m_2 = wrap_type(type_matches(ov::element::u4)); + auto zp_m_2 = wrap_type(type_matches(ov::element::u4)); + + auto weight_convert_const_m_2 = wrap_type({compressed_weights_m_2}, type_matches(ov::element::f16)); + auto zp_convert_const_m_2 = wrap_type({zp_m_2}, type_matches(ov::element::f16)); + auto sub_m_2 = wrap_type({weight_convert_const_m_2, zp_convert_const_m_2}); + + auto scale_m_2 = wrap_type(type_matches(ov::element::f16)); + auto mul_m_2 = wrap_type({sub_m_2, scale_m_2}); + + auto reshape_const_m_2 = wrap_type(); + auto reshape_m_2 = wrap_type({mul_m_2, reshape_const_m_2}, reshape_ungroup); + + auto convert_m_2 = wrap_type({reshape_m_2}, type_matches(ov::element::f32)); + + auto hidden_states_m = any_input(); + auto routing_weights_m = any_input(); + auto topk_m = any_input(); + + auto moe_root = wrap_type({hidden_states_m, routing_weights_m, topk_m, convert_m_0, convert_m_1, convert_m_2}, + [](const ov::Output& output) { + auto moe = ov::as_type_ptr(output.get_node_shared_ptr()); + return moe->get_config().expert_type == ov::op::internal::MOE::Expert_type::GEMM3_SWIGLU; + }); + + ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + + auto moe = ov::as_type_ptr(pattern_map.at(moe_root).get_node_shared_ptr()); + if (!moe || transformation_callback(moe)) { + return false; + } + OutputVector args(11); + args[0] = pattern_map.at(hidden_states_m); + args[1] = pattern_map.at(routing_weights_m); + args[2] = pattern_map.at(topk_m); + args[3] = pattern_map.at(compressed_weights_m_0); + args[4] = pattern_map.at(scale_m_0); + args[5] = pattern_map.at(zp_m_0); + args[6] = pattern_map.at(compressed_weights_m_1); + args[7] = pattern_map.at(scale_m_1); + args[8] = pattern_map.at(zp_m_1); + args[9] = pattern_map.at(compressed_weights_m_2); + args[10] = pattern_map.at(scale_m_2); + args[11] = pattern_map.at(zp_m_2); + ov::intel_gpu::op::MOECompressed::Config config; + auto weight_shape = pattern_map.at(compressed_weights_m_0).get_shape(); + if (weight_shape.size() != 4) { + return false; + } + auto topk_shape = pattern_map.at(topk_m).get_shape(); + config.hidden_size = weight_shape[2] * weight_shape[3]; + config.inter_size = weight_shape[1]; + config.num_expert = weight_shape[0]; + config.group_size = weight_shape[3]; + config.top_k = topk_shape.back(); + config.out_type = ov::element::f16; + auto moe_compressed = std::make_shared(args, config); + + auto w0 = pattern_map.at(compressed_weights_m_0).get_node_shared_ptr(); + auto s0 = pattern_map.at(scale_m_0).get_node_shared_ptr(); + auto z0 = pattern_map.at(zp_m_0).get_node_shared_ptr(); + auto w1 = pattern_map.at(compressed_weights_m_1).get_node_shared_ptr(); + auto s1 = pattern_map.at(scale_m_1).get_node_shared_ptr(); + auto z1 = pattern_map.at(zp_m_1).get_node_shared_ptr(); + auto w2 = pattern_map.at(compressed_weights_m_2).get_node_shared_ptr(); + auto s2 = pattern_map.at(scale_m_2).get_node_shared_ptr(); + auto z2 = pattern_map.at(zp_m_2).get_node_shared_ptr(); + ov::enable_keep_const_precision(w0); + ov::enable_keep_const_precision(s0); + ov::enable_keep_const_precision(z0); + ov::enable_keep_const_precision(w1); + ov::enable_keep_const_precision(s1); + ov::enable_keep_const_precision(z1); + ov::enable_keep_const_precision(w2); + ov::enable_keep_const_precision(s2); + ov::enable_keep_const_precision(z2); + + moe_compressed->set_friendly_name(moe->get_friendly_name()); + ov::copy_runtime_info(moe, moe_compressed); + ov::replace_node(moe, moe_compressed); + + return true; + }; + + auto m = std::make_shared(moe_root, "ConvertMOEToMOECompressed"); + this->register_matcher(m, callback); +} + +} // namespace ov::intel_gpu diff --git a/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.hpp b/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.hpp new file mode 100644 index 00000000000000..6d5fd0b7bde1c7 --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.hpp @@ -0,0 +1,17 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" + +namespace ov::intel_gpu { + +class ConvertMOEToMOECompressed: public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("ConvertMOEToMOECompressed"); + ConvertMOEToMOECompressed(); +}; + +} // namespace ov::intel_gpu From 9477586ceffd4ae2d1b7e9a784fa5f494709e883 Mon Sep 17 00:00:00 2001 From: "River, Li" Date: Tue, 21 Oct 2025 09:22:48 +0800 Subject: [PATCH 08/13] update --- .../src/graph/impls/ocl_v2/moe_opt.cpp | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp index dda9364535db3c..a5f4c906314287 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp @@ -356,8 +356,8 @@ struct onednn_linear { // auto jit = KernelGenerator::get_jit_constants(params); // auto desc = params.typed_desc(); // jit.make("SOFTMAX_TOPK_ENABLE", 1); -// jit.make("TOP_K", desc->_config.topk); -// jit.make("VALUE_NUM", desc->_config.num_experts); +// jit.make("TOP_K", desc->_config.top_k); +// jit.make("VALUE_NUM", desc->_config.num_expert); // jit.make("TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); // jit.make("TYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); // return jit; @@ -432,8 +432,8 @@ static void add_common_consts(const RuntimeParams& params, JitConstants& jit) { auto desc = params.typed_desc(); auto& engine = params.prog->get_engine(); const auto& info = engine.get_device_info(); - jit.make("MAX_TOPK", desc->_config.topk); - jit.make("EXPERT_NUM", desc->_config.num_experts); + jit.make("MAX_TOPK", desc->_config.top_k); + jit.make("EXPERT_NUM", desc->_config.num_expert); jit.make("HIDDEN_SIZE", desc->_config.hidden_size); jit.make("INTERMEDIATE_SIZE", desc->_config.inter_size); jit.make("N_BLOCK", N_BLOCK); @@ -607,12 +607,12 @@ class MOEOptImpl : public PrimitiveImplOCL { void init_dnnl_weights(const std::shared_ptr& cur_moe, cldnn::engine& engine, const struct moe_fusion_weights_base_addr& moe_fusion_wei_addr) { - if (_dnnl_weights.size() == cur_moe->_config.num_experts) + if (_dnnl_weights.size() == cur_moe->_config.num_expert) return; init(cur_moe); - _dnnl_weights.resize(cur_moe->_config.num_experts); - for (size_t j = 0; j < cur_moe->_config.num_experts; j++) { + _dnnl_weights.resize(cur_moe->_config.num_expert); + for (size_t j = 0; j < cur_moe->_config.num_expert; j++) { auto& dnnl_weights = _dnnl_weights[j]; dnnl_weights.resize(3); dnnl_weights[0].ic = _hidden_size; @@ -670,8 +670,8 @@ class MOEOptImpl : public PrimitiveImplOCL { std::vector get_internal_buffer_descs(const kernel_impl_params& params) const override { auto cur_moe = params.typed_desc(); const auto& config = cur_moe->_config; - int max_topk = static_cast(config.topk); - int expert_num = static_cast(config.num_experts); + int max_topk = static_cast(config.top_k); + int expert_num = static_cast(config.num_expert); auto hidden_states_layout = params.input_layouts[0]; auto batch = static_cast(hidden_states_layout.get_shape()[0]); @@ -717,7 +717,7 @@ class MOEOptImpl : public PrimitiveImplOCL { scratch.routing_weights = intermediates_memories[3]; scratch.gate = intermediates_memories[4]; const auto& config = instance.get_typed_desc()->_config; - int expert_num = static_cast(config.num_experts); + int expert_num = static_cast(config.num_expert); scratch.expert_masks.resize(expert_num); for (int i = 0; i < expert_num; i++) { scratch.expert_masks[i].batch = intermediates_memories[5 + 2 * i + 0]; @@ -749,8 +749,8 @@ class MOEOptImpl : public PrimitiveImplOCL { auto layout = mem->get_layout(); const auto& shape = layout.get_shape(); - int max_expert_num = static_cast(config.num_experts); - int max_topk = static_cast(config.topk); + int max_expert_num = static_cast(config.num_expert); + int max_topk = static_cast(config.top_k); int max_tokens = static_cast(shape[0]); expert_mask.pred_flag.resize(max_expert_num, 0); @@ -845,7 +845,7 @@ class MOEOptImpl : public PrimitiveImplOCL { cldnn::event::ptr exec_single_batch(typed_primitive_inst& instance, scratch_buffers& scratch) { auto cur_moe = instance.get_typed_desc(); - int max_topk = static_cast(cur_moe->_config.topk); + int max_topk = static_cast(cur_moe->_config.top_k); auto final_hidden_states_mem_ptr = instance.output_memory_ptr(0); // auto batch_mem_ptr = scratch.topk_id; @@ -1006,7 +1006,7 @@ class MOEOptImpl : public PrimitiveImplOCL { auto& instance = reinterpret_cast&>(ins); auto cur_moe = instance.get_typed_desc(); const auto& config = cur_moe->_config; - int max_topk = static_cast(config.topk); + int max_topk = static_cast(config.top_k); auto& cur_net = instance.get_network(); auto& stream = cur_net.get_stream(); @@ -1017,7 +1017,7 @@ class MOEOptImpl : public PrimitiveImplOCL { prepare_internal_buffers(instance, scratch, batch == 1); // softmax+topk - // auto lws_size = cur_moe->_config.num_experts; + // auto lws_size = cur_moe->_config.num_expert; // auto topk_event = execute_stage(events, // instance, // *softmax_topk, @@ -1068,7 +1068,7 @@ class MOEOptImpl : public PrimitiveImplOCL { auto lws_size = get_best_lws(_hidden_size); OPENVINO_ASSERT(batch != 1, "batch size shouldn't be 1 for this path!"); - for (size_t expert_no = 0; expert_no < config.num_experts; expert_no++) { + for (size_t expert_no = 0; expert_no < config.num_expert; expert_no++) { OPENVINO_ASSERT(expert_no < expert_mask.pred_flag.size()); auto can_skip_subgraph = !expert_mask.pred_flag[expert_no]; if (can_skip_subgraph) { From 982699d94ed56bdfd60db3bc784d8c448a28edbd Mon Sep 17 00:00:00 2001 From: "River, Li" Date: Tue, 21 Oct 2025 10:18:41 +0800 Subject: [PATCH 09/13] Enable moe transformation - FuseVectorizedMOE3GEMM and ConvertMOEToMOECompressed --- .../intel_gpu/src/plugin/transformations_pipeline.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index f8edfe407159fc..7c0fef9f23699b 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -74,6 +74,7 @@ #include "plugin/transformations/convert_convolution.hpp" #include "plugin/transformations/convert_fc_to_compressed.hpp" #include "plugin/transformations/convert_matmul_to_fc.hpp" +#include "plugin/transformations/convert_moe_to_compressed.hpp" #include "plugin/transformations/convert_stridedslices_to_variadicsplit.hpp" #include "plugin/transformations/decompose_reduce_scalar_output.hpp" #include "plugin/transformations/dynamic_quantize_fully_connected.hpp" @@ -107,6 +108,7 @@ #include "transformations/common_optimizations/lstm_cell_fusion.hpp" #include "transformations/common_optimizations/move_eltwise_up_data_movement.hpp" #include "transformations/common_optimizations/mvn_fusion.hpp" +#include "transformations/common_optimizations/matmul_experts_fusion.hpp" #include "transformations/common_optimizations/nop_elimination.hpp" #include "transformations/common_optimizations/rms_fusion.hpp" #include "transformations/common_optimizations/sdpa_scale_fusion.hpp" @@ -188,6 +190,7 @@ #include "openvino/op/ceiling.hpp" #include "openvino/op/clamp.hpp" #include "openvino/op/matmul.hpp" +#include "openvino/op/moe.hpp" #include "openvino/op/reverse_sequence.hpp" #include "openvino/op/roll.hpp" #include "openvino/op/shuffle_channels.hpp" @@ -395,6 +398,9 @@ void TransformationsPipeline::apply(std::shared_ptr func) { std::vector{ ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4 }, !device_info.supports_immad); + manager.register_pass(); + manager.register_pass(); + manager.register_pass(); manager.register_pass(); From b77097f016683e79c813e19a8a749c99f33ea63b Mon Sep 17 00:00:00 2001 From: "River, Li" Date: Tue, 21 Oct 2025 10:45:54 +0800 Subject: [PATCH 10/13] update dnnl weight convert --- .../src/graph/impls/ocl_v2/moe_opt.cpp | 32 ++++++++----------- .../convert_moe_to_compressed.cpp | 5 +-- 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp index a5f4c906314287..0261a843bafaae 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp @@ -626,28 +626,23 @@ class MOEOptImpl : public PrimitiveImplOCL { dnnl_weights[2].oc = _hidden_size; for (int i = 0; i < 3; i++) { // weight shape: [ic, oc], type: u4 - ov::Shape wei_shape = {static_cast(dnnl_weights[i].ic), static_cast(dnnl_weights[i].oc)}; - auto wei_layout = cldnn::layout(wei_shape, cldnn::data_types::u4, cldnn::format::get_default_format(wei_shape.size())); - auto wei_mem = engine.create_subbuffer(*moe_fusion_wei_addr.weight[i], wei_layout, j * dnnl_weights[i].ic * dnnl_weights[i].oc / 2); - dnnl_weights[i].weight = convert2dnnl(wei_mem, {dnnl_weights[i].ic, dnnl_weights[i].oc}, dnnl::memory::format_tag::ba); + size_t wei_offset = j * dnnl_weights[i].ic * dnnl_weights[i].oc / 2; + dnnl_weights[i].weight = + convert2dnnl(moe_fusion_wei_addr.weight[i], {dnnl_weights[i].ic, dnnl_weights[i].oc}, dnnl::memory::format_tag::ba, wei_offset); // scale shape: [ic / ic_group_size, oc], type: f16 - ov::Shape scale_shape = {static_cast(dnnl_weights[i].ic / dnnl_weights[i].ic_group_size), static_cast(dnnl_weights[i].oc)}; - auto scale_layout = cldnn::layout(scale_shape, cldnn::data_types::f16, cldnn::format::get_default_format(scale_shape.size())); - auto scale_mem = engine.create_subbuffer(*moe_fusion_wei_addr.scale[i], - scale_layout, - j * dnnl_weights[i].ic * dnnl_weights[i].oc / dnnl_weights[i].ic_group_size * 2); - dnnl_weights[i].scale = - convert2dnnl(scale_mem, {dnnl_weights[i].ic / dnnl_weights[i].ic_group_size, dnnl_weights[i].oc}, dnnl::memory::format_tag::ab); + size_t scale_offset = j * dnnl_weights[i].ic * dnnl_weights[i].oc / dnnl_weights[i].ic_group_size * 2; + dnnl_weights[i].scale = convert2dnnl(moe_fusion_wei_addr.scale[i], + {dnnl_weights[i].ic / dnnl_weights[i].ic_group_size, dnnl_weights[i].oc}, + dnnl::memory::format_tag::ab, + scale_offset); // zp shape: [ic / ic_group_size, oc], type: u4 - ov::Shape zp_shape = {static_cast(dnnl_weights[i].ic / dnnl_weights[i].ic_group_size), static_cast(dnnl_weights[i].oc)}; - auto zp_layout = cldnn::layout(zp_shape, cldnn::data_types::u4, cldnn::format::get_default_format(zp_shape.size())); - auto zp_mem = engine.create_subbuffer(*moe_fusion_wei_addr.zp[i], - zp_layout, - j * dnnl_weights[i].ic * dnnl_weights[i].oc / dnnl_weights[i].ic_group_size / 2); - dnnl_weights[i].zp = - convert2dnnl(zp_mem, {dnnl_weights[i].ic / dnnl_weights[i].ic_group_size, dnnl_weights[i].oc}, dnnl::memory::format_tag::ab); + size_t zp_offset = j * dnnl_weights[i].ic * dnnl_weights[i].oc / dnnl_weights[i].ic_group_size / 2; + dnnl_weights[i].zp = convert2dnnl(moe_fusion_wei_addr.zp[i], + {dnnl_weights[i].ic / dnnl_weights[i].ic_group_size, dnnl_weights[i].oc}, + dnnl::memory::format_tag::ab, + zp_offset); } } } @@ -1047,7 +1042,6 @@ class MOEOptImpl : public PrimitiveImplOCL { // auto topk_id_mem = scratch.topk_id; auto topk_id_mem = scratch.input_router_topk_idx; - expert_mask_cpu expert_mask; get_expert_mask_from_gpu(config, topk_id_mem, stream, expert_mask); diff --git a/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp b/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp index b254c7bc067bcd..d53fdc35b4c7b4 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp @@ -112,12 +112,12 @@ ConvertMOEToMOECompressed::ConvertMOEToMOECompressed() { if (weight_shape.size() != 4) { return false; } - auto topk_shape = pattern_map.at(topk_m).get_shape(); + auto topk_shape = pattern_map.at(topk_m).get_partial_shape(); config.hidden_size = weight_shape[2] * weight_shape[3]; config.inter_size = weight_shape[1]; config.num_expert = weight_shape[0]; config.group_size = weight_shape[3]; - config.top_k = topk_shape.back(); + config.top_k = topk_shape[1].get_length(); config.out_type = ov::element::f16; auto moe_compressed = std::make_shared(args, config); @@ -144,6 +144,7 @@ ConvertMOEToMOECompressed::ConvertMOEToMOECompressed() { ov::copy_runtime_info(moe, moe_compressed); ov::replace_node(moe, moe_compressed); + std::cout << "ConvertMOEToMOECompressed is done : config.top_k = " << config.top_k << std::endl; return true; }; From 818678d69fbf19d20a0d9a1bbff37540576df9b7 Mon Sep 17 00:00:00 2001 From: "River, Li" Date: Tue, 21 Oct 2025 13:40:44 +0800 Subject: [PATCH 11/13] Fix double free issue and kernel build errors --- .../src/graph/impls/ocl_v2/moe_mlp.cl | 18 ++++----- .../src/graph/impls/ocl_v2/moe_opt.cl | 40 +++++++++---------- .../src/graph/impls/ocl_v2/moe_opt.cpp | 16 ++++---- .../src/graph/impls/ocl_v2/moe_opt.hpp | 2 +- .../convert_moe_to_compressed.cpp | 2 +- 5 files changed, 39 insertions(+), 39 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_mlp.cl b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_mlp.cl index f5cd1297ff9e0e..78fdc0ef30807b 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_mlp.cl +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_mlp.cl @@ -135,7 +135,7 @@ inline void gemv_n2x(const __global uchar* weight, } __attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) -__kernel void mlp_gate_up( +KERNEL (mlp_gate_up)( const __global int* expert_list, const __global uchar* gate_weight_addr, const __global uchar* gate_scale_addr, @@ -143,8 +143,8 @@ __kernel void mlp_gate_up( const __global uchar* up_weight_addr, const __global uchar* up_scale_addr, const __global uchar* up_zp_addr, - __global TYPE* x, // [1, HIDDEN_SIZE] - __global TYPE* y) { // [MAX_TOPK, INTERMEDIATE_SIZE] + __global MOE_TYPE* x, // [1, HIDDEN_SIZE] + __global MOE_TYPE* y) { // [MAX_TOPK, INTERMEDIATE_SIZE] // global: [expert, SUBGROUP_SIZE, N//N_BLOCK],[1, SUBGROUP_SIZE, SUBGROUP_NUM] int expert_no = get_global_id(0); y += expert_no * INTERMEDIATE_SIZE; @@ -172,14 +172,14 @@ __kernel void mlp_gate_up( #elif DOWN_ENABLE __attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) -__kernel void mlp_down( +KERNEL (mlp_down)( const __global int* expert_list, const __global uchar* down_weight_addr, const __global uchar* down_scale_addr, const __global uchar* down_zp_addr, - const __global TYPE* x, // [MAX_TOPK, INTERMEDIATE_SIZE] - __global TYPE* routing_weights, // [MAX_TOPK] - __global TYPE* y) { // [MAX_TOPK, HIDDEN_SIZE] + const __global MOE_TYPE* x, // [MAX_TOPK, INTERMEDIATE_SIZE] + __global MOE_TYPE* routing_weights, // [MAX_TOPK] + __global MOE_TYPE* y) { // [MAX_TOPK, HIDDEN_SIZE] // global: [expert, SUBGROUP_SIZE, N//N_BLOCK],[1, SUBGROUP_SIZE, SUBGROUP_NUM] int expert_no = get_global_id(0); x += expert_no * INTERMEDIATE_SIZE; @@ -301,8 +301,8 @@ __kernel void mlp_down( #else __attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) -__kernel void mlp_reduce(const __global TYPE* x, // [MAX_TOPK, HIDDEN_SIZE] - __global TYPE* y) { // [1, HIDDEN_SIZE] +KERNEL (mlp_reduce)(const __global MOE_TYPE* x, // [MAX_TOPK, HIDDEN_SIZE] + __global MOE_TYPE* y) { // [1, HIDDEN_SIZE] int n = get_global_id(1); half sum[MAX_TOPK] = {0}; __attribute__((opencl_unroll_hint(MAX_TOPK))) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cl b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cl index 3f7837aae82e88..e1c717525354ff 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cl +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cl @@ -4,10 +4,10 @@ #if SOFTMAX_TOPK_ENABLE -__kernel void softmax_topk( - const __global TYPE* input, // [input_batch, sort_in_num] +KERNEL(softmax_topk)( + const __global MOE_TYPE* input, // [input_batch, sort_in_num] __global uint* output_index, // [input_batch, TOP_K] - __global TYPE* output // [input_batch, TOP_K] + __global MOE_TYPE* output // [input_batch, TOP_K] ) { // gws [batch, sort_in_num] const uint batch = (uint)get_global_id(0); @@ -18,17 +18,17 @@ __kernel void softmax_topk( uint sort_position = 0; - __local TYPE local_input[VALUE_NUM]; - __local TYPE local_output[TOP_K]; + __local MOE_TYPE local_input[VALUE_NUM]; + __local MOE_TYPE local_output[TOP_K]; __local uint local_index[TOP_K]; - TYPE in_value = as_half(intel_sub_group_block_read_us((const __global ushort*)(input))); + MOE_TYPE in_value = as_half(intel_sub_group_block_read_us((const __global ushort*)(input))); local_input[sort_index] = in_value; barrier(CLK_LOCAL_MEM_FENCE); __attribute__((opencl_unroll_hint(8))) for(uint i = 0; i < sort_index; i++) { - TYPE value = local_input[i]; + MOE_TYPE value = local_input[i]; if(value >= in_value) { sort_position++; } @@ -36,7 +36,7 @@ __kernel void softmax_topk( __attribute__((opencl_unroll_hint(8))) for(uint i = sort_index; i < sort_cnt; i++) { - TYPE value = local_input[i]; + MOE_TYPE value = local_input[i]; if(value > in_value) { sort_position++; } @@ -49,7 +49,7 @@ __kernel void softmax_topk( if(sort_position == 0) { float softmax_total = 1.0; - TYPE max_v = local_output[0]; + MOE_TYPE max_v = local_output[0]; local_output[0] = 1; for(uint i = 1; i < TOP_K; i++) { local_output[i] = native_exp(local_output[i] - max_v); @@ -66,13 +66,13 @@ __kernel void softmax_topk( } #elif GATHER_ENABLE -__kernel void gather_2d_ref( - const __global TYPE* src_tok, - const __global TYPE* src_rweight, +KERNEL (gather_2d_ref)( + const __global MOE_TYPE* src_tok, + const __global MOE_TYPE* src_rweight, __global int * tok_index, __global int * top_index, - __global TYPE* dst_tok, - __global TYPE* dst_rweight) { + __global MOE_TYPE* dst_tok, + __global MOE_TYPE* dst_rweight) { int k = get_global_id(0); int off = get_global_id(1); @@ -81,10 +81,10 @@ __kernel void gather_2d_ref( src_tok += tok_idx * HIDDEN_SIZE; dst_tok += k * HIDDEN_SIZE; - #if TYPE_SIZE == 2 + #if MOE_TYPE_SIZE == 2 ushort value = intel_sub_group_block_read_us((const __global ushort *)(src_tok + off)); intel_sub_group_block_write_us((__global ushort *)(dst_tok + off), value); - #elif TYPE_SIZE == 4 + #elif MOE_TYPE_SIZE == 4 uint value = intel_sub_group_block_read((const __global uint *)(src_tok + off)); intel_sub_group_block_write((__global uint *)(dst_tok + off), value); #else @@ -99,9 +99,9 @@ __kernel void gather_2d_ref( #elif SCATTER_ENABLE -__kernel void index_add_(const __global TYPE* src_tok, +KERNEL (index_add_)(const __global MOE_TYPE* src_tok, __global int * tok_index, - __global TYPE* dst_tok) { + __global MOE_TYPE* dst_tok) { int k = get_global_id(0); int off = get_global_id(1); @@ -110,12 +110,12 @@ __kernel void index_add_(const __global TYPE* src_tok, src_tok += k * HIDDEN_SIZE; dst_tok += tok_idx * HIDDEN_SIZE; - #if TYPE_SIZE == 2 + #if MOE_TYPE_SIZE == 2 half src_value = as_half(intel_sub_group_block_read_us((const __global ushort *)(src_tok + off))); half dst_value = as_half(intel_sub_group_block_read_us((const __global ushort *)(dst_tok + off))); half value = dst_value + src_value; intel_sub_group_block_write_us((__global ushort *)(dst_tok + off), as_ushort(value)); - #elif TYPE_SIZE == 4 + #elif MOE_TYPE_SIZE == 4 float src_value = as_float(intel_sub_group_block_read((const __global uint *)(src_tok + off))); float dst_value = as_float(intel_sub_group_block_read((const __global uint *)(dst_tok + off))); float value = dst_value + src_value; diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp index 0261a843bafaae..b5a553870d1a81 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp @@ -358,8 +358,8 @@ struct onednn_linear { // jit.make("SOFTMAX_TOPK_ENABLE", 1); // jit.make("TOP_K", desc->_config.top_k); // jit.make("VALUE_NUM", desc->_config.num_expert); -// jit.make("TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); -// jit.make("TYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); +// jit.make("MOE_TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); +// jit.make("MOE_TYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); // return jit; // } @@ -384,8 +384,8 @@ class MOEOptGather : public KernelGenerator { auto desc = params.typed_desc(); jit.make("GATHER_ENABLE", 1); jit.make("HIDDEN_SIZE", desc->_config.hidden_size); - jit.make("TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); - jit.make("TYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); + jit.make("MOE_TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); + jit.make("MOE_TYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); return jit; } @@ -410,8 +410,8 @@ class MOEOptScatter : public KernelGenerator { auto desc = params.typed_desc(); jit.make("SCATTER_ENABLE", 1); jit.make("HIDDEN_SIZE", desc->_config.hidden_size); - jit.make("TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); - jit.make("TYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); + jit.make("MOE_TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); + jit.make("MOE_TYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); return jit; } @@ -440,8 +440,8 @@ static void add_common_consts(const RuntimeParams& params, JitConstants& jit) { jit.make("SUBGROUP_SIZE", info.arch >= gpu_arch::xe2 ? 32 : 16); jit.make("SUBGROUP_NUM", SUBGROUP_NUM); jit.make("GROUP_SIZE", desc->_config.group_size); - jit.make("TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); - jit.make("TYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); + jit.make("MOE_TYPE", params.get_input_layout(0).data_type == ov::element::f16 ? "half" : "float"); + jit.make("MOE_TYPE_SIZE", params.get_input_layout(0).data_type == ov::element::f16 ? 2 : 4); } class MOEOptMLPGateUp : public KernelGenerator { diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.hpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.hpp index 4f5c971854fedf..1fd48115877176 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.hpp @@ -66,7 +66,7 @@ struct MOEOpt : public ImplementationManager { if (!one_of(wei_layout.data_type, supported_wei_type)) { return false; } - + std::cout << "ocl::moe::opt is supported..." << std::endl; return true; } }; diff --git a/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp b/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp index d53fdc35b4c7b4..6a427ff197f918 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp @@ -94,7 +94,7 @@ ConvertMOEToMOECompressed::ConvertMOEToMOECompressed() { if (!moe || transformation_callback(moe)) { return false; } - OutputVector args(11); + OutputVector args(12); args[0] = pattern_map.at(hidden_states_m); args[1] = pattern_map.at(routing_weights_m); args[2] = pattern_map.at(topk_m); From 6bf7222332039fbc5bc8244864d842cac7b63530 Mon Sep 17 00:00:00 2001 From: "River, Li" Date: Wed, 22 Oct 2025 20:32:49 +0800 Subject: [PATCH 12/13] Fix router weight gather issue --- .../src/graph/impls/ocl_v2/moe_opt.cl | 23 ++++++++-------- .../src/graph/impls/ocl_v2/moe_opt.cpp | 26 ++++++++++++------- .../convert_moe_to_compressed.cpp | 2 +- 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cl b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cl index e1c717525354ff..e06f73f14d677d 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cl +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cl @@ -67,15 +67,15 @@ KERNEL(softmax_topk)( #elif GATHER_ENABLE KERNEL (gather_2d_ref)( - const __global MOE_TYPE* src_tok, - const __global MOE_TYPE* src_rweight, - __global int * tok_index, - __global int * top_index, - __global MOE_TYPE* dst_tok, - __global MOE_TYPE* dst_rweight) { - - int k = get_global_id(0); - int off = get_global_id(1); + const __global MOE_TYPE* src_tok, // input tokens [total_token, hidden_size] - hidden_states_mem_ptr + const __global MOE_TYPE* src_rweight, // topk_weights [total_token, topk_experts] + __global int * tok_index, // token index [expert_idx][] = [actual_token_num] - expert_mask_mem.batch + __global int * top_index, // topk index [expert_idx][] = [actual_token_num] - expert_mask_mem.topk + __global MOE_TYPE* dst_tok, // output tokens [batch_size, hidden_size] - scratch.x + __global MOE_TYPE* dst_rweight) { // output topk_weights [batch_size] - scratch.routing_weights + + int k = get_global_id(0); // token_idx + int off = get_global_id(1); // hidden_size offset int tok_idx = tok_index[k]; src_tok += tok_idx * HIDDEN_SIZE; @@ -92,8 +92,9 @@ KERNEL (gather_2d_ref)( #endif if (off == 0) { - int top_idx = top_index[k]; - dst_rweight[k] = src_rweight[top_idx]; + // int top_idx = top_index[k]; + // dst_rweight[k] = src_rweight[top_idx]; + dst_rweight[k] = src_rweight[tok_idx]; } } diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp index b5a553870d1a81..93066b305ab847 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl_v2/moe_opt.cpp @@ -689,7 +689,7 @@ class MOEOptImpl : public PrimitiveImplOCL { // scratch.y = down(scratch.gate) * routing_weights internal_buffers.emplace_back(layout_down_out, true); // 2: x, scratch.x has same layout with down output layout routing_layout(ov::PartialShape{batch * max_topk}, data_type, cldnn::format::bfyx); - internal_buffers.emplace_back(layout_down_out, true); // 3: routing_weights + internal_buffers.emplace_back(routing_layout, true); // 3: routing_weights internal_buffers.emplace_back(layout_gateup_out, true); // 4: gate, scratch.gate has same layout with up // expert masks for gpu layout index_layout(ov::PartialShape{batch}, ov::element::i32, cldnn::format::bfyx); @@ -927,9 +927,6 @@ class MOEOptImpl : public PrimitiveImplOCL { auto& cur_net = instance.get_network(); auto& stream = cur_net.get_stream(); - // auto cur_moe = instance.get_typed_desc(); - // const auto& moe_mlp_params = cur_moe->_mlp_params; - // const auto& mlp_params = moe_mlp_params[expert_no]; auto& dnn_stream = stream.get_onednn_stream(); auto hidden_states_layout_dt = convert_data_type(instance.input_memory_ptr(static_cast(MOEInputIndex::HIDDEN_STATES))->get_layout().data_type); @@ -1042,13 +1039,16 @@ class MOEOptImpl : public PrimitiveImplOCL { // auto topk_id_mem = scratch.topk_id; auto topk_id_mem = scratch.input_router_topk_idx; + // Wait for topk is ready + for(auto &ev : events) { + ev->wait(); + } expert_mask_cpu expert_mask; get_expert_mask_from_gpu(config, topk_id_mem, stream, expert_mask); auto& dnn_stream = stream.get_onednn_stream(); cldnn::event::ptr result_event; - // auto routing_mem_ptr = scratch.topk_weights; auto routing_mem_ptr = scratch.input_routing_weights; auto get_best_lws = [](size_t hidden_size) { const size_t candidate[] = {128, 64, 32, 16, 8}; @@ -1060,6 +1060,7 @@ class MOEOptImpl : public PrimitiveImplOCL { OPENVINO_ASSERT(false, "hidden_size=", hidden_size, " is not divisible by any of ", sizeof(candidate) / sizeof(size_t), " candidates"); }; auto lws_size = get_best_lws(_hidden_size); + // std::cout << "routing_mem_ptr layout: " << routing_mem_ptr->get_layout().to_short_string() << std::endl; OPENVINO_ASSERT(batch != 1, "batch size shouldn't be 1 for this path!"); for (size_t expert_no = 0; expert_no < config.num_expert; expert_no++) { @@ -1076,28 +1077,33 @@ class MOEOptImpl : public PrimitiveImplOCL { auto n_token = static_cast(expert_mask.batch[expert_no].size()); onednn_kernel& kernel = get_kernel(n_token, static_cast(expert_no), instance); - memory::ptr& x = scratch.x; + + ov::Shape router_wei_shape = {static_cast(1), static_cast(batch)}; + auto router_wei_layout = cldnn::layout(router_wei_shape, cldnn::data_types::f16, cldnn::format::get_default_format(router_wei_shape.size())); + auto current_expert_routing_mem_ptr = engine.create_subbuffer(*routing_mem_ptr, router_wei_layout, batch * expert_no * 2); // f16 size=2 + // auto current_expert_routing_mem_ptr = scratch.input_routing_weights; + // std::cout << "MOEOptImpl::execute expert_no=" << expert_no << ", n_token=" << n_token << ", total_token_num = " << batch << std::endl; // gather execute_stage(events, instance, *gather, - {hidden_states_mem_ptr, routing_mem_ptr, expert_mask_mem.batch, expert_mask_mem.topk}, - {x, scratch.routing_weights}, + {hidden_states_mem_ptr, current_expert_routing_mem_ptr, expert_mask_mem.batch, expert_mask_mem.topk}, + {scratch.x, scratch.routing_weights}, {static_cast(n_token), static_cast(_hidden_size)}, {1, lws_size}); // up kernel.up.forward(dnn_stream, n_token, - convert2dnnl(x, {static_cast(n_token), dnnl_weights[1].ic}, dnnl::memory::format_tag::ab), + convert2dnnl(scratch.x, {static_cast(n_token), dnnl_weights[1].ic}, dnnl::memory::format_tag::ab), convert2dnnl(scratch.up, {static_cast(n_token), _intermediate_size}, dnnl::memory::format_tag::ab), dnnl::memory()); // gate kernel.gate.forward(dnn_stream, n_token, - convert2dnnl(x, {static_cast(n_token), dnnl_weights[0].ic}, dnnl::memory::format_tag::ab), + convert2dnnl(scratch.x, {static_cast(n_token), dnnl_weights[0].ic}, dnnl::memory::format_tag::ab), convert2dnnl(scratch.gate, {static_cast(n_token), _intermediate_size}, dnnl::memory::format_tag::ab), convert2dnnl(scratch.up, {static_cast(n_token), _intermediate_size}, dnnl::memory::format_tag::ab)); diff --git a/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp b/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp index 6a427ff197f918..d88e23a2f524be 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/convert_moe_to_compressed.cpp @@ -144,7 +144,7 @@ ConvertMOEToMOECompressed::ConvertMOEToMOECompressed() { ov::copy_runtime_info(moe, moe_compressed); ov::replace_node(moe, moe_compressed); - std::cout << "ConvertMOEToMOECompressed is done : config.top_k = " << config.top_k << std::endl; + std::cout << "ConvertMOEToMOECompressed is hit : config.top_k = " << config.top_k << std::endl; return true; }; From cdfb75bbbc209a3fb0d70c1b059a4e6b1a033723 Mon Sep 17 00:00:00 2001 From: "Zhai, Xuejun" Date: Wed, 22 Oct 2025 21:20:48 +0800 Subject: [PATCH 13/13] Add test for MOECompressed Signed-off-by: Zhai, Xuejun --- .../transformations/moe_compressed_tests.cpp | 115 ++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 src/plugins/intel_gpu/tests/unit/transformations/moe_compressed_tests.cpp diff --git a/src/plugins/intel_gpu/tests/unit/transformations/moe_compressed_tests.cpp b/src/plugins/intel_gpu/tests/unit/transformations/moe_compressed_tests.cpp new file mode 100644 index 00000000000000..6de4bfc9a12cc0 --- /dev/null +++ b/src/plugins/intel_gpu/tests/unit/transformations/moe_compressed_tests.cpp @@ -0,0 +1,115 @@ +#include + +#include "common_test_utils/ov_test_utils.hpp" +#include "intel_gpu/op/moe_compressed.hpp" +#include "intel_gpu/op/placeholder.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/moe.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/shape_of.hpp" +#include "openvino/op/split.hpp" +#include "openvino/op/subtract.hpp" +#include "openvino/op/variadic_split.hpp" +#include "plugin/transformations/convert_moe_to_compressed.hpp" + +using namespace testing; +using namespace ov::intel_gpu; + +namespace ov { +namespace test { +namespace intel_gpu { +TEST_F(TransformationTestsF, MoeCompressedTest) { + { + // Construct inputs + auto hidden_states = std::make_shared(element::f32, Shape{2, 2048}); + auto routing_weights = std::make_shared(element::f32, Shape{2, 4}); + auto topk = std::make_shared(element::i32, Shape{2, 2}); + + // Construct constant weights + auto w0 = op::v0::Constant::create(element::u4, Shape{4, 768, 16, 128}, {1}); + auto zp0 = op::v0::Constant::create(element::u4, Shape{4, 768, 16, 1}, {0}); + auto scale0 = op::v0::Constant::create(element::f16, Shape{4, 768, 16, 1}, {0.01f}); + auto reshape0 = op::v0::Constant::create(element::i64, Shape{3}, {4, 768, 2048}); + + // First projection + auto w0_f16 = std::make_shared(w0, element::f16); + auto zp0_f16 = std::make_shared(zp0, element::f16); + auto sub0 = std::make_shared(w0_f16, zp0_f16); + auto mul0 = std::make_shared(sub0, scale0); + auto reshape_m0 = std::make_shared(mul0, reshape0, false); + auto convert_m0 = std::make_shared(reshape_m0, element::f32); + + // Second projection + auto w1 = op::v0::Constant::create(element::u4, Shape{4, 768, 16, 128}, {1}); + auto zp1 = op::v0::Constant::create(element::u4, Shape{4, 768, 16, 1}, {0}); + auto scale1 = op::v0::Constant::create(element::f16, Shape{4, 768, 16, 1}, {0.01f}); + auto reshape1 = op::v0::Constant::create(element::i64, Shape{3}, {4, 768, 2048}); + + auto w1_f16 = std::make_shared(w1, element::f16); + auto zp1_f16 = std::make_shared(zp1, element::f16); + auto sub1 = std::make_shared(w1_f16, zp1_f16); + auto mul1 = std::make_shared(sub1, scale1); + auto reshape_m1 = std::make_shared(mul1, reshape1, false); + auto convert_m1 = std::make_shared(reshape_m1, element::f32); + + // Third projection + auto w2 = op::v0::Constant::create(element::u4, Shape{4, 768, 16, 128}, {1}); + auto zp2 = op::v0::Constant::create(element::u4, Shape{4, 768, 16, 1}, {0}); + auto scale2 = op::v0::Constant::create(element::f16, Shape{4, 768, 16, 1}, {0.01f}); + auto reshape2 = op::v0::Constant::create(element::i64, Shape{3}, {4, 768, 2048}); + + auto w2_f16 = std::make_shared(w2, element::f16); + auto zp2_f16 = std::make_shared(zp2, element::f16); + auto sub2 = std::make_shared(w2_f16, zp2_f16); + auto mul2 = std::make_shared(sub2, scale2); + auto reshape_m2 = std::make_shared(mul2, reshape2, false); + auto convert_m2 = std::make_shared(reshape_m2, element::f32); + + // Construct MOE node + ov::op::internal::MOE::Config config; + config.expert_type = ov::op::internal::MOE::Expert_type::GEMM3_SWIGLU; + auto moe = std::make_shared(ov::OutputVector{hidden_states, routing_weights, topk, convert_m0, convert_m1, convert_m2}, config); + model = std::make_shared(moe, ov::ParameterVector{hidden_states, routing_weights, topk}); + manager.register_pass(); + } + { + // Construct inputs + auto hidden_states = std::make_shared(element::f32, Shape{2, 2048}); + auto routing_weights = std::make_shared(element::f32, Shape{2, 4}); + auto topk = std::make_shared(element::i32, Shape{2, 2}); + + // Construct constant weights + auto w0 = op::v0::Constant::create(element::u4, Shape{4, 768, 16, 128}, {1}); + auto zp0 = op::v0::Constant::create(element::u4, Shape{4, 768, 16, 1}, {0}); + auto scale0 = op::v0::Constant::create(element::f16, Shape{4, 768, 16, 1}, {0.01f}); + + // Second projection + auto w1 = op::v0::Constant::create(element::u4, Shape{4, 768, 16, 128}, {1}); + auto zp1 = op::v0::Constant::create(element::u4, Shape{4, 768, 16, 1}, {0}); + auto scale1 = op::v0::Constant::create(element::f16, Shape{4, 768, 16, 1}, {0.01f}); + + // Third projection + auto w2 = op::v0::Constant::create(element::u4, Shape{4, 768, 16, 128}, {1}); + auto zp2 = op::v0::Constant::create(element::u4, Shape{4, 768, 16, 1}, {0}); + auto scale2 = op::v0::Constant::create(element::f16, Shape{4, 768, 16, 1}, {0.01f}); + + ov::intel_gpu::op::MOECompressed::Config config; + config.hidden_size = 2048; + config.inter_size = 768; + config.num_expert = 4; + config.group_size = 128; + config.top_k = 2; + config.out_type = ov::element::f16; + auto moe_compressed = std::make_shared( + ov::OutputVector{hidden_states, routing_weights, topk, w0, scale0, zp0, w1, scale1, zp1, w2, scale2, zp2}, + config); + model_ref = std::make_shared(moe_compressed, ov::ParameterVector{hidden_states, routing_weights, topk}); + } +} +} // namespace intel_gpu +} // namespace test +} // namespace ov \ No newline at end of file