Skip to content

Commit cbaffb6

Browse files
authored
[GPU] Gated mlp (#34427)
Implement gated_mlp primitive, op for onednn and FuseGatedMLP. Default fusion is disable. (disable_gated_mlp_fusion: TRUE) Needs onednn update. Depends on #34446 ### Tickets: - [CVS-181656](https://jira.devtools.intel.com/browse/CVS-181656) --------- Signed-off-by: hyunback <hyunback.kim@intel.com>
1 parent c7c101f commit cbaffb6

File tree

21 files changed

+1880
-5
lines changed

21 files changed

+1880
-5
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright (C) 2018-2026 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/core/node.hpp"
8+
#include "openvino/op/op.hpp"
9+
#include "ov_ops/glu.hpp"
10+
11+
namespace ov::intel_gpu::op {
12+
13+
class GatedMLP : public ov::op::Op {
14+
public:
15+
OPENVINO_OP("GatedMLP", "gpu_opset");
16+
17+
GatedMLP() = default;
18+
19+
GatedMLP(const ov::Output<Node>& src,
20+
const ov::Output<Node>& w_gate,
21+
const ov::Output<Node>& w_up,
22+
const ov::Output<Node>& w_down,
23+
ov::op::internal::GLU::GluType activation,
24+
const ov::element::Type output_type = ov::element::dynamic);
25+
26+
GatedMLP(const ov::Output<Node>& src,
27+
const ov::Output<Node>& w_gate,
28+
const ov::Output<Node>& w_up,
29+
const ov::Output<Node>& w_down,
30+
const ov::Output<Node>& scale_gate,
31+
const ov::Output<Node>& scale_up,
32+
const ov::Output<Node>& scale_down,
33+
ov::op::internal::GLU::GluType activation,
34+
const ov::element::Type output_type = ov::element::dynamic);
35+
36+
GatedMLP(const ov::Output<Node>& src,
37+
const ov::Output<Node>& w_gate,
38+
const ov::Output<Node>& w_up,
39+
const ov::Output<Node>& w_down,
40+
const ov::Output<Node>& scale_gate,
41+
const ov::Output<Node>& scale_up,
42+
const ov::Output<Node>& scale_down,
43+
const ov::Output<Node>& zp_gate,
44+
const ov::Output<Node>& zp_up,
45+
const ov::Output<Node>& zp_down,
46+
ov::op::internal::GLU::GluType activation,
47+
const ov::element::Type output_type = ov::element::dynamic);
48+
49+
bool visit_attributes(ov::AttributeVisitor& visitor) override;
50+
void validate_and_infer_types() override;
51+
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
52+
53+
ov::op::internal::GLU::GluType get_activation() const { return m_activation; }
54+
ov::element::Type get_output_type() const { return m_output_type; }
55+
bool is_compressed_weights() const { return m_compressed_weights; }
56+
bool has_decompression_zero_points() const { return m_has_decompression_zero_points; }
57+
58+
private:
59+
ov::op::internal::GLU::GluType m_activation = ov::op::internal::GLU::GluType::Swish;
60+
ov::element::Type m_output_type = ov::element::dynamic;
61+
bool m_compressed_weights = false;
62+
bool m_has_decompression_zero_points = false;
63+
};
64+
65+
} // namespace ov::intel_gpu::op

src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ REGISTER_FACTORY(internal, ReadValue);
302302
REGISTER_FACTORY(internal, ReadValues);
303303
REGISTER_FACTORY(internal, Gemm);
304304
REGISTER_FACTORY(internal, GLU);
305+
REGISTER_FACTORY(internal, GatedMLP);
305306
REGISTER_FACTORY(internal, IndirectGemm);
306307
REGISTER_FACTORY(internal, Convolution);
307308
REGISTER_FACTORY(internal, Placeholder);
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
// Copyright (C) 2018-2026 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "ov_ops/glu.hpp"
8+
#include "primitive.hpp"
9+
10+
namespace cldnn {
11+
12+
struct gated_mlp : public primitive_base<gated_mlp> {
13+
CLDNN_DECLARE_PRIMITIVE(gated_mlp)
14+
15+
gated_mlp() : primitive_base("", {}) {}
16+
17+
gated_mlp(const primitive_id& id,
18+
const input_info& src,
19+
const input_info& w_gate,
20+
const input_info& w_up,
21+
const input_info& w_down,
22+
ov::op::internal::GLU::GluType activation,
23+
const tensor& output_size,
24+
const data_types output_dt)
25+
: primitive_base(id, {src}, 1, {optional_data_type{output_dt}}),
26+
weights_gate(w_gate),
27+
weights_up(w_up),
28+
weights_down(w_down),
29+
activation(activation),
30+
output_size(output_size) {}
31+
32+
gated_mlp(const primitive_id& id,
33+
const input_info& src,
34+
const input_info& w_gate,
35+
const input_info& w_up,
36+
const input_info& w_down,
37+
const input_info& scale_gate,
38+
const input_info& scale_up,
39+
const input_info& scale_down,
40+
ov::op::internal::GLU::GluType activation,
41+
const tensor& output_size,
42+
const data_types output_dt)
43+
: primitive_base(id, {src}, 1, {optional_data_type{output_dt}}),
44+
weights_gate(w_gate),
45+
weights_up(w_up),
46+
weights_down(w_down),
47+
decompression_scale_gate(scale_gate),
48+
decompression_scale_up(scale_up),
49+
decompression_scale_down(scale_down),
50+
compressed_weights(true),
51+
activation(activation),
52+
output_size(output_size) {
53+
OPENVINO_ASSERT(decompression_scale_gate.is_valid() && decompression_scale_up.is_valid() && decompression_scale_down.is_valid(),
54+
"GatedMLP compressed mode requires decompression scales.");
55+
}
56+
57+
gated_mlp(const primitive_id& id,
58+
const input_info& src,
59+
const input_info& w_gate,
60+
const input_info& w_up,
61+
const input_info& w_down,
62+
const input_info& scale_gate,
63+
const input_info& scale_up,
64+
const input_info& scale_down,
65+
const input_info& zp_gate,
66+
const input_info& zp_up,
67+
const input_info& zp_down,
68+
ov::op::internal::GLU::GluType activation,
69+
const tensor& output_size,
70+
const data_types output_dt)
71+
: primitive_base(id, {src}, 1, {optional_data_type{output_dt}}),
72+
weights_gate(w_gate),
73+
weights_up(w_up),
74+
weights_down(w_down),
75+
decompression_scale_gate(scale_gate),
76+
decompression_scale_up(scale_up),
77+
decompression_scale_down(scale_down),
78+
decompression_zero_point_gate(zp_gate),
79+
decompression_zero_point_up(zp_up),
80+
decompression_zero_point_down(zp_down),
81+
compressed_weights(true),
82+
has_decompression_zero_points(true),
83+
activation(activation),
84+
output_size(output_size) {
85+
OPENVINO_ASSERT(decompression_scale_gate.is_valid() && decompression_scale_up.is_valid() && decompression_scale_down.is_valid(),
86+
"GatedMLP compressed mode requires decompression scales.");
87+
OPENVINO_ASSERT(decompression_zero_point_gate.is_valid() && decompression_zero_point_up.is_valid() && decompression_zero_point_down.is_valid(),
88+
"GatedMLP compressed mode with zero points requires decompression zero points.");
89+
}
90+
91+
input_info weights_gate;
92+
input_info weights_up;
93+
input_info weights_down;
94+
input_info decompression_scale_gate;
95+
input_info decompression_scale_up;
96+
input_info decompression_scale_down;
97+
input_info decompression_zero_point_gate;
98+
input_info decompression_zero_point_up;
99+
input_info decompression_zero_point_down;
100+
bool compressed_weights = false;
101+
bool has_decompression_zero_points = false;
102+
ov::op::internal::GLU::GluType activation = ov::op::internal::GLU::GluType::Swish;
103+
tensor output_size;
104+
105+
size_t hash() const override {
106+
size_t seed = primitive::hash();
107+
seed = hash_combine(seed, compressed_weights);
108+
seed = hash_combine(seed, has_decompression_zero_points);
109+
seed = hash_combine(seed, static_cast<size_t>(activation));
110+
return seed;
111+
}
112+
113+
bool operator==(const primitive& rhs) const override {
114+
if (!compare_common_params(rhs))
115+
return false;
116+
auto rhs_casted = downcast<const gated_mlp>(rhs);
117+
return activation == rhs_casted.activation &&
118+
compressed_weights == rhs_casted.compressed_weights &&
119+
has_decompression_zero_points == rhs_casted.has_decompression_zero_points;
120+
}
121+
122+
void save(BinaryOutputBuffer& ob) const override {
123+
primitive_base<gated_mlp>::save(ob);
124+
ob << weights_gate;
125+
ob << weights_up;
126+
ob << weights_down;
127+
ob << decompression_scale_gate;
128+
ob << decompression_scale_up;
129+
ob << decompression_scale_down;
130+
ob << decompression_zero_point_gate;
131+
ob << decompression_zero_point_up;
132+
ob << decompression_zero_point_down;
133+
ob << compressed_weights;
134+
ob << has_decompression_zero_points;
135+
ob << make_data(&activation, sizeof(activation));
136+
ob << output_size;
137+
}
138+
139+
void load(BinaryInputBuffer& ib) override {
140+
primitive_base<gated_mlp>::load(ib);
141+
ib >> weights_gate;
142+
ib >> weights_up;
143+
ib >> weights_down;
144+
ib >> decompression_scale_gate;
145+
ib >> decompression_scale_up;
146+
ib >> decompression_scale_down;
147+
ib >> decompression_zero_point_gate;
148+
ib >> decompression_zero_point_up;
149+
ib >> decompression_zero_point_down;
150+
ib >> compressed_weights;
151+
ib >> has_decompression_zero_points;
152+
ib >> make_data(&activation, sizeof(activation));
153+
ib >> output_size;
154+
}
155+
156+
protected:
157+
std::map<size_t, const input_info*> get_dependencies_map() const override {
158+
auto ret = std::map<size_t, const input_info*>{};
159+
auto idx = input.size();
160+
161+
OPENVINO_ASSERT(weights_gate.is_valid());
162+
OPENVINO_ASSERT(weights_up.is_valid());
163+
OPENVINO_ASSERT(weights_down.is_valid());
164+
ret[idx++] = &weights_gate;
165+
ret[idx++] = &weights_up;
166+
ret[idx++] = &weights_down;
167+
168+
if (decompression_scale_gate.is_valid())
169+
ret[idx++] = &decompression_scale_gate;
170+
if (decompression_scale_up.is_valid())
171+
ret[idx++] = &decompression_scale_up;
172+
if (decompression_scale_down.is_valid())
173+
ret[idx++] = &decompression_scale_down;
174+
175+
if (decompression_zero_point_gate.is_valid())
176+
ret[idx++] = &decompression_zero_point_gate;
177+
if (decompression_zero_point_up.is_valid())
178+
ret[idx++] = &decompression_zero_point_up;
179+
if (decompression_zero_point_down.is_valid())
180+
ret[idx++] = &decompression_zero_point_down;
181+
182+
return ret;
183+
}
184+
};
185+
186+
} // namespace cldnn

src/plugins/intel_gpu/include/intel_gpu/runtime/internal_properties.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ static constexpr Property<bool, ov::PropertyMutability::RW> disable_memory_reuse
164164
static constexpr Property<size_t, ov::PropertyMutability::RW> disable_post_ops_fusions{"GPU_DISABLE_POST_OPS_FUSIONS"};
165165
static constexpr Property<bool, ov::PropertyMutability::RW> disable_horizontal_fc_fusion{"GPU_DISABLE_HORIZONTAL_FC_FUSION"};
166166
static constexpr Property<bool, ov::PropertyMutability::RW> disable_fc_swiglu_fusion{"GPU_DISABLE_FC_SWIGLU_FUSION"};
167+
static constexpr Property<bool, ov::PropertyMutability::RW> disable_gated_mlp_fusion{"GPU_DISABLE_GATED_MLP_FUSION"};
167168
static constexpr Property<bool, ov::PropertyMutability::RW> disable_fake_alignment{"GPU_DISABLE_FAKE_ALIGNMENT"};
168169
static constexpr Property<bool, ov::PropertyMutability::RW> disable_moe_opt{"GPU_DISABLE_MOE_OPT"};
169170
static constexpr Property<bool, ov::PropertyMutability::RW> disable_runtime_skip_reorder{"GPU_DISABLE_RUNTIME_SKIP_REORDER"};

src/plugins/intel_gpu/include/intel_gpu/runtime/options.inl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, disable_runtime_buffer_fusing, false, "Dis
9090
OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, disable_post_ops_fusions, 0, "Disable fusions of operations as post-ops/fused-ops. Detailed debugging is possible by entering specific numbers. 1 specifies to disable all fusions of post-ops. 2-8 specifies to enable only single fusion sub-module from fuse_reorder() to optimize_fused_opt(). 11-13 specifies to enable only single fusion sub-module in fuse_simple_primitives.")
9191
OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, disable_horizontal_fc_fusion, false, "Disable pass which merges QKV projections into single MatMul")
9292
OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, disable_fc_swiglu_fusion, false, "Disable pass which merges FC and SwiGLU ops")
93+
OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, disable_gated_mlp_fusion, true, "Disable pass which fuses FC+SwiGLU to GatedMLP")
9394
OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, disable_fake_alignment, false, "Disable fake alignment feature which tries to keep gpu friendly memory alignment for arbitrary tensor shapes")
9495
OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, disable_moe_opt, false, "Disable mixture of expert optimization")
9596
OV_CONFIG_DEBUG_OPTION(ov::intel_gpu, disable_memory_reuse, false, "Disable memory reuse for activation tensors")
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Copyright (C) 2018-2026 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "gated_mlp_inst.h"
6+
7+
#include "json_object.h"
8+
#include "matmul_shape_inference.hpp"
9+
#include "primitive_type_base.h"
10+
11+
#include <string>
12+
13+
namespace cldnn {
14+
15+
GPU_DEFINE_PRIMITIVE_TYPE_ID(gated_mlp)
16+
17+
layout gated_mlp_inst::calc_output_layout(gated_mlp_node const& node, kernel_impl_params const& impl_param) {
18+
auto desc = impl_param.typed_desc<gated_mlp>();
19+
auto input_layout = impl_param.get_input_layout();
20+
auto output_type = impl_param.desc->output_data_types[0].value_or(input_layout.data_type);
21+
auto output_format = input_layout.format;
22+
23+
return layout(output_type, output_format, desc->output_size);
24+
}
25+
26+
template <typename ShapeType>
27+
std::vector<layout> gated_mlp_inst::calc_output_layouts(gated_mlp_node const& node, const kernel_impl_params& impl_param) {
28+
auto desc = impl_param.typed_desc<gated_mlp>();
29+
auto input_layout = impl_param.get_input_layout();
30+
auto output_type = impl_param.desc->output_data_types[0].value_or(input_layout.data_type);
31+
auto output_format = input_layout.format;
32+
33+
std::vector<ShapeType> input_shapes = {
34+
impl_param.get_input_layout(0).get<ShapeType>(),
35+
impl_param.get_input_layout(1).get<ShapeType>(),
36+
impl_param.get_input_layout(2).get<ShapeType>(),
37+
impl_param.get_input_layout(3).get<ShapeType>()
38+
};
39+
40+
ov::op::v0::MatMul matmul;
41+
matmul.set_transpose_a(false);
42+
matmul.set_transpose_b(false);
43+
44+
auto up_shapes = ov::op::v0::shape_infer(&matmul, std::vector<ShapeType>{input_shapes[0], input_shapes[2]});
45+
auto gate_shapes = ov::op::v0::shape_infer(&matmul, std::vector<ShapeType>{input_shapes[0], input_shapes[1]});
46+
47+
OPENVINO_ASSERT(up_shapes[0].compatible(gate_shapes[0]),
48+
"GatedMLP requires gate/up projection output shapes to match.");
49+
50+
auto out_shapes = ov::op::v0::shape_infer(&matmul, std::vector<ShapeType>{up_shapes[0], input_shapes[3]});
51+
52+
return {layout(out_shapes[0], output_type, output_format)};
53+
}
54+
55+
template std::vector<layout> gated_mlp_inst::calc_output_layouts<ov::PartialShape>(gated_mlp_node const& node,
56+
const kernel_impl_params& impl_param);
57+
58+
std::string gated_mlp_inst::to_string(gated_mlp_node const& node) {
59+
auto desc = node.get_primitive();
60+
auto node_info = node.desc_to_json();
61+
62+
std::stringstream primitive_description;
63+
json_composite gated_mlp_info;
64+
gated_mlp_info.add("input_id", node.input().id());
65+
gated_mlp_info.add("weights_gate_id", node.weights_gate().id());
66+
gated_mlp_info.add("weights_up_id", node.weights_up().id());
67+
gated_mlp_info.add("weights_down_id", node.weights_down().id());
68+
gated_mlp_info.add("compressed_weights", desc->compressed_weights);
69+
gated_mlp_info.add("has_decompression_zero_points", desc->has_decompression_zero_points);
70+
if (desc->compressed_weights) {
71+
gated_mlp_info.add("decompression_scale_gate_id", node.decompression_scale_gate().id());
72+
gated_mlp_info.add("decompression_scale_up_id", node.decompression_scale_up().id());
73+
gated_mlp_info.add("decompression_scale_down_id", node.decompression_scale_down().id());
74+
if (desc->has_decompression_zero_points) {
75+
gated_mlp_info.add("decompression_zero_point_gate_id", node.decompression_zero_point_gate().id());
76+
gated_mlp_info.add("decompression_zero_point_up_id", node.decompression_zero_point_up().id());
77+
gated_mlp_info.add("decompression_zero_point_down_id", node.decompression_zero_point_down().id());
78+
}
79+
}
80+
gated_mlp_info.add("activation", static_cast<int64_t>(desc->activation));
81+
82+
node_info->add("gated_mlp_info", gated_mlp_info);
83+
node_info->dump(primitive_description);
84+
return primitive_description.str();
85+
}
86+
87+
gated_mlp_inst::typed_primitive_inst(network& network, gated_mlp_node const& node) : parent(network, node) {}
88+
89+
} // namespace cldnn

