Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions doc/graph/fusion_patterns/gated_mlp.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,13 @@ optional.

## Data Types

oneDNN supports the floating-point Gated-MLP pattern with data types f32, bf16,
and f16. You can specify the data type via the input and output data type fields
of logical tensors for each operation. oneDNN does not support mixing different
floating data types in a floating-point Gated-MLP pattern.
oneDNN supports the floating-point Gated-MLP pattern with data types `f32`, `bf16`,
and `f16`. You can specify the data type via the input and output data type fields
of logical tensors for each operation. For `bf16` and `f16` Gated-MLP, the output
data types of the UP and Gate MatMuls need to be `f32` to preserve the accuracy of
intermediate results. A [TypeCast](@ref dev_guide_op_typecast) operation is
needed before the Down MatMul to downconvert the intermediate results from `f32`
to `bf16` or `f16`.

The definition of the data types and support status on different CPU and GPU
platforms follow the general description in @ref dev_guide_data_types.
Expand Down
25 changes: 19 additions & 6 deletions examples/graph/gated_mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,44 +108,56 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt,
// Incremental IDs used to create logical tensors and operations.
size_t id = 0;

// Intermediate data type
const logical_tensor::data_type dt_inter = logical_tensor::data_type::f32;

// fc_gate
auto src = logical_tensor(id++, dt, src_sz, layout_type::strided);
auto wei0 = logical_tensor(id++, dt, wei0_sz, layout_type::strided);
auto out0 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out0 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto fc_gate = op(id++, op::kind::MatMul, "fc_gate");
fc_gate.add_inputs({src, wei0});
fc_gate.add_outputs({out0});

// fc_up
auto wei1 = logical_tensor(id++, dt, wei0_sz, layout_type::strided);
auto out1 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out1 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto fc_up = op(id++, op::kind::MatMul, "fc_up");
fc_up.add_inputs({src, wei1});
fc_up.add_outputs({out1});

// activation swish: sigmoid
auto out2 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out2 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto swi_sig = op(id++, op::kind::Sigmoid, "swish/sigmoid");
swi_sig.add_inputs({out0});
swi_sig.add_outputs({out2});

// activation swish: multiply
auto out3 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out3 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto swi_mul = op(id++, op::kind::Multiply, "swish/multiply");
swi_mul.add_inputs({out0, out2});
swi_mul.add_outputs({out3});

// multiplication
auto out4 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out4 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto mul = op(id++, op::kind::Multiply, "mul");
mul.add_inputs({out3, out1});
mul.add_outputs({out4});

// downconversion when needed
auto out4_dt = out4;
auto typecast = op(id++, op::kind::TypeCast, "typecast");
if (dt != dt_inter) {
out4_dt = logical_tensor(id++, dt, hd_sz, layout_type::strided);
typecast.add_inputs({out4});
typecast.add_outputs({out4_dt});
}

// fc_down
auto wei2 = logical_tensor(id++, dt, wei2_sz, layout_type::strided);
auto dst = logical_tensor(id++, dt, out_sz, layout_type::strided);
auto fc_down = op(id++, op::kind::MatMul, "fc_down");
fc_down.add_inputs({out4, wei2});
fc_down.add_inputs({out4_dt, wei2});
fc_down.add_outputs({dst});

// Construct a gated mlp graph with engine kind and operations.
Expand All @@ -155,6 +167,7 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt,
mlp.add_op(swi_sig);
mlp.add_op(swi_mul);
mlp.add_op(mul);
if (dt != dt_inter) mlp.add_op(typecast);
mlp.add_op(fc_down);
mlp.finalize();

Expand Down
25 changes: 19 additions & 6 deletions examples/graph/gated_mlp_int4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt,
// Incremental IDs used to create logical tensors and operations.
size_t id = 0;

// Intermediate data type
const logical_tensor::data_type dt_inter = logical_tensor::data_type::f32;

// dequantize for fc_gate weights
auto wei0_int4 = logical_tensor(
id++, data_type::u4, wei0_sz, layout_type::strided);
Expand All @@ -130,7 +133,7 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt,

// fc_gate
auto src = logical_tensor(id++, dt, src_sz, layout_type::strided);
auto out0 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out0 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto fc_gate = op(id++, op::kind::MatMul, "fc_gate");
fc_gate.add_inputs({src, wei0_dt});
fc_gate.add_outputs({out0});
Expand All @@ -151,29 +154,38 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt,
deq_up.add_outputs({wei1_dt});

// fc_up
auto out1 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out1 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto fc_up = op(id++, op::kind::MatMul, "fc_up");
fc_up.add_inputs({src, wei1_dt});
fc_up.add_outputs({out1});

// activation swish: sigmoid
auto out2 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out2 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto swi_sig = op(id++, op::kind::Sigmoid, "swish/sigmoid");
swi_sig.add_inputs({out0});
swi_sig.add_outputs({out2});

