Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,22 @@
#include "openvino/core/node.hpp"
#include "openvino/core/node_vector.hpp"
#include "openvino/op/op.hpp"
#include "transformations_visibility.hpp"

namespace ov::intel_cpu {
namespace ov::op::internal {

class BatchGatherMatmul : public ov::op::Op {
class TRANSFORMATIONS_API GatherMatmul : public ov::op::Op {
public:
OPENVINO_OP("BatchGatherMatmul", "cpu_plugin_opset");
OPENVINO_OP("GatherMatmul");

BatchGatherMatmul() = default;
GatherMatmul() = default;

BatchGatherMatmul(const ov::Output<Node>& A,
const ov::Output<Node>& B,
const ov::Output<Node>& indices,
const ov::Output<Node>& bias);
GatherMatmul(const ov::Output<Node>& A,
const ov::Output<Node>& B,
const ov::Output<Node>& indices,
const ov::Output<Node>& bias);

BatchGatherMatmul(const ov::Output<Node>& A, const ov::Output<Node>& B, const ov::Output<Node>& indices);
GatherMatmul(const ov::Output<Node>& A, const ov::Output<Node>& B, const ov::Output<Node>& indices);

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

Expand All @@ -36,4 +37,4 @@ class BatchGatherMatmul : public ov::op::Op {
static constexpr bool transp_b = true;
};

} // namespace ov::intel_cpu
} // namespace ov::op::internal
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (C) 2018-2026 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <memory>

#include "openvino/core/attribute_visitor.hpp"
#include "openvino/core/node.hpp"
#include "openvino/core/node_vector.hpp"
#include "openvino/op/op.hpp"
#include "ov_ops/gather_matmul.hpp"

namespace ov::op::internal {

class TRANSFORMATIONS_API GatherMatmulCompressed : public GatherMatmul {
public:
OPENVINO_OP("GatherMatmulCompressed", "", GatherMatmul);

GatherMatmulCompressed() = default;

GatherMatmulCompressed(const ov::Output<Node>& A,
const ov::Output<Node>& B,
const ov::Output<Node>& indices,
const ov::Output<Node>& bias,
const ov::Output<Node>& weight_scales,
const ov::Output<Node>& weight_zero_points);

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

void validate_and_infer_types() override;
};

} // namespace ov::op::internal
93 changes: 93 additions & 0 deletions src/common/transformations/include/ov_ops/moe_compressed.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (C) 2018-2026 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/moe.hpp"
#include "openvino/op/op.hpp"
#include "transformations_visibility.hpp"

namespace ov::op::internal {

/// \brief MOECompressed experts that support compressed weights for GEMM3_SWIGLU MOE.
class TRANSFORMATIONS_API MOECompressed : public ov::op::internal::MOE {
public:
OPENVINO_OP("MOECompressed", "", ov::op::internal::MOE);

MOECompressed() = default;
MOECompressed(const OutputVector& args) : MOE(args) {}

enum class RoutingType { SOFTMAX, SIGMOID_BIAS };

struct Config : public MOE::Config {
size_t hidden_size = 0;
size_t inter_size = 0;
size_t num_expert = 0;
size_t top_k = 0;
// numeric_limits<size_t>::max() means per_channel compression (single group).
// other non-zero value means group compression with this given group_size.
size_t group_size = 0;
// In CB, intermediate shapes are expanded to {SeqLen, 1, HiddenSize}
// In Non-CB, intermediate shapes are expanded to {Batch, SeqLen, HiddenSize}
size_t has_batch_dim = 0;
bool has_zp = false;
ov::element::Type out_type = ov::element::dynamic;
RoutingType routing_type = RoutingType::SOFTMAX;
};

/// \brief Constructs a MOECompressed operation with config only
/// \param args The input tensors, in the following order:
/// 0: hidden_states - input tensor with hidden representations
/// 1: routing_weights - [num_experts, ...] normalized weights for selected experts
/// (input to final multiplication)
/// 2: router_topk_output_indices - [..., topk] indices of selected top-k experts
/// 3: w0_weight - expert weights for first projection,
/// shape [num_experts, inter_size, group_num, group_size]
/// 4: w0_scale - expert scale for first projection for compressed experts,
/// shape [num_experts, inter_size, group_num, 1]
/// 5: w0_zp - expert zp for first projection for compressed experts,
/// shape [num_experts, inter_size, group_num, 1]
/// 6: w1_weight - expert weights for second projection,
/// shape [num_experts, inter_size, group_num, group_size]
/// 7: w1_scale - expert scale for second projection for compressed experts,
/// shape [num_experts, inter_size, group_num, 1]
/// 8: w1_zp - expert zp for second projection for compressed experts,
/// shape [num_experts, inter_size, group_num, 1]
/// 9: w2_weight - expert weights for final projection,
/// shape [num_experts, hidden_size, group_num, group_size]
/// 10: w2_scale - expert scale for final projection for compressed experts,
/// shape [num_experts, hidden_size, group_num, 1]
/// 11: w2_zp - expert zp for final projection for compressed experts,
/// shape [num_experts, hidden_size, group_num, 1]
/// \param config Configuration for the MOE operation
MOECompressed(const OutputVector& args, const Config& config);

const Config& get_config() const;
void set_config(const Config& config);

bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

protected:
Config m_config;
};

TRANSFORMATIONS_API std::ostream& operator<<(std::ostream& s, const MOECompressed::RoutingType& type);

} // namespace ov::op::internal

namespace ov {
template <>
class AttributeAdapter<ov::op::internal::MOECompressed::RoutingType>
: public EnumAttributeAdapterBase<ov::op::internal::MOECompressed::RoutingType> {
public:
AttributeAdapter(ov::op::internal::MOECompressed::RoutingType& value)
: EnumAttributeAdapterBase<ov::op::internal::MOECompressed::RoutingType>(value) {}

OPENVINO_RTTI("AttributeAdapter<ov::op::internal::MOECompressed::RoutingType>");
~AttributeAdapter() override = default;
};

} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (C) 2018-2026 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/matcher_pass.hpp"
#include "transformations/common_optimizations/moe_op_fusion.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

// BGM-producing passes (IR → GatherMatmul)
class TRANSFORMATIONS_API ConvertTiledMoeBlockTo2GatherMatmuls;
class TRANSFORMATIONS_API ConvertTiledMoeBlockTo3GatherMatmuls;

class TRANSFORMATIONS_API ConvertTiledMoeBlockToGatherMatmuls;

} // namespace pass
} // namespace ov

