From c65dd566077336d35162f1a5ac61ccf8cbddc2dd Mon Sep 17 00:00:00 2001 From: "Lv, Tao A" Date: Thu, 5 Mar 2026 01:05:57 -0800 Subject: [PATCH 01/15] examples: graph: update gated-mlp with intermediate data type --- examples/graph/gated_mlp.cpp | 25 +++++++++++++++++------ examples/graph/gated_mlp_wei_combined.cpp | 25 +++++++++++++++++------ 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/examples/graph/gated_mlp.cpp b/examples/graph/gated_mlp.cpp index 81bf0496013..7670df83216 100644 --- a/examples/graph/gated_mlp.cpp +++ b/examples/graph/gated_mlp.cpp @@ -108,44 +108,56 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt, // Incremental IDs used to create logical tensors and operations. size_t id = 0; + // Intermediate data type + const logical_tensor::data_type dt_inter = logical_tensor::data_type::f32; + // fc_gate auto src = logical_tensor(id++, dt, src_sz, layout_type::strided); auto wei0 = logical_tensor(id++, dt, wei0_sz, layout_type::strided); - auto out0 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto out0 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided); auto fc_gate = op(id++, op::kind::MatMul, "fc_gate"); fc_gate.add_inputs({src, wei0}); fc_gate.add_outputs({out0}); // fc_up auto wei1 = logical_tensor(id++, dt, wei0_sz, layout_type::strided); - auto out1 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto out1 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided); auto fc_up = op(id++, op::kind::MatMul, "fc_up"); fc_up.add_inputs({src, wei1}); fc_up.add_outputs({out1}); // activation swish: sigmoid - auto out2 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto out2 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided); auto swi_sig = op(id++, op::kind::Sigmoid, "swish/sigmoid"); swi_sig.add_inputs({out0}); swi_sig.add_outputs({out2}); // activation swish: multiply - auto out3 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto out3 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided); auto swi_mul = op(id++, op::kind::Multiply, "swish/multiply"); swi_mul.add_inputs({out0, out2}); swi_mul.add_outputs({out3}); // multiplication - auto out4 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto out4 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided); auto mul = op(id++, op::kind::Multiply, "mul"); mul.add_inputs({out3, out1}); mul.add_outputs({out4}); + // downconversion when needed + auto out4_dt = out4; + auto typecast = op(id++, op::kind::TypeCast, "typecast"); + if (dt != dt_inter) { + out4_dt = logical_tensor(id++, dt, hd_sz, layout_type::strided); + typecast.add_inputs({out4}); + typecast.add_outputs({out4_dt}); + } + // fc_down auto wei2 = logical_tensor(id++, dt, wei2_sz, layout_type::strided); auto dst = logical_tensor(id++, dt, out_sz, layout_type::strided); auto fc_down = op(id++, op::kind::MatMul, "fc_down"); - fc_down.add_inputs({out4, wei2}); + fc_down.add_inputs({out4_dt, wei2}); fc_down.add_outputs({dst}); // Construct a gated mlp graph with engine kind and operations. @@ -155,6 +167,7 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt, mlp.add_op(swi_sig); mlp.add_op(swi_mul); mlp.add_op(mul); + if (dt != dt_inter) mlp.add_op(typecast); mlp.add_op(fc_down); mlp.finalize(); diff --git a/examples/graph/gated_mlp_wei_combined.cpp b/examples/graph/gated_mlp_wei_combined.cpp index f8f3ecb4f46..488eec89384 100644 --- a/examples/graph/gated_mlp_wei_combined.cpp +++ b/examples/graph/gated_mlp_wei_combined.cpp @@ -125,6 +125,9 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt, // Incremental IDs used to create logical tensors and operations. size_t id = 0; + // Intermediate data type + const logical_tensor::data_type dt_inter = logical_tensor::data_type::f32; + // This logical tensor is not part of the graph but is used to generate the // big chunk of device memory which should be already there in real user // application or framework. @@ -134,41 +137,50 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt, // fc_gate: wei0 is non-contiguous now. auto src = logical_tensor(id++, dt, src_sz, layout_type::strided); auto wei0 = logical_tensor(id++, dt, wei0_sz, combined_wei0_st); - auto out0 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto out0 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided); auto fc_gate = op(id++, op::kind::MatMul, "fc_gate"); fc_gate.add_inputs({src, wei0}); fc_gate.add_outputs({out0}); // fc_up: wei1 is non-contiguous now. auto wei1 = logical_tensor(id++, dt, wei0_sz, combined_wei0_st); - auto out1 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto out1 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided); auto fc_up = op(id++, op::kind::MatMul, "fc_up"); fc_up.add_inputs({src, wei1}); fc_up.add_outputs({out1}); // activation swish: sigmoid - auto out2 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto out2 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided); auto swi_sig = op(id++, op::kind::Sigmoid, "swish/sigmoid"); swi_sig.add_inputs({out0}); swi_sig.add_outputs({out2}); // activation swish: multiply - auto out3 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto out3 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided); auto swi_mul = op(id++, op::kind::Multiply, "swish/multiply"); swi_mul.add_inputs({out0, out2}); swi_mul.add_outputs({out3}); // multiplication - auto out4 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto out4 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided); auto mul = op(id++, op::kind::Multiply, "mul"); mul.add_inputs({out3, out1}); mul.add_outputs({out4}); + // downconversion when needed + auto out4_dt = out4; + auto typecast = op(id++, op::kind::TypeCast, "typecast"); + if (dt != dt_inter) { + out4_dt = logical_tensor(id++, dt, hd_sz, layout_type::strided); + typecast.add_inputs({out4}); + typecast.add_outputs({out4_dt}); + } + // fc_down auto wei2 = logical_tensor(id++, dt, wei2_sz, layout_type::strided); auto dst = logical_tensor(id++, dt, out_sz, layout_type::strided); auto fc_down = op(id++, op::kind::MatMul, "fc_down"); - fc_down.add_inputs({out4, wei2}); + fc_down.add_inputs({out4_dt, wei2}); fc_down.add_outputs({dst}); // Construct a gated mlp graph with engine kind and operations. @@ -178,6 +190,7 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt, mlp.add_op(swi_sig); mlp.add_op(swi_mul); mlp.add_op(mul); + if (dt != dt_inter) mlp.add_op(typecast); mlp.add_op(fc_down); mlp.finalize(); From 28b08112929e2d909c672130501e7c689e6b5a0d Mon Sep 17 00:00:00 2001 From: "Lv, Tao A" Date: Thu, 5 Mar 2026 01:18:10 -0800 Subject: [PATCH 02/15] graph: backend: dnnl: patterns: support typecast in gated mlp --- src/graph/backend/dnnl/patterns/mlp.cpp | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/graph/backend/dnnl/patterns/mlp.cpp b/src/graph/backend/dnnl/patterns/mlp.cpp index 4f1bf9aa7da..08c78746f32 100644 --- a/src/graph/backend/dnnl/patterns/mlp.cpp +++ b/src/graph/backend/dnnl/patterns/mlp.cpp @@ -72,9 +72,18 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, gated_mlp) auto bin = pgraph->append_alternation( get_binary_ops(), edges); + // optional typecast + auto tc = std::make_shared(); + pm::pb_op_t *ptypecast + = tc->append_op(graph::op_kind::TypeCast); + tc->create_input_port(0, ptypecast, 0); + tc->create_output_port(0, ptypecast, 0); + auto pre_tc + = pgraph->append_optional(tc, {in_edge(0, bin, 0)}); + // fc_down pgraph->append_op(graph::op_kind::MatMul, - in_edges_t {in_edge(0, bin, 0)}); + in_edges_t {in_edge(0, pre_tc, 0)}); }) .set_attr("FCreateKernel", []() -> kernel_ptr { return std::make_shared(); @@ -107,9 +116,18 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, gated_mlp_v1) auto bin = pgraph->append_alternation( get_binary_ops(), edges); + // optional typecast + auto tc = std::make_shared(); + pm::pb_op_t *ptypecast + = tc->append_op(graph::op_kind::TypeCast); + tc->create_input_port(0, ptypecast, 0); + tc->create_output_port(0, ptypecast, 0); + auto pre_tc + = pgraph->append_optional(tc, {in_edge(0, bin, 0)}); + // fc_down pgraph->append_op(graph::op_kind::MatMul, - in_edges_t {in_edge(0, bin, 0)}); + in_edges_t {in_edge(0, pre_tc, 0)}); }) .set_attr("FCreateKernel", []() -> kernel_ptr { return std::make_shared(); From c8ddcc8a4b9c463db8b27e5b52d8ec8fb1b1105a Mon Sep 17 00:00:00 2001 From: "Lv, Tao A" Date: Thu, 5 Mar 2026 07:03:40 -0800 Subject: [PATCH 03/15] benchdnn: inputs: graph: add cases for f16/bf16 gated mlp --- .../graph/complex_fusion/harness_mlp_ci | 3 +- .../complex_fusion/mlp/gated-mlp-f16-f32.json | 403 ++++++++++++++++++ 2 files changed, 405 insertions(+), 1 deletion(-) create mode 100644 tests/benchdnn/inputs/graph/complex_fusion/mlp/gated-mlp-f16-f32.json diff --git a/tests/benchdnn/inputs/graph/complex_fusion/harness_mlp_ci b/tests/benchdnn/inputs/graph/complex_fusion/harness_mlp_ci index f0454a5c584..9c92e8308ca 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/harness_mlp_ci +++ b/tests/benchdnn/inputs/graph/complex_fusion/harness_mlp_ci @@ -1,4 +1,5 @@ ---reset --dt=bf16,f16 --case=complex_fusion/mlp/gated-mlp-f32.json +--reset --case=complex_fusion/mlp/gated-mlp-f16-f32.json +--reset --dt=0:bf16+1:bf16+4:bf16+14:bf16+15:bf16+16:bf16 --case=complex_fusion/mlp/gated-mlp-f16-f32.json # WA1: use smaller problem to pass correctness check for f32 on pvc. # WA2: use subtract binary to avoid precision issue for f32 on xe-lpg. diff --git a/tests/benchdnn/inputs/graph/complex_fusion/mlp/gated-mlp-f16-f32.json b/tests/benchdnn/inputs/graph/complex_fusion/mlp/gated-mlp-f16-f32.json new file mode 100644 index 00000000000..05c713c651b --- /dev/null +++ b/tests/benchdnn/inputs/graph/complex_fusion/mlp/gated-mlp-f16-f32.json @@ -0,0 +1,403 @@ +{ + "version": "3.12.0", + "engine_kind": "cpu", + "fpmath_mode": "strict", + "fpmath_mode_apply_to_int": "false", + "input_ports": [ + 0, + 1, + 0, + 4, + 15 + ], + "output_ports": [ + 16 + ], + "graph": [ + { + "id": 3, + "name": "fc_gate", + "kind": "MatMul", + "attrs": { + "transpose_a": { + "type": "bool", + "value": 0 + }, + "transpose_b": { + "type": "bool", + "value": 0 + }, + "accumulation_mode": { + "type": "string", + "value": "strict" + } + }, + "inputs": [ + { + "id": 0, + "dtype": "f16", + "shape": [ + 1, + 4096 + ], + "stride": [ + 4096, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 1, + "dtype": "f16", + "shape": [ + 4096, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 2, + "dtype": "f32", + "shape": [ + 1, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 8, + "name": "swish/sigmoid", + "kind": "Sigmoid", + "attrs": {}, + "inputs": [ + { + "id": 2, + "dtype": "f32", + "shape": [ + 1, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 7, + "dtype": "f32", + "shape": [ + 1, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 10, + "name": "swish/multiply", + "kind": "Multiply", + "attrs": { + "auto_broadcast": { + "type": "string", + "value": "numpy" + } + }, + "inputs": [ + { + "id": 2, + "dtype": "f32", + "shape": [ + 1, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 7, + "dtype": "f32", + "shape": [ + 1, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 9, + "dtype": "f32", + "shape": [ + 1, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 6, + "name": "fc_up", + "kind": "MatMul", + "attrs": { + "transpose_a": { + "type": "bool", + "value": 0 + }, + "transpose_b": { + "type": "bool", + "value": 0 + }, + "accumulation_mode": { + "type": "string", + "value": "strict" + } + }, + "inputs": [ + { + "id": 0, + "dtype": "f16", + "shape": [ + 1, + 4096 + ], + "stride": [ + 4096, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 4, + "dtype": "f16", + "shape": [ + 4096, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 5, + "dtype": "f32", + "shape": [ + 1, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 12, + "name": "mul", + "kind": "Multiply", + "attrs": { + "auto_broadcast": { + "type": "string", + "value": "numpy" + } + }, + "inputs": [ + { + "id": 9, + "dtype": "f32", + "shape": [ + 1, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 5, + "dtype": "f32", + "shape": [ + 1, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 11, + "dtype": "f32", + "shape": [ + 1, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 13, + "name": "typecast", + "kind": "TypeCast", + "attrs": {}, + "inputs": [ + { + "id": 11, + "dtype": "f32", + "shape": [ + 1, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 14, + "dtype": "f16", + "shape": [ + 1, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 17, + "name": "fc_down", + "kind": "MatMul", + "attrs": { + "transpose_a": { + "type": "bool", + "value": 0 + }, + "transpose_b": { + "type": "bool", + "value": 0 + }, + "accumulation_mode": { + "type": "string", + "value": "strict" + } + }, + "inputs": [ + { + "id": 14, + "dtype": "f16", + "shape": [ + 1, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 15, + "dtype": "f16", + "shape": [ + 14336, + 4096 + ], + "stride": [ + 4096, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 16, + "dtype": "f16", + "shape": [ + 1, + 4096 + ], + "stride": [ + 4096, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + } + ] +} From ade83113794c4c26e4d50c1003a2f50d6355c807 Mon Sep 17 00:00:00 2001 From: "Lv, Tao A" Date: Thu, 5 Mar 2026 18:34:03 -0800 Subject: [PATCH 04/15] doc: graph: patterns: clarify data types for bf16/f16 gated mlp --- doc/graph/fusion_patterns/gated_mlp.md | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/doc/graph/fusion_patterns/gated_mlp.md b/doc/graph/fusion_patterns/gated_mlp.md index 73611ab6e1f..7a47dbae61e 100644 --- a/doc/graph/fusion_patterns/gated_mlp.md +++ b/doc/graph/fusion_patterns/gated_mlp.md @@ -80,10 +80,13 @@ optional. ## Data Types -oneDNN supports the floating-point Gated-MLP pattern with data types f32, bf16, -and f16. You can specify the data type via the input and output data type fields -of logical tensors for each operation. oneDNN does not support mixing different -floating data types in a floating-point Gated-MLP pattern. +oneDNN supports the floating-point Gated-MLP pattern with data types `f32`, `bf16`, +and `f16`. You can specify the data type via the input and output data type fields +of logical tensors for each operation. For `bf16` and `f16` Gated-MLP, the output +data types of the UP and Gate MatMuls need to be `f32` to preserve the accuracy of +intermediate results. A [TypeCast](@ref dev_guide_op_typecast) operation is +needed before the Down MatMul to downconvert the intermediate results from `f32` +to `bf16` or `f16`. The definition of the data types and support status on different CPU and GPU platforms follow the general description in @ref dev_guide_data_types. From c5993174994263f88ef10de4b216e8050dd3316b Mon Sep 17 00:00:00 2001 From: "Lv, Tao A" Date: Tue, 3 Mar 2026 21:26:47 -0800 Subject: [PATCH 05/15] graph: interface: init internal gated_mlp op --- src/graph/interface/c_types_map.hpp | 1 + src/graph/interface/op.hpp | 1 + src/graph/interface/op_def.hpp | 16 ++++++++++++++ src/graph/interface/opset.hpp | 1 + src/graph/interface/shape_infer.cpp | 34 +++++++++++++++++++++++++++++ src/graph/interface/shape_infer.hpp | 4 ++++ 6 files changed, 57 insertions(+) diff --git a/src/graph/interface/c_types_map.hpp b/src/graph/interface/c_types_map.hpp index aba3a8a05aa..e39eb9fc8bc 100644 --- a/src/graph/interface/c_types_map.hpp +++ b/src/graph/interface/c_types_map.hpp @@ -269,6 +269,7 @@ const op_kind_t _sdpa = 1069; const op_kind_t _host_scalar = 1070; const op_kind_t _identity = 1071; const op_kind_t _dropout = 1072; +const op_kind_t _gated_mlp = 1073; } // namespace op_kind using op_attr_t = typename std::underlying_type::type; diff --git a/src/graph/interface/op.hpp b/src/graph/interface/op.hpp index 61be61ab972..446405e17d6 100644 --- a/src/graph/interface/op.hpp +++ b/src/graph/interface/op.hpp @@ -545,6 +545,7 @@ struct dnnl_graph_op : public std::enable_shared_from_this { CASE(_host_scalar); CASE(_identity); CASE(_dropout); + CASE(_gated_mlp); default: return "undefined_op"; } #undef CASE diff --git a/src/graph/interface/op_def.hpp b/src/graph/interface/op_def.hpp index 054d2c19e3c..fb069445412 100644 --- a/src/graph/interface/op_def.hpp +++ b/src/graph/interface/op_def.hpp @@ -2375,6 +2375,22 @@ DNNL_GRAPH_OP_SCHEMA(_sdpa, 1, .set_attr(op_attr::vs_acc_mode, true, attribute_kind::s) .set_shape_inference_function(infer_dnnl_sdpa_output_shape)) +DNNL_GRAPH_OP_SCHEMA(_gated_mlp, 1, + op_schema_t() + .set_inputs_option(op_schema_t::param_num_option::variadic) + .set_num_inputs(std::set({4, 32})) + .set_num_outputs(2) + .set_input(0, "src") + .set_input(1, "gate_weights") + .set_input(2, "up_weights") + .set_input(3, "down_weights") + .set_output(0, "dst") + .set_output(1, "scratchpad") + .set_attr(op_attr::fusion_info, false, + attribute_kind::fusion_info) + .set_attr(op_attr::alg_kind, true, attribute_kind::i) + .set_shape_inference_function(infer_gated_mlp_output_shape)) + } // namespace graph } // namespace impl } // namespace dnnl diff --git a/src/graph/interface/opset.hpp b/src/graph/interface/opset.hpp index 6531592a6f1..6fea9a203ac 100644 --- a/src/graph/interface/opset.hpp +++ b/src/graph/interface/opset.hpp @@ -196,6 +196,7 @@ class opset_v1_t { fn(get_op_schema()); fn(get_op_schema()); fn(get_op_schema()); + fn(get_op_schema()); } }; diff --git a/src/graph/interface/shape_infer.cpp b/src/graph/interface/shape_infer.cpp index 188f3f11c32..a5c35d93c60 100644 --- a/src/graph/interface/shape_infer.cpp +++ b/src/graph/interface/shape_infer.cpp @@ -2416,6 +2416,40 @@ status_t infer_dnnl_layernorm_output_shape(op_t *n, return status::success; } +status_t infer_gated_mlp_output_shape(op_t *n, + std::vector &inputs, + std::vector &outputs) { + auto src = ltw(inputs[0]); + auto wei0 = ltw(inputs[1]); + auto wei1 = ltw(inputs[2]); + auto wei2 = ltw(inputs[3]); + auto dst = ltw(outputs[0]); + + auto wei0_ndims = wei0.ndims(); + + VCHECK_INVALID_SHAPE(wei0_ndims == 2, + "%s, only support 2D weight for gated mlp, but got weight dim: %d", + op_t::kind2str(n->get_kind()).c_str(), wei0_ndims); + VCHECK_INVALID_SHAPE(wei0.vdims() == wei1.vdims(), + "%s, wei0 and wei1 should have the same shape, but got wei0 shape: " + "%s, wei1 shape: %s", + op_t::kind2str(n->get_kind()).c_str(), + dims2str(wei0.vdims()).c_str(), dims2str(wei1.vdims()).c_str()); + + dims inferred = src.vdims(); + inferred.back() = wei2.vdims().back(); + + if (dst.ndims() != -1) { + VCHECK_INVALID_SHAPE(validate(inferred, dst.vdims()), + "%s, inferred out shape is not compatible with the given " + "output shape", + op_t::kind2str(n->get_kind()).c_str()); + } + set_shape_and_strides(*outputs[0], inferred); + + return status::success; +} + } // namespace graph } // namespace impl } // namespace dnnl diff --git a/src/graph/interface/shape_infer.hpp b/src/graph/interface/shape_infer.hpp index 077d867c8b8..c63bc7be914 100644 --- a/src/graph/interface/shape_infer.hpp +++ b/src/graph/interface/shape_infer.hpp @@ -311,6 +311,10 @@ status_t infer_dnnl_layernorm_output_shape(op_t *n, std::vector &inputs, std::vector &outputs); +status_t infer_gated_mlp_output_shape(op_t *n, + std::vector &inputs, + std::vector &outputs); + } // namespace graph } // namespace impl } // namespace dnnl From 20b5bc769b2653884ebc3821f71a260cd098b98e Mon Sep 17 00:00:00 2001 From: "Lv, Tao A" Date: Tue, 3 Mar 2026 19:32:03 -0800 Subject: [PATCH 06/15] graph: backend: dnnl: init gated_mlp kernels --- src/graph/backend/dnnl/kernels/gated_mlp.hpp | 112 +++++++++ .../dnnl/kernels/gated_mlp_primitive.cpp | 226 ++++++++++++++++++ .../dnnl/kernels/gated_mlp_primitive.hpp | 98 ++++++++ src/graph/backend/dnnl/patterns/mlp.cpp | 5 +- 4 files changed, 439 insertions(+), 2 deletions(-) create mode 100644 src/graph/backend/dnnl/kernels/gated_mlp.hpp create mode 100644 src/graph/backend/dnnl/kernels/gated_mlp_primitive.cpp create mode 100644 src/graph/backend/dnnl/kernels/gated_mlp_primitive.hpp diff --git a/src/graph/backend/dnnl/kernels/gated_mlp.hpp b/src/graph/backend/dnnl/kernels/gated_mlp.hpp new file mode 100644 index 00000000000..7fdbc447768 --- /dev/null +++ b/src/graph/backend/dnnl/kernels/gated_mlp.hpp @@ -0,0 +1,112 @@ +/******************************************************************************* +* Copyright 2026 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GRAPH_BACKEND_DNNL_KERNELS_GATED_MLP_HPP +#define GRAPH_BACKEND_DNNL_KERNELS_GATED_MLP_HPP + +#include +#include +#include +#include +#include + +#include "graph/backend/dnnl/kernels/gated_mlp_primitive.hpp" +#include "graph/backend/dnnl/kernels/kernel_base.hpp" +#include "graph/backend/dnnl/kernels/large_partition.hpp" + +#include "graph/backend/dnnl/dnnl_partition_impl.hpp" + +#define VDISPATCH_GRAPH_GATED_MLP(msg, ...) \ + VINFO(graph, create, dispatch, compile, msg, ##__VA_ARGS__) + +namespace dnnl { +namespace impl { +namespace graph { +namespace dnnl_impl { + +struct gated_mlp_base_t : public kernel_base_t { +private: + std::shared_ptr kernel; + +public: + status_t compile_impl(const dnnl_partition_impl_t *part, + const engine_t *engine, const std::vector &inputs, + const std::vector &outputs) override { + const engine_kind_t ekind = engine->kind(); + bool enable_ukernel = false; + + if (ekind == engine_kind::gpu) { enable_ukernel = !force_primitive(); } + + status_t ret = status::unimplemented; + + if (enable_ukernel) { + kernel = std::make_shared(); + ret = kernel->compile_impl(part, engine, inputs, outputs); + } + + if (ret != status::success) { + kernel = std::make_shared(); + ret = kernel->compile_impl(part, engine, inputs, outputs); + } + if (ret == status::success) + VDISPATCH_GRAPH_GATED_MLP( + "gated_mlp is dispatched to (%s)", kernel->str().c_str()); + else + VDISPATCH_GRAPH_GATED_MLP("gated_mlp is failed to dispatch"); + return ret; + } + + // Use large partition kernel as defautl. Turn the env var to 0 to select + // gated_mlp primitive. + bool force_primitive() const { + const int force = graph::utils::getenv_int_internal( + "GRAPH_GATED_MLP_FORCE_PRIMITIVE", 1); + return force > 0; + } + + status_t execute_impl(const stream_t *stream, + const std::vector &inputs, + const std::vector &outputs) override { + return kernel->execute_impl(stream, inputs, outputs); + } + +#ifdef DNNL_WITH_SYCL + status_t sycl_execute_impl(const stream_t *stream, + const std::vector &inputs, + const std::vector &outputs, + const std::vector<::sycl::event> &deps, + ::sycl::event *event) override { + return kernel->sycl_execute_impl(stream, inputs, outputs, deps, event); + } +#endif + +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL + status_t ocl_execute_impl(const stream_t *stream, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &deps, cl_event *event) override { + return kernel->ocl_execute_impl(stream, inputs, outputs, deps, event); + } +#endif + + std::string str() const override { return kernel->str(); } +}; +} // namespace dnnl_impl +} // namespace graph +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/graph/backend/dnnl/kernels/gated_mlp_primitive.cpp b/src/graph/backend/dnnl/kernels/gated_mlp_primitive.cpp new file mode 100644 index 00000000000..9a6d9feae99 --- /dev/null +++ b/src/graph/backend/dnnl/kernels/gated_mlp_primitive.cpp @@ -0,0 +1,226 @@ +/******************************************************************************* +* Copyright 2026 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "graph/backend/dnnl/kernels/gated_mlp_primitive.hpp" + +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL +#include "gpu/intel/ocl/stream.hpp" +#elif DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL +#include "gpu/intel/sycl/stream.hpp" +#endif + +#include "graph/backend/dnnl/passes/compile_ops.hpp" +#include "graph/backend/dnnl/passes/constant_propagation.hpp" +#include "graph/backend/dnnl/passes/insert_ops.hpp" +#include "graph/backend/dnnl/passes/layout_propagation.hpp" +#include "graph/backend/dnnl/passes/lower.hpp" +#include "graph/backend/dnnl/passes/memory_planning.hpp" +#include "graph/backend/dnnl/passes/transform.hpp" +#include "graph/backend/dnnl/passes/utils.hpp" + +#include "graph/backend/dnnl/op_executable.hpp" + +namespace dnnl { +namespace impl { +namespace graph { +namespace dnnl_impl { + +status_t gated_mlp_primitive_kernel_t::compile_impl( + const dnnl_partition_impl_t *part, const engine_t *eng, + const std::vector &inputs, + const std::vector &outputs) { +// gated_mlp_primitive_kernel_t only supports Intel GPU. +#if defined(DNNL_WITH_SYCL) && DNNL_GPU_VENDOR != DNNL_VENDOR_INTEL + return status::unimplemented; +#endif + + p_engine_ = make_dnnl_engine(*eng); + g_alloc_ = reinterpret_cast(eng->get_allocator()); + + // First, dry run on a deep copy + subgraph_ + = std::make_shared(graph_t::deep_copy(part->get_ops()), + p_engine_, part->get_fpmath_mode(), false, true); + CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs)); + + subgraph_visualizer_t vis(part->id(), [this](const value_t *val) { + return this->memory_planner_.get_memory_info(val); + }); + pass_pipeline_t pipeline = pass_pipeline_t(vis); + + BACKEND_DNNL_ADD_PASS(pipeline, lower_down); + pipeline.reset_visualize_arg(true, false); + BACKEND_DNNL_ADD_PASS(pipeline, infer_shape); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_gated_mlp); + BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation); + + // bind the memory for each op + auto memory_plan = [&](std::shared_ptr &sg) { + return memory_planner_.run(sg); + }; + pipeline.reset_visualize_arg(true, true); + BACKEND_DNNL_ADD_PASS(pipeline, memory_plan); + BACKEND_DNNL_ADD_PASS(pipeline, compile_ops); + + // Run the added passes + BACKEND_DNNL_CHECK(pipeline.run(subgraph_)); + + // fill information for inputs logical tensors + for (size_t i = 0; i < inputs.size(); i++) { + auto &in = const_cast(inputs[i]); + in = subgraph_->ins_[i]; + } + + // fill information for outputs logical tensors + for (size_t i = 0; i < outputs.size(); i++) { + auto &out = const_cast(outputs[i]); + out = subgraph_->outs_[i]; + } + + resource_ctor_ = [this]() { + return this->memory_planner_.get_exec_args_set().clone(); + }; + + return status::success; +} + +void gated_mlp_primitive_kernel_t::prepare_args_set( + const execution_args_set_t *res, const std::vector &inputs, + const std::vector &outputs, const scratchpad_t &scratchpad) { + // update the data of partition in/outputs args + for (const auto &mem_idx : res->get_mems_use_external_inputs()) { + const dnnl::memory &mem = mem_idx.first; + const tensor_t &ts = inputs[mem_idx.second]; + const logical_tensor_t lt = ts.get_logical_tensor(); + const logical_tensor_wrapper_t ltw(lt); + if (ltw.is_host_scalar()) { + DNNL_HOST_SCALAR_TYPE_SWITCH(ltw.data_type(), DType, { + mem.set_host_scalar_value( + *static_cast(ts.get_data_handle())); + }); + } else { + mem.set_data_handle(ts.get_data_handle()); + } + } + + for (const auto &mem_idx : res->get_mems_use_external_outputs()) { + mem_idx.first.set_data_handle( + outputs[mem_idx.second].get_data_handle()); + } + + grantor_t var_grantor = memory_planner_.internal_temporary_grantor( + scratchpad.get_buffer()); + + for (auto &mem_offkey : res->get_mems_use_internal_temporary()) { + mem_offkey.first.set_data_handle(var_grantor.get(mem_offkey.second)); + } +} + +status_t gated_mlp_primitive_kernel_t::execute_impl(const stream_t *stream, + const std::vector &inputs, + const std::vector &outputs) { + dnnl::stream p_stream = make_dnnl_stream(p_engine_, *stream); + + thread_local_cache_t res_cache; + execution_args_set_t *res = res_cache.get_or_add( + reinterpret_cast(this), resource_ctor_); + + temporary_scratchpad_t scratchpad( + memory_planner_.total_internal_temporary_size(), p_engine_, + *g_alloc_); + prepare_args_set(res, inputs, outputs, scratchpad); + + for (size_t i = 0; i < subgraph_->execs_.size(); i++) { + subgraph_->execs_[i]->execute(p_stream, res->get_exec_args()[i]); + } + + return status::success; +} + +#ifdef DNNL_WITH_SYCL +status_t gated_mlp_primitive_kernel_t::sycl_execute_impl(const stream_t *stream, + const std::vector &inputs, + const std::vector &outputs, + const std::vector<::sycl::event> &sycl_deps, ::sycl::event *ret_event) { +// gated_mlp_primitive_kernel_t only supports Intel GPU. +#if DNNL_GPU_VENDOR != DNNL_VENDOR_INTEL + return status::unimplemented; +#endif + auto deps = sycl_deps; + ::sycl::event returned_event; + + dnnl::stream p_stream = make_dnnl_stream(p_engine_, *stream); + + thread_local_cache_t res_cache; + execution_args_set_t *res = res_cache.get_or_add( + reinterpret_cast(this), resource_ctor_); + + temporary_scratchpad_t scratchpad( + memory_planner_.total_internal_temporary_size(), p_engine_, + *g_alloc_); + prepare_args_set(res, inputs, outputs, scratchpad); + + for (size_t i = 0; i < subgraph_->execs_.size(); i++) { + if (subgraph_->is_constant_[i]) continue; + returned_event = subgraph_->execs_[i]->execute_sycl( + p_stream, res->get_exec_args()[i], deps); + deps = {returned_event}; + } + + scratchpad.set_deps(returned_event); + if (ret_event) *ret_event = returned_event; + + return status::success; +} +#endif + +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL +status_t gated_mlp_primitive_kernel_t::ocl_execute_impl(const stream_t *stream, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &ocl_deps, cl_event *ret_event) { + auto deps = ocl_deps; + cl_event returned_event {}; + + dnnl::stream p_stream = make_dnnl_stream(p_engine_, *stream); + + thread_local_cache_t res_cache; + execution_args_set_t *res = res_cache.get_or_add( + reinterpret_cast(this), resource_ctor_); + + temporary_scratchpad_t scratchpad( + memory_planner_.total_internal_temporary_size(), p_engine_, + *g_alloc_); + prepare_args_set(res, inputs, outputs, scratchpad); + + for (size_t i = 0; i < subgraph_->execs_.size(); i++) { + if (subgraph_->is_constant_[i]) continue; + returned_event = subgraph_->execs_[i]->execute_ocl( + p_stream, res->get_exec_args()[i], deps); + deps.assign(1, returned_event); + } + + scratchpad.set_deps(returned_event); + if (ret_event) *ret_event = returned_event; + + return status::success; +} +#endif + +} // namespace dnnl_impl +} // namespace graph +} // namespace impl +} // namespace dnnl diff --git a/src/graph/backend/dnnl/kernels/gated_mlp_primitive.hpp b/src/graph/backend/dnnl/kernels/gated_mlp_primitive.hpp new file mode 100644 index 00000000000..27788b8218a --- /dev/null +++ b/src/graph/backend/dnnl/kernels/gated_mlp_primitive.hpp @@ -0,0 +1,98 @@ +/******************************************************************************* +* Copyright 2026 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GRAPH_BACKEND_DNNL_KERNELS_GATED_MLP_PRIMITIVE_HPP +#define GRAPH_BACKEND_DNNL_KERNELS_GATED_MLP_PRIMITIVE_HPP + +#include +#include +#include +#include +#include + +#include "graph/backend/dnnl/common.hpp" +#include "graph/backend/dnnl/dnnl_constant_tensor_cache.hpp" +#include "graph/backend/dnnl/dnnl_partition_impl.hpp" +#include "graph/backend/dnnl/op_executable.hpp" +#include "graph/backend/dnnl/scratchpad.hpp" +#include "graph/backend/dnnl/thread_local_cache.hpp" +#include "graph/backend/dnnl/utils.hpp" + +#include "graph/backend/dnnl/passes/memory_planning.hpp" + +namespace dnnl { +namespace impl { +namespace graph { +namespace dnnl_impl { + +struct gated_mlp_primitive_kernel_t : public kernel_base_t { +private: + allocator_t *g_alloc_ = nullptr; + + std::shared_ptr subgraph_; + memory_planner_t memory_planner_; + std::function()> resource_ctor_; + +public: + gated_mlp_primitive_kernel_t() { + thread_local_cache_t res_cache; + res_cache.retain(); + } + + ~gated_mlp_primitive_kernel_t() override { + thread_local_cache_t res_cache; + res_cache.remove_if_exist(reinterpret_cast(this)); + res_cache.release(); + } + + status_t compile_impl(const dnnl_partition_impl_t *part, + const engine_t *engine, const std::vector &inputs, + const std::vector &outputs) override; + + void prepare_args_set(const execution_args_set_t *res, + const std::vector &inputs, + const std::vector &outputs, + const scratchpad_t &scratchpad); + + status_t execute_impl(const stream_t *stream, + const std::vector &inputs, + const std::vector &outputs) override; + +#ifdef DNNL_WITH_SYCL + status_t sycl_execute_impl(const stream_t *stream, + const std::vector &inputs, + const std::vector &outputs, + const std::vector<::sycl::event> &deps, + ::sycl::event *ret_event) override; +#endif + +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL + status_t ocl_execute_impl(const stream_t *stream, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &deps, cl_event *ret_event) override; +#endif + + DEF_KERNEL_METHOD_STR(gated_mlp_primitive_kernel_t) + DNNL_DISALLOW_COPY_AND_ASSIGN(gated_mlp_primitive_kernel_t) +}; + +} // namespace dnnl_impl +} // namespace graph +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/graph/backend/dnnl/patterns/mlp.cpp b/src/graph/backend/dnnl/patterns/mlp.cpp index 08c78746f32..86ac06d7bd8 100644 --- a/src/graph/backend/dnnl/patterns/mlp.cpp +++ b/src/graph/backend/dnnl/patterns/mlp.cpp @@ -14,6 +14,7 @@ * limitations under the License. *******************************************************************************/ +#include "graph/backend/dnnl/kernels/gated_mlp.hpp" #include "graph/backend/dnnl/kernels/large_partition.hpp" #include "graph/backend/dnnl/patterns/fusions.hpp" @@ -86,7 +87,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, gated_mlp) in_edges_t {in_edge(0, pre_tc, 0)}); }) .set_attr("FCreateKernel", []() -> kernel_ptr { - return std::make_shared(); + return std::make_shared(); }); // gated mlp with swish decomposed to sigmoid and multiply. @@ -130,7 +131,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, gated_mlp_v1) in_edges_t {in_edge(0, pre_tc, 0)}); }) .set_attr("FCreateKernel", []() -> kernel_ptr { - return std::make_shared(); + return std::make_shared(); }); /* From 8c9bc2e5f6dba4aa2d31175fca214c1ade8dd8b6 Mon Sep 17 00:00:00 2001 From: "Lv, Tao A" Date: Tue, 3 Mar 2026 23:45:28 -0800 Subject: [PATCH 07/15] graph: backend: dnnl: add gated mlp executable --- .../backend/dnnl/executables/gated_mlp.cpp | 128 ++++++++++++++++++ .../backend/dnnl/executables/gated_mlp.hpp | 61 +++++++++ src/graph/backend/dnnl/layout_propagator.cpp | 31 +++++ src/graph/backend/dnnl/layout_propagator.hpp | 1 + src/graph/backend/dnnl/op_executable.cpp | 3 + src/graph/backend/dnnl/op_executable.hpp | 1 + 6 files changed, 225 insertions(+) create mode 100644 src/graph/backend/dnnl/executables/gated_mlp.cpp create mode 100644 src/graph/backend/dnnl/executables/gated_mlp.hpp diff --git a/src/graph/backend/dnnl/executables/gated_mlp.cpp b/src/graph/backend/dnnl/executables/gated_mlp.cpp new file mode 100644 index 00000000000..18ce7d73c49 --- /dev/null +++ b/src/graph/backend/dnnl/executables/gated_mlp.cpp @@ -0,0 +1,128 @@ +/******************************************************************************* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#include "graph/backend/dnnl/executables/gated_mlp.hpp" + +#include "common/gated_mlp_iface.hpp" + +namespace dnnl { +namespace impl { +namespace graph { +namespace dnnl_impl { + +gated_mlp_executable_t::gated_mlp_executable_t(std::shared_ptr &op, + const dnnl::engine &p_engine, pd_cache_t &pd_cache, + const fpmath_t &fpmath, bool use_block_layout) { + auto src_md = make_dnnl_memory_desc(op->get_input_logical_tensor(0)); + auto wei0_md = make_dnnl_memory_desc(op->get_input_logical_tensor(1)); + auto wei1_md = make_dnnl_memory_desc(op->get_input_logical_tensor(2)); + auto wei2_md = make_dnnl_memory_desc(op->get_input_logical_tensor(3)); + auto dst_md = make_dnnl_memory_desc(op->get_output_logical_tensor(0)); + + dnnl_primitive_attr_t attr = nullptr; + auto act_algo = op->has_attr(op_attr::alg_kind) + ? static_cast( + op->get_attr(op_attr::alg_kind)) + : dnnl::algorithm::undef; + auto ret = dnnl_gated_mlp_primitive_desc_create(&pd_, p_engine.get(), + src_md.get(), wei0_md.get(), wei1_md.get(), wei2_md.get(), + dst_md.get(), static_cast(act_algo), attr); + + dnnl::error::wrap_c_api(ret, + "could not create a primitive descriptor for a gated mlp " + "primitive"); + + ret = dnnl_primitive_create(&prim_, pd_); + dnnl::error::wrap_c_api( + ret, "could not create a primitive for a gated mlp primitive"); +} + +gated_mlp_executable_t::~gated_mlp_executable_t() { + if (prim_) dnnl_primitive_destroy(prim_); + if (pd_) dnnl_primitive_desc_destroy(pd_); +} + +void gated_mlp_executable_t::execute(const stream &stream, + const std::unordered_map &args) const { + UNUSED(stream); + UNUSED(args); + assert(!"gated_mlp_executable_t::execute() is not implemented on cpu"); +} + +#ifdef DNNL_WITH_SYCL +::sycl::event gated_mlp_executable_t::execute_sycl(const stream &stream, + const std::unordered_map &args, + const std::vector<::sycl::event> &deps) const { + std::vector c_args; + c_args.reserve(args.size()); + for (const auto &a : args) + c_args.push_back({a.first, a.second.get()}); + + sycl::event return_event; + auto ret = dnnl_sycl_interop_primitive_execute(prim_, stream.get(), + c_args.size(), c_args.data(), &deps, &return_event); + dnnl::error::wrap_c_api( + ret, "could not execute gated mlp primitive with sycl runtime"); + + return return_event; +} +#endif + +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL +cl_event gated_mlp_executable_t::execute_ocl(const stream &stream, + const std::unordered_map &args, + const std::vector &deps) const { + std::vector c_args; + c_args.reserve(args.size()); + for (const auto &a : args) + c_args.push_back({a.first, a.second.get()}); + + const cl_event *c_deps = deps.empty() ? nullptr : deps.data(); + + cl_event return_event = nullptr; + auto ret = dnnl_ocl_interop_primitive_execute(prim_, stream.get(), + static_cast(c_args.size()), c_args.data(), c_deps, + static_cast(deps.size()), &return_event); + dnnl::error::wrap_c_api( + ret, "could not execute gated mlp primitive with ocl runtime"); + + return return_event; +} +#endif + +#define DNNL_ARG_WEIGHTS_GATE DNNL_ARG_WEIGHTS_0 +#define DNNL_ARG_WEIGHTS_UP DNNL_ARG_WEIGHTS_1 +#define DNNL_ARG_WEIGHTS_DOWN DNNL_ARG_WEIGHTS_2 + +arg_indices_t gated_mlp_executable_t::get_arg_indices(const op_t *op) { + arg_indices_t args; + // inputs: src, gate weights, up weights, down weights + size_t idx = 0; + args.insert({DNNL_ARG_SRC, {indices_t::type_t::input, idx++}}); + args.insert({DNNL_ARG_WEIGHTS_GATE, {indices_t::type_t::input, idx++}}); + args.insert({DNNL_ARG_WEIGHTS_UP, {indices_t::type_t::input, idx++}}); + args.insert({DNNL_ARG_WEIGHTS_DOWN, {indices_t::type_t::input, idx++}}); + + // outputs + args.insert({DNNL_ARG_DST, {indices_t::type_t::output, 0}}); + args.insert({DNNL_ARG_SCRATCHPAD, {indices_t::type_t::output, 1}}); + return args; +} + +} // namespace dnnl_impl +} // namespace graph +} // namespace impl +} // namespace dnnl diff --git a/src/graph/backend/dnnl/executables/gated_mlp.hpp b/src/graph/backend/dnnl/executables/gated_mlp.hpp new file mode 100644 index 00000000000..5f8d1899ee9 --- /dev/null +++ b/src/graph/backend/dnnl/executables/gated_mlp.hpp @@ -0,0 +1,61 @@ +/******************************************************************************* + * Copyright 2026 Intel Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + *******************************************************************************/ + +#ifndef GRAPH_BACKEND_DNNL_EXECUTABLES_GATED_MLP_HPP +#define GRAPH_BACKEND_DNNL_EXECUTABLES_GATED_MLP_HPP + +#include "graph/backend/dnnl/executables/base.hpp" + +namespace dnnl { +namespace impl { +namespace graph { +namespace dnnl_impl { + +struct gated_mlp_executable_t : public op_executable_t { + DECLARE_ARG_INDICES_GETTER; + + gated_mlp_executable_t(std::shared_ptr &op, + const dnnl::engine &p_engine, pd_cache_t &pd_cache, + const fpmath_t &fpmath, bool use_block_layout); + + ~gated_mlp_executable_t() override; + + void execute(const stream &stream, + const std::unordered_map &args) const override; + +#ifdef DNNL_WITH_SYCL + ::sycl::event execute_sycl(const stream &stream, + const std::unordered_map &args, + const std::vector<::sycl::event> &deps) const override; +#endif + +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL + cl_event execute_ocl(const stream &stream, + const std::unordered_map &args, + const std::vector &deps) const override; +#endif + +private: + dnnl_primitive_desc_t pd_ = nullptr; + dnnl_primitive_t prim_ = nullptr; +}; + +} // namespace dnnl_impl +} // namespace graph +} // namespace impl +} // namespace dnnl + +#endif // GRAPH_BACKEND_DNNL_EXECUTABLES_GATED_MLP_HPP diff --git a/src/graph/backend/dnnl/layout_propagator.cpp b/src/graph/backend/dnnl/layout_propagator.cpp index f7b3010639a..d4ac7ebe619 100644 --- a/src/graph/backend/dnnl/layout_propagator.cpp +++ b/src/graph/backend/dnnl/layout_propagator.cpp @@ -1800,6 +1800,37 @@ status_t layout_propagator_for_host_scalar(std::shared_ptr &op, return status::success; } +status_t layout_propagator_for_gated_mlp(std::shared_ptr &op, + const dnnl::engine &p_engine, pd_cache_t &pd_cache, + const fpmath_t &fpmath, bool use_block_layout, + subgraph_rewriter_t &rewriter) { + UNUSED(p_engine); + UNUSED(pd_cache); + UNUSED(fpmath); + UNUSED(use_block_layout); + UNUSED(rewriter); + + value_ptr dst_val = op->get_output_value(0); + const logical_tensor_t &dst_lt = dst_val->get_logical_tensor(); + + dnnl::memory::desc expected_md; + if (ltw(dst_lt).is_any()) { + const auto tag = get_ncx_format(ltw(dst_lt).ndims()); + expected_md = {ltw(dst_lt).vdims(), + static_cast(ltw(dst_lt).data_type()), + tag}; + } else { + expected_md = make_dnnl_memory_desc(dst_lt); + } + status_t status = fill_layout_info(dst_val, expected_md); + + // fill scratchpads dimensions and data type to scratchpad value_t + value_ptr scratchpad_val = op->get_output_value(1); + const memory::desc scratchpad_desc; + status = fill_layout_info(scratchpad_val, scratchpad_desc); + return status; +} + } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/backend/dnnl/layout_propagator.hpp b/src/graph/backend/dnnl/layout_propagator.hpp index 4c7f9b780de..8f12fd966ca 100644 --- a/src/graph/backend/dnnl/layout_propagator.hpp +++ b/src/graph/backend/dnnl/layout_propagator.hpp @@ -96,6 +96,7 @@ DECLARE_LAYOUT_PROPAGATOR(mask); DECLARE_LAYOUT_PROPAGATOR(sdpa); DECLARE_LAYOUT_PROPAGATOR(host_scalar); DECLARE_LAYOUT_PROPAGATOR(identity); +DECLARE_LAYOUT_PROPAGATOR(gated_mlp); #undef DECLARE_LAYOUT_PROPAGATOR diff --git a/src/graph/backend/dnnl/op_executable.cpp b/src/graph/backend/dnnl/op_executable.cpp index 5f042d1fba7..27b281c49a4 100644 --- a/src/graph/backend/dnnl/op_executable.cpp +++ b/src/graph/backend/dnnl/op_executable.cpp @@ -77,6 +77,7 @@ executable_creator_func op_func_t::get_executable_creator(op_kind_t kind) { {_host_scalar, executable_creator}, {_identity, executable_creator}, {_dropout, dummy_executable_creator}, + {_gated_mlp, executable_creator}, }; if (_map.count(kind) == 0) { @@ -140,6 +141,7 @@ arg_indices_getter_func op_func_t::get_arg_indices_getter(op_kind_t kind) { {_host_scalar, host_scalar_executable_t::get_arg_indices}, {_identity, memory_reparser_t::get_arg_indices}, {_dropout, dummy_arg_indices_getter}, + {_gated_mlp, gated_mlp_executable_t::get_arg_indices}, }; if (_map.count(kind) == 0) { @@ -201,6 +203,7 @@ layout_propagator_func op_func_t::get_layout_propagator(op_kind_t kind) { {_sdpa, layout_propagator_for_sdpa}, {_host_scalar, layout_propagator_for_host_scalar}, {_identity, layout_propagator_for_identity}, + {_gated_mlp, layout_propagator_for_gated_mlp}, }; if (_map.count(kind) == 0) { diff --git a/src/graph/backend/dnnl/op_executable.hpp b/src/graph/backend/dnnl/op_executable.hpp index 2a07e56064c..85aecaff098 100644 --- a/src/graph/backend/dnnl/op_executable.hpp +++ b/src/graph/backend/dnnl/op_executable.hpp @@ -24,6 +24,7 @@ #include "graph/backend/dnnl/executables/conv.hpp" #include "graph/backend/dnnl/executables/deconv.hpp" #include "graph/backend/dnnl/executables/eltwise.hpp" +#include "graph/backend/dnnl/executables/gated_mlp.hpp" #include "graph/backend/dnnl/executables/gen_index.hpp" #include "graph/backend/dnnl/executables/group_norm.hpp" #include "graph/backend/dnnl/executables/host_scalar.hpp" From 39c716edb8f139882b2a4bc5860c6ce0a6e0a7bb Mon Sep 17 00:00:00 2001 From: "Lv, Tao A" Date: Wed, 4 Mar 2026 07:09:42 -0800 Subject: [PATCH 08/15] graph: backend: dnnl: passes: fuse gated mlp subgraph --- src/graph/backend/dnnl/passes/transform.cpp | 105 ++++++++++++++++++++ src/graph/backend/dnnl/passes/transform.hpp | 3 + 2 files changed, 108 insertions(+) diff --git a/src/graph/backend/dnnl/passes/transform.cpp b/src/graph/backend/dnnl/passes/transform.cpp index 1e3cd098f7b..5aaf0e4584c 100644 --- a/src/graph/backend/dnnl/passes/transform.cpp +++ b/src/graph/backend/dnnl/passes/transform.cpp @@ -4584,6 +4584,111 @@ status_t fuse_sdpa(std::shared_ptr &sg) { return status::success; } +// The pass is called against a gated mlp subgraph matched by the gated mlp +// patterns. Hence we have the basic assumptions for the topology which +// simplifies the pass logic below. +status_t fuse_gated_mlp(std::shared_ptr &sg) { + std::vector candidates; + const auto ops = sg->get_ops(); + size_t matmul_count = 0; + dnnl::algorithm act_algo = dnnl::algorithm::undef; + op_ptr gate = nullptr, up = nullptr, down = nullptr; + + for (const auto &op : ops) { + if (op->get_kind() == op_kind::_matmul) { + matmul_count++; + candidates.emplace_back(op); + // if it has no consumer, it's the down matmul. If it has a binary consumer, it's the up matmul. Otherwise, it's the gate matmul. + auto out_val = op->get_output_value(0); + auto &consumers = out_val->get_consumers(); + if (consumers.empty()) { + down = op; + } else if (consumers.size() == 1 + && consumers[0].get_op().get_kind() == op_kind::_binary) { + up = op; + } else { + gate = op; + } + } else if (op->get_kind() == op_kind::_eltwise) { + candidates.emplace_back(op); + act_algo = static_cast( + op->get_attr(op_attr::alg_kind)); + if (act_algo == dnnl::algorithm::eltwise_logistic) { + // check the consumer of sigmoid, it should be binary_mul + auto out_val = op->get_output_value(0); + if (out_val->get_consumers().size() != 1) { break; } + auto &consumer = out_val->get_consumers()[0].get_op(); + if (consumer.get_kind() != op_kind::_binary) { break; } + auto consumer_alg = static_cast( + consumer.get_attr(op_attr::alg_kind)); + if (consumer_alg != dnnl::algorithm::binary_mul) { break; } + // check the consumer of binary_mul, it should be the second matmul or another binary_mul. + auto out_val2 = consumer.get_output_value(0); + if (out_val2->get_consumers().size() != 1) { break; } + auto &consumer2 = out_val2->get_consumers()[0].get_op(); + if (consumer2.get_kind() != op_kind::_matmul + && consumer2.get_kind() != op_kind::_binary) { + break; + } + // if it's binary_mul, it's gated mlp with swish activation: sigmoid + binary_mul. + if (consumer2.get_kind() == op_kind::_binary) { + act_algo = dnnl::algorithm::eltwise_swish; + } + } + } else if (op->get_kind() == op_kind::_binary) { + auto alg = static_cast( + op->get_attr(op_attr::alg_kind)); + if (alg != dnnl::algorithm::binary_mul) { break; } + candidates.emplace_back(op); + } else if (op->get_kind() == op_kind::_reorder) { + // Optional typecast. + candidates.emplace_back(op); + } else { + // strange op for gated mlp pattern, bail out. + break; + } + } + + // seems not a gated mlp pattern, bail out. + if (matmul_count != 3) { return status::unimplemented; } + if (candidates.size() != ops.size()) { return status::unimplemented; } + if (gate == nullptr || up == nullptr || down == nullptr) { + return status::unimplemented; + } + + subgraph_rewriter_t rewriter(sg); + op_ptr gated_mlp_op = std::make_shared(op_kind::_gated_mlp); + gated_mlp_op->set_attr( + op_attr::alg_kind, static_cast(act_algo)); + + // connect inputs and outputs + auto src_val = gate->get_input_value(0); + auto wei0_val = gate->get_input_value(1); + auto wei1_val = up->get_input_value(1); + auto wei2_val = down->get_input_value(1); + src_val->remove_consumer(*gate, 0); + wei0_val->remove_consumer(*gate, 1); + wei1_val->remove_consumer(*up, 1); + wei2_val->remove_consumer(*down, 1); + gated_mlp_op->connect_input(0, src_val); + gated_mlp_op->connect_input(1, wei0_val); + gated_mlp_op->connect_input(2, wei1_val); + gated_mlp_op->connect_input(3, wei2_val); + auto dst_val = down->get_output_value(0); + dst_val->set_producer(*gated_mlp_op); + gated_mlp_op->add_output(dst_val); + + insert_empty_scratchpad(gated_mlp_op); + + for (auto &op : candidates) { + rewriter.to_remove(op); + } + rewriter.to_insert(gated_mlp_op); + rewriter.run(); + + return status::success; +} + } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/backend/dnnl/passes/transform.hpp b/src/graph/backend/dnnl/passes/transform.hpp index 37bcd905675..f2b07754360 100644 --- a/src/graph/backend/dnnl/passes/transform.hpp +++ b/src/graph/backend/dnnl/passes/transform.hpp @@ -302,6 +302,9 @@ status_t fuse_implicit_causal_mask(std::shared_ptr &sg); /// This pass will transform the sdpa subgraph into a dnnl_sdpa op. status_t fuse_sdpa(std::shared_ptr &sg); +/// This pass will transform the gated mlp subgraph into a _gated_mlp op. +status_t fuse_gated_mlp(std::shared_ptr &sg); + } // namespace dnnl_impl } // namespace graph } // namespace impl From a51db0eceb96e30b38efa43f2e6be2daa4e8eac0 Mon Sep 17 00:00:00 2001 From: "Lv, Tao A" Date: Tue, 10 Mar 2026 23:53:19 -0700 Subject: [PATCH 09/15] examples: graph: fix intermediate type for int4 gated mlp --- examples/graph/gated_mlp_int4.cpp | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/examples/graph/gated_mlp_int4.cpp b/examples/graph/gated_mlp_int4.cpp index 2910dba8712..0a4f3cde2a4 100644 --- a/examples/graph/gated_mlp_int4.cpp +++ b/examples/graph/gated_mlp_int4.cpp @@ -113,6 +113,9 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt, // Incremental IDs used to create logical tensors and operations. size_t id = 0; + // Intermediate data type + const logical_tensor::data_type dt_inter = logical_tensor::data_type::f32; + // dequantize for fc_gate weights auto wei0_int4 = logical_tensor( id++, data_type::u4, wei0_sz, layout_type::strided); @@ -130,7 +133,7 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt, // fc_gate auto src = logical_tensor(id++, dt, src_sz, layout_type::strided); - auto out0 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto out0 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided); auto fc_gate = op(id++, op::kind::MatMul, "fc_gate"); fc_gate.add_inputs({src, wei0_dt}); fc_gate.add_outputs({out0}); @@ -151,29 +154,38 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt, deq_up.add_outputs({wei1_dt}); // fc_up - auto out1 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto out1 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided); auto fc_up = op(id++, op::kind::MatMul, "fc_up"); fc_up.add_inputs({src, wei1_dt}); fc_up.add_outputs({out1}); // activation swish: sigmoid - auto out2 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto out2 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided); auto swi_sig = op(id++, op::kind::Sigmoid, "swish/sigmoid"); swi_sig.add_inputs({out0}); swi_sig.add_outputs({out2}); // activation swish: multiply - auto out3 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto out3 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided); auto swi_mul = op(id++, op::kind::Multiply, "swish/multiply"); swi_mul.add_inputs({out0, out2}); swi_mul.add_outputs({out3}); // multiplication - auto out4 = logical_tensor(id++, dt, hd_sz, layout_type::strided); + auto out4 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided); auto mul = op(id++, op::kind::Multiply, "mul"); mul.add_inputs({out3, out1}); mul.add_outputs({out4}); + // downconversion when needed + auto out4_dt = out4; + auto typecast = op(id++, op::kind::TypeCast, "typecast"); + if (dt != dt_inter) { + out4_dt = logical_tensor(id++, dt, hd_sz, layout_type::strided); + typecast.add_inputs({out4}); + typecast.add_outputs({out4_dt}); + } + // dequantize for fc_down weights auto wei2_int4 = logical_tensor( id++, data_type::u4, wei2_sz, layout_type::strided); @@ -192,7 +204,7 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt, // fc_down auto dst = logical_tensor(id++, dt, out_sz, layout_type::strided); auto fc_down = op(id++, op::kind::MatMul, "fc_down"); - fc_down.add_inputs({out4, wei2_dt}); + fc_down.add_inputs({out4_dt, wei2_dt}); fc_down.add_outputs({dst}); // Construct a gated mlp graph with engine kind and operations. @@ -205,6 +217,7 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt, mlp.add_op(swi_sig); mlp.add_op(swi_mul); mlp.add_op(mul); + if (dt != dt_inter) { mlp.add_op(typecast); } mlp.add_op(deq_down); mlp.add_op(fc_down); mlp.finalize(); From aabb1672f149358aa9caa65787eb402b1e5b5556 Mon Sep 17 00:00:00 2001 From: "Lv, Tao A" Date: Tue, 10 Mar 2026 23:57:09 -0700 Subject: [PATCH 10/15] graph: backend: dnnl: patterns: support typecast in quant gated mlp --- src/graph/backend/dnnl/patterns/mlp.cpp | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/graph/backend/dnnl/patterns/mlp.cpp b/src/graph/backend/dnnl/patterns/mlp.cpp index 86ac06d7bd8..3c95e5dea04 100644 --- a/src/graph/backend/dnnl/patterns/mlp.cpp +++ b/src/graph/backend/dnnl/patterns/mlp.cpp @@ -178,11 +178,20 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, quantized_gated_mlp) auto bin = pgraph->append_alternation( get_binary_ops(), edges); + // optional typecast + auto tc = std::make_shared(); + pm::pb_op_t *ptypecast + = tc->append_op(graph::op_kind::TypeCast); + tc->create_input_port(0, ptypecast, 0); + tc->create_output_port(0, ptypecast, 0); + auto pre_tc + = pgraph->append_optional(tc, {in_edge(0, bin, 0)}); + // fc_down pm::pb_op_t *deq_down = pgraph->append_op( graph::op_kind::DynamicDequantize); in_edges_t fc_down_edges - = {in_edge(0, bin, 0), in_edge(1, deq_down, 0)}; + = {in_edge(0, pre_tc, 0), in_edge(1, deq_down, 0)}; pgraph->append_op(graph::op_kind::MatMul, fc_down_edges); }) .set_attr("FCreateKernel", []() -> kernel_ptr { @@ -220,11 +229,20 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, quantized_gated_mlp_v1) auto bin = pgraph->append_alternation( get_binary_ops(), edges); + // optional typecast + auto tc = std::make_shared(); + pm::pb_op_t *ptypecast + = tc->append_op(graph::op_kind::TypeCast); + tc->create_input_port(0, ptypecast, 0); + tc->create_output_port(0, ptypecast, 0); + auto pre_tc + = pgraph->append_optional(tc, {in_edge(0, bin, 0)}); + // fc_down pm::pb_op_t *deq_down = pgraph->append_op( graph::op_kind::DynamicDequantize); in_edges_t fc_down_edges - = {in_edge(0, bin, 0), in_edge(1, deq_down, 0)}; + = {in_edge(0, pre_tc, 0), in_edge(1, deq_down, 0)}; pgraph->append_op(graph::op_kind::MatMul, fc_down_edges); }) .set_attr("FCreateKernel", []() -> kernel_ptr { From 7706ffd2fb35a3f4e39e203e50b90032e8803735 Mon Sep 17 00:00:00 2001 From: "Lv, Tao A" Date: Tue, 10 Mar 2026 23:58:08 -0700 Subject: [PATCH 11/15] benchdnn: inputs: graph: add cases for int4 gated mlp --- .../complex_fusion/mlp/gated-mlp-int4.json | 276 +++++++++++------- 1 file changed, 163 insertions(+), 113 deletions(-) diff --git a/tests/benchdnn/inputs/graph/complex_fusion/mlp/gated-mlp-int4.json b/tests/benchdnn/inputs/graph/complex_fusion/mlp/gated-mlp-int4.json index 706521b529e..65c167f31aa 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/mlp/gated-mlp-int4.json +++ b/tests/benchdnn/inputs/graph/complex_fusion/mlp/gated-mlp-int4.json @@ -1,23 +1,23 @@ { - "version": "3.7.0", - "engine_kind": "cpu", + "version": "3.12.0", + "engine_kind": "gpu", "fpmath_mode": "strict", "fpmath_mode_apply_to_int": "true", "input_ports": [ - 0, - 1, - 2, - 5, - 8, - 9, - 10, - 5, - 21, - 22, - 23 + 0, + 1, + 2, + 5, + 8, + 9, + 10, + 5, + 23, + 24, + 25 ], "output_ports": [ - 26 + 28 ], "graph": [ { @@ -32,7 +32,7 @@ "group_shape": { "type": "s64[]", "value": [ - 1, + 1, 128 ] }, @@ -46,39 +46,39 @@ "id": 0, "dtype": "u4", "shape": [ - 4096, + 4096, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", "property_type": "undef" - }, + }, { "id": 1, "dtype": "f16", "shape": [ - 4096, + 4096, 112 ], "stride": [ - 112, + 112, 1 ], "layout_type": "strided", "property_type": "undef" - }, + }, { "id": 2, "dtype": "u8", "shape": [ - 4096, + 4096, 112 ], "stride": [ - 112, + 112, 1 ], "layout_type": "strided", @@ -90,18 +90,18 @@ "id": 3, "dtype": "f16", "shape": [ - 4096, + 4096, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", "property_type": "undef" } ] - }, + }, { "id": 7, "name": "fc_gate", @@ -114,6 +114,10 @@ "transpose_b": { "type": "bool", "value": 0 + }, + "accumulation_mode": { + "type": "string", + "value": "strict" } }, "inputs": [ @@ -121,25 +125,25 @@ "id": 5, "dtype": "f16", "shape": [ - 1, + 1, 4096 ], "stride": [ - 4096, + 4096, 1 ], "layout_type": "strided", "property_type": "undef" - }, + }, { "id": 3, "dtype": "f16", "shape": [ - 4096, + 4096, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", @@ -149,20 +153,20 @@ "outputs": [ { "id": 6, - "dtype": "f16", + "dtype": "f32", "shape": [ - 1, + 1, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", "property_type": "undef" } ] - }, + }, { "id": 16, "name": "swish/sigmoid", @@ -171,13 +175,13 @@ "inputs": [ { "id": 6, - "dtype": "f16", + "dtype": "f32", "shape": [ - 1, + 1, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", @@ -187,20 +191,20 @@ "outputs": [ { "id": 15, - "dtype": "f16", + "dtype": "f32", "shape": [ - 1, + 1, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", "property_type": "undef" } ] - }, + }, { "id": 18, "name": "swish/multiply", @@ -214,27 +218,27 @@ "inputs": [ { "id": 6, - "dtype": "f16", + "dtype": "f32", "shape": [ - 1, + 1, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", "property_type": "undef" - }, + }, { "id": 15, - "dtype": "f16", + "dtype": "f32", "shape": [ - 1, + 1, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", @@ -244,20 +248,20 @@ "outputs": [ { "id": 17, - "dtype": "f16", + "dtype": "f32", "shape": [ - 1, + 1, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", "property_type": "undef" } ] - }, + }, { "id": 12, "name": "deq_up", @@ -270,7 +274,7 @@ "group_shape": { "type": "s64[]", "value": [ - 1, + 1, 128 ] }, @@ -284,39 +288,39 @@ "id": 8, "dtype": "u4", "shape": [ - 4096, + 4096, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", "property_type": "undef" - }, + }, { "id": 9, "dtype": "f16", "shape": [ - 4096, + 4096, 112 ], "stride": [ - 112, + 112, 1 ], "layout_type": "strided", "property_type": "undef" - }, + }, { "id": 10, "dtype": "u8", "shape": [ - 4096, + 4096, 112 ], "stride": [ - 112, + 112, 1 ], "layout_type": "strided", @@ -328,18 +332,18 @@ "id": 11, "dtype": "f16", "shape": [ - 4096, + 4096, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", "property_type": "undef" } ] - }, + }, { "id": 14, "name": "fc_up", @@ -352,6 +356,10 @@ "transpose_b": { "type": "bool", "value": 0 + }, + "accumulation_mode": { + "type": "string", + "value": "strict" } }, "inputs": [ @@ -359,25 +367,25 @@ "id": 5, "dtype": "f16", "shape": [ - 1, + 1, 4096 ], "stride": [ - 4096, + 4096, 1 ], "layout_type": "strided", "property_type": "undef" - }, + }, { "id": 11, "dtype": "f16", "shape": [ - 4096, + 4096, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", @@ -387,20 +395,20 @@ "outputs": [ { "id": 13, - "dtype": "f16", + "dtype": "f32", "shape": [ - 1, + 1, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", "property_type": "undef" } ] - }, + }, { "id": 20, "name": "mul", @@ -414,27 +422,27 @@ "inputs": [ { "id": 17, - "dtype": "f16", + "dtype": "f32", "shape": [ - 1, + 1, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", "property_type": "undef" - }, + }, { "id": 13, - "dtype": "f16", + "dtype": "f32", "shape": [ - 1, + 1, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", @@ -444,22 +452,60 @@ "outputs": [ { "id": 19, + "dtype": "f32", + "shape": [ + 1, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 21, + "name": "typecast", + "kind": "TypeCast", + "attrs": {}, + "inputs": [ + { + "id": 19, + "dtype": "f32", + "shape": [ + 1, + 14336 + ], + "stride": [ + 14336, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 22, "dtype": "f16", "shape": [ - 1, + 1, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", "property_type": "undef" } ] - }, + }, { - "id": 25, + "id": 27, "name": "deq_down", "kind": "DynamicDequantize", "attrs": { @@ -470,7 +516,7 @@ "group_shape": { "type": "s64[]", "value": [ - 1, + 1, 128 ] }, @@ -481,42 +527,42 @@ }, "inputs": [ { - "id": 21, + "id": 23, "dtype": "u4", "shape": [ - 14336, + 14336, 4096 ], "stride": [ - 4096, + 4096, 1 ], "layout_type": "strided", "property_type": "undef" - }, + }, { - "id": 22, + "id": 24, "dtype": "f16", "shape": [ - 14336, + 14336, 32 ], "stride": [ - 32, + 32, 1 ], "layout_type": "strided", "property_type": "undef" - }, + }, { - "id": 23, + "id": 25, "dtype": "u8", "shape": [ - 14336, + 14336, 32 ], "stride": [ - 32, + 32, 1 ], "layout_type": "strided", @@ -525,23 +571,23 @@ ], "outputs": [ { - "id": 24, + "id": 26, "dtype": "f16", "shape": [ - 14336, + 14336, 4096 ], "stride": [ - 4096, + 4096, 1 ], "layout_type": "strided", "property_type": "undef" } ] - }, + }, { - "id": 27, + "id": 29, "name": "fc_down", "kind": "MatMul", "attrs": { @@ -552,32 +598,36 @@ "transpose_b": { "type": "bool", "value": 0 + }, + "accumulation_mode": { + "type": "string", + "value": "strict" } }, "inputs": [ { - "id": 19, + "id": 22, "dtype": "f16", "shape": [ - 1, + 1, 14336 ], "stride": [ - 14336, + 14336, 1 ], "layout_type": "strided", "property_type": "undef" - }, + }, { - "id": 24, + "id": 26, "dtype": "f16", "shape": [ - 14336, + 14336, 4096 ], "stride": [ - 4096, + 4096, 1 ], "layout_type": "strided", @@ -586,14 +636,14 @@ ], "outputs": [ { - "id": 26, + "id": 28, "dtype": "f16", "shape": [ - 1, + 1, 4096 ], "stride": [ - 4096, + 4096, 1 ], "layout_type": "strided", @@ -602,4 +652,4 @@ ] } ] -} \ No newline at end of file +} From 775dce1238a74ea1634d499d88720645e57a5d05 Mon Sep 17 00:00:00 2001 From: "Lv, Tao A" Date: Wed, 11 Mar 2026 00:01:13 -0700 Subject: [PATCH 12/15] graph: backend: dnnl: kernels: supports quantized gated mlp --- src/graph/backend/dnnl/kernels/gated_mlp.hpp | 4 ++- .../dnnl/kernels/gated_mlp_primitive.cpp | 34 ++++++++++++++----- .../dnnl/kernels/gated_mlp_primitive.hpp | 1 + 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/graph/backend/dnnl/kernels/gated_mlp.hpp b/src/graph/backend/dnnl/kernels/gated_mlp.hpp index 7fdbc447768..bbc1fed4c4b 100644 --- a/src/graph/backend/dnnl/kernels/gated_mlp.hpp +++ b/src/graph/backend/dnnl/kernels/gated_mlp.hpp @@ -37,6 +37,7 @@ namespace impl { namespace graph { namespace dnnl_impl { +template struct gated_mlp_base_t : public kernel_base_t { private: std::shared_ptr kernel; @@ -53,7 +54,8 @@ struct gated_mlp_base_t : public kernel_base_t { status_t ret = status::unimplemented; if (enable_ukernel) { - kernel = std::make_shared(); + kernel = std::make_shared< + gated_mlp_primitive_kernel_t>(); ret = kernel->compile_impl(part, engine, inputs, outputs); } diff --git a/src/graph/backend/dnnl/kernels/gated_mlp_primitive.cpp b/src/graph/backend/dnnl/kernels/gated_mlp_primitive.cpp index 9a6d9feae99..276797efb7c 100644 --- a/src/graph/backend/dnnl/kernels/gated_mlp_primitive.cpp +++ b/src/graph/backend/dnnl/kernels/gated_mlp_primitive.cpp @@ -38,7 +38,8 @@ namespace impl { namespace graph { namespace dnnl_impl { -status_t gated_mlp_primitive_kernel_t::compile_impl( +template +status_t gated_mlp_primitive_kernel_t::compile_impl( const dnnl_partition_impl_t *part, const engine_t *eng, const std::vector &inputs, const std::vector &outputs) { @@ -62,6 +63,16 @@ status_t gated_mlp_primitive_kernel_t::compile_impl( pass_pipeline_t pipeline = pass_pipeline_t(vis); BACKEND_DNNL_ADD_PASS(pipeline, lower_down); + + if (quantized) { + BACKEND_DNNL_ADD_PASS(pipeline, fuse_typecast_to_matmul_or_conv); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_post_typecast_to_predecessor); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_src_scales); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_src_zero_points); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_scales); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_dst_zero_points); + } + pipeline.reset_visualize_arg(true, false); BACKEND_DNNL_ADD_PASS(pipeline, infer_shape); BACKEND_DNNL_ADD_PASS(pipeline, fuse_gated_mlp); @@ -97,7 +108,8 @@ status_t gated_mlp_primitive_kernel_t::compile_impl( return status::success; } -void gated_mlp_primitive_kernel_t::prepare_args_set( +template +void gated_mlp_primitive_kernel_t::prepare_args_set( const execution_args_set_t *res, const std::vector &inputs, const std::vector &outputs, const scratchpad_t &scratchpad) { // update the data of partition in/outputs args @@ -129,8 +141,9 @@ void gated_mlp_primitive_kernel_t::prepare_args_set( } } -status_t gated_mlp_primitive_kernel_t::execute_impl(const stream_t *stream, - const std::vector &inputs, +template +status_t gated_mlp_primitive_kernel_t::execute_impl( + const stream_t *stream, const std::vector &inputs, const std::vector &outputs) { dnnl::stream p_stream = make_dnnl_stream(p_engine_, *stream); @@ -151,8 +164,9 @@ status_t gated_mlp_primitive_kernel_t::execute_impl(const stream_t *stream, } #ifdef DNNL_WITH_SYCL -status_t gated_mlp_primitive_kernel_t::sycl_execute_impl(const stream_t *stream, - const std::vector &inputs, +template +status_t gated_mlp_primitive_kernel_t::sycl_execute_impl( + const stream_t *stream, const std::vector &inputs, const std::vector &outputs, const std::vector<::sycl::event> &sycl_deps, ::sycl::event *ret_event) { // gated_mlp_primitive_kernel_t only supports Intel GPU. @@ -188,8 +202,9 @@ status_t gated_mlp_primitive_kernel_t::sycl_execute_impl(const stream_t *stream, #endif #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL -status_t gated_mlp_primitive_kernel_t::ocl_execute_impl(const stream_t *stream, - const std::vector &inputs, +template +status_t gated_mlp_primitive_kernel_t::ocl_execute_impl( + const stream_t *stream, const std::vector &inputs, const std::vector &outputs, const std::vector &ocl_deps, cl_event *ret_event) { auto deps = ocl_deps; @@ -220,6 +235,9 @@ status_t gated_mlp_primitive_kernel_t::ocl_execute_impl(const stream_t *stream, } #endif +template struct gated_mlp_primitive_kernel_t; +template struct gated_mlp_primitive_kernel_t; + } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/backend/dnnl/kernels/gated_mlp_primitive.hpp b/src/graph/backend/dnnl/kernels/gated_mlp_primitive.hpp index 27788b8218a..e1cdbfa64ac 100644 --- a/src/graph/backend/dnnl/kernels/gated_mlp_primitive.hpp +++ b/src/graph/backend/dnnl/kernels/gated_mlp_primitive.hpp @@ -38,6 +38,7 @@ namespace impl { namespace graph { namespace dnnl_impl { +template struct gated_mlp_primitive_kernel_t : public kernel_base_t { private: allocator_t *g_alloc_ = nullptr; From 8600354a9bf28b8f9b89bf5dc0b352c9cfd27d9e Mon Sep 17 00:00:00 2001 From: "Lv, Tao A" Date: Wed, 11 Mar 2026 00:03:18 -0700 Subject: [PATCH 13/15] graph: backend: dnnl: executables: prepare args for quantized gated mlp --- .../backend/dnnl/executables/gated_mlp.cpp | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/graph/backend/dnnl/executables/gated_mlp.cpp b/src/graph/backend/dnnl/executables/gated_mlp.cpp index 18ce7d73c49..6db9736ec0d 100644 --- a/src/graph/backend/dnnl/executables/gated_mlp.cpp +++ b/src/graph/backend/dnnl/executables/gated_mlp.cpp @@ -116,9 +116,39 @@ arg_indices_t gated_mlp_executable_t::get_arg_indices(const op_t *op) { args.insert({DNNL_ARG_WEIGHTS_UP, {indices_t::type_t::input, idx++}}); args.insert({DNNL_ARG_WEIGHTS_DOWN, {indices_t::type_t::input, idx++}}); + // optional scales/zps for quantization + const auto &fusion_info = op->has_attr(op_attr::fusion_info) + ? op->get_attr(op_attr::fusion_info) + : fusion_info_t(); + if (fusion_info.with_runtime_scales(true, DNNL_ARG_WEIGHTS_GATE)) { + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS_GATE, + {indices_t::type_t::input, idx++}}); + } + if (fusion_info.with_runtime_zero_points(true, DNNL_ARG_WEIGHTS_GATE)) { + args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS_GATE, + {indices_t::type_t::input, idx++}}); + } + if (fusion_info.with_runtime_scales(true, DNNL_ARG_WEIGHTS_UP)) { + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS_UP, + {indices_t::type_t::input, idx++}}); + } + if (fusion_info.with_runtime_zero_points(true, DNNL_ARG_WEIGHTS_UP)) { + args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS_UP, + {indices_t::type_t::input, idx++}}); + } + if (fusion_info.with_runtime_scales(true, DNNL_ARG_WEIGHTS_DOWN)) { + args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS_DOWN, + {indices_t::type_t::input, idx++}}); + } + if (fusion_info.with_runtime_zero_points(true, DNNL_ARG_WEIGHTS_DOWN)) { + args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS_DOWN, + {indices_t::type_t::input, idx++}}); + } + // outputs args.insert({DNNL_ARG_DST, {indices_t::type_t::output, 0}}); args.insert({DNNL_ARG_SCRATCHPAD, {indices_t::type_t::output, 1}}); + return args; } From 9245b6b83a81fac46645afd9f27ac1782ce9aa09 Mon Sep 17 00:00:00 2001 From: "Lv, Tao A" Date: Wed, 11 Mar 2026 00:04:48 -0700 Subject: [PATCH 14/15] graph: backend: dnnl: passes: fuse quantized gated mlp --- src/graph/backend/dnnl/passes/transform.cpp | 69 +++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/src/graph/backend/dnnl/passes/transform.cpp b/src/graph/backend/dnnl/passes/transform.cpp index 5aaf0e4584c..b1f0459f508 100644 --- a/src/graph/backend/dnnl/passes/transform.cpp +++ b/src/graph/backend/dnnl/passes/transform.cpp @@ -4584,6 +4584,10 @@ status_t fuse_sdpa(std::shared_ptr &sg) { return status::success; } +#define DNNL_ARG_WEIGHTS_GATE DNNL_ARG_WEIGHTS_0 +#define DNNL_ARG_WEIGHTS_UP DNNL_ARG_WEIGHTS_1 +#define DNNL_ARG_WEIGHTS_DOWN DNNL_ARG_WEIGHTS_2 + // The pass is called against a gated mlp subgraph matched by the gated mlp // patterns. Hence we have the basic assumptions for the topology which // simplifies the pass logic below. @@ -4656,11 +4660,64 @@ status_t fuse_gated_mlp(std::shared_ptr &sg) { return status::unimplemented; } + fusion_info_t fusion_info; + if (gate->has_attr(op_attr::fusion_info)) { + auto gate_fusion_info + = gate->get_attr(op_attr::fusion_info); + if (gate_fusion_info.get_mutable_scales(true, 1)) { + fusion_info.set_runtime_scales( + gate_fusion_info.get_mutable_scales(true, 1) + ->shared_from_this(), + true, DNNL_ARG_WEIGHTS_GATE); + } + if (gate_fusion_info.with_runtime_zero_points(true, 1)) { + fusion_info.set_zero_points( + gate_fusion_info.get_mutable_zero_points(true, 1) + ->shared_from_this(), + true, DNNL_ARG_WEIGHTS_GATE); + } + } + + if (up->has_attr(op_attr::fusion_info)) { + auto up_fusion_info = up->get_attr(op_attr::fusion_info); + if (up_fusion_info.get_mutable_scales(true, 1)) { + fusion_info.set_runtime_scales( + up_fusion_info.get_mutable_scales(true, 1) + ->shared_from_this(), + true, DNNL_ARG_WEIGHTS_UP); + } + if (up_fusion_info.with_runtime_zero_points(true, 1)) { + fusion_info.set_zero_points( + up_fusion_info.get_mutable_zero_points(true, 1) + ->shared_from_this(), + true, DNNL_ARG_WEIGHTS_UP); + } + } + + if (down->has_attr(op_attr::fusion_info)) { + auto down_fusion_info + = down->get_attr(op_attr::fusion_info); + if (down_fusion_info.get_mutable_scales(true, 1)) { + fusion_info.set_runtime_scales( + down_fusion_info.get_mutable_scales(true, 1) + ->shared_from_this(), + true, DNNL_ARG_WEIGHTS_DOWN); + } + if (down_fusion_info.with_runtime_zero_points(true, 1)) { + fusion_info.set_zero_points( + down_fusion_info.get_mutable_zero_points(true, 1) + ->shared_from_this(), + true, DNNL_ARG_WEIGHTS_DOWN); + } + } + subgraph_rewriter_t rewriter(sg); op_ptr gated_mlp_op = std::make_shared(op_kind::_gated_mlp); gated_mlp_op->set_attr( op_attr::alg_kind, static_cast(act_algo)); + gated_mlp_op->set_attr(op_attr::fusion_info, fusion_info); + // connect inputs and outputs auto src_val = gate->get_input_value(0); auto wei0_val = gate->get_input_value(1); @@ -4674,6 +4731,18 @@ status_t fuse_gated_mlp(std::shared_ptr &sg) { gated_mlp_op->connect_input(1, wei0_val); gated_mlp_op->connect_input(2, wei1_val); gated_mlp_op->connect_input(3, wei2_val); + + size_t input_idx = 4; + // Handle quantization parameters from matmuls + for (const auto &matmul : {gate, up, down}) { + auto inputs = matmul->get_input_values(); + for (size_t idx = 2; idx < inputs.size(); ++idx) { + const auto &qparam_val = inputs[idx]; + qparam_val->remove_consumer(*matmul, idx); + gated_mlp_op->connect_input(input_idx++, qparam_val); + } + } + auto dst_val = down->get_output_value(0); dst_val->set_producer(*gated_mlp_op); gated_mlp_op->add_output(dst_val); From 6c4449028948e742d3ebcf1d2697fe05f69a88ad Mon Sep 17 00:00:00 2001 From: "Lv, Tao A" Date: Wed, 11 Mar 2026 00:09:11 -0700 Subject: [PATCH 15/15] graph: backend: dnnl: patterns: enable quantized gated mlp kernels --- src/graph/backend/dnnl/patterns/mlp.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/graph/backend/dnnl/patterns/mlp.cpp b/src/graph/backend/dnnl/patterns/mlp.cpp index 3c95e5dea04..784fa9a1b8d 100644 --- a/src/graph/backend/dnnl/patterns/mlp.cpp +++ b/src/graph/backend/dnnl/patterns/mlp.cpp @@ -15,7 +15,6 @@ *******************************************************************************/ #include "graph/backend/dnnl/kernels/gated_mlp.hpp" -#include "graph/backend/dnnl/kernels/large_partition.hpp" #include "graph/backend/dnnl/patterns/fusions.hpp" #include "graph/backend/dnnl/patterns/pattern_matcher_pass.hpp" @@ -87,7 +86,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, gated_mlp) in_edges_t {in_edge(0, pre_tc, 0)}); }) .set_attr("FCreateKernel", []() -> kernel_ptr { - return std::make_shared(); + return std::make_shared>(); }); // gated mlp with swish decomposed to sigmoid and multiply. @@ -131,7 +130,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, gated_mlp_v1) in_edges_t {in_edge(0, pre_tc, 0)}); }) .set_attr("FCreateKernel", []() -> kernel_ptr { - return std::make_shared(); + return std::make_shared>(); }); /* @@ -195,7 +194,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, quantized_gated_mlp) pgraph->append_op(graph::op_kind::MatMul, fc_down_edges); }) .set_attr("FCreateKernel", []() -> kernel_ptr { - return std::make_shared(); + return std::make_shared>(); }); // quantized gated mlp with swish decomposed to sigmoid and multiply. @@ -246,7 +245,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, quantized_gated_mlp_v1) pgraph->append_op(graph::op_kind::MatMul, fc_down_edges); }) .set_attr("FCreateKernel", []() -> kernel_ptr { - return std::make_shared(); + return std::make_shared>(); }); DNNL_BACKEND_REGISTER_PATTERN_DEF_END