Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
af6463d
qwen3 moe_compressed primitive_impl
riverlijunjie Oct 20, 2025
1b34668
MOECompressed internal op
chenhu-wang Oct 20, 2025
54bb288
update
riverlijunjie Oct 20, 2025
80a6509
Update weight management
riverlijunjie Oct 20, 2025
b891113
MOECompressed internal op
chenhu-wang Oct 20, 2025
f324a12
Remove softmax_top out of moe primitive implement
riverlijunjie Oct 21, 2025
4d1d72d
MOE to MOECompressed
chenhu-wang Oct 20, 2025
cfcbacc
update
riverlijunjie Oct 21, 2025
72ab1fd
Enable moe transformation - FuseVectorizedMOE3GEMM and ConvertMOEToMO…
riverlijunjie Oct 21, 2025
7b4b2d5
update dnnl weight convert
riverlijunjie Oct 21, 2025
c05f001
Fix double free issue and kernel build errors
riverlijunjie Oct 21, 2025
d630503
Fix router weight gather issue
riverlijunjie Oct 22, 2025
d5ba90d
MOECompressed to MOEFusedCompressed
chenhu-wang Oct 23, 2025
5d63e20
Fuse softmax_topk_oneshot with moe_compressed
riverlijunjie Oct 23, 2025
c19ca31
Fix windows compiling error
riverlijunjie Oct 24, 2025
5fc87c8
Switch on FuseMOE moc transformation
riverlijunjie Oct 26, 2025
d3b8778
align scale and zp format
chenhu-wang Oct 28, 2025
3527786
minor update
riverlijunjie Oct 28, 2025
078d825
Restore keeping FuseMOE off by default
riverlijunjie Oct 29, 2025
eaac162
Update moe kernel
riverlijunjie Oct 29, 2025
f814faf
cleanup & optimizate intermediate memory
riverlijunjie Oct 30, 2025
a081a9c
inherit from moe, keep moe const pass and code clean up
chenhu-wang Oct 30, 2025
420d3f4
Switch on FuseMOE
riverlijunjie Oct 30, 2025
cb9cbe2
WA: OCL OUT_OF_RESOURCE issue when input token size < 8
riverlijunjie Oct 31, 2025
db436a8
add unit_test
zhaixuejun1993 Oct 31, 2025
1b57d17
Fix OUT_OF_RESOURCE issue and remove WA
riverlijunjie Oct 31, 2025
e025c81
WA: OCL OUT_OF_RESOURCE when input token size < 8
riverlijunjie Oct 31, 2025
71e6db6
add accuracy ut
zaixing-wang Oct 31, 2025
b277578
Fix typo issue
riverlijunjie Oct 31, 2025
79d6a13
add transformations test
chenhu-wang Nov 2, 2025
457c810
Fix 32/1024 out of resouce issue on PTL
riverlijunjie Nov 3, 2025
f7d8ead
fix
zaixing-wang Nov 3, 2025
e15bed9
Remove WA for CVS-175938
peterchen-intel Nov 3, 2025
2a10a4f
Merge branch 'master' into river/qwen3_moe_fused_compressed
peterchen-intel Nov 3, 2025
5d0101f
add supports_immad condition
zaixing-wang Nov 3, 2025
18b8cf3
Merge branch 'master' into river/qwen3_moe_fused_compressed
riverlijunjie Nov 4, 2025
0aa8c53
Update for review comments
riverlijunjie Nov 4, 2025
9465fc7
update ut
zaixing-wang Nov 4, 2025
0ead32d
separate
zaixing-wang Nov 4, 2025
8a0fa98
clean
zaixing-wang Nov 4, 2025
6c32e83
Align moe cpp file path
riverlijunjie Nov 4, 2025
358a015
Update for reviewing comments
riverlijunjie Nov 4, 2025
0eb170e
update code comment
chenhu-wang Nov 4, 2025
6f69933
Merge branch 'master' into river/qwen3_moe_fused_compressed
riverlijunjie Nov 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class MOECompressed : public ov::op::internal::MOE {
OPENVINO_OP("MOECompressed", "gpu_opset", ov::op::internal::MOE);

MOECompressed() = default;
MOECompressed(const OutputVector& args) : MOE(args) {}

struct Config : public MOE::Config {
size_t hidden_size = 0;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (C) 2018-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "intel_gpu/op/moe_compressed.hpp"

namespace ov::intel_gpu::op {

/// \brief MOEFusedCompressed that support compressed and fused MOE for GEMM3_SWIGLU.
class MOEFusedCompressed : public MOECompressed {
public:
OPENVINO_OP("MOEFusedCompressed", "gpu_opset", MOECompressed);

MOEFusedCompressed() = default;

/// \brief Constructs a MOEFusedCompressed operation with config only
/// \param args The input tensors, in the following order:
/// 0: hidden_states - input tensor with hidden representations
/// 1: routing_weights - [num_seq, num_experts] routing weights for all experts
/// 2: w0_weight - expert weights for first projection,
/// shape [num_experts, inter_size, group_num, group_size]
/// 3: w0_scale - expert scale for first projection for compressed experts,
/// shape [num_experts, inter_size, group_num, 1]
/// 4: w0_zp - expert zp for first projection for compressed experts,
/// shape [num_experts, inter_size, group_num, 1]
/// 5: w1_weight - expert weights for second projection,
/// shape [num_experts, inter_size, group_num, group_size]
/// 6: w1_scale - expert scale for second projection for compressed experts,
/// shape [num_experts, inter_size, group_num, 1]
/// 7: w1_zp - expert zp for second projection for compressed experts,
/// shape [num_experts, inter_size, group_num, 1]
/// 8: w2_weight - expert weights for final projection,
/// shape [num_experts, hidden_size, group_num, group_size]
/// 9: w2_scale - expert scale for final projection for compressed experts,
/// shape [num_experts, hidden_size, group_num, 1]
/// 10: w2_zp - expert zp for final projection for compressed experts,
/// shape [num_experts, hidden_size, group_num, 1]
/// \param config Configuration for the MOE operation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This description is for 3gemm_Swiglu_type only. Please mention that

MOEFusedCompressed(const OutputVector& args, const MOECompressed::Config config);

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
};

} // namespace ov::intel_gpu::op
Original file line number Diff line number Diff line change
Expand Up @@ -312,4 +312,5 @@ REGISTER_FACTORY(internal, PagedAttentionExtension);
REGISTER_FACTORY(internal, LoraSubgraph);
REGISTER_FACTORY(internal, LoraSubgraphFused);
REGISTER_FACTORY(internal, VLSDPA);
REGISTER_FACTORY(internal, MOEFusedCompressed);
REGISTER_FACTORY(internal, MOECompressed);
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once
#include <vector>

#include "intel_gpu/op/moe_fused_compressed.hpp"
#include "intel_gpu/runtime/engine.hpp"
#include "primitive.hpp"

namespace cldnn {
using MOEFusedCompressed = ov::intel_gpu::op::MOEFusedCompressed;

/// @brief moe compressed primitive
/// @details Performs moe compressed
struct moe_fused_compressed : public primitive_base<moe_fused_compressed> {
CLDNN_DECLARE_PRIMITIVE(moe_fused_compressed)

moe_fused_compressed() : primitive_base("", {}) {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please modify the primitive name too, for the specifc target pattern.


// @brief Constructs moe primitive / layer.
//
// @param id An identifier of new primitive.
// @param inputs A list of Input primitive ids (inputs).
// 0: hidden_states - input tensor with hidden representations
// 1: routing_weights - [num_seq, num_experts] routing weights for all experts
// 2: w0_weight - expert weights for first projection,
// shape [num_experts, inter_size, group_num, group_size]
// 3: w0_scale - expert scale for first projection for compressed experts,
// shape [num_experts, inter_size, group_num, 1]
// 4: w0_zp - expert zp for first projection for compressed experts,
// shape [num_experts, inter_size, group_num, 1]
// 5: w1_weight - expert weights for second projection,
// shape [num_experts, inter_size, group_num, group_size]
// 6: w1_scale - expert scale for second projection for compressed experts,
// shape [num_experts, inter_size, group_num, 1]
// 7: w1_zp - expert zp for second projection for compressed experts,
// shape [num_experts, inter_size, group_num, 1]
// 8: w2_weight - expert weights for final projection,
// shape [num_experts, hidden_size, group_num, group_size]
// 9: w2_scale - expert scale for final projection for compressed experts,
// shape [num_experts, hidden_size, group_num, 1]
// 10: w2_zp - expert zp for final projection for compressed experts,
//
moe_fused_compressed(const primitive_id& id, const std::vector<input_info>& inputs, const MOEFusedCompressed::Config& config)
: primitive_base(id, inputs, 1, {optional_data_type()}),
_config(config) {}

MOEFusedCompressed::Config _config;

bool operator==(const primitive& rhs) const override {
if (!compare_common_params(rhs))
return false;

auto rhs_casted = downcast<const moe_fused_compressed>(rhs);

return std::memcmp(&_config, &rhs_casted._config, sizeof(_config)) == 0;
}

void save(BinaryOutputBuffer& ob) const override {
primitive_base<moe_fused_compressed>::save(ob);
ob << make_data(&_config, sizeof(_config));
}

void load(BinaryInputBuffer& ib) override {
primitive_base<moe_fused_compressed>::load(ib);
ib >> make_data(&_config, sizeof(_config));
}
};

} // namespace cldnn
Loading
Loading