// BGM-producing passes — create GatherMatmul ops + routing reconstruction
class ov::pass::ConvertTiledMoeBlockTo2GatherMatmuls : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("ConvertTiledMoeBlockTo2GatherMatmuls");
ConvertTiledMoeBlockTo2GatherMatmuls();
};

class ov::pass::ConvertTiledMoeBlockTo3GatherMatmuls : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("ConvertTiledMoeBlockTo3GatherMatmuls");
ConvertTiledMoeBlockTo3GatherMatmuls();
};

// CPU uses BGM-producing passes only (stops at BGMs)
class ov::pass::ConvertTiledMoeBlockToGatherMatmuls : public ov::pass::GraphRewrite {
public:
OPENVINO_GRAPH_REWRITE_RTTI("ConvertTiledMoeBlockToGatherMatmuls");
ConvertTiledMoeBlockToGatherMatmuls() {
add_matcher<ov::pass::ConvertTiledMoeBlockTo2GatherMatmuls>();
add_matcher<ov::pass::ConvertTiledMoeBlockTo3GatherMatmuls>();
}
};

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (C) 2018-2026 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/core/type/element_type.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

// BGM→MOE passes (GatherMatmul graph → MOE op, used by GPU)
class TRANSFORMATIONS_API Convert2GatherMatmulMoeBlockToMoeOp;
class TRANSFORMATIONS_API Convert3GatherMatmulMoeBlockToMoeOp;
class TRANSFORMATIONS_API MoeOpFusion;

} // namespace pass
} // namespace ov

// BGM→MOE passes — convert post-BGM graph (GatherMatmul + compact routing) to MOE op.
// When BGMCompressed nodes are present, produces MOECompressed instead of MOE.
class ov::pass::Convert2GatherMatmulMoeBlockToMoeOp : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("Convert2GatherMatmulMoeBlockToMoeOp");
Convert2GatherMatmulMoeBlockToMoeOp(size_t has_batch_dim = 1, ov::element::Type out_type = ov::element::dynamic);
};

class ov::pass::Convert3GatherMatmulMoeBlockToMoeOp : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("Convert3GatherMatmulMoeBlockToMoeOp");
Convert3GatherMatmulMoeBlockToMoeOp(size_t has_batch_dim = 1, ov::element::Type out_type = ov::element::f16);
};

class ov::pass::MoeOpFusion : public ov::pass::GraphRewrite {
public:
OPENVINO_GRAPH_REWRITE_RTTI("MoeOpFusion");
MoeOpFusion(size_t has_batch_dim = 1) {
add_matcher<ov::pass::Convert2GatherMatmulMoeBlockToMoeOp>(has_batch_dim);
add_matcher<ov::pass::Convert3GatherMatmulMoeBlockToMoeOp>(has_batch_dim);
}
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2018-2026 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <functional>
#include <memory>
#include <vector>

#include "openvino/pass/matcher_pass.hpp"
#include "ov_ops/gather_matmul_compressed.hpp"
#include "transformations_visibility.hpp"

namespace ov::pass {

class TRANSFORMATIONS_API ConvertGatherMatmulToGatherMatmulCompressed : public ov::pass::MatcherPass {
public:
OPENVINO_MATCHER_PASS_RTTI("ConvertGatherMatmulToGatherMatmulCompressed");

using SupportsPredicate =
std::function<bool(const std::shared_ptr<ov::op::internal::GatherMatmulCompressed>&, size_t, size_t, size_t)>;

ConvertGatherMatmulToGatherMatmulCompressed(const std::vector<ov::element::Type>& supported_activation_types,
const std::vector<ov::element::Type>& supported_weights_types,
const SupportsPredicate& supports_config = nullptr,
bool convert_u4zp_to_u8 = false);
};

} // namespace ov::pass
Loading
Loading