// activation swish: multiply
auto out3 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out3 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto swi_mul = op(id++, op::kind::Multiply, "swish/multiply");
swi_mul.add_inputs({out0, out2});
swi_mul.add_outputs({out3});

// multiplication
auto out4 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out4 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto mul = op(id++, op::kind::Multiply, "mul");
mul.add_inputs({out3, out1});
mul.add_outputs({out4});

// downconversion when needed
auto out4_dt = out4;
auto typecast = op(id++, op::kind::TypeCast, "typecast");
if (dt != dt_inter) {
out4_dt = logical_tensor(id++, dt, hd_sz, layout_type::strided);
typecast.add_inputs({out4});
typecast.add_outputs({out4_dt});
}

// dequantize for fc_down weights
auto wei2_int4 = logical_tensor(
id++, data_type::u4, wei2_sz, layout_type::strided);
Expand All @@ -192,7 +204,7 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt,
// fc_down
auto dst = logical_tensor(id++, dt, out_sz, layout_type::strided);
auto fc_down = op(id++, op::kind::MatMul, "fc_down");
fc_down.add_inputs({out4, wei2_dt});
fc_down.add_inputs({out4_dt, wei2_dt});
fc_down.add_outputs({dst});

// Construct a gated mlp graph with engine kind and operations.
Expand All @@ -205,6 +217,7 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt,
mlp.add_op(swi_sig);
mlp.add_op(swi_mul);
mlp.add_op(mul);
if (dt != dt_inter) { mlp.add_op(typecast); }
mlp.add_op(deq_down);
mlp.add_op(fc_down);
mlp.finalize();
Expand Down
25 changes: 19 additions & 6 deletions examples/graph/gated_mlp_wei_combined.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,9 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt,
// Incremental IDs used to create logical tensors and operations.
size_t id = 0;

// Intermediate data type
const logical_tensor::data_type dt_inter = logical_tensor::data_type::f32;

// This logical tensor is not part of the graph but is used to generate the
// big chunk of device memory which should be already there in real user
// application or framework.
Expand All @@ -134,41 +137,50 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt,
// fc_gate: wei0 is non-contiguous now.
auto src = logical_tensor(id++, dt, src_sz, layout_type::strided);
auto wei0 = logical_tensor(id++, dt, wei0_sz, combined_wei0_st);
auto out0 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out0 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto fc_gate = op(id++, op::kind::MatMul, "fc_gate");
fc_gate.add_inputs({src, wei0});
fc_gate.add_outputs({out0});

// fc_up: wei1 is non-contiguous now.
auto wei1 = logical_tensor(id++, dt, wei0_sz, combined_wei0_st);
auto out1 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out1 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto fc_up = op(id++, op::kind::MatMul, "fc_up");
fc_up.add_inputs({src, wei1});
fc_up.add_outputs({out1});

// activation swish: sigmoid
auto out2 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out2 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto swi_sig = op(id++, op::kind::Sigmoid, "swish/sigmoid");
swi_sig.add_inputs({out0});
swi_sig.add_outputs({out2});

// activation swish: multiply
auto out3 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out3 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto swi_mul = op(id++, op::kind::Multiply, "swish/multiply");
swi_mul.add_inputs({out0, out2});
swi_mul.add_outputs({out3});

// multiplication
auto out4 = logical_tensor(id++, dt, hd_sz, layout_type::strided);
auto out4 = logical_tensor(id++, dt_inter, hd_sz, layout_type::strided);
auto mul = op(id++, op::kind::Multiply, "mul");
mul.add_inputs({out3, out1});
mul.add_outputs({out4});

// downconversion when needed
auto out4_dt = out4;
auto typecast = op(id++, op::kind::TypeCast, "typecast");
if (dt != dt_inter) {
out4_dt = logical_tensor(id++, dt, hd_sz, layout_type::strided);
typecast.add_inputs({out4});
typecast.add_outputs({out4_dt});
}

// fc_down
auto wei2 = logical_tensor(id++, dt, wei2_sz, layout_type::strided);
auto dst = logical_tensor(id++, dt, out_sz, layout_type::strided);
auto fc_down = op(id++, op::kind::MatMul, "fc_down");
fc_down.add_inputs({out4, wei2});
fc_down.add_inputs({out4_dt, wei2});
fc_down.add_outputs({dst});

// Construct a gated mlp graph with engine kind and operations.
Expand All @@ -178,6 +190,7 @@ void bench_gated_mlp(engine::kind ekind, logical_tensor::data_type dt,
mlp.add_op(swi_sig);
mlp.add_op(swi_mul);
mlp.add_op(mul);
if (dt != dt_inter) mlp.add_op(typecast);
mlp.add_op(fc_down);
mlp.finalize();

Expand Down
Loading
Loading