src/plugins/intel_gpu/src/graph/graph_optimizer/add_required_reorders.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "program_node.h"
88
#include "convert_color_inst.h"
99
#include "fully_connected_inst.h"
10+
#include "gated_mlp_inst.h"
1011
#include "assign_inst.h"
1112
#include "mvn_inst.h"
1213

@@ -276,7 +277,7 @@ void add_required_reorders::run(program& p) {
276277
continue;
277278

278279
bool correct_layout_selected = false;
279-
bool weights_data = (usr->is_type<convolution>() || usr->is_type<deconvolution>() || usr->is_type<fully_connected>());
280+
bool weights_data = (usr->is_type<convolution>() || usr->is_type<deconvolution>() || usr->is_type<fully_connected>() || usr->is_type<gated_mlp>());
280281

281282
layout original_layout = usr->get_output_layout();
282283

@@ -339,6 +340,6 @@ void add_required_reorders::run(program& p) {
339340
OPENVINO_ASSERT(correct_layout_selected,
340341
"[GPU] No layout format available for ", usr->id(), ", impl_type: ", usr->get_preferred_impl_type(),
341342
" (format: ", original_layout.format.to_string(),
342-
", data_type: ", ov::element::Type(original_layout.data_type), ") ");
343+
", data_type: ", ov::element::Type(original_layout.data_type), ") ", original_layout.to_string(), ", ", correct_layout_selected);
343344
}
344345
}

0 commit comments

Comments
 (0)