From cbe271f456f9944883c6b3107678c5204743b442 Mon Sep 17 00:00:00 2001 From: syurkevi Date: Tue, 24 Feb 2026 13:49:40 -0800 Subject: [PATCH 01/23] common: sdpa: adds backwards sdpa primitives --- src/common/memory_tracking.hpp | 4 + src/common/primitive_hashing.cpp | 6 + src/common/primitive_serialization.cpp | 6 + src/common/sdpa_pd.hpp | 295 +++++++++++++++++++------ src/common/sdpa_types.hpp | 15 ++ src/gpu/intel/sdpa/ref.hpp | 4 +- 6 files changed, 259 insertions(+), 71 deletions(-) diff --git a/src/common/memory_tracking.hpp b/src/common/memory_tracking.hpp index 32309b48ca4..54e199e2740 100644 --- a/src/common/memory_tracking.hpp +++ b/src/common/memory_tracking.hpp @@ -325,6 +325,10 @@ enum { key_rnn_ptrs_wei_iter, key_rnn_ptrs_wei_projection, key_shuffle_precompute_transpose, + key_sdpa_Di, + key_sdpa_dQ_reduction, + key_sdpa_dK_reduction, + key_sdpa_dV_reduction, key_softmax_dst_scales, key_softmax_reduction, key_softmax_interim_store, diff --git a/src/common/primitive_hashing.cpp b/src/common/primitive_hashing.cpp index ac15ebc3e31..38ba0367414 100644 --- a/src/common/primitive_hashing.cpp +++ b/src/common/primitive_hashing.cpp @@ -735,6 +735,7 @@ size_t get_desc_hash(const sdpa_desc_t &desc) { size_t seed = 0; // Kinds seed = hash_combine(seed, static_cast(desc.primitive_kind)); + seed = hash_combine(seed, static_cast(desc.prop_kind)); // Memory descriptors seed = hash_combine(seed, get_md_hash(desc.q_desc)); seed = hash_combine(seed, get_md_hash(desc.k_desc)); @@ -743,7 +744,12 @@ size_t get_desc_hash(const sdpa_desc_t &desc) { seed = hash_combine(seed, desc.kq_zero_points.get_hash()); seed = hash_combine(seed, desc.vs_scales.get_hash()); seed = hash_combine(seed, desc.vs_zero_points.get_hash()); + seed = hash_combine(seed, get_md_hash(desc.dS_desc)); seed = hash_combine(seed, get_md_hash(desc.dst_desc)); + seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc)); + seed = hash_combine(seed, get_md_hash(desc.diff_q_desc)); + seed = hash_combine(seed, get_md_hash(desc.diff_k_desc)); + seed = hash_combine(seed, get_md_hash(desc.diff_v_desc)); seed = hash_combine(seed, get_md_hash(desc.attn_mask_desc)); seed = hash_combine(seed, get_md_hash(desc.scale_desc)); // Scale type diff --git a/src/common/primitive_serialization.cpp b/src/common/primitive_serialization.cpp index 1778c0af743..fe3ee400767 100644 --- a/src/common/primitive_serialization.cpp +++ b/src/common/primitive_serialization.cpp @@ -581,6 +581,7 @@ void serialize(serialization_stream_t &sstream, const sum_desc_t &desc) { void serialize(serialization_stream_t &sstream, const sdpa_desc_t &desc) { // Kind sstream.append(desc.primitive_kind); + sstream.append(desc.prop_kind); serialize(sstream, desc.q_desc); serialize(sstream, desc.k_desc); serialize(sstream, desc.v_desc); @@ -588,7 +589,12 @@ void serialize(serialization_stream_t &sstream, const sdpa_desc_t &desc) { desc.kq_zero_points.serialize(sstream); desc.vs_scales.serialize(sstream); desc.vs_zero_points.serialize(sstream); + serialize(sstream, desc.dS_desc); serialize(sstream, desc.dst_desc); + serialize(sstream, desc.diff_dst_desc); + serialize(sstream, desc.diff_q_desc); + serialize(sstream, desc.diff_k_desc); + serialize(sstream, desc.diff_v_desc); serialize(sstream, desc.attn_mask_desc); serialize(sstream, desc.scale_desc); sstream.append(desc.kq_acc_dt); diff --git a/src/common/sdpa_pd.hpp b/src/common/sdpa_pd.hpp index dbccef487c1..9eef733b0e6 100644 --- a/src/common/sdpa_pd.hpp +++ b/src/common/sdpa_pd.hpp @@ -37,71 +37,29 @@ namespace impl { this->info(engine), ##__VA_ARGS__) // NOLINTBEGIN(google-default-arguments) + +struct sdpa_fwd_pd_t; + struct sdpa_pd_t : public primitive_desc_t { static constexpr auto base_pkind = primitive_kind::sdpa; + static constexpr int mask_mb_index = 0; + static constexpr int mask_q_index = 2; + static constexpr int mask_k_index = 3; + static constexpr int ndims = 4; + using base_class = sdpa_pd_t; - using hint_class = sdpa_pd_t; + using hint_class = sdpa_fwd_pd_t; const sdpa_desc_t *desc() const { return &desc_; } const op_desc_t *op_desc() const override { return reinterpret_cast(this->desc()); } - arg_usage_t arg_usage(int arg) const override { - // TODO: this is broken for cases when the user passes quantization - // memories unconditionally but the primitive desc is not set up for - // quantization. - if (utils::one_of(arg, DNNL_ARG_QUERIES, DNNL_ARG_KEYS, DNNL_ARG_VALUES, - DNNL_ARG_ATTN_MASK, DNNL_ARG_SCALE, - DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS, - DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES, - DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS, - DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES)) - return arg_usage_t::input; - - if (arg == DNNL_ARG_DST) return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); - } - - const memory_desc_t *arg_md( - int arg, bool user_input = false) const override { - switch (arg) { - case DNNL_ARG_QUERIES: return src_md(0); - case DNNL_ARG_KEYS: return src_md(1); - case DNNL_ARG_VALUES: return src_md(2); - case DNNL_ARG_ATTN_MASK: return src_md(3); - case DNNL_ARG_DST: return dst_md(0, user_input); - default: return primitive_desc_t::arg_md(arg); - } - } - - const memory_desc_t *src_md( - int index = 0, bool user_input = false) const override { - switch (index) { - case 0: return &desc_.q_desc; - case 1: return &desc_.k_desc; - case 2: return &desc_.v_desc; - case 3: return &desc_.attn_mask_desc; - default: return &glob_zero_md; - } - } - const memory_desc_t *dst_md( - int index = 0, bool user_input = false) const override { - return index == 0 ? &desc_.dst_desc : &glob_zero_md; - } - - const memory_desc_t *qry_md() const { return &desc_.q_desc; } - const memory_desc_t *key_md() const { return &desc_.k_desc; } - const memory_desc_t *val_md() const { return &desc_.v_desc; } - const memory_desc_t *attn_mask_md() const { return &desc_.attn_mask_desc; } - const memory_desc_t *scale_md() const { return &desc_.scale_desc; } - - int n_inputs() const override { - return 3 + int(with_attn_mask()) + int(with_attn_scale()); + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); } - int n_outputs() const override { return 1; } bool with_attn_scale() const { return (scale_md()->data_type != data_type::undef); @@ -115,6 +73,10 @@ struct sdpa_pd_t : public primitive_desc_t { return (attn_mask_md()->data_type != data_type::undef); } + bool with_dS() const { + return (desc_.dS_desc.data_type != data_type::undef); + } + /// Returns the accumulation data type of the KQ matmul data_type_t kq_acc_dt() const { return desc()->kq_acc_dt; } @@ -189,13 +151,31 @@ struct sdpa_pd_t : public primitive_desc_t { return out; } + const memory_desc_t *qry_md() const { return &desc_.q_desc; } + const memory_desc_t *key_md() const { return &desc_.k_desc; } + const memory_desc_t *val_md() const { return &desc_.v_desc; } + const memory_desc_t *attn_mask_md() const { return &desc_.attn_mask_desc; } + const memory_desc_t *scale_md() const { return &desc_.scale_desc; } + protected: sdpa_desc_t desc_; + const sdpa_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t ws_md_; sdpa_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, const hint_class *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) - , desc_(*op_desc_t::to_desc(adesc)) {} + , desc_(*op_desc_t::to_desc(adesc)) + , hint_fwd_pd_(hint_fwd_pd) {} + + void init_default_ws() { + dims_t d; + d[0] = desc()->batch_size() + * desc()->queries(); // (logsumexp) per query + + memory_desc_init_by_tag(ws_md_, 1, d, data_type::f32, format_tag::a); + } bool set_default_format(memory_desc_t *md) { memory_desc_wrapper mdw(md); @@ -204,20 +184,6 @@ struct sdpa_pd_t : public primitive_desc_t { return true; } - bool set_default_formats() { - bool ok = true; - - for (auto md : {&desc_.q_desc, &desc_.k_desc, &desc_.v_desc, - &desc_.dst_desc}) { - ok = ok && set_default_format(md); - } - - auto status = attr_.post_ops_.set_default_formats(&desc_.dst_desc); - ok = ok && (status == status::success); - - return ok; - } - private: static int group_size( const quant_entry_t &scales, const memory_desc_t &desc) { @@ -239,6 +205,197 @@ struct sdpa_pd_t : public primitive_desc_t { return static_cast(out); } }; + +struct sdpa_fwd_pd_t : public sdpa_pd_t { + using base_class = sdpa_fwd_pd_t; + using hint_class = sdpa_fwd_pd_t; + + arg_usage_t arg_usage(int arg) const override { + // TODO: this is broken for cases when the user passes quantization + // memories unconditionally but the primitive desc is not set up for + // quantization. + if (utils::one_of(arg, DNNL_ARG_QUERIES, DNNL_ARG_KEYS, DNNL_ARG_VALUES, + DNNL_ARG_ATTN_MASK, DNNL_ARG_SCALE, + DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS, + DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES, + DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS, + DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES)) + return arg_usage_t::input; + + if (arg == DNNL_ARG_DST) return arg_usage_t::output; + + if (arg == DNNL_ARG_WORKSPACE) + return !types::is_zero_md(workspace_md()) ? arg_usage_t::output + : arg_usage_t::unused; + + return primitive_desc_t::arg_usage(arg); + } + + const memory_desc_t *arg_md( + int arg, bool user_input = false) const override { + switch (arg) { + case DNNL_ARG_QUERIES: return src_md(0); + case DNNL_ARG_KEYS: return src_md(1); + case DNNL_ARG_VALUES: return src_md(2); + case DNNL_ARG_ATTN_MASK: return src_md(3); + case DNNL_ARG_DST: return dst_md(0, user_input); + default: return primitive_desc_t::arg_md(arg); + } + } + + const memory_desc_t *src_md( + int index = 0, bool user_input = false) const override { + switch (index) { + case 0: return &desc_.q_desc; + case 1: return &desc_.k_desc; + case 2: return &desc_.v_desc; + case 3: return &desc_.attn_mask_desc; + default: return &glob_zero_md; + } + } + const memory_desc_t *dst_md( + int index = 0, bool user_input = false) const override { + return index == 0 ? &desc_.dst_desc : &glob_zero_md; + } + const memory_desc_t *workspace_md(int index = 0) const override { + return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ + : &glob_zero_md; + } + + int n_inputs() const override { + return 3 + int(with_attn_mask()) + int(with_attn_scale()); + } + int n_outputs() const override { + return 1 + (!types::is_zero_md(workspace_md())); + } + +protected: + sdpa_fwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, + const hint_class *hint_fwd_pd) + : sdpa_pd_t(adesc, attr, hint_fwd_pd) {} + + bool set_default_formats() { + bool ok = true; + + for (auto md : {&desc_.q_desc, &desc_.k_desc, &desc_.v_desc, + &desc_.dst_desc}) { + ok = ok && set_default_format(md); + } + + auto status = attr_.post_ops_.set_default_formats(&desc_.dst_desc); + ok = ok && (status == status::success); + + return ok; + } +}; + +struct sdpa_bwd_pd_t : public sdpa_pd_t { + using base_class = sdpa_bwd_pd_t; + using hint_class = sdpa_fwd_pd_t; + + arg_usage_t arg_usage(int arg) const override { + if (utils::one_of(arg, DNNL_ARG_QUERIES, DNNL_ARG_KEYS, DNNL_ARG_VALUES, + DNNL_ARG_DST, DNNL_ARG_DIFF_DST, DNNL_ARG_ATTN_MASK, + DNNL_ARG_SCALE)) + return arg_usage_t::input; + + if (utils::one_of(arg, DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS, + DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES, + DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS, + DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES)) + return arg_usage_t::unused; + + if (utils::one_of(arg, DNNL_ARG_DIFF_QUERIES, DNNL_ARG_DIFF_KEYS, + DNNL_ARG_DIFF_VALUES)) + return arg_usage_t::output; + + if (arg == DNNL_ARG_DS) + return with_dS() ? arg_usage_t::output : arg_usage_t::unused; + + if (arg == DNNL_ARG_WORKSPACE) + return !types::is_zero_md(workspace_md()) ? arg_usage_t::input + : arg_usage_t::unused; + + return primitive_desc_t::arg_usage(arg); + } + + const memory_desc_t *arg_md( + int arg, bool user_input = false) const override { + switch (arg) { + case DNNL_ARG_QUERIES: return src_md(0); + case DNNL_ARG_KEYS: return src_md(1); + case DNNL_ARG_VALUES: return src_md(2); + case DNNL_ARG_ATTN_MASK: return src_md(3); + case DNNL_ARG_DST: return src_md(4); + case DNNL_ARG_DIFF_DST: return src_md(5); + case DNNL_ARG_DIFF_QUERIES: return dst_md(0, user_input); + case DNNL_ARG_DIFF_KEYS: return dst_md(1, user_input); + case DNNL_ARG_DIFF_VALUES: return dst_md(2, user_input); + case DNNL_ARG_DS: return dst_md(3, user_input); + default: return primitive_desc_t::arg_md(arg); + } + } + + const memory_desc_t *src_md( + int index = 0, bool user_input = false) const override { + switch (index) { + case 0: return &desc_.q_desc; + case 1: return &desc_.k_desc; + case 2: return &desc_.v_desc; + case 3: return &desc_.attn_mask_desc; + case 4: return &desc_.dst_desc; + case 5: return &desc_.diff_dst_desc; + default: return &glob_zero_md; + } + } + const memory_desc_t *dst_md( + int index = 0, bool user_input = false) const override { + switch (index) { + case 0: return &desc_.diff_q_desc; + case 1: return &desc_.diff_k_desc; + case 2: return &desc_.diff_v_desc; + case 3: return &desc_.dS_desc; + default: return &glob_zero_md; + } + } + const memory_desc_t *workspace_md(int index = 0) const override { + return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ + : &glob_zero_md; + } + + int n_inputs() const override { + // Q, K, V, O, dO + return 5 + int(with_attn_mask()) + int(with_attn_scale()) + + int(!types::is_zero_md(workspace_md())); + } + int n_outputs() const override { return 3 + int(with_dS()); } + + const memory_desc_t *diff_qry_md() const { return &desc_.diff_q_desc; } + const memory_desc_t *diff_key_md() const { return &desc_.diff_k_desc; } + const memory_desc_t *diff_val_md() const { return &desc_.diff_v_desc; } + const memory_desc_t *diff_dst_md() const { return &desc_.diff_dst_desc; } + +protected: + sdpa_bwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, + const hint_class *hint_fwd_pd) + : sdpa_pd_t(adesc, attr, hint_fwd_pd) {} + + bool set_default_formats() { + bool ok = true; + + for (auto md : {&desc_.q_desc, &desc_.k_desc, &desc_.v_desc, + &desc_.dst_desc, &desc_.diff_dst_desc, &desc_.diff_q_desc, + &desc_.diff_k_desc, &desc_.diff_v_desc}) { + ok = ok && set_default_format(md); + } + + auto status = attr_.post_ops_.set_default_formats(&desc_.dst_desc); + ok = ok && (status == status::success); + + return ok; + } +}; + // NOLINTEND(google-default-arguments) } // namespace impl diff --git a/src/common/sdpa_types.hpp b/src/common/sdpa_types.hpp index 7e836812f33..95898979ad6 100644 --- a/src/common/sdpa_types.hpp +++ b/src/common/sdpa_types.hpp @@ -33,6 +33,11 @@ namespace impl { #define DNNL_ARG_VALUES DNNL_ARG_SRC_2 #define DNNL_ARG_ATTN_MASK DNNL_ARG_SHIFT +#define DNNL_ARG_DIFF_QUERIES DNNL_ARG_DIFF_SRC_0 +#define DNNL_ARG_DIFF_KEYS DNNL_ARG_DIFF_SRC_1 +#define DNNL_ARG_DIFF_VALUES DNNL_ARG_DIFF_SRC_2 +#define DNNL_ARG_DS DNNL_ARG_DIFF_SRC_3 + // NOLINTBEGIN(modernize-use-using) /// Types of attention mask typedef enum { @@ -66,6 +71,8 @@ struct sdpa_desc_t : public op_desc_t { return utils::make_unique(*this); } + prop_kind_t prop_kind {}; + memory_desc_t q_desc; /* queries */ memory_desc_t k_desc; /* keys */ memory_desc_t v_desc; /* values */ @@ -77,7 +84,13 @@ struct sdpa_desc_t : public op_desc_t { quant_entry_t vs_scales; quant_entry_t vs_zero_points; + memory_desc_t dS_desc; + memory_desc_t dst_desc; + memory_desc_t diff_dst_desc; + memory_desc_t diff_q_desc; + memory_desc_t diff_k_desc; + memory_desc_t diff_v_desc; memory_desc_t attn_mask_desc; memory_desc_t scale_desc; data_type_t kq_acc_dt {}; @@ -98,6 +111,8 @@ struct sdpa_desc_t : public op_desc_t { dnnl_dim_t keys() const { return k_desc.dims[k_desc.ndims - 1]; } // Number of values. dnnl_dim_t values() const { return v_desc.dims[v_desc.ndims - 1]; } + dim_t num_q_heads() const { return q_desc.dims[1]; } + dim_t num_kv_heads() const { return kv_head_number; } // Total batch size. dnnl_dim_t batch_size() const { dnnl_dim_t batch = 1; diff --git a/src/gpu/intel/sdpa/ref.hpp b/src/gpu/intel/sdpa/ref.hpp index 07f1705e1ba..3856dc2dd24 100644 --- a/src/gpu/intel/sdpa/ref.hpp +++ b/src/gpu/intel/sdpa/ref.hpp @@ -29,8 +29,8 @@ namespace sdpa { struct ref_t : public primitive_t { using primitive_t::primitive_t; - struct pd_t : public sdpa::pd_t { - using sdpa::pd_t::pd_t; + struct pd_t : public sdpa_fwd_pd_t { + using sdpa_fwd_pd_t::sdpa_fwd_pd_t; DECLARE_COMMON_PD_T("ocl:ref:any", ref_t); From fa301291afa27973f619e0b77168d040e739af24 Mon Sep 17 00:00:00 2001 From: syurkevi Date: Tue, 24 Feb 2026 13:56:20 -0800 Subject: [PATCH 02/23] common: sdpa: prepares internal primitive for bwd testing --- src/common/sdpa_test_iface.cpp | 34 +++++++++- src/common/sdpa_utils.hpp | 42 +++++++++++- tests/gtests/internals/sdpa_internal.hpp | 84 +++++++++++++++++++++--- 3 files changed, 146 insertions(+), 14 deletions(-) diff --git a/src/common/sdpa_test_iface.cpp b/src/common/sdpa_test_iface.cpp index e04b0dcd5ff..759d4066804 100644 --- a/src/common/sdpa_test_iface.cpp +++ b/src/common/sdpa_test_iface.cpp @@ -32,7 +32,7 @@ dnnl_status_t DNNL_API sdpa_primitive_desc_create( bool invert_scale, dnnl_dim_t kv_head_number, int attn_mask_type, dnnl_alg_kind_t softmax_alg, const_dnnl_primitive_attr_t attr, const_dnnl_primitive_attr_t kq_attr, - const_dnnl_primitive_attr_t vs_attr) { + const_dnnl_primitive_attr_t vs_attr, prop_kind_t prop) { CHECK(sdpa_desc_check(query_desc, key_desc, value_desc, dst_desc, mask_desc, engine, attr, kq_attr, vs_attr)); CHECK(sdpa_attr_check( @@ -41,7 +41,37 @@ dnnl_status_t DNNL_API sdpa_primitive_desc_create( dnnl::impl::sdpa_desc_t sdpa_desc = dnnl::impl::create_sdpa_desc(query_desc, key_desc, value_desc, dst_desc, mask_desc, scale_desc, invert_scale, kv_head_number, static_cast(attn_mask_type), - softmax_alg, kq_attr, vs_attr); + softmax_alg, kq_attr, vs_attr, prop); return dnnl::impl::primitive_desc_create(primitive_desc_iface, engine, (const dnnl::impl::op_desc_t *)&sdpa_desc, nullptr, attr); } + +dnnl_status_t DNNL_API sdpa_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine, + const_dnnl_memory_desc_t query_desc, const_dnnl_memory_desc_t key_desc, + const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc, + const_dnnl_memory_desc_t diff_query_desc, + const_dnnl_memory_desc_t diff_key_desc, + const_dnnl_memory_desc_t diff_value_desc, + const_dnnl_memory_desc_t diff_dst_desc, + const_dnnl_memory_desc_t dS_desc, const_dnnl_memory_desc_t mask_desc, + const_dnnl_memory_desc_t scale_desc, bool invert_scale, + dnnl_dim_t kv_head_number, int attn_mask_type, + dnnl_alg_kind_t softmax_alg, const_dnnl_primitive_attr_t attr, + const_dnnl_primitive_attr_t kq_attr, + const_dnnl_primitive_attr_t vs_attr, + const_dnnl_primitive_desc_t hint_fwd_pd = nullptr) { + CHECK(sdpa_desc_check(query_desc, key_desc, value_desc, dst_desc, mask_desc, + engine, attr, kq_attr, vs_attr)); + CHECK(sdpa_attr_check( + query_desc, key_desc, value_desc, engine, attr, kq_attr, vs_attr)); + + dnnl::impl::sdpa_desc_t sdpa_desc = dnnl::impl::create_sdpa_desc(query_desc, + key_desc, value_desc, dst_desc, diff_query_desc, diff_key_desc, + diff_value_desc, diff_dst_desc, dS_desc, mask_desc, scale_desc, + invert_scale, kv_head_number, + static_cast(attn_mask_type), softmax_alg, kq_attr, + vs_attr); + return dnnl::impl::primitive_desc_create(primitive_desc_iface, engine, + (const dnnl::impl::op_desc_t *)&sdpa_desc, hint_fwd_pd, attr); +} diff --git a/src/common/sdpa_utils.hpp b/src/common/sdpa_utils.hpp index 3354c09b943..8cc015288d5 100644 --- a/src/common/sdpa_utils.hpp +++ b/src/common/sdpa_utils.hpp @@ -153,7 +153,8 @@ static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md, const memory_desc_t *dst_md, const memory_desc_t *attn_mask_md, const memory_desc_t *scale_md, bool invert_scale, dim_t kv_head_number, attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg, - const primitive_attr_t *kq_attr, const primitive_attr_t *vs_attr) { + const primitive_attr_t *kq_attr, const primitive_attr_t *vs_attr, + prop_kind_t prop) { auto sdpa_desc = sdpa_desc_t(); sdpa_desc.primitive_kind = primitive_kind::sdpa; sdpa_desc.q_desc = *q_md; @@ -182,6 +183,40 @@ static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md, sdpa_desc.kv_head_number = kv_head_number; sdpa_desc.mask_type = attn_mask_type; sdpa_desc.softmax_alg = softmax_alg; + sdpa_desc.prop_kind = prop; + return sdpa_desc; +} + +static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md, + const memory_desc_t *k_md, const memory_desc_t *v_md, + const memory_desc_t *dst_md, const memory_desc_t *diff_q_md, + const memory_desc_t *diff_k_md, const memory_desc_t *diff_v_md, + const memory_desc_t *diff_dst_md, const memory_desc_t *dS_md, + const memory_desc_t *attn_mask_md, const memory_desc_t *scale_md, + bool invert_scale, dim_t kv_head_number, + attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg, + const primitive_attr_t *kq_attr, const primitive_attr_t *vs_attr) { + auto sdpa_desc = sdpa_desc_t(); + sdpa_desc.primitive_kind = primitive_kind::sdpa; + sdpa_desc.q_desc = *q_md; + sdpa_desc.k_desc = *k_md; + sdpa_desc.v_desc = *v_md; + + sdpa_desc.dst_desc = *dst_md; + sdpa_desc.diff_dst_desc = *diff_dst_md; + + if (dS_md) sdpa_desc.dS_desc = *dS_md; + + sdpa_desc.diff_q_desc = *diff_q_md; + sdpa_desc.diff_k_desc = *diff_k_md; + sdpa_desc.diff_v_desc = *diff_v_md; + if (attn_mask_md) sdpa_desc.attn_mask_desc = *attn_mask_md; + sdpa_desc.scale_desc = *scale_md; + sdpa_desc.invert_scale = invert_scale; + sdpa_desc.kv_head_number = kv_head_number; + sdpa_desc.mask_type = attn_mask_type; + sdpa_desc.softmax_alg = softmax_alg; + sdpa_desc.prop_kind = prop_kind::backward; return sdpa_desc; } @@ -193,14 +228,15 @@ static inline status_t create_sdpa_pd( bool invert_scale, dim_t kv_head_number, attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg, const primitive_attr_t *attr, const primitive_attr_t *kq_attr = nullptr, - const primitive_attr_t *vs_attr = nullptr) { + const primitive_attr_t *vs_attr = nullptr, + prop_kind_t prop = prop_kind::forward_inference) { CHECK(sdpa_attr_check(q_md, k_md, v_md, engine, attr, kq_attr, vs_attr)); CHECK(sdpa_desc_check(q_md, k_md, v_md, dst_md, attn_mask_md, engine, attr, kq_attr, vs_attr)); auto sdpa_desc = create_sdpa_desc(q_md, k_md, v_md, dst_md, attn_mask_md, scale_md, invert_scale, kv_head_number, attn_mask_type, softmax_alg, - kq_attr, vs_attr); + kq_attr, vs_attr, prop); primitive_attr_t sdpa_attr = attr ? *attr : default_attr(); diff --git a/tests/gtests/internals/sdpa_internal.hpp b/tests/gtests/internals/sdpa_internal.hpp index 8ab35a07884..901eec4a383 100644 --- a/tests/gtests/internals/sdpa_internal.hpp +++ b/tests/gtests/internals/sdpa_internal.hpp @@ -44,7 +44,22 @@ dnnl_status_t DNNL_API sdpa_primitive_desc_create( bool invert_scale, dnnl_dim_t kv_head_number, int attn_mask_type, dnnl_alg_kind_t softmax_alg, const_dnnl_primitive_attr_t attr, const_dnnl_primitive_attr_t kq_attr, - const_dnnl_primitive_attr_t vs_attr); + const_dnnl_primitive_attr_t vs_attr, dnnl_prop_kind_t prop); + +dnnl_status_t DNNL_API sdpa_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine, + const_dnnl_memory_desc_t query_desc, const_dnnl_memory_desc_t key_desc, + const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc, + const_dnnl_memory_desc_t diff_query_desc, + const_dnnl_memory_desc_t diff_key_desc, + const_dnnl_memory_desc_t diff_value_desc, + const_dnnl_memory_desc_t diff_dst_desc, + const_dnnl_memory_desc_t dS_desc, const_dnnl_memory_desc_t mask_desc, + dnnl_data_type_t scale_dt, bool invert_scale, dnnl_dim_t kv_head_number, + int attn_mask_type, dnnl_alg_kind_t softmax_alg, + const_dnnl_primitive_attr_t attr, const_dnnl_primitive_attr_t kq_attr, + const_dnnl_primitive_attr_t vs_attr, + const dnnl_primitive_desc *hint_fwd_pd); namespace dnnl { namespace impl { @@ -65,16 +80,17 @@ struct sdpa : public dnnl::primitive { int attn_mask_type, int softmax_alg, const primitive_attr &attr = default_attr(), const primitive_attr &kq_attr = default_attr(), - const primitive_attr &vs_attr = default_attr()) { + const primitive_attr &vs_attr = default_attr(), + prop_kind_t prop_kind = prop_kind::forward_inference) { dnnl_primitive_desc_t pd = nullptr; - dnnl_status_t status - = sdpa_primitive_desc_create(&pd, aengine.get(), - query_desc.get(), key_desc.get(), value_desc.get(), - output_desc.get(), optional_arg(attn_mask_desc), - scale_desc.get(), invert_scale, kv_head_number, - attn_mask_type, (dnnl_alg_kind_t)softmax_alg, - attr.get(), kq_attr.get(), vs_attr.get()); + dnnl_status_t status = sdpa_primitive_desc_create(&pd, + aengine.get(), query_desc.get(), key_desc.get(), + value_desc.get(), output_desc.get(), + optional_arg(attn_mask_desc), scale_desc.get(), + invert_scale, kv_head_number, attn_mask_type, + (dnnl_alg_kind_t)softmax_alg, attr.get(), kq_attr.get(), + vs_attr.get(), (prop_kind_t)prop_kind); dnnl::error::wrap_c_api(status, "could not create a primitive descriptor for a sdpa " @@ -90,6 +106,56 @@ struct sdpa : public dnnl::primitive { /// @param pd Primitive descriptor for a sdpa primitive. sdpa(const primitive_desc &pd) : primitive(pd) {} }; + +/// Scaled Dot Product Attention (sdpa) backward propagation internal primitive. +/// Implementing internally for more flexible validation +struct sdpa_backward : public dnnl::primitive { + /// Primitive descriptor for a sdpa_backward primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + primitive_desc(const engine &aengine, const memory::desc &query_desc, + const memory::desc &key_desc, const memory::desc &value_desc, + const memory::desc *attn_mask_desc, memory::data_type scale_dt, + const memory::desc &output_desc, + const memory::desc &diff_query_desc, + const memory::desc &diff_key_desc, + const memory::desc &diff_value_desc, + const memory::desc &diff_output_desc, + const memory::desc *dS_desc, bool invert_scale, + memory::dim kv_head_number, int attn_mask_type, int softmax_alg, + const sdpa::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr(), + const primitive_attr &kq_attr = default_attr(), + const primitive_attr &vs_attr = default_attr()) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = sdpa_primitive_desc_create(&pd, + aengine.get(), query_desc.get(), key_desc.get(), + value_desc.get(), output_desc.get(), diff_query_desc.get(), + diff_key_desc.get(), diff_value_desc.get(), + diff_output_desc.get(), dS_desc ? dS_desc->get() : nullptr, + optional_arg(attn_mask_desc), (dnnl_data_type_t)scale_dt, + invert_scale, kv_head_number, attn_mask_type, + (dnnl_alg_kind_t)softmax_alg, attr.get(), kq_attr.get(), + vs_attr.get(), hint_fwd_pd.get()); + + dnnl::error::wrap_c_api(status, + "could not create a primitive descriptor for a sdpa " + "primitive"); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + sdpa_backward() = default; + + /// Constructs a sdpa primitive. + /// @param pd Primitive descriptor for a sdpa primitive. + sdpa_backward(const primitive_desc &pd) : primitive(pd) {} +}; + } // namespace impl } // namespace dnnl From c6222686ee1af162b25b84836c59be3a8b18d40c Mon Sep 17 00:00:00 2001 From: syurkevi Date: Tue, 24 Feb 2026 14:01:45 -0800 Subject: [PATCH 03/23] gtests: internals: adds sdpa training tests --- tests/gtests/internals/sdpa_internal.hpp | 13 +- tests/gtests/internals/test_sdpa.cpp | 1060 +++++++++++++++++++++- 2 files changed, 1041 insertions(+), 32 deletions(-) diff --git a/tests/gtests/internals/sdpa_internal.hpp b/tests/gtests/internals/sdpa_internal.hpp index 901eec4a383..1f3a23957c9 100644 --- a/tests/gtests/internals/sdpa_internal.hpp +++ b/tests/gtests/internals/sdpa_internal.hpp @@ -55,9 +55,10 @@ dnnl_status_t DNNL_API sdpa_primitive_desc_create( const_dnnl_memory_desc_t diff_value_desc, const_dnnl_memory_desc_t diff_dst_desc, const_dnnl_memory_desc_t dS_desc, const_dnnl_memory_desc_t mask_desc, - dnnl_data_type_t scale_dt, bool invert_scale, dnnl_dim_t kv_head_number, - int attn_mask_type, dnnl_alg_kind_t softmax_alg, - const_dnnl_primitive_attr_t attr, const_dnnl_primitive_attr_t kq_attr, + const_dnnl_memory_desc_t scale_desc, bool invert_scale, + dnnl_dim_t kv_head_number, int attn_mask_type, + dnnl_alg_kind_t softmax_alg, const_dnnl_primitive_attr_t attr, + const_dnnl_primitive_attr_t kq_attr, const_dnnl_primitive_attr_t vs_attr, const dnnl_primitive_desc *hint_fwd_pd); @@ -117,8 +118,8 @@ struct sdpa_backward : public dnnl::primitive { primitive_desc(const engine &aengine, const memory::desc &query_desc, const memory::desc &key_desc, const memory::desc &value_desc, - const memory::desc *attn_mask_desc, memory::data_type scale_dt, - const memory::desc &output_desc, + const memory::desc *attn_mask_desc, + const memory::desc &scale_desc, const memory::desc &output_desc, const memory::desc &diff_query_desc, const memory::desc &diff_key_desc, const memory::desc &diff_value_desc, @@ -136,7 +137,7 @@ struct sdpa_backward : public dnnl::primitive { value_desc.get(), output_desc.get(), diff_query_desc.get(), diff_key_desc.get(), diff_value_desc.get(), diff_output_desc.get(), dS_desc ? dS_desc->get() : nullptr, - optional_arg(attn_mask_desc), (dnnl_data_type_t)scale_dt, + optional_arg(attn_mask_desc), scale_desc.get(), invert_scale, kv_head_number, attn_mask_type, (dnnl_alg_kind_t)softmax_alg, attr.get(), kq_attr.get(), vs_attr.get(), hint_fwd_pd.get()); diff --git a/tests/gtests/internals/test_sdpa.cpp b/tests/gtests/internals/test_sdpa.cpp index f786eef95de..d8e491ac27e 100644 --- a/tests/gtests/internals/test_sdpa.cpp +++ b/tests/gtests/internals/test_sdpa.cpp @@ -28,6 +28,8 @@ #include #include +#define DEBUG_PRINT_MEM 0 + using mdt = memory::data_type; using dnnl::accumulation_mode; @@ -267,6 +269,9 @@ struct sdpa_tensors_t { memory m_scale; // tested sdpa arg, can be host-side scalar memory m_scale_prim; // reference (prim) sdpa arg + memory m_diff_query, m_diff_key, m_diff_value, m_diff_output, m_dS; + memory m_diff_query_quantized, m_diff_key_quantized, m_diff_value_quantized; + memory m_key_scales, m_key_zp, m_value_scales, m_value_zp; dnnl::primitive_attr sdpa_attr_quantized, sdpa_kq_attr_quantized, sdpa_vs_attr_quantized; @@ -519,6 +524,7 @@ sdpa_tensors_t get_descriptors(dnnl::engine &eng, dnnl::stream &strm, = {p.mb, p.heads.kv, p.seq_len.kv * 2, p.head_group.head_size}; const memory::dims v_sz = {p.mb, p.heads.kv, p.seq_len.kv, p.head_group.head_size}; + const memory::dims dS_sz = {p.mb, p.heads.q, p.seq_len.q, p.seq_len.kv}; const memory::dims scale_sz = {1, 1, 1, 1}; bool with_host_scale = p.stype == scale_type::host_side; @@ -602,24 +608,45 @@ sdpa_tensors_t get_descriptors(dnnl::engine &eng, dnnl::stream &strm, auto mask_md = memory::desc(mask_sz, p.mask.dt != mdt::undef ? p.mask.dt : p.dt.dt, abcd); auto output_md = memory::desc(q_sz, p.dt.dt, abcd); auto output_quantized_md = memory::desc(q_sz, p.dt.dt, abcd); + + auto dS_md = memory::desc(dS_sz, p.dt.dt, abcd); // clang-format on // Create memory objects out.m_query = double_and_resize(query_md, eng, strm, doubled_memory); + out.m_diff_query = double_and_resize(query_md, eng, strm, doubled_memory); + out.m_diff_query_quantized + = double_and_resize(query_md, eng, strm, doubled_memory); + out.m_key_quantized = double_and_resize(key_quantized_md, eng, strm, doubled_memory); + out.m_diff_key + = double_and_resize(key_quantized_md, eng, strm, doubled_memory); + out.m_diff_key_quantized + = double_and_resize(key_quantized_md, eng, strm, doubled_memory); + out.m_key_scales = double_and_resize(key_scales_md, eng, strm, doubled_memory); out.m_key_zp = double_and_resize(key_zp_md, eng, strm, doubled_memory); + out.m_value_quantized = double_and_resize(val_quantized_md, eng, strm, doubled_memory); + out.m_diff_value + = double_and_resize(val_quantized_md, eng, strm, doubled_memory); + out.m_diff_value_quantized + = double_and_resize(val_quantized_md, eng, strm, doubled_memory); + out.m_value_scales = double_and_resize(val_scales_md, eng, strm, doubled_memory); out.m_value_zp = double_and_resize(val_zp_md, eng, strm, doubled_memory); out.m_mask = double_and_resize(mask_md, eng, strm, doubled_memory); + out.m_output = double_and_resize(output_md, eng, strm, doubled_memory); out.m_output_quantized = double_and_resize(output_quantized_md, eng, strm, doubled_memory); + out.m_diff_output = double_and_resize(output_md, eng, strm, doubled_memory); + + out.m_dS = double_and_resize(dS_md, eng, strm, doubled_memory); // Allocate user data. std::vector query_data(product(q_sz), 0.f); @@ -628,6 +655,7 @@ sdpa_tensors_t get_descriptors(dnnl::engine &eng, dnnl::stream &strm, std::vector val_quantized_data(product(v_sz), 0); std::vector key_scale_data(product(key_scales_sz), std::nanf("1")); std::vector val_scale_data(product(val_scales_sz), std::nanf("1")); + std::vector diff_output_data(product(q_sz), 0.f); std::vector key_zp_data_signed(product(key_scales_sz), INT_MAX); std::vector val_zp_data_signed(product(val_scales_sz), INT_MAX); @@ -638,6 +666,8 @@ sdpa_tensors_t get_descriptors(dnnl::engine &eng, dnnl::stream &strm, std::vector mask_data(product(mask_sz), NAN); std::vector output_data(product(q_sz), NAN); + std::vector dS_data(product(dS_sz), 0); + out.sdpa_attr_quantized.set_scratchpad_mode(dnnl::scratchpad_mode::library); out.kq_mask = 0; @@ -700,6 +730,7 @@ sdpa_tensors_t get_descriptors(dnnl::engine &eng, dnnl::stream &strm, } fill_random(query_data, query_md); + fill_random(diff_output_data, output_md); fill_random_quantized(key_quantized_data, key_quantized_md, (p.key.dt == mdt::u4 || p.key.dt == mdt::u8)); fill_random_quantized(val_quantized_data, val_quantized_md, @@ -912,6 +943,8 @@ sdpa_tensors_t get_descriptors(dnnl::engine &eng, dnnl::stream &strm, // Write data to tensor object's handle. write_to_dnnl_memory(query_data.data(), out.m_query, eng, strm); + write_to_dnnl_memory(diff_output_data.data(), out.m_diff_output, eng, strm); + write_to_dnnl_memory( key_quantized_data.data(), out.m_key_quantized, eng, strm); @@ -962,6 +995,7 @@ sdpa_tensors_t get_descriptors(dnnl::engine &eng, dnnl::stream &strm, setup_device_scale(&out.m_scale); setup_device_scale(&out.m_scale_prim); + write_to_dnnl_memory(dS_data.data(), out.m_dS, eng, strm); return out; } @@ -1309,12 +1343,456 @@ void prim_sdpa_quant(const sdpa_dims_t &p, const sdpa_tensors_t &t, loop(); strm.wait(); + + void *output_ptr_ = (void *)output.map_data(); + void *grouped_output_ptr_ = (void *)grouped_output.map_data(); + memcpy(output_ptr_, grouped_output_ptr_, grouped_query_md.get_size()); + grouped_output.unmap_data(grouped_output_ptr_); + output.unmap_data(output_ptr_); + strm.wait(); +} + +std::vector timeit( + const std::function &func, dnnl::stream &str, int iterations) { + using namespace std::chrono; + func(); + func(); + std::vector times; + for (int j = 0; j < 5; j++) { + auto e = steady_clock::now(); + str.wait(); + auto s = steady_clock::now(); + for (int i = 0; i < iterations; i++) { + func(); + } + str.wait(); + e = steady_clock::now(); + printf("timeit: %f \n", + (float)std::chrono::duration_cast(e - s).count() + / 1e6 / iterations); + times.push_back(std::chrono::duration_cast(e - s)); + } + return times; +} + +std::chrono::nanoseconds prim_sdpa_quant_bwd(const sdpa_dims_t &p, + const sdpa_tensors_t &t, dnnl::engine &eng, dnnl::stream &strm, + dnnl::memory &query, dnnl::memory &key, + dnnl::memory::data_type scale_dt, dnnl::memory &scale, + dnnl::memory &mask, dnnl::memory &value, dnnl::memory &output, + dnnl::memory &diff_output, bool invert_scale, + std::vector &doubled_memory, dnnl::memory &diff_query, + dnnl::memory &diff_key, dnnl::memory &diff_value, + bool with_timing = false) { + + using namespace dnnl; + + primitive_attr bmm1_attr; + bmm1_attr.set_scratchpad_mode(dnnl::scratchpad_mode::library); + + post_ops bmm1_po; + auto scale_f32 = as(strm, scale, mdt::f32); + auto mask_f32 = as(strm, mask, mdt::f32); + auto mask_sz = mask.get_desc().get_dims(); + + if (scale_dt != mdt::undef) { + scale_f32 = reshape(strm, scale_f32, + {{1, 1, 1, 1, 1}, mdt::f32, memory::format_tag::abcde}); + if (invert_scale) + bmm1_po.append_binary(algorithm::binary_div, scale_f32.get_desc()); + else + bmm1_po.append_binary(algorithm::binary_mul, scale_f32.get_desc()); + } + if (p.mask.type != mask_type::no_mask) { + mask_f32 = reshape(strm, mask_f32, + {{mask_sz[0], 1, 1, mask_sz[2], mask_sz[3]}, mdt::f32, + memory::format_tag::abcde}); + bmm1_po.append_binary(algorithm::binary_add, mask_f32.get_desc()); + } + bmm1_attr.set_post_ops(bmm1_po); + // Keep a copy of the 5D scale for the backward pass. + // bmm1_args will std::move scale_f32, so save it before that happens. + memory scale_f32_bwd = scale_f32; + + int head_kv_group_size = 0; + int head_q_group_size = 0; + int head_group_batches = 0; + if (p.heads.kv == p.heads.q) { + head_kv_group_size = p.heads.kv; + head_q_group_size = p.heads.q; + head_group_batches = 1; + } else { + head_kv_group_size = 1; + head_q_group_size = p.heads.q / p.heads.kv; + head_group_batches = p.heads.kv; + } + + auto original_k_sz = key.get_desc().get_dims(); + const memory::dims k_sz {p.mb, head_group_batches, head_kv_group_size, + original_k_sz[2], original_k_sz[3]}; + const memory::dims v_sz {p.mb, head_group_batches, head_kv_group_size, + p.seq_len.kv, p.head_group.head_size}; + const memory::dims q_sz {p.mb, head_group_batches, head_q_group_size, + p.seq_len.q, p.head_group.head_size}; + + memory::desc grouped_key_md(k_sz, p.dt.dt, memory::format_tag::abcde); + memory::desc grouped_value_md(v_sz, p.dt.dt, memory::format_tag::abcde); + memory::desc grouped_query_md(q_sz, p.dt.dt, memory::format_tag::abcde); + + memory key_dequantized; + auto keytmp = as(strm, key, p.dt.dt); + grouped_key_md = p.key_format_tag == memory::format_tag::abcd + ? memory::desc(k_sz, p.dt.dt, memory::format_tag::abcde) + : memory::desc(k_sz, p.dt.dt, memory::format_tag::abced); + + key_dequantized = reshape(strm, keytmp, grouped_key_md); + + memory value_dequantized; + auto value32 = as(strm, value, p.dt.dt); + value_dequantized = reshape(strm, value32, grouped_value_md); + + memory grouped_query = reshape(strm, query, grouped_query_md); + + const memory::dims score_sz = {p.mb, head_group_batches, head_q_group_size, + p.seq_len.q, p.seq_len.kv}; + memory::desc score_md {score_sz, mdt::f32, memory::format_tag::abcde}; + memory::desc score_f16_md {score_sz, mdt::f16, memory::format_tag::abcde}; + + auto score = memory(score_md, eng); + auto score_f16 = memory(score_f16_md, eng); + auto score2 = memory(score_md, eng); + auto score2_f16 = memory(score_f16_md, eng); + + // matmul primitive for QK + auto bmm1_pd = matmul::primitive_desc(eng, grouped_query_md, + key_dequantized.get_desc(), score_md, bmm1_attr); + auto bmm1_prim = matmul(bmm1_pd); + + // softmax forward + primitive_attr softmax_attr; + softmax_attr.set_scratchpad_mode(scratchpad_mode::library); + auto softmax_fwd_pd = softmax_forward::primitive_desc(eng, + prop_kind::forward_training, + (algorithm)dnnl::impl::alg_kind::softmax_accurate_inf_as_zero, + score.get_desc(), score.get_desc(), 4, softmax_attr); + auto softmax_prim = softmax_forward(softmax_fwd_pd); + + // attention_output = attention_probs x value + primitive_attr bmm2_attr; + + bmm2_attr.set_scratchpad_mode(scratchpad_mode::library); + auto grouped_output + = double_and_resize(grouped_query_md, eng, strm, doubled_memory); + auto bmm2_pd = matmul::primitive_desc( + eng, score_md, grouped_value_md, grouped_query_md, bmm2_attr); + auto bmm2_prim = matmul(bmm2_pd); + + // setup args + std::unordered_map bmm1_args = {{DNNL_ARG_SRC, grouped_query}, + {DNNL_ARG_WEIGHTS, key_dequantized}, {DNNL_ARG_DST, score}}; + + if (scale_dt != mdt::undef) { + bmm1_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1] + = std::move(scale_f32); + if (p.mask.type != mask_type::no_mask) { + bmm1_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1] + = std::move(mask_f32); + } + } else { + if (p.mask.type != mask_type::no_mask) { + bmm1_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1] + = std::move(mask_f32); + } + } + + std::unordered_map bmm2_args + = {{DNNL_ARG_SRC, score2}, {DNNL_ARG_WEIGHTS, value_dequantized}, + {DNNL_ARG_DST, grouped_output}}; + + const auto fwd_loop = [&]() { + // QK + bmm1_prim.execute(strm, bmm1_args); + + // softmax + softmax_prim.execute(strm, + { + {DNNL_ARG_SRC, score}, + {DNNL_ARG_DST, score2}, + }); + + // SV + bmm2_prim.execute(strm, bmm2_args); + }; + + // Warmup run. + // Execute primitives of sdpa. + fwd_loop(); + strm.wait(); +#if DEBUG_PRINT_MEM + print_mem(grouped_query, "FWD grouped_query"); + print_mem(key_dequantized, "FWD keq_deq"); + print_mem(scale, "FWD scale"); + print_mem(mask, "FWD mask"); + print_mem(score, "FWD intermediate score"); + print_mem(score2, "FWD intermediate score2"); +#endif + void *output_ptr_ = (void *)output.map_data(); void *grouped_output_ptr_ = (void *)grouped_output.map_data(); memcpy(output_ptr_, grouped_output_ptr_, grouped_query_md.get_size()); grouped_output.unmap_data(grouped_output_ptr_); output.unmap_data(output_ptr_); + + strm.wait(); + + // end forward portion + // backwards pass of SDPA + + // init incoming gradient, dO + memory::desc grouped_output_md = grouped_output.get_desc(); + memory dO_mem = reshape(strm, diff_output, grouped_output_md); + +#if DEBUG_PRINT_MEM + printf("\n\n================\n\n\n"); + print_mem(dO_mem, "BWD incoming dO"); +#endif + + ///////// final MM + // matmul forward, O = s2 * v + + // init v, score2 gradients + memory::desc diff_score2_md(score_sz, mdt::f32, memory::format_tag::abcde); + memory dV_mem(grouped_value_md, eng); + memory diff_score2_mem(diff_score2_md, eng); + + // backwards pass gradient of s2 (dS2 = dO * v^t) + memory::desc v_t_md + = memory::desc({v_sz[0], v_sz[1], v_sz[2], v_sz[4], v_sz[3]}, + /*dO.get_dt*/ p.dt.dt, memory::format_tag::abced); + matmul::primitive_desc mm_bwd_dS2_pd( + eng, grouped_output_md, v_t_md, diff_score2_md); + matmul mm_bwd_dS2(mm_bwd_dS2_pd); + mm_bwd_dS2.execute(strm, + {{DNNL_ARG_SRC, dO_mem}, {DNNL_ARG_WEIGHTS, value_dequantized}, + {DNNL_ARG_DST, diff_score2_mem}}); strm.wait(); + +#if DEBUG_PRINT_MEM + print_mem(diff_score2_mem, "BWD dS2"); +#endif + + // backwards pass gradient of v (dV = s2^t * dO) + + // downcast score to p.dt.dt + memory::desc s2_t_md = memory::desc( + {score_sz[0], score_sz[1], score_sz[2], score_sz[4], score_sz[3]}, + p.dt.dt, memory::format_tag::abced); + + memory score2_downcast; + auto score32 = as(strm, score2, p.dt.dt); + score2_downcast = reshape(strm, score32, s2_t_md); + + const bool is_gqa = (head_q_group_size != head_kv_group_size); + memory::dims dV_full_dims = {v_sz[0], v_sz[1], + is_gqa ? head_q_group_size : head_kv_group_size, v_sz[3], v_sz[4]}; + memory::desc dV_full_md(dV_full_dims, p.dt.dt, memory::format_tag::abcde); + memory dV_full_mem = is_gqa ? memory(dV_full_md, eng) : dV_mem; + matmul::primitive_desc mm_bwd_dV_pd( + eng, s2_t_md, grouped_output_md, dV_full_mem.get_desc()); + matmul mm_bwd_dV(mm_bwd_dV_pd); + mm_bwd_dV.execute(strm, + {{DNNL_ARG_SRC, score2_downcast}, {DNNL_ARG_WEIGHTS, dO_mem}, + {DNNL_ARG_DST, dV_full_mem}}); + strm.wait(); + // reduce [mb, groups, hq_group, kv, D] -> [mb, groups, 1, kv, D] for GQA + dnnl::reduction dV_reduce; + if (is_gqa) { + dnnl::reduction::primitive_desc dV_reduce_pd(eng, + algorithm::reduction_sum, dV_full_md, grouped_value_md, 0.f, + 0.f); + dV_reduce = dnnl::reduction(dV_reduce_pd); + dV_reduce.execute( + strm, {{DNNL_ARG_SRC, dV_full_mem}, {DNNL_ARG_DST, dV_mem}}); + strm.wait(); + } + +#if DEBUG_PRINT_MEM + print_mem(dV_mem, "BWD dV"); +#endif + + ///////// intermediate softmax(QK) + + // init memory for softmax gradients + memory::desc diff_score_md(score_sz, p.dt.dt, memory::format_tag::abcde); + memory diff_score_mem(diff_score_md, eng); + + // backwards pass gradient of softmax + softmax_backward::primitive_desc softmax_bwd_pd(eng, + algorithm::softmax_accurate, diff_score_md, diff_score2_md, + score_md, 4, softmax_fwd_pd); + softmax_backward softmax_bwd(softmax_bwd_pd); + softmax_bwd.execute(strm, + {{DNNL_ARG_DST, score2}, {DNNL_ARG_DIFF_DST, diff_score2_mem}, + {DNNL_ARG_DIFF_SRC, diff_score_mem}}); + + binary scale_bwd_prim; + if (scale_dt != mdt::undef) { + auto scale_algo_bwd + = invert_scale ? algorithm::binary_div : algorithm::binary_mul; + auto scale_bwd_pd = binary::primitive_desc(eng, scale_algo_bwd, + diff_score_md, scale_f32_bwd.get_desc(), diff_score_md); + scale_bwd_prim = binary(scale_bwd_pd); + scale_bwd_prim.execute(strm, + {{DNNL_ARG_SRC_0, diff_score_mem}, + {DNNL_ARG_SRC_1, scale_f32_bwd}, + {DNNL_ARG_DST, diff_score_mem}}); + } + + strm.wait(); + // print softmax gradient +#if DEBUG_PRINT_MEM + print_mem(diff_score_mem, "BWD dS"); +#endif + + ///////// first MM + // matmul forward, s = q * k + + // init q,k gradients + memory dQ_mem(grouped_query_md, eng); + memory dK_mem(key_dequantized.get_desc(), eng); + + // backwards pass gradient of q (dQ = dS * k^t) + // TODO: handle transposed K test case + // memory::desc k_t_md = p.with_key_transposed // k^t requires transposed format and dims + // ? memory::desc({k_sz[0], k_sz[1], k_sz[2], k_sz[4], k_sz[3]}, + // memory::data_type::f32, memory::format_tag::abcde) + // : memory::desc({k_sz[0], k_sz[1], k_sz[2], k_sz[4], k_sz[3]}, + // memory::data_type::f32, memory::format_tag::abced); + memory::desc k_t_md + = memory::desc({k_sz[0], k_sz[1], k_sz[2], k_sz[4], k_sz[3]}, + p.dt.dt, memory::format_tag::abced); + matmul::primitive_desc mm_bwd_dq_pd( + eng, diff_score_md, k_t_md, grouped_query_md); + matmul mm_bwd_dq(mm_bwd_dq_pd); + mm_bwd_dq.execute(strm, + {{DNNL_ARG_SRC, diff_score_mem}, + {DNNL_ARG_WEIGHTS, key_dequantized}, + {DNNL_ARG_DST, dQ_mem}}); + strm.wait(); + + // backwards pass gradient of k (dK = q^t * dS) + memory::desc q_t_md({q_sz[0], q_sz[1], q_sz[2], q_sz[4], q_sz[3]}, p.dt.dt, + memory::format_tag::abced); + + // for GQA cases, perform additional reduction of dK along headq_group + auto key_fmt = p.key_format_tag == memory::format_tag::abcd + ? memory::format_tag::abcde + : memory::format_tag::abced; + memory::dims dK_full_dims = {k_sz[0], k_sz[1], + is_gqa ? head_q_group_size : head_kv_group_size, k_sz[3], k_sz[4]}; + memory::desc dK_full_md(dK_full_dims, p.dt.dt, key_fmt); + memory dK_full_mem = is_gqa ? memory(dK_full_md, eng) : dK_mem; + matmul::primitive_desc mm_bwd_dk_pd( + eng, q_t_md, diff_score_md, dK_full_mem.get_desc()); + matmul mm_bwd_dk(mm_bwd_dk_pd); + mm_bwd_dk.execute(strm, + {{DNNL_ARG_SRC, grouped_query}, {DNNL_ARG_WEIGHTS, diff_score_mem}, + {DNNL_ARG_DST, dK_full_mem}}); + strm.wait(); + // reduce [mb, groups, hq_group, D, K] -> [mb, groups, 1, D, K] for GQA cases + dnnl::reduction dK_reduce; + if (is_gqa) { + dnnl::reduction::primitive_desc dK_reduce_pd(eng, + algorithm::reduction_sum, dK_full_md, + key_dequantized.get_desc(), 0.f, 0.f); + dK_reduce = dnnl::reduction(dK_reduce_pd); + dK_reduce.execute( + strm, {{DNNL_ARG_SRC, dK_full_mem}, {DNNL_ARG_DST, dK_mem}}); + strm.wait(); + } + + const auto bwd_loop = [&]() { + strm.wait(); + mm_bwd_dS2.execute(strm, + {{DNNL_ARG_SRC, dO_mem}, {DNNL_ARG_WEIGHTS, value_dequantized}, + {DNNL_ARG_DST, diff_score2_mem}}); + mm_bwd_dV.execute(strm, + {{DNNL_ARG_SRC, score2_downcast}, {DNNL_ARG_WEIGHTS, dO_mem}, + {DNNL_ARG_DST, dV_full_mem}}); + if (is_gqa) + dV_reduce.execute(strm, + {{DNNL_ARG_SRC, dV_full_mem}, {DNNL_ARG_DST, dV_mem}}); + + softmax_bwd.execute(strm, + {{DNNL_ARG_DST, score2}, {DNNL_ARG_DIFF_DST, diff_score2_mem}, + {DNNL_ARG_DIFF_SRC, diff_score_mem}}); + + if (scale_dt != mdt::undef) { + scale_bwd_prim.execute(strm, + {{DNNL_ARG_SRC_0, diff_score_mem}, + {DNNL_ARG_SRC_1, scale_f32_bwd}, + {DNNL_ARG_DST, diff_score_mem}}); + } + + mm_bwd_dq.execute(strm, + {{DNNL_ARG_SRC, diff_score_mem}, + {DNNL_ARG_WEIGHTS, key_dequantized}, + {DNNL_ARG_DST, dQ_mem}}); + mm_bwd_dk.execute(strm, + {{DNNL_ARG_SRC, grouped_query}, + {DNNL_ARG_WEIGHTS, diff_score_mem}, + {DNNL_ARG_DST, dK_full_mem}}); + if (is_gqa) + dK_reduce.execute(strm, + {{DNNL_ARG_SRC, dK_full_mem}, {DNNL_ARG_DST, dK_mem}}); + strm.wait(); + }; + bwd_loop(); + strm.wait(); + + using namespace std::chrono; + nanoseconds qtime_bwd {}; + + if (with_timing) { + auto loop_bwd_prim = [&] { bwd_loop(); }; + + int iterations = 20; + auto quantized_bwd_time = timeit(loop_bwd_prim, strm, iterations); + + auto min_time = [](const std::vector &a) { + return *std::min_element(a.begin(), a.end()); + }; + + qtime_bwd = min_time(quantized_bwd_time) / iterations; + printf("qtime_training_backwards_prim %f\n", (float)qtime_bwd.count()); + } + + // print q,k gradients +#if DEBUG_PRINT_MEM + print_mem(dQ_mem, "BWD dQ"); + print_mem(dK_mem, "BWD dK"); +#endif + + // reshape gradients from grouped layout back to 4d + void *diff_query_ptr = (void *)diff_query.map_data(); + void *dQ_ptr = (void *)dQ_mem.map_data(); + memcpy(diff_query_ptr, dQ_ptr, grouped_query_md.get_size()); + dQ_mem.unmap_data(dQ_ptr); + diff_query.unmap_data(diff_query_ptr); + + void *diff_key_ptr = (void *)diff_key.map_data(); + void *dK_ptr = (void *)dK_mem.map_data(); + memcpy(diff_key_ptr, dK_ptr, key_dequantized.get_desc().get_size()); + dK_mem.unmap_data(dK_ptr); + diff_key.unmap_data(diff_key_ptr); + + void *diff_value_ptr = (void *)diff_value.map_data(); + void *dV_ptr = (void *)dV_mem.map_data(); + memcpy(diff_value_ptr, dV_ptr, grouped_value_md.get_size()); + dV_mem.unmap_data(dV_ptr); + diff_value.unmap_data(diff_value_ptr); + + return qtime_bwd; } template @@ -1334,7 +1812,7 @@ void check_memory(dnnl::stream &strm, memory &gold, memory &test, float max_diff = std::numeric_limits::min(); std::map> hist; - const bool verbose = false; + const bool verbose = true; for_(int l = 0; l < dims[0]; l++) for_(int k = 0; k < dims[1]; k++) for_(int j = 0; j < dims[2]; j++) @@ -1380,7 +1858,7 @@ void check_memory(dnnl::stream &strm, memory &gold, memory &test, gold.unmap_data(mapped_ptr_gold); test.unmap_data(mapped_ptr_test); - int threshold = total * 0.0006; + int threshold = total * 0.002; ASSERT_LE(mismatches, threshold) << mismatches << " out of: " << total; ASSERT_LE(max_diff, max_diff_threshold); @@ -1397,26 +1875,6 @@ int to_attn_mask_type(mask_type t) { return static_cast(attn_mask); } -std::vector timeit( - const std::function &func, dnnl::stream &str, int iterations) { - using namespace std::chrono; - func(); - func(); - std::vector times; - for (int j = 0; j < 5; j++) { - auto e = steady_clock::now(); - str.wait(); - auto s = steady_clock::now(); - for (int i = 0; i < iterations; i++) { - func(); - } - str.wait(); - e = steady_clock::now(); - times.push_back(std::chrono::duration_cast(e - s)); - } - return times; -} - template O magnitude_cast(I input) { using ratio = std::ratio_divide; @@ -1642,6 +2100,186 @@ class sdpa_test_t : public ::testing::TestWithParam { #endif } + void compare_bwd() { + using namespace dnnl::impl; + auto mask = t.m_mask.get_desc(); + + memory::desc *mask_ptr = nullptr; + + switch (p.mask.type) { + case mask_type::no_mask: + case mask_type::causal_tl: + case mask_type::causal_br: mask_ptr = nullptr; break; + case mask_type::oneD: + case mask_type::twoD: mask_ptr = &mask; break; + } + + auto dS_desc = t.m_dS.get_desc(); + memory::desc *dS_ptr = nullptr; + //dS_ptr = &dS_desc; // uncomment for optional dS output (expensive) + + // fwd sdpa primitive to populate dst, col_maxes + sdpa::primitive_desc sdpa_fwd_pd; + sdpa sdpa_fwd; + + sdpa_backward::primitive_desc sdpa_bwd_pd; + sdpa_backward sdpa_bwd; + try { + sdpa_fwd_pd = sdpa::primitive_desc(eng, t.m_query.get_desc(), + t.m_key_quantized.get_desc(), + t.m_value_quantized.get_desc(), mask_ptr, + t.m_scale.get_desc(), t.m_output_quantized.get_desc(), + invert_scale, p.heads.kv, to_attn_mask_type(p.mask.type), + dnnl::impl::alg_kind::softmax_accurate_inf_as_zero, + t.sdpa_attr_quantized, t.sdpa_kq_attr_quantized, + t.sdpa_vs_attr_quantized, prop_kind::forward_training); + sdpa_fwd = sdpa(sdpa_fwd_pd); + + sdpa_bwd_pd = sdpa_backward::primitive_desc(eng, + t.m_query.get_desc(), t.m_key_quantized.get_desc(), + t.m_value_quantized.get_desc(), mask_ptr, + t.m_scale.get_desc(), t.m_output_quantized.get_desc(), + t.m_diff_query_quantized.get_desc(), + t.m_diff_key_quantized.get_desc(), + t.m_diff_value_quantized.get_desc(), + t.m_diff_output.get_desc(), dS_ptr, invert_scale, + p.heads.kv, to_attn_mask_type(p.mask.type), + dnnl::impl::alg_kind::softmax_accurate_inf_as_zero, + sdpa_fwd_pd, t.sdpa_attr_quantized, + t.sdpa_kq_attr_quantized, t.sdpa_vs_attr_quantized); + sdpa_bwd = sdpa_backward(sdpa_bwd_pd); + } catch (const dnnl::error &e) { + if (e.status == dnnl_unimplemented) + GTEST_SKIP() << "Unimplemented: " << e.what(); + else + throw; + } + + auto sdpa_fwd_workspace_memory + = memory(sdpa_fwd_pd.workspace_desc(), eng); + + std::unordered_map sdpa_fwd_args + = {{DNNL_ARG_QUERIES, t.m_query}, + {DNNL_ARG_KEYS, t.m_key_quantized}, + {DNNL_ARG_VALUES, t.m_value_quantized}, + {DNNL_ARG_DST, t.m_output_quantized}, + {DNNL_ARG_WORKSPACE, sdpa_fwd_workspace_memory}}; + + if (scale_dt != mdt::undef) { + sdpa_fwd_args[DNNL_ARG_SCALE] = t.m_scale; + } + if (mask_ptr) { sdpa_fwd_args[DNNL_ARG_ATTN_MASK] = t.m_mask; } + + strm.wait(); + sdpa_fwd.execute(strm, sdpa_fwd_args); + strm.wait(); + +#if DEBUG_PRINT_MEM + print_mem(sdpa_fwd_workspace_memory, "sharedworkspace"); +#endif + + std::unordered_map sdpa_bwd_args + = {{DNNL_ARG_QUERIES, t.m_query}, + {DNNL_ARG_KEYS, t.m_key_quantized}, + {DNNL_ARG_VALUES, t.m_value_quantized}, + {DNNL_ARG_DST, t.m_output_quantized}, + {DNNL_ARG_DIFF_DST, t.m_diff_output}, + {DNNL_ARG_DIFF_QUERIES, t.m_diff_query_quantized}, + {DNNL_ARG_DIFF_KEYS, t.m_diff_key_quantized}, + {DNNL_ARG_DIFF_VALUES, t.m_diff_value_quantized}, + {DNNL_ARG_DS, t.m_dS}, + {DNNL_ARG_WORKSPACE, sdpa_fwd_workspace_memory}}; + +#if DEBUG_PRINT_MEM + print_mem(t.m_query, "input t_m_query"); + print_mem(t.m_key_quantized, "input t_m_key"); + print_mem(t.m_value_quantized, "input t_m_value"); + print_mem(t.m_diff_output, "input dA"); +#endif + + if (scale_dt != mdt::undef) { + sdpa_bwd_args[DNNL_ARG_SCALE] = t.m_scale; + } + if (mask_ptr) { sdpa_bwd_args[DNNL_ARG_ATTN_MASK] = t.m_mask; } + + sdpa_bwd.execute(strm, sdpa_bwd_args); + strm.wait(); + +#if DEBUG_PRINT_MEM + print_mem(t.m_dS, "computed dS"); + print_mem(t.m_diff_value_quantized, "dV bwd out"); +#endif + + printf("-------------- Primitives based implementation -------------- " + "\n"); + + // perform primitives based backwards sdpa pass to generate "gold" gradient outputs + prim_sdpa_quant_bwd(p, t, eng, strm, t.m_query, t.m_key_quantized, + scale_dt, t.m_scale, t.m_mask, t.m_value_quantized, t.m_output, + t.m_diff_output, invert_scale, doubled_memory, t.m_diff_query, + t.m_diff_key, t.m_diff_value); + + float max_diff_threshold = 0.3f; + float fthreshold = 0.f; + if (p.dt.dt == mdt::bf16 || p.dt.dt == mdt::f16) { + //fthreshold = 0.0079f; //todo: correct threshold or better values + fthreshold = 0.1; + } else { + fthreshold = 0.001466f; + } + + strm.wait(); + + if (t.m_output.get_desc().get_data_type() == mdt::f16) { + check_memory(strm, t.m_output, t.m_output_quantized, + max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_query, + t.m_diff_query_quantized, max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_key, t.m_diff_key_quantized, + max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_value, + t.m_diff_value_quantized, max_diff_threshold, fthreshold); + + } else if (t.m_output.get_desc().get_data_type() == mdt::bf16) { + check_memory(strm, t.m_output, t.m_output_quantized, + max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_query, + t.m_diff_query_quantized, max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_key, t.m_diff_key_quantized, + max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_value, + t.m_diff_value_quantized, max_diff_threshold, fthreshold); + + } else if (t.m_output.get_desc().get_data_type() == mdt::f32) { + check_memory(strm, t.m_output, t.m_output_quantized, + max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_query, + t.m_diff_query_quantized, max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_key, t.m_diff_key_quantized, + max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_value, + t.m_diff_value_quantized, max_diff_threshold, fthreshold); + } + +#if DEBUG_PRINT_MEM + print_mem(t.m_output, "gold m_output"); + print_mem(t.m_output_quantized, "test m_output"); + printf("----------\n\n"); + + print_mem(t.m_diff_query, "gold m_diff_query"); + print_mem(t.m_diff_query_quantized, "test m_diff_query"); + printf("----------\n\n"); + + print_mem(t.m_diff_key, "gold m_diff_key"); + print_mem(t.m_diff_key_quantized, "test m_diff_key"); + printf("----------\n\n"); + + print_mem(t.m_diff_value, "gold m_diff_value"); + print_mem(t.m_diff_value_quantized, "test m_diff_value"); + printf("----------\n\n"); +#endif + } + void perf() { using namespace dnnl::impl; auto mask = t.m_mask.get_desc(); @@ -1788,6 +2426,194 @@ class sdpa_test_t : public ::testing::TestWithParam { << "|" << std::endl; } + void perf_bwd(bool time_reference = false) { + using namespace dnnl::impl; + auto mask = t.m_mask.get_desc(); + + memory::desc *mask_ptr = nullptr; + + switch (p.mask.type) { + case mask_type::no_mask: + case mask_type::causal_tl: + case mask_type::causal_br: mask_ptr = nullptr; break; + case mask_type::oneD: + case mask_type::twoD: mask_ptr = &mask; break; + } + + // Forward training pass (needed for workspace) + sdpa::primitive_desc sdpa_fwd_pd; + sdpa sdpa_fwd; + + sdpa_backward::primitive_desc sdpa_bwd_pd; + sdpa_backward sdpa_bwd; + try { + sdpa_fwd_pd = sdpa::primitive_desc(eng, t.m_query.get_desc(), + t.m_key_quantized.get_desc(), + t.m_value_quantized.get_desc(), mask_ptr, + t.m_scale.get_desc(), t.m_output_quantized.get_desc(), + invert_scale, p.heads.kv, to_attn_mask_type(p.mask.type), + dnnl::impl::alg_kind::softmax_accurate_inf_as_zero, + t.sdpa_attr_quantized, t.sdpa_kq_attr_quantized, + t.sdpa_vs_attr_quantized, prop_kind::forward_training); + sdpa_fwd = sdpa(sdpa_fwd_pd); + + sdpa_bwd_pd = sdpa_backward::primitive_desc(eng, + t.m_query.get_desc(), t.m_key_quantized.get_desc(), + t.m_value_quantized.get_desc(), mask_ptr, + t.m_scale.get_desc(), t.m_output_quantized.get_desc(), + t.m_diff_query_quantized.get_desc(), + t.m_diff_key_quantized.get_desc(), + t.m_diff_value_quantized.get_desc(), + t.m_diff_output.get_desc(), nullptr, invert_scale, + p.heads.kv, to_attn_mask_type(p.mask.type), + dnnl::impl::alg_kind::softmax_accurate_inf_as_zero, + sdpa_fwd_pd, t.sdpa_attr_quantized, + t.sdpa_kq_attr_quantized, t.sdpa_vs_attr_quantized); + sdpa_bwd = sdpa_backward(sdpa_bwd_pd); + } catch (const dnnl::error &e) { + if (e.status == dnnl_unimplemented) + GTEST_SKIP() << "Unimplemented: " << e.what(); + else + throw; + } + + // Execute forward pass to populate workspace + auto sdpa_fwd_workspace_memory + = memory(sdpa_fwd_pd.workspace_desc(), eng); + + std::unordered_map sdpa_fwd_args + = {{DNNL_ARG_QUERIES, t.m_query}, + {DNNL_ARG_KEYS, t.m_key_quantized}, + {DNNL_ARG_VALUES, t.m_value_quantized}, + {DNNL_ARG_DST, t.m_output_quantized}, + {DNNL_ARG_WORKSPACE, sdpa_fwd_workspace_memory}}; + + if (scale_dt != mdt::undef) { + sdpa_fwd_args[DNNL_ARG_SCALE] = t.m_scale; + } + if (mask_ptr) { sdpa_fwd_args[DNNL_ARG_ATTN_MASK] = t.m_mask; } + + sdpa_fwd.execute(strm, sdpa_fwd_args); + strm.wait(); + + // Build backward args + std::unordered_map sdpa_bwd_args + = {{DNNL_ARG_QUERIES, t.m_query}, + {DNNL_ARG_KEYS, t.m_key_quantized}, + {DNNL_ARG_VALUES, t.m_value_quantized}, + {DNNL_ARG_DST, t.m_output_quantized}, + {DNNL_ARG_DIFF_DST, t.m_diff_output}, + {DNNL_ARG_DIFF_QUERIES, t.m_diff_query_quantized}, + {DNNL_ARG_DIFF_KEYS, t.m_diff_key_quantized}, + {DNNL_ARG_DIFF_VALUES, t.m_diff_value_quantized}, + {DNNL_ARG_WORKSPACE, sdpa_fwd_workspace_memory}}; + + if (scale_dt != mdt::undef) { + sdpa_bwd_args[DNNL_ARG_SCALE] = t.m_scale; + } + if (mask_ptr) { sdpa_bwd_args[DNNL_ARG_ATTN_MASK] = t.m_mask; } + + // Time backward pass + auto loop_bwd = [&] { sdpa_bwd.execute(strm, sdpa_bwd_args); }; + + int iterations = 20; + auto bwd_time = timeit(loop_bwd, strm, iterations); + + using namespace std::chrono; + auto min_time = [](const std::vector &a) { + return *std::min_element(a.begin(), a.end()); + }; + + auto qtime = min_time(bwd_time) / iterations; + + // Backward reads: Q, K, V, O, dO, workspace(logsumexp) + // Backward writes: dQ, dK, dV + byte_t<> total_bytes = t.m_query.get_desc().get_size() + + t.m_key_quantized.get_desc().get_size() + + t.m_value_quantized.get_desc().get_size() + + t.m_output_quantized.get_desc().get_size() + + t.m_diff_output.get_desc().get_size() + + t.m_diff_query_quantized.get_desc().get_size() + + t.m_diff_key_quantized.get_desc().get_size() + + t.m_diff_value_quantized.get_desc().get_size() + + (mask_ptr ? t.m_mask.get_desc().get_size() : 0); + + size_t kv_slice_tensor_elements + = (p.head_group.head_size * p.seq_len.kv); + auto mask_slice_elements = 0; + switch (p.mask.type) { + case mask_type::twoD: + mask_slice_elements = p.seq_len.kv * p.seq_len.q; + break; + case mask_type::oneD: mask_slice_elements = p.seq_len.kv; break; + default: mask_slice_elements = 0; break; + } + size_t batch_elements = p.mb * std::max(p.heads.q, p.heads.kv); + + // Effective bytes: expand broadcast tensors to full batch size + // Reads: Q, K, V, O, dO (all expanded); Writes: dQ, dK, dV + byte_t<> total_bytes_effective = (batch_elements + * (byte_t<>(p.key.dt) * kv_slice_tensor_elements + + byte_t<>(p.value.dt) * kv_slice_tensor_elements + + byte_t<>(p.dt.dt) + * (5 * p.head_group.head_size * p.seq_len.q) + + byte_t<>(p.dt.dt) * (2 * kv_slice_tensor_elements) + + (mask_ptr ? byte_t<>(p.mask.dt) * mask_slice_elements + : 0))); + + // Backward has 5 matmuls: S=K*Q, dS2=V*dO, dV=S*dO, dK=dS*Q, dQ=dS*K + // plus softmax_bwd, scale, Di reduction + float causal_divisor = (p.mask.type == mask_type::causal_tl + || p.mask.type == mask_type::causal_br) + ? 2.f + : 1.f; + + // 5 matmuls, flash-attention paper scales FWD flops by 5/2 + // softmax_bwd (~5*K*Q), scale, Di + num_ops_t<> total_flops = std::max(p.heads.kv, p.heads.q) * p.mb + * (2.f + * (2.5f * 2.f * p.head_group.head_size + * p.seq_len.kv * p.seq_len.q) + + (scale_dt != mdt::undef + ? (2 * p.seq_len.kv * p.seq_len.q) + : 0) + + (p.mask.type != mask_type::no_mask + ? (p.seq_len.kv * p.seq_len.q) + : 0) + + (5 * p.seq_len.kv * p.seq_len.q)) + / causal_divisor; + + // Just the 5 matmul flops + num_ops_t<> flash_flops + = (2.5f * 2.f * p.mb * p.heads.q * p.seq_len.kv * p.seq_len.q + * p.head_group.head_size) + / causal_divisor; + + std::call_once(header_flag, print_table_header); + std::cout << "BWD" << print_row(p) << "|" << qtime.count() << "|" + << bandwidth(magnitude_cast(total_bytes_effective), + qtime) + << "/" + << bandwidth(magnitude_cast(total_bytes), qtime) + << "|" << compute(magnitude_cast(flash_flops), qtime) + << "/" << compute(magnitude_cast(total_flops), qtime) + << "|" << std::endl; + + if (time_reference) { + auto ref_time = prim_sdpa_quant_bwd(p, t, eng, strm, t.m_query, + t.m_key_quantized, scale_dt, t.m_scale, t.m_mask, + t.m_value_quantized, t.m_output, t.m_diff_output, + invert_scale, doubled_memory, t.m_diff_query, t.m_diff_key, + t.m_diff_value, /*with_timing=*/true); + + float speedup = ref_time.count() > 0 + ? (float)ref_time.count() / (float)qtime.count() + : 0.f; + std::cout << "REF" << print_row(p) << "|" << ref_time.count() << "|" + << speedup << "x speedup|" << std::endl; + } + } + protected: dnnl::engine eng; dnnl::stream strm; @@ -1804,6 +2630,8 @@ memory::format_tag no_key_transposed = memory::format_tag::abcd; using sdpa_test = sdpa_test_t; using sdpa_test_datatypes = sdpa_test_t; +using sdpa_bwd_test = sdpa_test_t; +using sdpa_bwd_test_datatypes = sdpa_test_t; // clang-format off @@ -2071,6 +2899,32 @@ INSTANTIATE_TEST_SUITE_P(phi3_mini_4k_instruct, sdpa_dims_t{ 1, 32, 32, 2049, 1, 96, 96, 96, mdt::f16, mdt::s8, mdt::f16, mdt::s8, mdt::s8, mdt::f16, mdt::s8, mdt::f16, quantize_type::per_token_with_groups, with_key_transposed, mask_type::twoD } ), &print_to_string); +INSTANTIATE_TEST_SUITE_P(bwd_perf, + sdpa_bwd_test, + // mb,hd_num,kv_hd_num,seq_len,qry_num,hd_size, kg_sz, vgrp_sz, dt, kdt, ksdt, kzpdt, vdt, vsdt, vzpdt, mskdt, qtype + testing::Values( + sdpa_dims_t{ 1, 1, 1, 32, 32, 32, 32, 32, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 1, 1, 1, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 1, 1, 1, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 1, 2, 2, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 1, 2, 2, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 1, 4, 4, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 1, 4, 4, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 2, 1, 1, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 2, 1, 1, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 2, 2, 2, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 2, 2, 2, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 2, 4, 4, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 2, 4, 4, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 1, 1, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 1, 1, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 2, 2, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 2, 2, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::causal_tl }, + sdpa_dims_t{ 4, 4, 2, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::causal_tl }, + sdpa_dims_t{ 4, 12, 2, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::causal_tl } + ), &print_to_string); + // clang-format on GPU_TEST_P(sdpa_test, compare) { @@ -2081,12 +2935,166 @@ GPU_TEST_P(sdpa_test_datatypes, compare) { compare(); } +/* +GPU_TEST_P(sdpa_bwd_test, compare_bwd) { + compare_bwd(); +} +*/ + GPU_TEST_P(sdpa_test, perf) { perf(); } -/* -GPU_TEST_P(sdpa_test_datatypes, perf) { - perf(); +GPU_TEST_P(sdpa_bwd_test, perf_bwd) { + const bool time_reference = true; + perf_bwd(time_reference); } -*/ + +GPU_TEST_P(sdpa_bwd_test_datatypes, compare_bwd) { + compare_bwd(); +} + +// clang-format off + +// backward pass: f16 +INSTANTIATE_TEST_SUITE_P(bwd_f16, sdpa_bwd_test_datatypes, + testing::Combine(testing::Values(1, 2), // mb + testing::Values(num_heads_t {1, 1}, num_heads_t {2, 2}, + num_heads_t {8, 8}), // heads + testing::Values(seq_len_size_t {32, 32}, seq_len_size_t {64, 64}, + seq_len_size_t {384, 384}), // seq_len + testing::Values(head_group_size_t {32, 32, 32}, + head_group_size_t {64, 64, 64}, + head_group_size_t {128, 128, 128}), // head_size + testing::Values(tensor_type_t("Q", mdt::f16)), // dt + testing::Values(tensor_type_t("K", mdt::f16)), // kdt + testing::Values(tensor_type_t("V", mdt::f16)), // vdt + testing::Values(quantize_type::no_quantization), // qtype + testing::Values(dnnl::memory::format_tag::abcd), // key_format_tag + testing::Values(mask_config_t {mask_type::no_mask}, + mask_config_t {mask_type::causal_tl}, mask_config_t {mask_type::causal_br}, + mask_config_t {mask_type::twoD}, mask_config_t {mask_type::oneD} + ), // mask_type + testing::Values(default_scale_type), // scale_type + testing::Values( + accumulation_t {accumulation_mode::f32, + accumulation_mode::f32}) // accumulation_mode + ), + &print_to_string2); + +// backward pass: bf16 +INSTANTIATE_TEST_SUITE_P(bwd_bf16, sdpa_bwd_test_datatypes, + testing::Combine(testing::Values(1, 2), // mb + testing::Values(num_heads_t {1, 1}, num_heads_t {2, 2}), // heads + testing::Values(seq_len_size_t {32, 32}, seq_len_size_t {128, 128}, + seq_len_size_t {384, 384}), // seq_len + testing::Values(head_group_size_t {32, 32, 32}, + head_group_size_t {64, 64, 64}, + head_group_size_t {128, 128, 128}), // head_size + testing::Values(tensor_type_t("Q", mdt::bf16)), // dt + testing::Values(tensor_type_t("K", mdt::bf16)), // kdt + testing::Values(tensor_type_t("V", mdt::bf16)), // vdt + testing::Values(quantize_type::no_quantization), // qtype + testing::Values(dnnl::memory::format_tag::abcd), // key_format_tag + testing::Values(mask_config_t {mask_type::causal_tl}), // mask_type + testing::Values(default_scale_type), // scale_type + testing::Values( + accumulation_t {accumulation_mode::f32, + accumulation_mode::f32}) // accumulation_mode + ), + &print_to_string2); + +// backward pass: GQA configurations +INSTANTIATE_TEST_SUITE_P(bwd_gqa, sdpa_bwd_test_datatypes, + testing::Combine(testing::Values(1), // mb + testing::Values(num_heads_t {4, 2}, num_heads_t {8, 2}, + num_heads_t {32, 2}), // heads (q > kv) + testing::Values(seq_len_size_t {32, 32}, seq_len_size_t {128, 128}, + seq_len_size_t {384, 384}), // seq_len + testing::Values(head_group_size_t {64, 64, 64}, + head_group_size_t {128, 128, 128}), // head_size + testing::Values(tensor_type_t("Q", mdt::f16)), // dt + testing::Values(tensor_type_t("K", mdt::f16)), // kdt + testing::Values(tensor_type_t("V", mdt::f16)), // vdt + testing::Values(quantize_type::no_quantization), // qtype + testing::Values(dnnl::memory::format_tag::abcd), // key_format_tag + testing::Values(mask_config_t {mask_type::no_mask}, + mask_config_t {mask_type::causal_tl}), // mask_type + testing::Values(default_scale_type), // scale_type + testing::Values( + accumulation_t {accumulation_mode::f32, + accumulation_mode::f32}) // accumulation_mode + ), + &print_to_string2); + +// backward pass: non-uniform sequence lengths (q != kv) +INSTANTIATE_TEST_SUITE_P(bwd_nonuniform_seq, sdpa_bwd_test_datatypes, + testing::Combine(testing::Values(1), // mb + testing::Values(num_heads_t {1, 1}, + num_heads_t {2, 2}), // heads + testing::Values(seq_len_size_t {33, 65}, + seq_len_size_t {65, 4097}, + seq_len_size_t {4096, 64}, + seq_len_size_t {1025, 15}), // seq_len (q, kv) + testing::Values(head_group_size_t {32, 32, 32}, + head_group_size_t {64, 64, 64}), // head_size + testing::Values(tensor_type_t("Q", mdt::f16)), // dt + testing::Values(tensor_type_t("K", mdt::f16)), // kdt + testing::Values(tensor_type_t("V", mdt::f16)), // vdt + testing::Values(quantize_type::no_quantization), // qtype + testing::Values(dnnl::memory::format_tag::abcd), // key_format_tag + testing::Values(mask_config_t {mask_type::no_mask}), // mask_type + testing::Values(default_scale_type), // scale_type + testing::Values( + accumulation_t {accumulation_mode::f32, + accumulation_mode::f32}) // accumulation_mode + ), + &print_to_string2); + +// backward pass: f32 +INSTANTIATE_TEST_SUITE_P(bwd_f32, sdpa_bwd_test_datatypes, + testing::Combine(testing::Values(1, 4), // mb + testing::Values(num_heads_t {1, 1}, num_heads_t {2, 2}, + num_heads_t {12, 12}), // heads + testing::Values(seq_len_size_t {32, 32}, seq_len_size_t {64, 64}, + seq_len_size_t {384, 384}, + seq_len_size_t {4096, 4096}), // seq_len + testing::Values(head_group_size_t {16, 16, 16}, + head_group_size_t {32, 32, 32}, + head_group_size_t {64, 64, 64}, + head_group_size_t {128, 128, 128}), // head_size + testing::Values(tensor_type_t("Q", mdt::f32)), // dt + testing::Values(tensor_type_t("K", mdt::f32)), // kdt + testing::Values(tensor_type_t("V", mdt::f32)), // vdt + testing::Values(quantize_type::no_quantization), // qtype + testing::Values(dnnl::memory::format_tag::abcd), // key_format_tag + testing::Values(mask_config_t {mask_type::no_mask}, + mask_config_t {mask_type::causal_tl}), // mask_type + testing::Values(default_scale_type), // scale_type + testing::Values( + accumulation_t {accumulation_mode::f32, + accumulation_mode::f32}) // accumulation_mode + ), + &print_to_string2); + + +// backward pass: large batch and head counts +INSTANTIATE_TEST_SUITE_P(bwd_large_batch, sdpa_bwd_test_datatypes, + testing::Combine(testing::Values(4), // mb + testing::Values(num_heads_t {12, 12}), // heads + testing::Values(seq_len_size_t {4096, 4096}), // seq_len + testing::Values(head_group_size_t {32, 32, 32}), // head_size + testing::Values(tensor_type_t("Q", mdt::f16)), // dt + testing::Values(tensor_type_t("K", mdt::f16)), // kdt + testing::Values(tensor_type_t("V", mdt::f16)), // vdt + testing::Values(quantize_type::no_quantization), // qtype + testing::Values(dnnl::memory::format_tag::abcd), // key_format_tag + testing::Values(mask_config_t {mask_type::no_mask}), // mask_type + testing::Values(default_scale_type), // scale_type + testing::Values( + accumulation_t {accumulation_mode::f32, + accumulation_mode::f32}) // accumulation_mode + ), + &print_to_string2); + +// clang-format on From 8b5c39b7c03b354cb602f0e1f5b2b41058b8426b Mon Sep 17 00:00:00 2001 From: syurkevi Date: Tue, 24 Feb 2026 14:10:10 -0800 Subject: [PATCH 04/23] xe: sdpa: prepares config structs for bwd pass --- src/gpu/intel/sdpa/configs.cpp | 220 ++++++++++++++++++++++++++++++--- src/gpu/intel/sdpa/configs.hpp | 101 ++++++++++++--- src/gpu/intel/sdpa/micro.cpp | 2 +- src/gpu/intel/sdpa/micro.hpp | 2 +- 4 files changed, 289 insertions(+), 36 deletions(-) diff --git a/src/gpu/intel/sdpa/configs.cpp b/src/gpu/intel/sdpa/configs.cpp index 1af3b975811..9e0571997aa 100644 --- a/src/gpu/intel/sdpa/configs.cpp +++ b/src/gpu/intel/sdpa/configs.cpp @@ -78,6 +78,14 @@ std::string to_string(const config_t &c) { << c.wg_m_vs << "," << c.wg_n_vs; return s.str(); } +std::string to_string(const bwd_config_t &c) { + std::stringstream s; + s << c.unroll_m_BcBr << "," << c.unroll_n_BcBr << "," << c.unroll_m_DBc + << "," << c.unroll_n_DBc << "," << c.unroll_m_DBr << "," << c.unroll_n_DBr + << "," << c.wg_m_BcBr << "," << c.wg_n_BcBr << "," << c.wg_m_DBc << "," + << c.wg_n_DBc << "," << c.wg_m_DBr << "," << c.wg_n_DBr; + return s.str(); +} // A matching config is a combination of mandatory and optional requirements // it is assumed the query criteria are specific whereas key criteria may be approximate @@ -86,32 +94,35 @@ std::string to_string(const config_t &c) { // head size and sequence length must strictly match the inequality with a caveat for // the more general key criteria "any = -1" // properties must match exactly if they are specified in the key criteria -bool operator==(const config_record_t &key, const config_query_t &query) { - bool result = ((query.arch == key.criteria.arch) - && (query.head_size <= key.criteria.head_size) - && ((key.criteria.seq_len == -1) - || (key.criteria.seq_len != -1 - && query.seq_len <= key.criteria.seq_len)) +bool criteria_matches( + const config_criteria_t &key, const config_query_t &query) { + return ((query.arch == key.arch) && (query.head_size <= key.head_size) + && ((key.seq_len == -1) + || (key.seq_len != -1 && query.seq_len <= key.seq_len)) && ((((query.prop & property::second_token) - == (key.criteria.prop & property::second_token))) + == (key.prop & property::second_token))) && (((query.prop & property::quantized) - == (key.criteria.prop & property::quantized))) + == (key.prop & property::quantized))) && (((query.prop & property::fma) - == (key.criteria.prop & property::fma))) - && (((key.criteria.prop & property::f32) == property::none) + == (key.prop & property::fma))) + && (((key.prop & property::f32) == property::none) || ((query.prop & property::f32) - == (key.criteria.prop & property::f32))) - && (((key.criteria.prop & property::f16_accumulate) + == (key.prop & property::f32))) + && (((key.prop & property::f16_accumulate) == property::none) || ((query.prop & property::f16_accumulate) - == (key.criteria.prop - & property::f16_accumulate))) - && (((key.criteria.prop & property::integrated) - == property::none) + == (key.prop & property::f16_accumulate))) + && (((key.prop & property::integrated) == property::none) || ((query.prop & property::integrated) - == (key.criteria.prop - & property::integrated))))); - return result; + == (key.prop & property::integrated))))); +} + +bool operator==(const config_record_t &key, const config_query_t &query) { + return criteria_matches(key.criteria, query); +} + +bool operator==(const bwd_config_record_t &key, const config_query_t &query) { + return criteria_matches(key.criteria, query); } template @@ -185,6 +196,10 @@ bool operator<(const config_record_t &lhs, const config_record_t &rhs) { return lhs.criteria < rhs.criteria; } +bool operator<(const bwd_config_record_t &lhs, const bwd_config_record_t &rhs) { + return lhs.criteria < rhs.criteria; +} + static auto constexpr second_token = property::second_token; static auto constexpr quantized = property::quantized; static auto constexpr integrated = property::integrated; @@ -669,11 +684,162 @@ dim_t nearest_conf_seq_interval(compute::gpu_arch_t arch, dim_t head_size, return utils::rnd_up_pow2(seq); } +// Backward pass kernel configurations: +// [ arch, head_size, {sequence length}, {properties} ] -> bwd_config +// bwd_config_t fields: {unroll_m_BcBr, unroll_n_BcBr, +// unroll_m_DBc, unroll_n_DBc, +// unroll_m_DBr, unroll_n_DBr, +// wg_m_BcBr, wg_n_BcBr, +// wg_m_DBc, wg_n_DBc, +// wg_m_DBr, wg_n_DBr} +static std::vector sorted_bwd_configs = []() { + // clang-format off + std::vector configs = { + // xe_hpc + {{compute::gpu_arch_t::xe_hpc, 32}, { 16, 32, 16, 16, 16, 16, 4, 8, 2, 4, 2, 16 }}, + {{compute::gpu_arch_t::xe_hpc, 64}, { 16, 32, 16, 16, 32, 32, 4, 8, 4, 4, 2, 8 }}, + {{compute::gpu_arch_t::xe_hpc, 128}, { 16, 32, 16, 16, 32, 32, 4, 8, 8, 4, 4, 8 }}, + + /* xe2 todo: + {{compute::gpu_arch_t::xe2, 64}, {16, 64, 64, 16, 64, 16, 4, 1, 1, 4, 1, 4}}, + {{compute::gpu_arch_t::xe2, 128}, {16, 64, 64, 16, 64, 16, 4, 2, 2, 4, 2, 4}}, + {{compute::gpu_arch_t::xe2, 256}, {16, 64, 64, 16, 64, 16, 4, 2, 4, 4, 4, 4}}, + */ + }; + // clang-format on + + // ensures configs appear in order of most to least defined/desirable + std::sort(std::begin(configs), std::end(configs)); + return configs; +}(); + +bwd_config_t *choose_bwd_config(compute::gpu_arch_t arch, dim_t head_size, + dim_t seq, bool is_thin_q, bool is_quantized, bool is_integrated, + bool is_fma, bool is_f32) { + const bool is_f16_accumulate = false; + compute::gpu_arch_t arch_query = (arch >= compute::gpu_arch_t::xe3) + ? compute::gpu_arch_t::xe2 + : arch; + property query_properties = set_properties(is_thin_q, is_quantized, + is_integrated, is_fma, is_f32, is_f16_accumulate); + + config_query_t query(arch_query, static_cast(head_size), + static_cast(seq), query_properties); + auto it = find(begin(sorted_bwd_configs), end(sorted_bwd_configs), query); + if (it != end(sorted_bwd_configs)) { + VDEBUGINFO(4, primitive, sdpa, + "bwd config search: {query %s} -> {%s config:%s},", + to_string(query).c_str(), to_string(it->criteria).c_str(), + to_string(it->config).c_str()); + return &it->config; + } + return nullptr; +} + void deserialize_config_to_gemmstone(micro::HWInformation &hwInfo, gemmstone::GEMMProblem &problem_kq, gemmstone::GEMMProblem &problem_vs, micro::GEMMOptions &opts_kq, micro::GEMMOptions &opts_vs, gemmstone::SizeParams &sizes_kq, gemmstone::SizeParams &sizes_vs, - const micro_ukernel_params_t &ukernel_config) { + const micro_fwd_ukernel_params_t &ukernel_config) { + + // hardware info + hwInfo.gmdid = ukernel_config.hwinfo.gmdid; + hwInfo.euCount = ukernel_config.hwinfo.euCount; + hwInfo.systolicAvailable = ukernel_config.hwinfo.systolicAvailable; + + // options kq, vs + auto deserialize_options + = [](micro::GEMMOptions &gemmstone_opts, + const ukernel_serialized_opts_t &serialized_opts) { + gemmstone_opts.localA = serialized_opts.localA; + gemmstone_opts.localB = serialized_opts.localB; + gemmstone_opts.slmPtr = serialized_opts.slmPtr; + gemmstone_opts.scaleA = serialized_opts.scaleA; + gemmstone_opts.offsetA = serialized_opts.offsetA; + }; + + deserialize_options(opts_kq, ukernel_config.opts_kq); + deserialize_options(opts_vs, ukernel_config.opts_vs); + + // problems kq, vs + auto deserialize_problem + = [](gemmstone::GEMMProblem &problem, + const ukernel_serialized_problem_t &serialized_problem) { + problem.Ta_ext = { + static_cast(serialized_problem.Ta_ext)}; + problem.Tb_ext = { + static_cast(serialized_problem.Tb_ext)}; + problem.Ta + = {static_cast(serialized_problem.Ta)}; + problem.Tb + = {static_cast(serialized_problem.Tb)}; + problem.Tc_ext = { + static_cast(serialized_problem.Tc_ext)}; + problem.Tc + = {static_cast(serialized_problem.Tc)}; + problem.Ts + = {static_cast(serialized_problem.Ts)}; + problem.A.layout = static_cast( + serialized_problem.A_layout); + + problem.Ta_scale = {static_cast( + serialized_problem.Ta_scale)}; + problem.A_scale.setAlignment(serialized_problem.A_scale_alignment); + problem.A_scale.layout = static_cast( + serialized_problem.A_scale_layout); + problem.asPtrDims = serialized_problem.asPtrDims; + problem.Tao + = {static_cast(serialized_problem.Tao)}; + problem.AO.setAlignment(serialized_problem.AO_alignment); + problem.AO.layout = static_cast( + serialized_problem.AO_layout); + problem.aoPtrDims = serialized_problem.aoPtrDims; + problem.aOffset + = static_cast(serialized_problem.aOffset); + problem.aqGroupM = serialized_problem.aqGroupM; + problem.aqGroupK = serialized_problem.aqGroupK; + + problem.B.layout = static_cast( + serialized_problem.B_layout); + problem.C.layout = static_cast( + serialized_problem.C_layout); + problem.A.setAlignment(serialized_problem.A_alignment); + problem.A.crosspack = serialized_problem.A_crosspack; + problem.A.tileR = serialized_problem.A_tileR; + problem.A.tileC = serialized_problem.A_tileC; + + problem.B.setAlignment(serialized_problem.B_alignment); + problem.B.crosspack = serialized_problem.B_crosspack; + problem.B.tileR = serialized_problem.B_tileR; + problem.B.tileC = serialized_problem.B_tileC; + }; + deserialize_problem(problem_kq, ukernel_config.problem_kq); + deserialize_problem(problem_vs, ukernel_config.problem_vs); + + // sizes kq, vs + auto deserialize_sizes + = [](gemmstone::SizeParams &sizes, + const ukernel_serialized_sizes_t &serialized_sizes) { + sizes.m = serialized_sizes.m; + sizes.n = serialized_sizes.n; + sizes.k = serialized_sizes.k; + sizes.batch = serialized_sizes.batch; + }; + deserialize_sizes(sizes_kq, ukernel_config.sizes_kq); + deserialize_sizes(sizes_vs, ukernel_config.sizes_vs); +} + +void deserialize_config_to_gemmstone(micro::HWInformation &hwInfo, + gemmstone::GEMMProblem &problem_kq, gemmstone::GEMMProblem &problem_vs, + gemmstone::GEMMProblem &problem_vtdA, + gemmstone::GEMMProblem &problem_ktq, + gemmstone::GEMMProblem &problem_qdSt, micro::GEMMOptions &opts_kq, + micro::GEMMOptions &opts_vs, micro::GEMMOptions &opts_vtdA, + micro::GEMMOptions &opts_ktq, micro::GEMMOptions &opts_qdSt, + gemmstone::SizeParams &sizes_kq, gemmstone::SizeParams &sizes_vs, + gemmstone::SizeParams &sizes_vtdA, gemmstone::SizeParams &sizes_ktq, + gemmstone::SizeParams &sizes_qdSt, + const micro_bwd_ukernel_params_t &ukernel_config) { // hardware info hwInfo.gmdid = ukernel_config.hwinfo.gmdid; @@ -684,6 +850,7 @@ void deserialize_config_to_gemmstone(micro::HWInformation &hwInfo, auto deserialize_options = [](micro::GEMMOptions &gemmstone_opts, const ukernel_serialized_opts_t &serialized_opts) { + gemmstone_opts.localA = serialized_opts.localA; gemmstone_opts.localB = serialized_opts.localB; gemmstone_opts.slmPtr = serialized_opts.slmPtr; gemmstone_opts.scaleA = serialized_opts.scaleA; @@ -691,6 +858,9 @@ void deserialize_config_to_gemmstone(micro::HWInformation &hwInfo, }; deserialize_options(opts_kq, ukernel_config.opts_kq); deserialize_options(opts_vs, ukernel_config.opts_vs); + deserialize_options(opts_vtdA, ukernel_config.opts_vtdA); + deserialize_options(opts_ktq, ukernel_config.opts_ktq); + deserialize_options(opts_qdSt, ukernel_config.opts_qdSt); // problems kq, vs auto deserialize_problem @@ -735,6 +905,10 @@ void deserialize_config_to_gemmstone(micro::HWInformation &hwInfo, problem.C.layout = static_cast( serialized_problem.C_layout); problem.A.setAlignment(serialized_problem.A_alignment); + problem.A.crosspack = serialized_problem.A_crosspack; + problem.A.tileR = serialized_problem.A_tileR; + problem.A.tileC = serialized_problem.A_tileC; + problem.B.setAlignment(serialized_problem.B_alignment); problem.B.crosspack = serialized_problem.B_crosspack; problem.B.tileR = serialized_problem.B_tileR; @@ -742,6 +916,9 @@ void deserialize_config_to_gemmstone(micro::HWInformation &hwInfo, }; deserialize_problem(problem_kq, ukernel_config.problem_kq); deserialize_problem(problem_vs, ukernel_config.problem_vs); + deserialize_problem(problem_vtdA, ukernel_config.problem_vtdA); + deserialize_problem(problem_ktq, ukernel_config.problem_ktq); + deserialize_problem(problem_qdSt, ukernel_config.problem_qdSt); // sizes kq, vs auto deserialize_sizes @@ -754,6 +931,9 @@ void deserialize_config_to_gemmstone(micro::HWInformation &hwInfo, }; deserialize_sizes(sizes_kq, ukernel_config.sizes_kq); deserialize_sizes(sizes_vs, ukernel_config.sizes_vs); + deserialize_sizes(sizes_vtdA, ukernel_config.sizes_vtdA); + deserialize_sizes(sizes_ktq, ukernel_config.sizes_ktq); + deserialize_sizes(sizes_qdSt, ukernel_config.sizes_qdSt); } } // namespace sdpa diff --git a/src/gpu/intel/sdpa/configs.hpp b/src/gpu/intel/sdpa/configs.hpp index 3994691fb96..301f0cd5c29 100644 --- a/src/gpu/intel/sdpa/configs.hpp +++ b/src/gpu/intel/sdpa/configs.hpp @@ -38,6 +38,15 @@ struct config_t { int wg_m_vs, wg_n_vs; // Workgroup configuration for V*S GEMM }; +struct bwd_config_t { + int unroll_m_BcBr, unroll_n_BcBr; // Subgroup tile sizes for Br*Bc GEMMs + int unroll_m_DBc, unroll_n_DBc; // Subgroup tile sizes for Bc*D GEMMs + int unroll_m_DBr, unroll_n_DBr; // Subgroup tile sizes for Br*D GEMMs + int wg_m_BcBr, wg_n_BcBr; // Workgroup configuration for Br*Bc GEMMs + int wg_m_DBc, wg_n_DBc; // Workgroup configuration for Bc*D GEMMs + int wg_m_DBr, wg_n_DBr; // Workgroup configuration for Br*D GEMMs +}; + enum class property : int { none = 0x0, second_token = 0x1, @@ -91,17 +100,31 @@ struct config_record_t { config_t config; }; +struct bwd_config_record_t { + config_criteria_t criteria; + bwd_config_t config; +}; + +// Common criteria matching: returns true if query matches key criteria +bool criteria_matches( + const config_criteria_t &key, const config_query_t &query); + std::ostream &operator<<(std::ostream &s, const config_query_t &q); std::ostream &operator<<(std::ostream &s, const config_criteria_t &c); std::ostream &operator<<(std::ostream &s, const config_t &c); bool operator==(const config_record_t &key, const config_query_t &query); +bool operator==(const bwd_config_record_t &key, const config_query_t &query); bool operator<(const config_criteria_t &lhs, const config_criteria_t &rhs); bool operator<(const config_record_t &lhs, const config_record_t &rhs); +bool operator<(const bwd_config_record_t &lhs, const bwd_config_record_t &rhs); config_t *choose_config(compute::gpu_arch_t arch, dim_t head_size, dim_t seq, bool is_thin_q, bool is_quantized, bool is_integrated, bool is_fma, bool is_f32, bool is_f16_accumulate); +bwd_config_t *choose_bwd_config(compute::gpu_arch_t arch, dim_t head_size, + dim_t seq, bool is_thin_q, bool is_quantized, bool is_integrated, + bool is_fma, bool is_f32, bool is_f16_accumulate); dim_t round_up_seq_interval(dim_t seq, compute::gpu_arch_t arch); dim_t nearest_conf_seq_interval(compute::gpu_arch_t arch, dim_t head_size, @@ -116,12 +139,13 @@ struct ukernel_serialized_opts_t ukernel_serialized_opts_t() = default; ukernel_serialized_opts_t(micro::GEMMOptions opts) - : localB(opts.localB) + : localA(opts.localA) + , localB(opts.localB) , slmPtr(opts.slmPtr) , scaleA(opts.scaleA) , offsetA(opts.offsetA) {} - bool localB, slmPtr, scaleA, offsetA; - uint8_t padding[4] = {0}; + bool localA, localB, slmPtr, scaleA, offsetA; + uint8_t padding[3] = {0}; }; DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(ukernel_serialized_opts_t); static_assert(sizeof(ukernel_serialized_opts_t) == 8, @@ -155,6 +179,8 @@ struct ukernel_serialized_sizes_t int64_t m = 0, n = 0, k = 0; }; DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(ukernel_serialized_sizes_t); +static_assert(sizeof(ukernel_serialized_sizes_t) == 32, + "Expected sizeof(ukernel_serialized_sizes_t) == 32"); struct ukernel_serialized_problem_t : trivially_serializable_t { @@ -171,8 +197,11 @@ struct ukernel_serialized_problem_t , A_layout(static_cast(problem.A.layout)) , B_layout(static_cast(problem.B.layout)) , C_layout(static_cast(problem.C.layout)) + , A_tileR(problem.A.tileR) , B_tileR(problem.B.tileR) + , A_tileC(problem.A.tileC) , B_tileC(problem.B.tileC) + , A_crosspack(problem.A.crosspack) , B_crosspack(problem.B.crosspack) , A_alignment(problem.A.alignment) , A_scale_alignment(problem.A_scale.alignment) @@ -195,18 +224,18 @@ struct ukernel_serialized_problem_t int B_layout; int C_layout; - uint16_t B_tileR; - uint16_t B_tileC; - uint8_t B_crosspack; + uint16_t A_tileR, B_tileR; + uint16_t A_tileC, B_tileC; + uint8_t A_crosspack, B_crosspack; uint8_t A_alignment; uint8_t A_scale_alignment; uint8_t AO_alignment; uint8_t B_alignment; // trivially serializable classes require alignment to 8-byte boundaries - // padding0 bumps class size from 49->56 bytes so uint8_t arguments + // padding0 bumps class size from 54->56 bytes so uint8_t arguments // related to alignment can be grouped together rather than placed at the end of the struct - uint8_t padding0[7] = {0}; + uint8_t padding0[2] = {0}; int asPtrDims; int aOffset; @@ -219,13 +248,14 @@ struct ukernel_serialized_problem_t int aoPtrDims; int aqGroupM; int aqGroupK; + uint8_t padding1[4] = {0}; }; DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(ukernel_serialized_problem_t); -static_assert(sizeof(ukernel_serialized_problem_t) == 92, - "Expected sizeof(ukernel_serialized_problem_t) == 92"); +static_assert(sizeof(ukernel_serialized_problem_t) == 96, + "Expected sizeof(ukernel_serialized_problem_t) == 96"); -struct micro_ukernel_params_t - : trivially_serializable_t { +struct micro_fwd_ukernel_params_t + : trivially_serializable_t { int unroll_m_kq, unroll_n_kq; int unroll_m_vs, unroll_n_vs; int wg_m_kq, wg_n_kq; @@ -238,13 +268,56 @@ struct micro_ukernel_params_t ukernel_serialized_opts_t opts_vs; ukernel_serialized_sizes_t sizes_kq, sizes_vs; }; -DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(micro_ukernel_params_t); +DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(micro_fwd_ukernel_params_t); + +struct micro_bwd_ukernel_params_t + : trivially_serializable_t { + int unroll_m_BcBr, unroll_n_BcBr; + int unroll_m_DBc, unroll_n_DBc; + int unroll_m_DBr, unroll_n_DBr; + + int wg_m_BcBr, wg_n_BcBr; + int wg_m_DBc, wg_n_DBc; + int wg_m_DBr, wg_n_DBr; + + ukernel_serialized_hwinfo_t hwinfo; + + ukernel_serialized_problem_t problem_kq; + ukernel_serialized_problem_t problem_vs; + ukernel_serialized_problem_t problem_vtdA; + ukernel_serialized_problem_t problem_ktq; + ukernel_serialized_problem_t problem_qdSt; + + ukernel_serialized_opts_t opts_kq; + ukernel_serialized_opts_t opts_vs; + ukernel_serialized_opts_t opts_vtdA; + ukernel_serialized_opts_t opts_ktq; + ukernel_serialized_opts_t opts_qdSt; + + ukernel_serialized_sizes_t sizes_kq, sizes_vs; + ukernel_serialized_sizes_t sizes_vtdA; + ukernel_serialized_sizes_t sizes_ktq; + ukernel_serialized_sizes_t sizes_qdSt; +}; +DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(micro_bwd_ukernel_params_t); void deserialize_config_to_gemmstone(micro::HWInformation &hwInfo, gemmstone::GEMMProblem &problem_kq, gemmstone::GEMMProblem &problem_vs, micro::GEMMOptions &opts_kq, micro::GEMMOptions &opts_vs, gemmstone::SizeParams &sizes_kq, gemmstone::SizeParams &sizes_vs, - const micro_ukernel_params_t &ukernel_config); + const micro_fwd_ukernel_params_t &ukernel_config); + +void deserialize_config_to_gemmstone(micro::HWInformation &hwInfo, + gemmstone::GEMMProblem &problem_kq, gemmstone::GEMMProblem &problem_vs, + gemmstone::GEMMProblem &problem_vtdA, + gemmstone::GEMMProblem &problem_ktq, + gemmstone::GEMMProblem &problem_qdSt, micro::GEMMOptions &opts_kq, + micro::GEMMOptions &opts_vs, micro::GEMMOptions &opts_vtdA, + micro::GEMMOptions &opts_ktq, micro::GEMMOptions &opts_qdSt, + gemmstone::SizeParams &sizes_kq, gemmstone::SizeParams &sizes_vs, + gemmstone::SizeParams &sizes_vtdA, gemmstone::SizeParams &sizes_ktq, + gemmstone::SizeParams &sizes_qdSt, + const micro_bwd_ukernel_params_t &ukernel_config); } // namespace sdpa } // namespace intel diff --git a/src/gpu/intel/sdpa/micro.cpp b/src/gpu/intel/sdpa/micro.cpp index eb41be65927..38f3a1a999a 100644 --- a/src/gpu/intel/sdpa/micro.cpp +++ b/src/gpu/intel/sdpa/micro.cpp @@ -177,7 +177,7 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { // serializable minimal set of configuration params for ukernels // will be used to generate shim ukernels in reusable kernel_ctx - micro_ukernel_params_t ukernel_params; + micro_fwd_ukernel_params_t ukernel_params; ukernel_params.unroll_m_kq = config->unroll_m_kq; ukernel_params.unroll_n_kq = config->unroll_n_kq; ukernel_params.unroll_m_vs = config->unroll_m_vs; diff --git a/src/gpu/intel/sdpa/micro.hpp b/src/gpu/intel/sdpa/micro.hpp index 33a14f27819..55efdb5729e 100644 --- a/src/gpu/intel/sdpa/micro.hpp +++ b/src/gpu/intel/sdpa/micro.hpp @@ -93,7 +93,7 @@ struct micro_params_t : trivially_serializable_t { bool require_stateless_addressing; uint8_t padding3[6] = {0}; - micro_ukernel_params_t ukernel_config; + micro_fwd_ukernel_params_t ukernel_config; }; DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(micro_params_t); From 88ed3bca46aaf59c4bce17596d9974bd2dd70e1b Mon Sep 17 00:00:00 2001 From: syurkevi Date: Fri, 6 Mar 2026 16:54:17 -0800 Subject: [PATCH 05/23] gpu: gemm: jit: allow LLR gemms --- src/gpu/intel/gemm/jit/generator/microkernel_selector.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/gpu/intel/gemm/jit/generator/microkernel_selector.cpp b/src/gpu/intel/gemm/jit/generator/microkernel_selector.cpp index 4bc3de0f4ba..213d7a19508 100644 --- a/src/gpu/intel/gemm/jit/generator/microkernel_selector.cpp +++ b/src/gpu/intel/gemm/jit/generator/microkernel_selector.cpp @@ -237,10 +237,6 @@ Package selectGEMM(const GEMMOptions &options, HWInformation hwInfo, SizeParams evalParams.beta = 0; evalParams.euCount = hwInfo.euCount; - /* Locate appropriate kernel catalog */ - if (localA && localB) - stub("Unsupported protocol"); - /* Generate interface */ InterfaceHandler interface = effOptions.generateInterface(hw); From 868dd363062e661fb6a624f62e377ed95caef267 Mon Sep 17 00:00:00 2001 From: syurkevi Date: Fri, 6 Mar 2026 16:58:08 -0800 Subject: [PATCH 06/23] gpu: gemm: jit: enables ukernel with multiple kernel sources --- .../gemm/jit/generator/microkernel/fuser.cpp | 151 +++++++++--------- 1 file changed, 79 insertions(+), 72 deletions(-) diff --git a/src/gpu/intel/gemm/jit/generator/microkernel/fuser.cpp b/src/gpu/intel/gemm/jit/generator/microkernel/fuser.cpp index 728d337d843..7d19b42e1d9 100644 --- a/src/gpu/intel/gemm/jit/generator/microkernel/fuser.cpp +++ b/src/gpu/intel/gemm/jit/generator/microkernel/fuser.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include "generator/microkernel/elf.hpp" @@ -50,9 +51,8 @@ void fuse(std::vector &binary, "IGC did not generate a valid zebin program binary"); bool foundZeInfo = false; - SectionHeader *text = nullptr; const char *snames = nullptr; - int textSectionID = -1, relSectionID = -1; + std::vector> textSections; auto *sheaders = reinterpret_cast( base + fheaderPtr->sectionTableOff); @@ -70,104 +70,111 @@ void fuse(std::vector &binary, continue; if (sname.substr(0, 6) != ".text.") continue; } - if (text) - throw std::runtime_error("Multiple kernels in program"); - text = sheaders + s; - textSectionID = s; + textSections.emplace_back(sheaders + s, s); break; } default: break; } } - if (!foundZeInfo || !text || text->offset + text->size > bytes) + if (!foundZeInfo || textSections.empty()) throw std::runtime_error( "IGC did not generate a valid zebin program binary"); - std::string rname = ".rel"; - rname += (snames + text->name); - for (int s = 0; s < fheaderPtr->sectionCount; s++) { - if (sheaders[s].type != SectionHeader::Type::Relocation) continue; - if (rname != (snames + sheaders[s].name)) continue; - if (relSectionID >= 0) - throw std::runtime_error("Multiple relocation sections for kernel"); - relSectionID = s; - } - - auto *insn = reinterpret_cast(base + text->offset); - auto *iend = reinterpret_cast( - base + text->offset + text->size); - - const uint8_t *spliceStart = nullptr; - const uint8_t *spliceEnd = nullptr; - - for (; insn < iend; insn += 4) { - if (insn[0] & (1u << 29)) - insn -= 2; - else if (insn[3] == (sigilStart ^ id)) - spliceStart = reinterpret_cast(insn); - else if (insn[3] == (sigilEnd ^ id)) { - spliceEnd = reinterpret_cast(insn); - break; + for (auto &entry : textSections) { + auto *text = entry.first; + int textSectionID = entry.second; + if (text->offset + text->size > bytes) continue; + + auto *insn = reinterpret_cast(base + text->offset); + auto *iend = reinterpret_cast( + base + text->offset + text->size); + + const uint8_t *spliceStart = nullptr; + const uint8_t *spliceEnd = nullptr; + + for (; insn < iend; insn += 4) { + if (insn[0] & (1u << 29)) + insn -= 2; + else if (insn[3] == (sigilStart ^ id)) + spliceStart = reinterpret_cast(insn); + else if (insn[3] == (sigilEnd ^ id)) { + spliceEnd = reinterpret_cast(insn); + break; + } } - } - if (!spliceStart || !spliceEnd) return; + if (!spliceStart || !spliceEnd) continue; + + int relSectionID = -1; + std::string rname = ".rel"; + rname += (snames + text->name); + for (int s = 0; s < fheaderPtr->sectionCount; s++) { + if (sheaders[s].type != SectionHeader::Type::Relocation) continue; + if (rname != (snames + sheaders[s].name)) continue; + if (relSectionID >= 0) + throw std::runtime_error( + "Multiple relocation sections for kernel"); + relSectionID = s; + } - auto removeBytes = spliceEnd - spliceStart + 16; + auto removeBytes = spliceEnd - spliceStart + 16; - size_t before = spliceStart - base; - auto after = bytes - before - removeBytes; - ptrdiff_t sizeAdjust = microkernel.size() - removeBytes; + size_t before = spliceStart - base; + auto after = bytes - before - removeBytes; + ptrdiff_t sizeAdjust = microkernel.size() - removeBytes; - auto kbefore = before - text->offset; - auto kafter = text->size - kbefore - removeBytes; + auto kbefore = before - text->offset; + auto kafter = text->size - kbefore - removeBytes; - std::vector newBinary(bytes + sizeAdjust); - auto newBase = newBinary.data(); + std::vector newBinary(bytes + sizeAdjust); + auto newBase = newBinary.data(); - memmove(newBase, base, before); - memmove(newBase + before, microkernel.data(), microkernel.size()); - memmove(newBase + before + microkernel.size(), spliceStart + removeBytes, - after); + memmove(newBase, base, before); + memmove(newBase + before, microkernel.data(), microkernel.size()); + memmove(newBase + before + microkernel.size(), + spliceStart + removeBytes, after); - fixupJumpTargets(newBase + text->offset, kbefore, +sizeAdjust); - fixupJumpTargets( - newBase + before + microkernel.size(), kafter, -sizeAdjust); + fixupJumpTargets(newBase + text->offset, kbefore, +sizeAdjust); + fixupJumpTargets( + newBase + before + microkernel.size(), kafter, -sizeAdjust); - fheaderPtr = reinterpret_cast(newBase); + fheaderPtr = reinterpret_cast(newBase); - if (fheaderPtr->sectionTableOff > before) - fheaderPtr->sectionTableOff += sizeAdjust; + if (fheaderPtr->sectionTableOff > before) + fheaderPtr->sectionTableOff += sizeAdjust; - sheaders = reinterpret_cast( - newBase + fheaderPtr->sectionTableOff); - sheaders[textSectionID].size += sizeAdjust; - for (int s = 0; s < fheaderPtr->sectionCount; s++) - if (sheaders[s].offset > before) sheaders[s].offset += sizeAdjust; + sheaders = reinterpret_cast( + newBase + fheaderPtr->sectionTableOff); + sheaders[textSectionID].size += sizeAdjust; + for (int s = 0; s < fheaderPtr->sectionCount; s++) + if (sheaders[s].offset > before) sheaders[s].offset += sizeAdjust; - if (relSectionID >= 0) { - auto relSection = sheaders + relSectionID; - auto rel = reinterpret_cast(newBase + relSection->offset); - auto relEnd = reinterpret_cast( - newBase + relSection->offset + relSection->size); - for (; rel < relEnd; rel++) { - if (rel->offset >= kbefore) rel->offset += sizeAdjust; + if (relSectionID >= 0) { + auto relSection = sheaders + relSectionID; + auto rel = reinterpret_cast( + newBase + relSection->offset); + auto relEnd = reinterpret_cast( + newBase + relSection->offset + relSection->size); + for (; rel < relEnd; rel++) { + if (rel->offset >= kbefore) rel->offset += sizeAdjust; + } } - } #ifdef SPLICE_DEBUG - std::ofstream dump0("original.bin"); - dump0.write((const char *)binary.data(), binary.size()); + std::ofstream dump0("original." + std::to_string(id) + ".bin"); + dump0.write((const char *)binary.data(), binary.size()); - std::ofstream dump("patched.bin"); - dump.write((const char *)newBinary.data(), newBinary.size()); + std::ofstream dump("patched." + std::to_string(id) + ".bin"); + dump.write((const char *)newBinary.data(), newBinary.size()); #endif - std::swap(binary, newBinary); + std::swap(binary, newBinary); - // Tail-recurse to handle any further instances of this microkernel - fuse(binary, microkernel, id); + // Tail-recurse to handle any further instances of this microkernel + fuse(binary, microkernel, id); + return; + } } void fuse(std::vector &binary, const char *source) { From 9170719c90a3cff0807b4c33c800da98e43c8eda Mon Sep 17 00:00:00 2001 From: syurkevi Date: Tue, 24 Feb 2026 16:51:31 -0800 Subject: [PATCH 07/23] xe: sdpa: splits fwd/bwd gpu training primitives --- src/gpu/intel/sdpa/micro.cpp | 2 +- src/gpu/intel/sdpa/micro.hpp | 255 +++++++++++++++++++++++++++++++++-- 2 files changed, 245 insertions(+), 12 deletions(-) diff --git a/src/gpu/intel/sdpa/micro.cpp b/src/gpu/intel/sdpa/micro.cpp index 38f3a1a999a..688779db3b8 100644 --- a/src/gpu/intel/sdpa/micro.cpp +++ b/src/gpu/intel/sdpa/micro.cpp @@ -728,7 +728,7 @@ status_t micro_params_t::get_kernel_ctx( return status::success; } -status_t micro_t::execute(const exec_ctx_t &ctx) const { +status_t micro_t::execute_forward(const exec_ctx_t &ctx) const { const auto &conf = pd()->conf; const auto &qry = CTX_IN_STORAGE(DNNL_ARG_QUERIES); diff --git a/src/gpu/intel/sdpa/micro.hpp b/src/gpu/intel/sdpa/micro.hpp index 55efdb5729e..da0be1370b2 100644 --- a/src/gpu/intel/sdpa/micro.hpp +++ b/src/gpu/intel/sdpa/micro.hpp @@ -39,8 +39,9 @@ namespace sdpa { struct micro_params_t : trivially_serializable_t { const std::vector &get_kernel_names() const { - static const std::vector kernel_names = {"micro_sdpa"}; - return kernel_names; + static const std::vector kernel_names_fwd + = {"micro_sdpa"}; + return kernel_names_fwd; } status_t create_generator(const intel::engine_t &engine, @@ -91,26 +92,77 @@ struct micro_params_t : trivially_serializable_t { bool use_systolic_ukernel; bool kq_f16_accumulate, vs_f16_accumulate; bool require_stateless_addressing; - uint8_t padding3[6] = {0}; + bool is_training; + uint8_t padding3[5] = {0}; micro_fwd_ukernel_params_t ukernel_config; }; DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(micro_params_t); +struct micro_bwd_params_t : trivially_serializable_t { + + const std::vector &get_kernel_names() const { + static const std::vector kernel_names_bwd = { + "preprocess_Di", "micro_sdpa_bwd", "postprocess_dQ", "zero_dQ"}; + return kernel_names_bwd; + } + + status_t create_generator(const intel::engine_t &engine, + compute::kernel_bundle_t &bundle) const { + compute::kernel_ctx_t kernel_ctx; + CHECK(get_kernel_ctx(kernel_ctx)); + auto status = engine.create_kernel_bundle( + bundle, get_kernel_names(), kernel_ctx); + return status; + } + + status_t get_kernel_ctx(compute::kernel_ctx_t &) const; + + int ndims; + int kv_group_size; + data_type_t data_t; + data_type_t dst_data_t, key_data_t, qry_data_t, val_data_t, msk_data_t; + + int q_align, k_align, v_align, a_align; + bool transpose_k; + uint8_t padding0[3] = {0}; + + int key_group_size, val_group_size; + data_type_t scale_data_t; + + int attn_mask_undef, attn_mask_buffer, attn_mask_top_left, + attn_mask_bottom_right; + bool invert_scale, with_attn_scale, with_host_scale, with_attn_mask, + broadcast_mask_q, with_causal_mask; + uint8_t padding1[2] = {0}; + int subgroup_size, d_max; + + bool d_full, arch_gte_hpc; + bool block_k, block_dK, block_dV; + bool prefetch_mask, prefetch_k0, prefetch_k, prefetch_v, + prefetch_remainder; // TODO: prefetch for bwd + bool remainder_q; + bool use_systolic_ukernel; + bool with_dS; + uint8_t padding2[3] = {0}; + int prefetch_d_max; + uint8_t padding3[4] = {0}; + + micro_bwd_ukernel_params_t ukernel_config; +}; +DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(micro_bwd_params_t); + struct micro_t : public primitive_t { using primitive_t::primitive_t; - struct pd_t : public sdpa::pd_t { - using sdpa::pd_t::pd_t; - static constexpr int mask_mb_index = 0; - static constexpr int mask_q_index = 2; - static constexpr int mask_k_index = 3; - static constexpr int ndims = 4; + struct pd_t : public sdpa_fwd_pd_t { + using sdpa_fwd_pd_t::sdpa_fwd_pd_t; DECLARE_COMMON_PD_T("ocl:micro:reusable", micro_t); status_t init(impl::engine_t *engine) { using namespace data_type; + VCHECK_SDPA_COND(is_fwd(), VERBOSE_BAD_PROPKIND); VCHECK_SDPA_COND( utils::everyone_is(4, qry_md()->ndims, key_md()->ndims, val_md()->ndims, dst_md()->ndims), @@ -257,7 +309,6 @@ struct micro_t : public primitive_t { CHECK(init_conf_microkernels(engine)); CHECK(init_conf(engine)); - VCHECK_SDPA_COND( IMPLICATION((arch() == compute::gpu_arch_t::xe_hpc) && (qry_md()->data_type == data_type::f32), @@ -311,13 +362,195 @@ struct micro_t : public primitive_t { status_t init(impl::engine_t *engine) override; + status_t execute(const exec_ctx_t &ctx) const override { + return execute_forward(ctx); + } + private: const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - status_t execute(const exec_ctx_t &ctx) const override; + status_t execute_forward(const exec_ctx_t &ctx) const; compute::kernel_t kernel_; }; +struct micro_bwd_t : public primitive_t { + + using primitive_t::primitive_t; + struct pd_t : public sdpa_bwd_pd_t { + using sdpa_bwd_pd_t::sdpa_bwd_pd_t; + + DECLARE_COMMON_PD_T("ocl:micro:reusable", micro_bwd_t); + + status_t init(impl::engine_t *engine) { + using namespace data_type; + + VCHECK_SDPA_COND(!is_fwd(), VERBOSE_BAD_PROPKIND); + + VCHECK_SDPA_COND( + utils::everyone_is(4, qry_md()->ndims, key_md()->ndims, + val_md()->ndims, dst_md()->ndims), + VERBOSE_UNSUPPORTED_TAG); + VCHECK_SDPA_COND(utils::everyone_is(4, diff_qry_md()->ndims, + diff_key_md()->ndims, diff_val_md()->ndims, + diff_dst_md()->ndims), + VERBOSE_UNSUPPORTED_TAG); + if (with_attn_mask()) { + VCHECK_SDPA_COND( + attn_mask_md()->ndims == 4, VERBOSE_UNSUPPORTED_TAG); + VCHECK_SDPA_COND( + utils::one_of(attn_mask_md()->dims[mask_q_index], + desc()->queries(), 1), + VERBOSE_INVALID_BROADCAST, "attn_mask", mask_q_index); + VCHECK_SDPA_COND( + attn_mask_md()->dims[mask_k_index] == desc()->keys(), + VERBOSE_INVALID_BROADCAST, "attn_mask", mask_k_index); + VCHECK_SDPA_COND( + attn_mask_md()->data_type == qry_md()->data_type, + "Mask data type should match Qry/Dst data type."); + } + VCHECK_SDPA_COND( + (utils::everyone_is(data_type::f16, qry_md()->data_type, + dst_md()->data_type) + || utils::everyone_is(data_type::bf16, + qry_md()->data_type, dst_md()->data_type) + || utils::everyone_is(data_type::f32, + qry_md()->data_type, dst_md()->data_type)), + VERBOSE_UNSUPPORTED_DT); + VCHECK_SDPA_COND(utils::one_of(key_md()->data_type, f32, bf16, f16), + VERBOSE_UNSUPPORTED_DT); + VCHECK_SDPA_COND(utils::one_of(val_md()->data_type, f32, bf16, f16), + VERBOSE_UNSUPPORTED_DT); + VCHECK_SDPA_COND(set_default_formats() == status::success, + VERBOSE_UNSUPPORTED_TAG); + VCHECK_SDPA_COND(desc()->values() == desc()->head_size(), + "values does not match head size"); + + VCHECK_SDPA_COND(qry_md()->dims[1] >= key_md()->dims[1] + && qry_md()->dims[1] >= val_md()->dims[1], + "number of heads in query tensor(%ld) must be greater " + "than the number of heads in the key(%ld) and value(%ld) " + "tensors", + static_cast(qry_md()->dims[1]), + static_cast(key_md()->dims[1]), + static_cast(val_md()->dims[1])); + { + memory_desc_wrapper diff_qry_mdw(diff_qry_md()); + memory_desc_wrapper diff_key_mdw(diff_key_md()); + memory_desc_wrapper diff_val_mdw(diff_val_md()); + memory_desc_wrapper diff_dst_mdw(diff_dst_md()); + VCHECK_SDPA_COND( + utils::everyone_is(true, diff_qry_mdw.is_plain(), + diff_key_mdw.is_plain(), + diff_val_mdw.is_plain(), + diff_dst_mdw.is_plain()), + VERBOSE_UNSUPPORTED_TAG); + } + + // make sure gradient outputs match input dimensions + for (int i = 0; i < qry_md()->ndims; i++) + VCHECK_SDPA_COND(diff_qry_md()->dims[i] == qry_md()->dims[i], + "diff_qry dim[%d](%ld) must match qry dim[%d](%ld)", i, + (long)diff_qry_md()->dims[i], i, + (long)qry_md()->dims[i]); + for (int i = 0; i < key_md()->ndims; i++) + VCHECK_SDPA_COND(diff_key_md()->dims[i] == key_md()->dims[i], + "diff_key dim[%d](%ld) must match key dim[%d](%ld)", i, + (long)diff_key_md()->dims[i], i, + (long)key_md()->dims[i]); + for (int i = 0; i < val_md()->ndims; i++) + VCHECK_SDPA_COND(diff_val_md()->dims[i] == val_md()->dims[i], + "diff_val dim[%d](%ld) must match val dim[%d](%ld)", i, + (long)diff_val_md()->dims[i], i, + (long)val_md()->dims[i]); + // dO.dims() == O.dims() + for (int i = 0; i < src_md(4)->ndims; i++) + VCHECK_SDPA_COND(diff_dst_md()->dims[i] == src_md(4)->dims[i], + "diff_dst dim[%d](%ld) must match dst dim[%d](%ld)", i, + (long)diff_dst_md()->dims[i], i, + (long)src_md(4)->dims[i]); + + VCHECK_SDPA_COND( + utils::everyone_is(qry_md()->data_type, + diff_qry_md()->data_type, diff_key_md()->data_type, + diff_val_md()->data_type, diff_dst_md()->data_type), + "diff tensor data types must match qry data type(%s) " + " ?= dQ(%s), dK(%s), dV(%s), dO(%s)", + dnnl_dt2str(qry_md()->data_type), + dnnl_dt2str(diff_qry_md()->data_type), + dnnl_dt2str(diff_key_md()->data_type), + dnnl_dt2str(diff_val_md()->data_type), + dnnl_dt2str(diff_dst_md()->data_type)); + + init_default_ws(); + VCHECK_SDPA_COND(compare_ws(hint_fwd_pd_), VERBOSE_WS_MISMATCH); + + CHECK(init_conf_microkernels(engine)); + CHECK(init_conf(engine)); + CHECK(init_scratchpad(engine)); + + return status::success; + } + + status_t set_default_format(memory_desc_t &md, bool allow_transpose) { + using namespace format_tag; + memory_desc_wrapper mdw(md); + if (mdw.format_any()) return status::unimplemented; + if (!is_md_gemm_compatible_plain_format(&md)) + return status::unimplemented; + if (gemm_desc_t::get_trans(md) == dnnl_trans && !allow_transpose) + return status::unimplemented; + return status::success; + } + + status_t set_default_formats() { + CHECK(set_default_format(desc_.q_desc, false)); + CHECK(set_default_format(desc_.k_desc, true)); + CHECK(set_default_format(desc_.v_desc, false)); + CHECK(set_default_format(desc_.dst_desc, false)); + CHECK(set_default_format(desc_.diff_dst_desc, false)); + CHECK(set_default_format(desc_.diff_q_desc, false)); + CHECK(set_default_format(desc_.diff_k_desc, false)); + CHECK(set_default_format(desc_.diff_v_desc, false)); + return status::success; + } + + int sg_size() const { return sg_size_; } + bool use_systolic_ukernel() const { return use_systolic_ukernel_; } + + // Block size for head_size, which must be hard-coded into the kernel. + int d_max() const { + int head_size = into(desc()->head_size()); + for (int i = 32; i <= 1024; i *= 2) + if (head_size <= i) return i; + return head_size; + } + + compute::gpu_arch_t arch() const { return arch_; } + micro_bwd_params_t conf; + + private: + int sg_size_ = 0; + bool use_systolic_ukernel_ = true; + compute::gpu_arch_t arch_ = compute::gpu_arch_t::unknown; + + status_t init_scratchpad(impl::engine_t *engine); + status_t init_conf_microkernels(impl::engine_t *engine); + status_t init_conf(impl::engine_t *engine); + }; + + status_t init(impl::engine_t *engine) override; + + status_t execute(const exec_ctx_t &ctx) const override { + return execute_backward(ctx); + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + status_t execute_backward(const exec_ctx_t &ctx) const; + + compute::kernel_t kernel_, preprocess_, postprocess_, zero_; +}; + } // namespace sdpa } // namespace intel } // namespace gpu From 52114cc5a7d1f3672c8325e047740721c7761c01 Mon Sep 17 00:00:00 2001 From: syurkevi Date: Tue, 24 Feb 2026 21:59:50 -0800 Subject: [PATCH 08/23] xe: sdpa: split gemm setup for backwards pass --- src/gpu/gpu_sdpa_list.cpp | 21 +- src/gpu/intel/sdpa/configs.hpp | 2 +- src/gpu/intel/sdpa/micro.cpp | 1122 ++++++++++++++++++++++++++++++-- src/gpu/intel/sdpa/micro.hpp | 3 +- 4 files changed, 1086 insertions(+), 62 deletions(-) diff --git a/src/gpu/gpu_sdpa_list.cpp b/src/gpu/gpu_sdpa_list.cpp index 0b9d63bc43a..a99336ab66d 100644 --- a/src/gpu/gpu_sdpa_list.cpp +++ b/src/gpu/gpu_sdpa_list.cpp @@ -28,19 +28,34 @@ namespace impl { namespace gpu { namespace { +using namespace dnnl::impl::prop_kind; // clang-format off -constexpr impl_list_item_t impl_list[] = REG_SDPA_P({ +const std::map> + impl_list_map REG_SDPA_P({ + {{forward}, { GPU_INSTANCE_INTEL(intel::sdpa::micro_t) GPU_INSTANCE_INTEL_DEVMODE(intel::sdpa::ref_t) nullptr, + }}, + {{backward}, REG_BWD_PK({ + GPU_INSTANCE_INTEL(intel::sdpa::micro_bwd_t) + nullptr, + })}, }); // clang-format on } // namespace const impl_list_item_t *get_sdpa_impl_list(const sdpa_desc_t *desc) { - UNUSED(desc); - return impl_list; + static const impl_list_item_t empty_list[] = {nullptr}; + + const bool is_fwd = utils::one_of( + desc->prop_kind, forward_training, forward_inference); + prop_kind_t prop_kind = is_fwd ? forward : backward; + + const auto impl_list_it = impl_list_map.find({prop_kind}); + return impl_list_it != impl_list_map.cend() ? impl_list_it->second.data() + : empty_list; } } // namespace gpu diff --git a/src/gpu/intel/sdpa/configs.hpp b/src/gpu/intel/sdpa/configs.hpp index 301f0cd5c29..6b15fcd6570 100644 --- a/src/gpu/intel/sdpa/configs.hpp +++ b/src/gpu/intel/sdpa/configs.hpp @@ -124,7 +124,7 @@ config_t *choose_config(compute::gpu_arch_t arch, dim_t head_size, dim_t seq, bool is_f32, bool is_f16_accumulate); bwd_config_t *choose_bwd_config(compute::gpu_arch_t arch, dim_t head_size, dim_t seq, bool is_thin_q, bool is_quantized, bool is_integrated, - bool is_fma, bool is_f32, bool is_f16_accumulate); + bool is_fma, bool is_f32); dim_t round_up_seq_interval(dim_t seq, compute::gpu_arch_t arch); dim_t nearest_conf_seq_interval(compute::gpu_arch_t arch, dim_t head_size, diff --git a/src/gpu/intel/sdpa/micro.cpp b/src/gpu/intel/sdpa/micro.cpp index 688779db3b8..ca5d550b0c3 100644 --- a/src/gpu/intel/sdpa/micro.cpp +++ b/src/gpu/intel/sdpa/micro.cpp @@ -100,6 +100,47 @@ status_t update_config_from_devenv_values(config_t *config, bool quantized) { return status::success; } +status_t update_config_from_devenv_values(bwd_config_t *config) { + std::string bwd_config_str + = gpu_utils::dev_getenv("BWD_SDPA_CONFIG", std::string("")); + if (!bwd_config_str.empty()) { + std::array config_values; + int i; + int num_values = 0; + + stringstream_t ss(bwd_config_str); + while (ss >> i) { + config_values[num_values++] = i; + if (ss.peek() == ',') ss.ignore(); + } + VCHECK_SDPA_COND(num_values == 12, + "BWD_SDPA_CONFIG(%s) is invalid. Must be 12 integers " + "separate by a comma: " + ",," + ",," + ",," + ",," + ",," + ",", + bwd_config_str.c_str()); + if (num_values == 12) { + config->unroll_m_BcBr = config_values[0]; + config->unroll_n_BcBr = config_values[1]; + config->unroll_m_DBc = config_values[2]; + config->unroll_n_DBc = config_values[3]; + config->unroll_m_DBr = config_values[4]; + config->unroll_n_DBr = config_values[5]; + config->wg_m_BcBr = config_values[6]; + config->wg_n_BcBr = config_values[7]; + config->wg_m_DBc = config_values[8]; + config->wg_n_DBc = config_values[9]; + config->wg_m_DBr = config_values[10]; + config->wg_n_DBr = config_values[11]; + } + } + return status::success; +} + status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { using namespace jit; using gemm::jit::convert_dnnl_to_kernel_type; @@ -182,6 +223,7 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { ukernel_params.unroll_n_kq = config->unroll_n_kq; ukernel_params.unroll_m_vs = config->unroll_m_vs; ukernel_params.unroll_n_vs = config->unroll_n_vs; + ukernel_params.wg_m_kq = config->wg_m_kq; ukernel_params.wg_n_kq = config->wg_n_kq; ukernel_params.wg_m_vs = config->wg_m_vs; @@ -226,6 +268,7 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { problem.Ts = problem.Tc; auto problem_kq = problem; + problem_kq.Tc = problem_kq.Ts = (kq_acc_dt() == data_type::f16) ? Type::f16 : Type::f32; @@ -261,7 +304,7 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { const memory_desc_wrapper key_mdw(key_md()); auto ldk = static_cast( gemm_desc_t::get_ld(*key_md()) * key_mdw.data_type_size()); - problem_kq.A.setAlignment(micro::alignmentForLD(ldk)); + problem_kq.A.setAlignment(micro::alignmentForLD(int(ldk))); problem_kq.B.setAlignment(64); // Q is packed in VNNI format in SLM if (use_systolic_ukernel()) { problem_kq.B.crosspack = 2; @@ -282,7 +325,7 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { /* Set up problem size information */ SizeParams heuristic_sizes; - // quanatizing sizes to large intervals allows kernel + // quantizing sizes to large intervals allows kernel // selection search while avoiding recompilation for every new size heuristic_sizes.m = nearest_conf_seq_interval(arch_, d->head_size(), d->keys(), thin_q, quantized, is_integrated, use_fma_config, is_f32, @@ -337,7 +380,7 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { const memory_desc_wrapper val_mdw(val_md()); auto ldv = static_cast( gemm_desc_t::get_ld(*val_md()) * val_mdw.data_type_size()); - problem_vs.A.setAlignment(micro::alignmentForLD(ldv)); + problem_vs.A.setAlignment(micro::alignmentForLD(int(ldv))); problem_vs.B.setAlignment(64); // S is packed in SLM if (use_systolic_ukernel()) { problem_vs.B.crosspack = 16; } @@ -367,17 +410,371 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { return status::success; } +status_t micro_bwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { + using namespace jit; + using gemm::jit::convert_dnnl_to_kernel_type; + + assert(engine->kind() == engine_kind::gpu); + auto *intel_engine = utils::downcast(engine); + auto *dev_info = intel_engine->device_info(); + arch_ = dev_info->gpu_arch(); + auto *d = desc(); + + VCHECK_SDPA_COND(compute::mayiuse_microkernels(intel_engine), + "Microkernels not supported by the OpenCL driver."); + + /* Retrieve pre-tuned kernel configuration */ + bwd_config_t *config = nullptr; + const dim_t thin_q_threshold = 16; + auto queries = d->queries(); + // TODO: q=1 batch group optimizations + // if (queries == 1) { queries = (d->q_desc.dims[1] / d->kv_head_number); } + + bool thin_q = (queries <= thin_q_threshold); + bool quantized = false; + bool is_integrated = intel_engine->device_info()->is_integrated(); + bool is_f32 = (qry_md()->data_type == data_type::f32); + use_systolic_ukernel_ + = intel_engine->mayiuse(compute::device_ext_t:: + intel_subgroup_matrix_multiply_accumulate) + && !is_f32; // f32 -> non-systolic kernel only + + bool use_fma_config = !use_systolic_ukernel_; + config = choose_bwd_config(arch_, d->head_size(), d->keys(), thin_q, + quantized, is_integrated, use_fma_config, is_f32); + + VCHECK_SDPA_COND(config != nullptr, + "No suitable kernel configuration found for the given problem " + "size and attributes."); + + CHECK(update_config_from_devenv_values(config)); + + VDEBUGINFO(4, primitive, sdpa, + "D=%d,K=%d,%s%s%s" + "BcBr_tile(%d, %d): unroll_m=%d unroll_n=%d wg_m=%d wg_n=%d," + "DBc_tile(%d, %d): unroll_m=%d unroll_n=%d wg_m=%d wg_n=%d" + "DBr_tile(%d, %d): unroll_m=%d unroll_n=%d wg_m=%d wg_n=%d", + static_cast(d->head_size()), static_cast(d->keys()), + thin_q ? "thin_q," : "", quantized ? "quant," : "", + is_integrated ? "integrated" : "", + config->unroll_m_BcBr * config->wg_m_BcBr, + config->unroll_n_BcBr * config->wg_n_BcBr, config->unroll_m_BcBr, + config->unroll_n_BcBr, config->wg_m_BcBr, config->wg_n_BcBr, + config->unroll_m_DBc * config->wg_m_DBc, + config->unroll_n_DBc * config->wg_n_DBc, config->unroll_m_DBc, + config->unroll_n_DBc, config->wg_m_DBc, config->wg_n_DBc, + config->unroll_m_DBr * config->wg_m_DBr, + config->unroll_n_DBr * config->wg_n_DBr, config->unroll_m_DBr, + config->unroll_n_DBr, config->wg_m_DBr, config->wg_n_DBr); + + // Bc(Br) == (D)Bc + VCHECK_SDPA_COND( + ((config->unroll_m_BcBr * config->wg_m_BcBr + == config->unroll_n_DBc * config->wg_n_DBc) + && ((config->wg_m_DBc * config->wg_n_DBc) + <= (config->wg_m_BcBr * config->wg_n_BcBr))), + "[CONFIG] The config BcBr work_group tile M(%d) axis must equal " + "DBc work_group tile N(%d) axis and number of total subgroups " + "should be less than BcBr subgroups (%d ?<= %d)", + config->unroll_m_BcBr * config->wg_m_BcBr, + config->unroll_n_DBc * config->wg_n_DBc, + config->wg_m_DBc * config->wg_n_DBc, + config->wg_m_BcBr * config->wg_n_BcBr); + + // D(Bc) >= head size + VCHECK_SDPA_COND(config->unroll_m_DBc * config->wg_m_DBc >= d->head_size(), + "The DBc matmul config work_group tile N(%d*%d=%d) axis must be " + "greater than or equal to head size(%ld)", + config->unroll_m_DBc, config->wg_m_DBc, + config->unroll_m_DBc * config->wg_m_DBc, + static_cast(d->head_size())); + + // (Bc)Br == (D)Br, ngroups <= BcBr ngroups + VCHECK_SDPA_COND(((config->unroll_n_BcBr * config->wg_n_BcBr + == config->unroll_n_DBr * config->wg_n_DBr) + && (config->wg_m_DBr * config->wg_n_DBr + <= config->wg_m_BcBr * config->wg_n_BcBr)), + "[CONFIG] The config BcBr work_group tile N(%d) axis must equal " + "DBr work_group tile N(%d) axis and number of total subgroups " + "should be less than BcBr subgroups (%d ?<= %d)", + config->unroll_n_BcBr * config->wg_n_BcBr, + config->unroll_n_DBr * config->wg_n_DBr, + config->wg_m_DBr * config->wg_n_DBr, + config->wg_m_BcBr * config->wg_n_BcBr); + + // D(Br) >= head size + VCHECK_SDPA_COND(config->unroll_m_DBr * config->wg_m_DBr >= d->head_size(), + "The DBr matmul config work_group tile M(%d*%d=%d) axis must be " + "greater than or equal to head size(%ld)", + config->unroll_m_DBr, config->wg_m_DBr, + config->unroll_m_DBr * config->wg_m_DBr, + static_cast(d->head_size())); + + // serializable minimal set of configuration params for ukernels + // will be used to generate shim ukernels in reusable kernel_ctx + micro_bwd_ukernel_params_t ukernel_params; + + ukernel_params.unroll_m_BcBr = config->unroll_m_BcBr; + ukernel_params.unroll_n_BcBr = config->unroll_n_BcBr; + + ukernel_params.unroll_m_DBc = config->unroll_m_DBc; + ukernel_params.unroll_n_DBc = config->unroll_n_DBc; + + ukernel_params.unroll_m_DBr = config->unroll_m_DBr; + ukernel_params.unroll_n_DBr = config->unroll_n_DBr; + + ukernel_params.wg_m_BcBr = config->wg_m_BcBr; + ukernel_params.wg_n_BcBr = config->wg_n_BcBr; + + ukernel_params.wg_m_DBc = config->wg_m_DBc; + ukernel_params.wg_n_DBc = config->wg_n_DBc; + + ukernel_params.wg_m_DBr = config->wg_m_DBr; + ukernel_params.wg_n_DBr = config->wg_n_DBr; + + /* Get device information */ + micro::HWInformation hw_info; + hw_info.euCount = dev_info->eu_count(); + hw_info.gmdid = dev_info->ip_version(); + hw_info.systolicAvailable = use_systolic_ukernel_; + + if (hw_info.gmdid == 0) return status::unimplemented; + + ukernel_params.hwinfo = {hw_info}; + + sg_size_ = dev_info->min_subgroup_size(); + + auto convert_dnnl_to_kernel_layout = [](const memory_desc_t *md) { + return (gemm_desc_t::get_trans(*md) == dnnl_trans) ? MatrixLayout::T + : MatrixLayout::N; + }; + auto transpose_layout = [](const gemmstone::MatrixLayout l) { + switch (l) { + case MatrixLayout::N: return MatrixLayout::T; + case MatrixLayout::T: return MatrixLayout::N; + case MatrixLayout::Pr: return MatrixLayout::Pc; + case MatrixLayout::Pc: return MatrixLayout::Pr; + default: return l; + } + }; + + /* Set up GEMMProblem structure for first GEMM: K^T * Q */ + GEMMProblem problem; + problem.Ta_ext = convert_dnnl_to_kernel_type(key_md()->data_type); + problem.Tb_ext = convert_dnnl_to_kernel_type(qry_md()->data_type); + if (qry_md()->data_type == data_type::f16) { + problem.Ta = problem.Tb = Type::f16; + } else if (qry_md()->data_type == data_type::bf16) { + problem.Ta = problem.Tb = Type::bf16; + } else if (qry_md()->data_type == data_type::f32) { + problem.Ta = problem.Tb = Type::f32; + } else { + VCHECK_SDPA_COND(utils::one_of(qry_md()->data_type, data_type::f16, + data_type::bf16, data_type::f32), + "Q tensor's data type must be bf16, f16, or f32"); + } + problem.Tc = problem.Tc_ext = Type::f32; + problem.Ts = problem.Tc; + + const int wg_tile_m_BcBr = config->wg_m_BcBr * config->unroll_m_BcBr; + const int wg_tile_n_BcBr = config->wg_n_BcBr * config->unroll_n_BcBr; + + auto problem_kq = problem; + + problem_kq.A.layout = MatrixLayout::Pc; + problem_kq.B.layout = MatrixLayout::N; + problem_kq.C.layout = MatrixLayout::N; + const memory_desc_wrapper key_mdw(key_md()); + const memory_desc_wrapper qry_mdw(qry_md()); + auto ldk = static_cast( + gemm_desc_t::get_ld(*key_md()) * key_mdw.data_type_size()); + auto ldq = static_cast( + gemm_desc_t::get_ld(*qry_md()) * qry_mdw.data_type_size()); + problem_kq.A.setAlignment(64); // Q is packed in VNNI format in SLM + if (use_systolic_ukernel()) { + problem_kq.A.crosspack = 2; + problem_kq.A.tileR = into(sg_size_); + problem_kq.A.tileC = into(d_max()); + } + problem_kq.B.setAlignment(micro::alignmentForLD(int(ldq))); + + ukernel_params.problem_kq = {problem_kq}; + + /* Set up microkernel options */ + micro::GEMMOptions opts_kq; + opts_kq.localA = true; + opts_kq.slmPtr = true; + opts_kq.scaleA = false; + opts_kq.offsetA = false; + + ukernel_params.opts_kq = {opts_kq}; + + /* Set up problem size information */ + SizeParams heuristic_sizes; + heuristic_sizes.m = wg_tile_m_BcBr; + heuristic_sizes.n = wg_tile_n_BcBr; + heuristic_sizes.k = d->head_size(); + heuristic_sizes.batch = 1; + + ukernel_params.sizes_kq = {heuristic_sizes}; + + /* Set up GEMMProblem structure for second GEMM: V * S */ + auto problem_vs = std::move(problem); + problem_vs.Tc = problem_vs.Ts + = (vs_acc_dt() == data_type::f16) ? Type::f16 : Type::f32; + + problem_vs.Ta_ext = convert_dnnl_to_kernel_type(val_md()->data_type); + problem_vs.A.layout = convert_dnnl_to_kernel_layout(diff_dst_md()); + problem_vs.B.layout = MatrixLayout::Pr; + problem_vs.C.layout = MatrixLayout::N; + const memory_desc_wrapper diff_dst_mdw(diff_dst_md()); + auto lda = static_cast(gemm_desc_t::get_ld(*diff_dst_md()) + * diff_dst_mdw.data_type_size()); + problem_vs.A.setAlignment(micro::alignmentForLD(int(lda))); + problem_vs.B.setAlignment(64); // S is packed in SLM + if (use_systolic_ukernel()) { problem_vs.B.crosspack = 16; } + + ukernel_params.problem_vs = {problem_vs}; + + // directly tied to config, will recompile w/head size and config updates + // no need for interval quantization + heuristic_sizes.m = d->head_size(); + heuristic_sizes.n = wg_tile_m_BcBr; + heuristic_sizes.k = wg_tile_n_BcBr; + + ukernel_params.sizes_vs = {heuristic_sizes}; + + /* Set up microkernel options */ + micro::GEMMOptions opts_vs; + opts_vs.localA = false; + opts_vs.localB = true; + opts_vs.slmPtr = true; + + ukernel_params.opts_vs = {opts_vs}; + + //////// Vt * dA + auto problem_vtdA = problem; + problem_vtdA.Ta_ext = convert_dnnl_to_kernel_type(val_md()->data_type); + + problem_vtdA.A.layout = transpose_layout( + convert_dnnl_to_kernel_layout(val_md())); //TODO hardcode? + problem_vtdA.B.layout + = convert_dnnl_to_kernel_layout(diff_dst_md()); //TODO hardcode? + problem_vtdA.C.layout = MatrixLayout::N; + const memory_desc_wrapper val_mdw(val_md()); + auto ldv = gemm_desc_t::get_ld(*val_md()) * val_mdw.data_type_size(); + problem_vtdA.A.setAlignment(micro::alignmentForLD(int(ldv))); + problem_vtdA.B.setAlignment(micro::alignmentForLD(int(lda))); + + ukernel_params.problem_vtdA = {problem_vtdA}; + + heuristic_sizes.m = wg_tile_m_BcBr; + heuristic_sizes.n = wg_tile_n_BcBr; + heuristic_sizes.k = d->head_size(); + + ukernel_params.sizes_vtdA = {heuristic_sizes}; + + /* Set up microkernel options */ + micro::GEMMOptions opts_vtdA; + opts_vtdA.localA = false; + opts_vtdA.localB = false; + opts_vtdA.slmPtr = true; + ukernel_params.opts_vtdA = {opts_vtdA}; + + //////// Q * dS^t + auto problem_qdSt = problem; + problem_qdSt.Ta_ext = convert_dnnl_to_kernel_type(qry_md()->data_type); + problem_qdSt.A.layout = MatrixLayout::Pc; + problem_qdSt.B.layout + = transpose_layout(convert_dnnl_to_kernel_layout(qry_md())); + problem_qdSt.C.layout = MatrixLayout::N; + + problem_qdSt.A.setAlignment(64); + problem_qdSt.B.setAlignment(micro::alignmentForLD(int(ldq))); + if (use_systolic_ukernel()) { + problem_qdSt.A.crosspack = 2; + problem_qdSt.A.tileR = into( + sg_size_); // tile will be transposed (dS^t -> n x m) + problem_qdSt.A.tileC = into(wg_tile_n_BcBr); + } + + ukernel_params.problem_qdSt = {problem_qdSt}; + + heuristic_sizes.m = wg_tile_m_BcBr; + heuristic_sizes.n = d->values(); + heuristic_sizes.k = wg_tile_n_BcBr; + + ukernel_params.sizes_qdSt = {heuristic_sizes}; + + /* Set up microkernel options */ + micro::GEMMOptions opts_qdSt; + opts_qdSt.localA = true; + opts_qdSt.localB = false; + opts_qdSt.slmPtr = true; + ukernel_params.opts_qdSt = {opts_qdSt}; + + // dS * K + auto problem_ktq = problem; + problem_ktq.Ta_ext = convert_dnnl_to_kernel_type(key_md()->data_type); + + problem_ktq.A.layout + = transpose_layout(convert_dnnl_to_kernel_layout(key_md())); + problem_ktq.B.layout = MatrixLayout::Pr; + problem_ktq.C.layout = MatrixLayout::N; + + problem_ktq.A.setAlignment(micro::alignmentForLD(int(ldk))); + problem_ktq.B.setAlignment(64); // S is packed in SLM + if (use_systolic_ukernel()) { problem_ktq.B.crosspack = 16; } + + ukernel_params.problem_ktq = {problem_ktq}; + + heuristic_sizes.m = d->head_size(); + heuristic_sizes.n = wg_tile_n_BcBr; + heuristic_sizes.k = wg_tile_m_BcBr; + + ukernel_params.sizes_ktq = {heuristic_sizes}; + + /* Set up microkernel options */ + micro::GEMMOptions opts_ktq; + opts_ktq.localA = false; + opts_ktq.localB = true; + opts_ktq.slmPtr = true; + ukernel_params.opts_ktq = {opts_ktq}; + + conf.ukernel_config = ukernel_params; + + return status::success; +} + status_t micro_t::init(impl::engine_t *engine) { CHECK(create_kernel( engine, kernel_, pd()->conf.get_kernel_names()[0], pd()->conf)); + if (!kernel_) return status::runtime_error; return status::success; } -status_t micro_t::pd_t::init_conf(impl::engine_t *engine) { - using namespace micro; +status_t micro_bwd_t::init(impl::engine_t *engine) { + std::vector kernel_names = pd()->conf.get_kernel_names(); - auto *pd = this; + std::vector kernels; + CHECK(create_kernels(engine, kernels, kernel_names, pd()->conf)); + + preprocess_ = kernels[0]; + kernel_ = kernels[1]; + postprocess_ = kernels[2]; + zero_ = kernels[3]; + + if (!preprocess_) return status::runtime_error; + if (!kernel_) return status::runtime_error; + if (!postprocess_) return status::runtime_error; + if (!zero_) return status::runtime_error; + return status::success; +} + +template +static void init_conf_common(conf_t &conf, pd_type *pd) { + using pd_t = sdpa_pd_t; auto *d = pd->desc(); data_type_t data_t = pd->dst_md()->data_type; @@ -395,17 +792,9 @@ status_t micro_t::pd_t::init_conf(impl::engine_t *engine) { conf.val_data_t = val_mdw.data_type(); conf.dst_data_t = dst_mdw.data_type(); - conf.require_stateless_addressing = has_large_buffers(); - conf.msk_data_t = data_type::undef; if (pd->with_attn_mask()) { conf.msk_data_t = msk_mdw.data_type(); } - conf.key_scales_data_t = pd->key_scales_dt(); - conf.value_scales_data_t = pd->value_scales_dt(); - - conf.key_zp_data_t = pd->key_zp_dt(); - conf.value_zp_data_t = pd->value_zp_dt(); - auto Q_num_heads_dim = qry_mdw.dims()[1]; conf.kv_group_size = static_cast(Q_num_heads_dim / d->kv_head_number); @@ -414,13 +803,61 @@ status_t micro_t::pd_t::init_conf(impl::engine_t *engine) { auto ldv = gemm_desc_t::get_ld(*pd->val_md()) * val_mdw.data_type_size(); auto lda = gemm_desc_t::get_ld(*pd->dst_md()) * dst_mdw.data_type_size(); - conf.q_align = alignmentForLD(int(ldq)); - conf.k_align = alignmentForLD(int(ldk)); - conf.v_align = alignmentForLD(int(ldv)); - conf.a_align = alignmentForLD(int(lda)); + conf.q_align = micro::alignmentForLD(int(ldq)); + conf.k_align = micro::alignmentForLD(int(ldk)); + conf.v_align = micro::alignmentForLD(int(ldv)); + conf.a_align = micro::alignmentForLD(int(lda)); conf.transpose_k = gemm_desc_t::get_trans(*pd->key_md()) == dnnl_trans; + conf.scale_data_t = pd->scale_md()->data_type; + + conf.attn_mask_undef = attn_mask_type::undef; + conf.attn_mask_buffer = attn_mask_type::buffer; + conf.attn_mask_top_left = attn_mask_type::top_left; + conf.attn_mask_bottom_right = attn_mask_type::bottom_right; + + conf.invert_scale = d->invert_scale; + conf.with_attn_scale = pd->with_attn_scale(); + conf.with_host_scale = pd->with_host_scale(); + conf.with_attn_mask = (pd->with_attn_mask() && !pd->with_causal_mask()); + conf.broadcast_mask_q = (msk_mdw.dims()[pd_t::mask_q_index] == 1); + conf.with_causal_mask = pd->with_causal_mask(); + + conf.subgroup_size = pd->sg_size(); + conf.d_max = pd->d_max(); + + bool d_full = (d->head_size() == pd->d_max()); + conf.d_full = d_full; + conf.arch_gte_hpc = (pd->arch() >= compute::gpu_arch_t::xe_hpc); + + conf.use_systolic_ukernel = pd->use_systolic_ukernel(); +} + +status_t micro_t::pd_t::init_conf(impl::engine_t *engine) { + using namespace micro; + + auto *pd = this; + auto *d = pd->desc(); + + init_conf_common(conf, pd); + + conf.require_stateless_addressing = has_large_buffers(); + + const memory_desc_wrapper qry_mdw(pd->qry_md()); + const memory_desc_wrapper key_mdw(pd->key_md()); + const memory_desc_wrapper val_mdw(pd->val_md()); + const memory_desc_wrapper dst_mdw(pd->dst_md()); + + conf.key_scales_data_t = pd->key_scales_dt(); + conf.value_scales_data_t = pd->value_scales_dt(); + + conf.key_zp_data_t = pd->key_zp_dt(); + conf.value_zp_data_t = pd->value_zp_dt(); + + auto ldq = gemm_desc_t::get_ld(*pd->qry_md()) * qry_mdw.data_type_size(); + auto lda = gemm_desc_t::get_ld(*pd->dst_md()) * dst_mdw.data_type_size(); + int kq_scale_mask = (static_cast(pd->with_key_scales()) << 1) | static_cast(with_quantize_common(d->kq_scales)); conf.kq_scale_mask = kq_scale_mask; @@ -458,29 +895,12 @@ status_t micro_t::pd_t::init_conf(impl::engine_t *engine) { if (pd->with_value_scales() || pd->with_value_zp()) conf.val_group_size = pd->value_group_size(); - conf.scale_data_t = pd->scale_md()->data_type; - - conf.attn_mask_undef = attn_mask_type::undef; - conf.attn_mask_buffer = attn_mask_type::buffer; - conf.attn_mask_top_left = attn_mask_type::top_left; - conf.attn_mask_bottom_right = attn_mask_type::bottom_right; - - conf.invert_scale = d->invert_scale; - conf.with_attn_scale = pd->with_attn_scale(); - conf.with_host_scale = pd->with_host_scale(); - conf.with_attn_mask = (pd->with_attn_mask() && !pd->with_causal_mask()); - conf.broadcast_mask_q = (msk_mdw.dims()[pd_t::mask_q_index] == 1); - conf.with_causal_mask = pd->with_causal_mask(); - - conf.subgroup_size = pd->sg_size(); - conf.d_max = pd->d_max(); - - /* Set up microkernel strategy */ - const config_t config = {conf.ukernel_config.unroll_m_kq, - conf.ukernel_config.unroll_n_kq, conf.ukernel_config.unroll_m_vs, - conf.ukernel_config.unroll_n_vs, conf.ukernel_config.wg_m_kq, - conf.ukernel_config.wg_n_kq, conf.ukernel_config.wg_m_vs, - conf.ukernel_config.wg_n_vs}; + /* Set up microkernel strategy */ + const config_t config = {conf.ukernel_config.unroll_m_kq, + conf.ukernel_config.unroll_n_kq, conf.ukernel_config.unroll_m_vs, + conf.ukernel_config.unroll_n_vs, conf.ukernel_config.wg_m_kq, + conf.ukernel_config.wg_n_kq, conf.ukernel_config.wg_m_vs, + conf.ukernel_config.wg_n_vs}; const int kq_wg_tile_m = config.wg_m_kq * config.unroll_m_kq; const int kq_wg_tile_n = config.wg_n_kq * config.unroll_n_kq; @@ -488,7 +908,7 @@ status_t micro_t::pd_t::init_conf(impl::engine_t *engine) { int tile_k = kq_wg_tile_m; int tile_v = vs_wg_tile_m; - bool d_full = (d->head_size() == pd->d_max()); + bool d_full = conf.d_full; bool v_full = (d->head_size() == tile_v); auto Q = d->queries(); @@ -496,15 +916,12 @@ status_t micro_t::pd_t::init_conf(impl::engine_t *engine) { bool q_full = ((Q_per_kv_group % kq_wg_tile_n) != 0); conf.remainder_q = d_full && q_full; - conf.d_full = d_full; - conf.arch_gte_hpc = (pd->arch() >= compute::gpu_arch_t::xe_hpc); - conf.block_q = conf.block_a = conf.block_2d_a = false; if (d_full) { conf.block_q = (ldq % 4 == 0); - conf.block_a = (lda % 16 == 0 && v_full); + conf.block_a = (lda % 4 == 0 && v_full); } else if (pd->arch() >= compute::gpu_arch_t::xe_hpc - && (config.unroll_m_vs * dst_mdw.data_type_size()) <= 64) { + && config.unroll_m_vs < 64) { auto vbytes = d->values() * val_mdw.data_type_size(); if (lda % 16 == 0 && vbytes % 4 == 0) conf.block_2d_a = true; } @@ -514,29 +931,122 @@ status_t micro_t::pd_t::init_conf(impl::engine_t *engine) { conf.prefetch_k0 = true; conf.prefetch_k = true; conf.prefetch_v = true; + conf.prefetch_d_max = nstl::min(pd->d_max(), 64); bool no_rem = d_full && v_full && (d->keys() % tile_k == 0); conf.prefetch_remainder = !no_rem; - conf.prefetch_d_max = nstl::min(pd->d_max(), 64); } else { conf.prefetch_mask = conf.prefetch_k0 = conf.prefetch_k = conf.prefetch_v = conf.prefetch_remainder = false; conf.prefetch_d_max = 0; } - const bool arch_gte_xe2 = pd->arch() >= compute::gpu_arch_t::xe2; - conf.q_arrive_await_barrier = (Q > 1) && !arch_gte_xe2; + conf.q_arrive_await_barrier = (Q > 1); conf.softmax_inf_as_zero = (d->softmax_alg == alg_kind::softmax_accurate_inf_as_zero); - conf.use_systolic_ukernel = pd->use_systolic_ukernel(); conf.kq_f16_accumulate = (kq_acc_dt() == data_type::f16); conf.vs_f16_accumulate = (vs_acc_dt() == data_type::f16); + + bool is_training = desc()->prop_kind == prop_kind::forward_training; + conf.is_training = is_training; + if (is_training) { pd->init_default_ws(); } + + return status::success; +} + +status_t micro_bwd_t::pd_t::init_conf(impl::engine_t *engine) { + auto *pd = this; + auto *d = pd->desc(); + + init_conf_common(conf, pd); + + conf.require_stateless_addressing = has_large_buffers(); + conf.with_dS = pd->with_dS(); + + const memory_desc_wrapper key_mdw(pd->key_md()); + const memory_desc_wrapper val_mdw(pd->val_md()); + + auto ldk = gemm_desc_t::get_ld(*pd->key_md()) * key_mdw.data_type_size(); + auto ldv = gemm_desc_t::get_ld(*pd->val_md()) * val_mdw.data_type_size(); + + /* Set up microkernel strategy */ + const bwd_config_t config = {conf.ukernel_config.unroll_m_BcBr, + conf.ukernel_config.unroll_n_BcBr, conf.ukernel_config.unroll_m_DBc, + conf.ukernel_config.unroll_n_DBc, conf.ukernel_config.unroll_m_DBr, + conf.ukernel_config.unroll_n_DBr, conf.ukernel_config.wg_m_BcBr, + conf.ukernel_config.wg_n_BcBr, conf.ukernel_config.wg_m_DBc, + conf.ukernel_config.wg_n_DBc, conf.ukernel_config.wg_m_DBr, + conf.ukernel_config.wg_n_DBr}; + + const int kq_wg_tile_m = config.wg_m_BcBr * config.unroll_m_BcBr; + const int tile_k = kq_wg_tile_m; + + const int tile_dv = config.wg_n_DBc * config.unroll_n_DBc; + + bool d_full = conf.d_full; + bool dv_full = (d->head_size() == tile_dv); + + conf.block_k = conf.block_dK = conf.block_dV = false; + if (d_full) { + conf.block_dK = conf.block_k + = (ldk % 4 == 0) && (d->keys() % tile_k == 0); + conf.block_dV = (ldv % 4 == 0) && (dv_full); + } + + /* + * TODO: prefetching for bwd + * conf.prefetch_mask = conf.prefetch_k0 = conf.prefetch_k + * = conf.prefetch_v = conf.prefetch_remainder = false; + * conf.prefetch_d_max = 0; + */ + + return status::success; +} + +status_t micro_bwd_t::pd_t::init_scratchpad(impl::engine_t *engine) { + auto scratchpad = scratchpad_registry().registrar(); + auto gpu_align + = utils::downcast(engine)->get_buffer_alignment(); + memory_desc_wrapper dQ_wspace(diff_qry_md()); + size_t wspace_size = dQ_wspace.dims()[0] * dQ_wspace.dims()[1] + * dQ_wspace.dims()[2] * dQ_wspace.dims()[3]; + // f32 can directly atomic add to output + // others need intermediate scratchpad before conversion + if (conf.data_t != data_type::f32) { + scratchpad.book(memory_tracking::names::key_sdpa_dQ_reduction, + wspace_size, sizeof(float), gpu_align); + } + + // for GQA cases multiple Q heads atomic add into shared dK/dV + const bool needs_intermediate_dKV + = (conf.kv_group_size > 1 && conf.data_t != data_type::f32); + if (needs_intermediate_dKV) { + memory_desc_wrapper dK_wspace(diff_key_md()); + size_t dK_size = dK_wspace.dims()[0] * dK_wspace.dims()[1] + * dK_wspace.dims()[2] * dK_wspace.dims()[3]; + scratchpad.book(memory_tracking::names::key_sdpa_dK_reduction, dK_size, + sizeof(float), gpu_align); + + memory_desc_wrapper dV_wspace(diff_val_md()); + size_t dV_size = dV_wspace.dims()[0] * dV_wspace.dims()[1] + * dV_wspace.dims()[2] * dV_wspace.dims()[3]; + scratchpad.book(memory_tracking::names::key_sdpa_dV_reduction, dV_size, + sizeof(float), gpu_align); + } + + // space for D_i preprocess result + dim_t batch = qry_md()->dims[0]; + dim_t num_q_heads = qry_md()->dims[1]; + dim_t Q = desc()->queries(); + size_t Di_size = batch * num_q_heads * Q; + scratchpad.book(memory_tracking::names::key_sdpa_Di, Di_size, sizeof(float), + gpu_align); + return status::success; } status_t micro_params_t::get_kernel_ctx( compute::kernel_ctx_t &kernel_ctx) const { using namespace micro; - kernel_ctx.require_stateless_addressing(require_stateless_addressing); kernel_ctx.define_int("NDIMS", ndims); kernel_ctx.set_data_type(data_t); @@ -545,7 +1055,6 @@ status_t micro_params_t::get_kernel_ctx( def_data_type(kernel_ctx, qry_data_t, "QRY"); def_data_type(kernel_ctx, val_data_t, "VAL"); def_data_type(kernel_ctx, dst_data_t, "DST"); - def_data_type(kernel_ctx, scale_data_t, "SCALE", !with_host_scale); if (with_attn_mask) { def_data_type(kernel_ctx, msk_data_t, "MSK"); } @@ -577,6 +1086,7 @@ status_t micro_params_t::get_kernel_ctx( kernel_ctx.define_int("KEY_GROUP_SIZE", key_group_size); kernel_ctx.define_int("VAL_GROUP_SIZE", val_group_size); + def_data_type(kernel_ctx, scale_data_t, "SCALE", !with_host_scale); kernel_ctx.define_int("INVERT_SCALE", invert_scale); kernel_ctx.define_int("WITH_ATTN_SCALE", with_attn_scale); kernel_ctx.define_int("WITH_HOST_SCALE", with_host_scale); @@ -609,6 +1119,7 @@ status_t micro_params_t::get_kernel_ctx( kernel_ctx.define_int("USE_SYSTOLIC_UKERNEL", use_systolic_ukernel); kernel_ctx.define_int("KQ_F16_ACC", kq_f16_accumulate); kernel_ctx.define_int("VS_F16_ACC", vs_f16_accumulate); + kernel_ctx.define_int("IS_TRAINING", is_training); micro::HWInformation hw_info; gemmstone::GEMMProblem problem_kq, problem_vs; @@ -639,6 +1150,7 @@ status_t micro_params_t::get_kernel_ctx( reqs_vs.push_back(StrategyRequirement::WGM == config.wg_m_vs); reqs_vs.push_back(StrategyRequirement::WGN == config.wg_n_vs); + /* Ask microkernel provider for microkernel */ auto kq_strat_override = [&](gemmstone::GEMMStrategy &strat) { std::string newStrat; newStrat = gpu_utils::dev_getenv("SDPA_KQ_USTRATEGY", newStrat); @@ -657,7 +1169,6 @@ status_t micro_params_t::get_kernel_ctx( adjustStrategy(hw, problem_kq, strat); } }; - /* Ask microkernel provider for microkernel */ try { gemm_kq = micro::selectGEMM(opts_kq, hw_info, sizes_kq, problem_kq, reqs_kq, kq_strat_override); @@ -704,6 +1215,7 @@ status_t micro_params_t::get_kernel_ctx( "gemm_vs microkernel generation failure with message: %s", ex.what()); } + VDEBUGINFO(4, primitive, sdpa, "kq_gemm: %s, vs_gemm: %s,", problem_kq.toString().c_str(), problem_vs.toString().c_str()); @@ -728,6 +1240,221 @@ status_t micro_params_t::get_kernel_ctx( return status::success; } +status_t micro_bwd_params_t::get_kernel_ctx( + compute::kernel_ctx_t &kernel_ctx) const { + using namespace micro; + + kernel_ctx.define_int("NDIMS", ndims); + kernel_ctx.set_data_type(data_t); + + def_data_type(kernel_ctx, key_data_t, "KEY"); + def_data_type(kernel_ctx, qry_data_t, "QRY"); + def_data_type(kernel_ctx, val_data_t, "VAL"); + def_data_type(kernel_ctx, dst_data_t, "DST"); + + if (with_attn_mask) { def_data_type(kernel_ctx, msk_data_t, "MSK"); } + + kernel_ctx.define_int("KV_GROUP_SIZE", kv_group_size); + + kernel_ctx.define_int("Q_ALIGN", q_align); + kernel_ctx.define_int("K_ALIGN", k_align); + kernel_ctx.define_int("V_ALIGN", v_align); + kernel_ctx.define_int("A_ALIGN", a_align); + + kernel_ctx.define_int("TRANSPOSE_K", transpose_k); + + def_data_type(kernel_ctx, scale_data_t, "SCALE", !with_host_scale); + kernel_ctx.define_int("INVERT_SCALE", invert_scale); + kernel_ctx.define_int("WITH_ATTN_SCALE", with_attn_scale); + kernel_ctx.define_int("WITH_HOST_SCALE", with_host_scale); + kernel_ctx.define_int("ATTN_MASK_UNDEF", attn_mask_undef); + kernel_ctx.define_int("ATTN_MASK_BUFFER", attn_mask_buffer); + kernel_ctx.define_int("ATTN_MASK_TOP_LEFT", attn_mask_top_left); + kernel_ctx.define_int("ATTN_MASK_BOTTOM_RIGHT", attn_mask_bottom_right); + + kernel_ctx.define_int("WITH_ATTN_MASK", with_attn_mask); + kernel_ctx.define_int("BROADCAST_MASK_Q", broadcast_mask_q); + kernel_ctx.define_int("WITH_CAUSAL_MASK", with_causal_mask); + kernel_ctx.define_int("WITH_DS", with_dS); + + kernel_ctx.define_int("SUBGROUP_SIZE", subgroup_size); + kernel_ctx.define_int("D_MAX", d_max); + + kernel_ctx.define_int("BLOCK_K", block_k); + kernel_ctx.define_int("BLOCK_DK", block_dK); + kernel_ctx.define_int("BLOCK_DV", block_dV); + + //TODO: remove or add prefetching to BWD + kernel_ctx.define_int("PREFETCH_MASK", prefetch_mask); + kernel_ctx.define_int("PREFETCH_K0", prefetch_k0); + kernel_ctx.define_int("PREFETCH_K", prefetch_k); + kernel_ctx.define_int("PREFETCH_V", prefetch_v); + kernel_ctx.define_int("PREFETCH_REMAINDER", prefetch_remainder); + kernel_ctx.define_int("PREFETCH_D_MAX", prefetch_d_max); + + kernel_ctx.define_int("USE_SYSTOLIC_UKERNEL", use_systolic_ukernel); + + micro::HWInformation hw_info; + gemmstone::GEMMProblem problem_kq, problem_vs; + micro::GEMMOptions opts_kq, opts_vs; + gemmstone::SizeParams sizes_kq, sizes_vs; + + gemmstone::GEMMProblem problem_vtdA, problem_ktq, problem_qdSt; + micro::GEMMOptions opts_vtdA, opts_ktq, opts_qdSt; + gemmstone::SizeParams sizes_vtdA, sizes_ktq, sizes_qdSt; + + deserialize_config_to_gemmstone(hw_info, problem_kq, problem_vs, + problem_vtdA, problem_ktq, problem_qdSt, opts_kq, opts_vs, + opts_vtdA, opts_ktq, opts_qdSt, sizes_kq, sizes_vs, sizes_vtdA, + sizes_ktq, sizes_qdSt, ukernel_config); + + micro::Package gemm_kq, gemm_vs, gemm_vtdA, gemm_ktq, gemm_qdSt; + + /* Set up microkernel strategy */ + const bwd_config_t config + = {ukernel_config.unroll_m_BcBr, ukernel_config.unroll_n_BcBr, + ukernel_config.unroll_m_DBc, ukernel_config.unroll_n_DBc, + ukernel_config.unroll_m_DBr, ukernel_config.unroll_n_DBr, + ukernel_config.wg_m_BcBr, ukernel_config.wg_n_BcBr, + ukernel_config.wg_m_DBc, ukernel_config.wg_n_DBc, + ukernel_config.wg_m_DBr, ukernel_config.wg_n_DBr}; + + std::vector reqs_kq; + reqs_kq.push_back(StrategyRequirement::UnrollM == config.unroll_m_BcBr); + reqs_kq.push_back(StrategyRequirement::UnrollN == config.unroll_n_BcBr); + reqs_kq.push_back(StrategyRequirement::WGM == config.wg_m_BcBr); + reqs_kq.push_back(StrategyRequirement::WGN == config.wg_n_BcBr); + + std::vector reqs_vs; + reqs_vs.push_back(StrategyRequirement::UnrollM == config.unroll_m_DBc); + reqs_vs.push_back(StrategyRequirement::UnrollN == config.unroll_n_DBc); + reqs_vs.push_back(StrategyRequirement::WGM == config.wg_m_DBc); + reqs_vs.push_back(StrategyRequirement::WGN == config.wg_n_DBc); + + std::vector reqs_vtdA; + reqs_vtdA.push_back(StrategyRequirement::UnrollM == config.unroll_m_BcBr); + reqs_vtdA.push_back(StrategyRequirement::UnrollN == config.unroll_n_BcBr); + reqs_vtdA.push_back(StrategyRequirement::WGM == config.wg_m_BcBr); + reqs_vtdA.push_back(StrategyRequirement::WGN == config.wg_n_BcBr); + + std::vector reqs_ktq; + reqs_ktq.push_back(StrategyRequirement::UnrollM == config.unroll_m_DBr); + reqs_ktq.push_back(StrategyRequirement::UnrollN == config.unroll_n_DBr); + reqs_ktq.push_back(StrategyRequirement::WGM == config.wg_m_DBr); + reqs_ktq.push_back(StrategyRequirement::WGN == config.wg_n_DBr); + + std::vector reqs_qdSt; + reqs_qdSt.push_back(StrategyRequirement::UnrollM == config.unroll_n_DBc); + reqs_qdSt.push_back(StrategyRequirement::UnrollN == config.unroll_m_DBc); + reqs_qdSt.push_back(StrategyRequirement::WGM == config.wg_n_DBc); + reqs_qdSt.push_back(StrategyRequirement::WGN == config.wg_m_DBc); + + /* Ask microkernel provider for microkernel */ + try { + gemm_kq = micro::selectGEMM( + opts_kq, hw_info, sizes_kq, problem_kq, reqs_kq); + } catch (const std::runtime_error &ex) { + VCHECK_SDPA_COND(false, + "gemm_kq microkernel generation failure with message: %s", + ex.what()); + } + + try { + if (use_systolic_ukernel) { + auto adjust_vs = [](GEMMStrategy &strategy) { + /* Enable dpasw */ + strategy.dpasw |= strategy.fused; + }; + gemm_vs = micro::selectGEMM( + opts_vs, hw_info, sizes_vs, problem_vs, reqs_vs, adjust_vs); + } else { + gemm_vs = micro::selectGEMM( + opts_vs, hw_info, sizes_vs, problem_vs, reqs_vs); + } + } catch (const std::runtime_error &ex) { + VCHECK_SDPA_COND(false, + "gemm_vs microkernel generation failure with message: %s", + ex.what()); + } + + VDEBUGINFO(4, primitive, sdpa, + "kq_gemm: %s, vs_gemm: %s, vtdA_gemm: %s, ktq_gemm: %s, qdSt: %s\n", + problem_kq.toString().c_str(), problem_vs.toString().c_str(), + problem_vtdA.toString().c_str(), problem_ktq.toString().c_str(), + problem_qdSt.toString().c_str()); + + /* Generate microkernel shims */ + micro::ShimOptions shimOptions; + shimOptions.subgroupSize = subgroup_size; + shimOptions.useTileOps = true; + shimOptions.decorator = "kq"; + + std::string gemm_kq_header + = micro::generateShim(gemm_kq, HostLanguage::OpenCL_C, shimOptions); + kernel_ctx.add_custom_header("gemm_kq.h", std::move(gemm_kq_header)); + + shimOptions.microkernelID++; + shimOptions.decorator = "vs"; + + std::string gemm_vs_header + = micro::generateShim(gemm_vs, HostLanguage::OpenCL_C, shimOptions); + kernel_ctx.add_custom_header("gemm_vs.h", std::move(gemm_vs_header)); + + try { + gemm_vtdA = micro::selectGEMM( + opts_vtdA, hw_info, sizes_vtdA, problem_vtdA, reqs_vtdA); + } catch (const std::runtime_error &ex) { + VCHECK_SDPA_COND(false, + "gemm_vtdA microkernel generation failure with message: %s", + ex.what()); + } + + shimOptions.microkernelID++; + shimOptions.decorator = "vtdA"; + + std::string gemm_vtdA_header = micro::generateShim( + gemm_vtdA, HostLanguage::OpenCL_C, shimOptions); + kernel_ctx.add_custom_header("gemm_vtdA.h", std::move(gemm_vtdA_header)); + + try { + gemm_ktq = micro::selectGEMM( + opts_ktq, hw_info, sizes_ktq, problem_ktq, reqs_ktq); + } catch (const std::runtime_error &ex) { + VCHECK_SDPA_COND(false, + "gemm_ktq microkernel generation failure with message: %s", + ex.what()); + } + + shimOptions.microkernelID++; + shimOptions.decorator = "ktq"; + + std::string gemm_ktq_header = micro::generateShim( + gemm_ktq, HostLanguage::OpenCL_C, shimOptions); + kernel_ctx.add_custom_header("gemm_ktq.h", std::move(gemm_ktq_header)); + + try { + gemm_qdSt = micro::selectGEMM( + opts_qdSt, hw_info, sizes_qdSt, problem_qdSt, reqs_qdSt); + } catch (const std::runtime_error &ex) { + VCHECK_SDPA_COND(false, + "gemm_qdSt microkernel generation failure with message: %s", + ex.what()); + } + + shimOptions.microkernelID++; + shimOptions.decorator = "qdSt"; + + std::string gemm_qdSt_header = micro::generateShim( + gemm_qdSt, HostLanguage::OpenCL_C, shimOptions); + kernel_ctx.add_custom_header("gemm_qdSt.h", std::move(gemm_qdSt_header)); + + if (gemm_kq.grfMin > 128 || gemm_vs.grfMin > 128 || gemm_vtdA.grfMin > 128 + || gemm_ktq.grfMin > 128 || gemm_qdSt.grfMin > 128) + kernel_ctx.add_option("-cl-intel-256-GRF-per-thread"); + + return status::success; +} + status_t micro_t::execute_forward(const exec_ctx_t &ctx) const { const auto &conf = pd()->conf; @@ -735,6 +1462,7 @@ status_t micro_t::execute_forward(const exec_ctx_t &ctx) const { const auto &key = CTX_IN_STORAGE(DNNL_ARG_KEYS); const auto &val = CTX_IN_STORAGE(DNNL_ARG_VALUES); auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); + auto &ws = CTX_OUT_STORAGE(DNNL_ARG_WORKSPACE); const auto &scale = CTX_IN_STORAGE(DNNL_ARG_SCALE); const auto &attn_mask = CTX_IN_STORAGE(DNNL_ARG_ATTN_MASK); @@ -791,9 +1519,6 @@ status_t micro_t::execute_forward(const exec_ctx_t &ctx) const { arg_list.append(strides4); }; - int mask_type = static_cast(pd()->desc()->mask_type); - compute::kernel_arg_list_t arg_list; - const memory_desc_wrapper scale_mdw(pd()->scale_md()); float scalar_scale = 1.f; float inv_scalar_scale = 1.f; @@ -809,9 +1534,12 @@ status_t micro_t::execute_forward(const exec_ctx_t &ctx) const { inv_scalar_scale = 1. / scalar_scale; } + int mask_type = static_cast(pd()->desc()->mask_type); + compute::kernel_arg_list_t arg_list; arg_list.append(key); arg_list.append(qry); arg_list.append(val); + arg_list.append(ws); arg_list.append(dst); if (pd()->with_host_scale()) { arg_list.append(scalar_scale); @@ -858,6 +1586,286 @@ status_t micro_t::execute_forward(const exec_ctx_t &ctx) const { return parallel_for(ctx, nd_range, kernel_, arg_list); } +status_t micro_bwd_t::execute_backward(const exec_ctx_t &ctx) const { + const auto &qry = CTX_IN_STORAGE(DNNL_ARG_QUERIES); + const auto &key = CTX_IN_STORAGE(DNNL_ARG_KEYS); + const auto &val = CTX_IN_STORAGE(DNNL_ARG_VALUES); + auto &ws = CTX_IN_STORAGE(DNNL_ARG_WORKSPACE); + const auto &dst = CTX_IN_STORAGE(DNNL_ARG_DST); + const auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); + auto &diff_q = CTX_OUT_STORAGE(DNNL_ARG_DIFF_QUERIES); + auto &diff_k = CTX_OUT_STORAGE(DNNL_ARG_DIFF_KEYS); + auto &diff_v = CTX_OUT_STORAGE(DNNL_ARG_DIFF_VALUES); + const auto &scale = CTX_IN_STORAGE(DNNL_ARG_SCALE); + const auto &attn_mask = CTX_IN_STORAGE(DNNL_ARG_ATTN_MASK); + auto Di_scratch = ctx.get_scratchpad_grantor().get_memory_storage( + memory_tracking::names::key_sdpa_Di); + auto diff_q_scratch = ctx.get_scratchpad_grantor().get_memory_storage( + memory_tracking::names::key_sdpa_dQ_reduction); + auto diff_k_scratch = ctx.get_scratchpad_grantor().get_memory_storage( + memory_tracking::names::key_sdpa_dK_reduction); + auto diff_v_scratch = ctx.get_scratchpad_grantor().get_memory_storage( + memory_tracking::names::key_sdpa_dV_reduction); + + const bool with_dS = pd()->with_dS(); + + const int kv_group_size = pd()->conf.kv_group_size; + const dim_t Q = pd()->desc()->queries(); + const dim_t K = pd()->desc()->keys(); + const dim_t D = pd()->desc()->head_size(); + + const data_type_t data_t = pd()->dst_md()->data_type; + const bool needs_intermediate_dQ = (data_t != data_type::f32); + const bool needs_intermediate_dKV + = (kv_group_size > 1 && data_t != data_type::f32); + const bool needs_zero_dKV = (kv_group_size > 1); + + const auto &conf = pd()->conf; + + const bwd_config_t config = {conf.ukernel_config.unroll_m_BcBr, + conf.ukernel_config.unroll_n_BcBr, conf.ukernel_config.unroll_m_DBc, + conf.ukernel_config.unroll_n_DBc, conf.ukernel_config.unroll_m_DBr, + conf.ukernel_config.unroll_n_DBr, conf.ukernel_config.wg_m_BcBr, + conf.ukernel_config.wg_n_BcBr, conf.ukernel_config.wg_m_DBc, + conf.ukernel_config.wg_n_DBc, conf.ukernel_config.wg_m_DBr, + conf.ukernel_config.wg_n_DBr}; + + auto wg_tile_k = config.unroll_m_BcBr * config.wg_m_BcBr; + auto wg_tile_q = config.unroll_n_BcBr * config.wg_n_BcBr; + auto sg_per_wg = config.wg_m_BcBr * config.wg_n_BcBr; + + auto sg_per_wg_BcBr = config.wg_m_BcBr * config.wg_n_BcBr; + auto sg_per_wg_DBc = config.wg_m_DBc * config.wg_n_DBc; + auto sg_per_wg_DBr = config.wg_m_DBr * config.wg_n_DBr; + + using std::max; + sg_per_wg = max(max(sg_per_wg_BcBr, sg_per_wg_DBc), sg_per_wg_DBr); + + const memory_desc_wrapper qry_mdw(pd()->qry_md()); + const memory_desc_wrapper key_mdw(pd()->key_md()); + const memory_desc_wrapper val_mdw(pd()->val_md()); + const memory_desc_wrapper dst_mdw(pd()->dst_md()); + const memory_desc_wrapper msk_mdw(pd()->attn_mask_md()); + const memory_desc_wrapper diff_dst_mdw(pd()->diff_dst_md()); + const memory_desc_wrapper diff_qry_mdw(pd()->diff_qry_md()); + const memory_desc_wrapper diff_key_mdw(pd()->diff_key_md()); + const memory_desc_wrapper diff_val_mdw(pd()->diff_val_md()); + using offset_t = decltype(offsets_t().src_off); + + offset_t qry_off, key_off, val_off, dst_off, msk_off; + + set_offsets(qry_mdw, qry_off); + set_offsets(key_mdw, key_off); + set_offsets(val_mdw, val_off); + set_offsets(dst_mdw, dst_off); + set_offsets(msk_mdw, msk_off); + + auto append_offs + = [](compute::kernel_arg_list_t &arg_list, const offset_t &offs) { + compute::int64x4_t dims4 + = {offs[3][0], offs[3][1], offs[3][2], offs[3][3]}; + compute::int64x4_t strides4 + = {offs[1][0], offs[1][1], offs[1][2], offs[1][3]}; + arg_list.append(dims4); + arg_list.append(strides4); + }; + + int mask_type = static_cast(pd()->desc()->mask_type); + + const memory_desc_wrapper scale_mdw(pd()->scale_md()); + float scalar_scale = 1.f; + float inv_scalar_scale = 1.f; + if (pd()->with_host_scale()) { + auto scalar_storage = utils::downcast< + const dnnl::impl::host_scalar_memory_storage_t *>(&scale); + auto status = scalar_storage->get_scalar_value( + &scalar_scale, scale_mdw.data_type_size()); + assert(status == status::success); + if (status != status::success) return status; + scalar_scale = dnnl::impl::cpu::io::load_float_value( + pd()->scale_md()->data_type, &scalar_scale, 0); + inv_scalar_scale = 1. / scalar_scale; + } + + /// preprocess kernel + // will zero dQ, calculate Di + compute::range_t lws = {(size_t)pd()->sg_size(), (size_t)sg_per_wg, 1}; + compute::range_t gws_preprocess = lws; + + gws_preprocess[0] *= utils::div_up(Q, wg_tile_q); + gws_preprocess[1] *= pd()->dst_md()->dims[1]; + gws_preprocess[2] *= pd()->dst_md()->dims[0]; + + auto nd_range_preprocess = compute::nd_range_t(gws_preprocess, lws); + + compute::kernel_arg_list_t preprocess_arg_list; + preprocess_arg_list.append(*Di_scratch); + preprocess_arg_list.append(dst); + preprocess_arg_list.append(diff_dst); + preprocess_arg_list.append((int)D); + preprocess_arg_list.append((int)K); + preprocess_arg_list.append((int)Q); + + append_offs(preprocess_arg_list, qry_off); + append_offs(preprocess_arg_list, dst_off); + + status_t s = parallel_for( + ctx, nd_range_preprocess, preprocess_, preprocess_arg_list); + if (s != status::success) return s; + + // zero f32 intermediates before atomic adds in the main kernel + // dQ always needs atomics, dK/dV only for GQA cases + { + const dim_t num_kv_heads = pd()->dst_md()->dims[1] / kv_group_size; + const dim_t num_q_heads = pd()->dst_md()->dims[1]; + const int lws_zero = 256; + + auto dispatch_zero + = [&](const memory_storage_t &buf, dim_t count, + const offset_t &offs, dim_t num_heads) -> status_t { + compute::kernel_arg_list_t args; + args.append(buf); + args.append((int)count); + append_offs(args, offs); + + compute::range_t lws_z = {(size_t)lws_zero, 1, 1}; + compute::range_t gws_z = lws_z; + gws_z[0] *= utils::div_up(count, lws_zero); + gws_z[1] *= num_heads; + gws_z[2] *= pd()->dst_md()->dims[0]; + return parallel_for( + ctx, compute::nd_range_t(gws_z, lws_z), zero_, args); + }; + + // always zero dQ + auto &dQ_buf = needs_intermediate_dQ ? *diff_q_scratch : diff_q; + s = dispatch_zero(dQ_buf, Q * D, qry_off, num_q_heads); + if (s != status::success) return s; + + // zero dK/dV for GQA cases + if (needs_zero_dKV) { + auto &dK_buf = needs_intermediate_dKV ? *diff_k_scratch : diff_k; + auto &dV_buf = needs_intermediate_dKV ? *diff_v_scratch : diff_v; + + s = dispatch_zero(dK_buf, K * D, key_off, num_kv_heads); + if (s != status::success) return s; + s = dispatch_zero(dV_buf, K * D, val_off, num_kv_heads); + if (s != status::success) return s; + } + } + + /// backwards pass kernel, calculates dK, dV, dQ(float) + compute::kernel_arg_list_t arg_list; + arg_list.append(key); + arg_list.append(qry); + arg_list.append(val); + arg_list.append(ws); + arg_list.append(*Di_scratch); + arg_list.append(dst); + arg_list.append(diff_dst); + if (with_dS) arg_list.append(CTX_OUT_STORAGE(DNNL_ARG_DS)); + arg_list.append(needs_intermediate_dKV ? *diff_k_scratch : diff_k); + arg_list.append(needs_intermediate_dQ ? *diff_q_scratch : diff_q); + arg_list.append(needs_intermediate_dKV ? *diff_v_scratch : diff_v); + if (pd()->with_host_scale()) { + arg_list.append(scalar_scale); + arg_list.append(inv_scalar_scale); + } else { + arg_list.append(scale); + } + arg_list.append((int)D); + arg_list.append((int)K); + arg_list.append((int)Q); + arg_list.append(mask_type); + if (pd()->with_attn_mask()) arg_list.append(attn_mask); + + append_offs(arg_list, key_off); + append_offs(arg_list, qry_off); + append_offs(arg_list, val_off); + append_offs(arg_list, dst_off); + + if (pd()->with_attn_mask()) { append_offs(arg_list, msk_off); } + const int remainder_k = (K % wg_tile_k) != 0; + + auto *d = pd()->desc(); + const bool d_full = (d->head_size() == pd()->d_max()); + const int remainder_q = d_full && ((Q % wg_tile_q) != 0); + + arg_list.append(remainder_k); + arg_list.append(remainder_q); + + compute::range_t gws = lws; + + gws[0] *= utils::div_up(K, wg_tile_k); + gws[1] *= pd()->dst_md()->dims[1]; + gws[2] *= pd()->dst_md()->dims[0]; + auto nd_range = compute::nd_range_t(gws, lws); + + s = parallel_for(ctx, nd_range, kernel_, arg_list); + if (s != status::success) return s; + + /// postprocessing kernels + // will cast dQ/dK/dV to lower precision outputs if needed + if (needs_intermediate_dQ) { + const int lws_pp = 256; + compute::range_t lws_p = {(size_t)lws_pp, 1, 1}; + compute::range_t gws_p = lws_p; + gws_p[0] *= utils::div_up(Q * D, lws_pp); + gws_p[1] *= pd()->dst_md()->dims[1]; // Q heads + gws_p[2] *= pd()->dst_md()->dims[0]; + + compute::kernel_arg_list_t pp; + pp.append(diff_q); + pp.append(*diff_q_scratch); + pp.append((int)(Q * D)); + append_offs(pp, qry_off); + s = parallel_for( + ctx, compute::nd_range_t(gws_p, lws_p), postprocess_, pp); + if (s != status::success) return s; + } + + if (needs_intermediate_dKV) { + const dim_t num_kv_heads = pd()->dst_md()->dims[1] / kv_group_size; + const int lws_pp = 256; + compute::range_t lws_p = {(size_t)lws_pp, 1, 1}; + + // dK + { + compute::range_t gws_p = lws_p; + gws_p[0] *= utils::div_up(K * D, lws_pp); + gws_p[1] *= num_kv_heads; // KV heads + gws_p[2] *= pd()->dst_md()->dims[0]; + + compute::kernel_arg_list_t pp; + pp.append(diff_k); + pp.append(*diff_k_scratch); + pp.append((int)(K * D)); + append_offs(pp, key_off); + s = parallel_for( + ctx, compute::nd_range_t(gws_p, lws_p), postprocess_, pp); + if (s != status::success) return s; + } + // dV + { + compute::range_t gws_p = lws_p; + gws_p[0] *= utils::div_up(K * D, lws_pp); + gws_p[1] *= num_kv_heads; + gws_p[2] *= pd()->dst_md()->dims[0]; + + compute::kernel_arg_list_t pp; + pp.append(diff_v); + pp.append(*diff_v_scratch); + pp.append((int)(K * D)); + append_offs(pp, val_off); + s = parallel_for( + ctx, compute::nd_range_t(gws_p, lws_p), postprocess_, pp); + if (s != status::success) return s; + } + } + + return status::success; +} + } // namespace sdpa } // namespace intel } // namespace gpu diff --git a/src/gpu/intel/sdpa/micro.hpp b/src/gpu/intel/sdpa/micro.hpp index da0be1370b2..3c5936a6609 100644 --- a/src/gpu/intel/sdpa/micro.hpp +++ b/src/gpu/intel/sdpa/micro.hpp @@ -144,7 +144,8 @@ struct micro_bwd_params_t : trivially_serializable_t { bool remainder_q; bool use_systolic_ukernel; bool with_dS; - uint8_t padding2[3] = {0}; + bool require_stateless_addressing; + uint8_t padding2[2] = {0}; int prefetch_d_max; uint8_t padding3[4] = {0}; From 05eee81699ce0409f2a6542e463caac21d68bf10 Mon Sep 17 00:00:00 2001 From: syurkevi Date: Tue, 24 Feb 2026 23:37:48 -0800 Subject: [PATCH 09/23] xe: sdpa: calculates forward logusmexp to ws --- src/gpu/intel/sdpa/micro.cl | 64 +++++++++++++++++++++---------------- 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/src/gpu/intel/sdpa/micro.cl b/src/gpu/intel/sdpa/micro.cl index d53312cabed..23aafc16327 100644 --- a/src/gpu/intel/sdpa/micro.cl +++ b/src/gpu/intel/sdpa/micro.cl @@ -370,7 +370,7 @@ inline void tile_store_t_slm_src1(q_tile_type *Q_tile, local QRY_DATA_T *Q_slm, __attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, - const global VAL_DATA_T *V, global DST_DATA_T *A, + const global VAL_DATA_T *V, global float *ws, global DST_DATA_T *A, #if WITH_HOST_SCALE float scalar_scale, float inv_scalar_scale, #else @@ -541,7 +541,6 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, #endif #endif #endif - scale *= 1.442695f; // log2(e) } #if PREFETCH_K0 @@ -684,6 +683,7 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, ldkq #endif ); + #if KQ_F16_ACC s_tile_type_float S_tile; tile_copy_reblock(S_tile_f16, &S_tile); @@ -778,11 +778,10 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, /* cache */ LSC_LDCC_L1C_L3C); #endif #endif -#ifndef ALT_MAX + /* Read back WG-wide maxima */ intel_work_group_barrier_wait(CLK_LOCAL_MEM_FENCE); tile_load_full(&S_max_tile, S_max_slm, ugemm_kq_wg_tile_n, sg_j0_kq, 0); -#endif #if SOFTMAX_INF_AS_ZERO #define set_zeros(v) vselect(-FLT_MAX, v, visfinite(v)) @@ -792,39 +791,23 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, tile_vbroadcast_sub(&S_tile, S_max_tile); /* Scale + exponentiate */ -#define scaled_exp(x) native_vexp2(x *scale) +#define scaled_exp(x) native_vexp2(x *scale * 1.442695f) tile_elementwise(S_tile, scaled_exp); - -#ifdef ALT_MAX - /* Read back WG-wide maxima and adjust S to match */ - intel_work_group_barrier_wait(CLK_LOCAL_MEM_FENCE); - s_sum_tile_type S_max_tile1; - tile_copy(S_max_tile, S_max_tile1); - tile_load_full(&S_max_tile, S_max_slm, ugemm_kq_wg_tile_n, sg_j0_kq, 0); - -#define binary_exp_neg(x, y) native_vexp2(scale *((x) - (y))) - tile_binary(S_max_tile1, S_max_tile, binary_exp_neg); - tile_vbroadcast_mul(&S_tile, S_max_tile1); -#endif +#undef scaled_exp /* Accumulate sums. S tile is transposed for easy summation. */ s_sum_tile_type S_sum_tile1; tile_fill(S_sum_tile1, 0.0f); tile_vreduce_add(S_tile, &S_sum_tile1); -#if USE_SYSTOLIC_UKERNEL - /* Convert to half or bf16, VNNI format */ - s_tile_type_packed S_tile_packed; - tile_copy_to_vec2(S_tile, S_tile_packed, VEC_TYPE2); - - /* Store to SLM, in packed format */ - tile_store_t_sys_src2(S_tile_packed, (local uint *)S_slm, - ugemm_vs_sg_tile_n, ugemm_kq_wg_tile_m / 2, sg_i0_kq / 2, - sg_j0_kq); -#else /* Reblock and store to SLM */ s_tile_type_reblock S_tile_reblock; tile_copy_reblock(S_tile, &S_tile_reblock); + +#if USE_SYSTOLIC_UKERNEL + tile_store_t_sys_src2(S_tile_reblock, (local FMA_TYPE *)S_slm, + ugemm_vs_sg_tile_n, ugemm_kq_wg_tile_m, sg_i0_kq, sg_j0_kq); +#else tile_store_block_packed(S_tile_reblock, S_slm, ugemm_vs_sg_tile_n, ugemm_kq_wg_tile_m, sg_j0_kq, sg_i0_kq); #endif @@ -833,7 +816,7 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, /* Rescale existing accumulator and sums to match new maxima */ if (!first) { -#define binary_exp_sub(x, y) native_vexp2(scale *((x) - (y))) +#define binary_exp_sub(x, y) native_vexp2(scale * 1.442695f * ((x) - (y))) #define binary_mul(x, y) ((x) * (y)) tile_binary(S_max_tile_old, S_max_tile, binary_exp_sub); tile_binary(S_sum_tile, S_max_tile_old, binary_mul); @@ -998,6 +981,31 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, if (need_sum_barrier) intel_work_group_barrier_wait(CLK_LOCAL_MEM_FENCE); +#if IS_TRAINING + s_sum_tile_type S_sum_total, S_sum_load; + tile_fill(S_sum_total, 0.f); +#pragma unroll + for (uint sg1 = 0; sg1 < ugemm_kq_sg_per_wg_m; sg1++) { + tile_load_full(&S_sum_load, S_sum_slm, ugemm_kq_wg_tile_n, + ugemm_kq_sg_tile_n * sg_j_kq, sg1); + tile_binary(S_sum_total, S_sum_load, binary_add); + } + +#define log2(x) (native_vlog2(x) * 0.6931471805f) + tile_elementwise(S_sum_total, log2); +#define scale_op(x) ((x) * scale) + tile_elementwise(S_max_tile_old, scale_op); + tile_binary(S_max_tile_old, S_sum_total, binary_add); + + // save columns logsumexp to workspace for training pass + const uint preprocess_batch = b1 * (DST_D1 * q) + b0 * q; + + global float *ws_logsumexp = ws + preprocess_batch; + tile_store(S_max_tile_old, ws_logsumexp, q, 1, q, sg_j0_kq + wg_j0, + sg_i0_kq); + // sg_i0 specified to avoid OOB subgroups from aliasing +#endif + /* Load column sums from SLM + reduce in registers */ a_scale_tile_type A_scale_tile, A_scale_tile_load; tile_fill(A_scale_tile, 0.0f); From 1fbb9a99a3c7049a3cdeed6819ba99c04c917805 Mon Sep 17 00:00:00 2001 From: syurkevi Date: Tue, 24 Feb 2026 23:51:27 -0800 Subject: [PATCH 10/23] xe: sdpa: updates tile_ops packed load/stores --- src/gpu/intel/include/generic_vector_ops.h | 17 + src/gpu/intel/include/tile_ops.h | 435 ++++++++++++++++++++- src/gpu/intel/sdpa/micro.cpp | 3 +- src/gpu/intel/sdpa/utils.h | 2 + 4 files changed, 447 insertions(+), 10 deletions(-) diff --git a/src/gpu/intel/include/generic_vector_ops.h b/src/gpu/intel/include/generic_vector_ops.h index 2ed6bcb4d41..0621d4d6718 100644 --- a/src/gpu/intel/include/generic_vector_ops.h +++ b/src/gpu/intel/include/generic_vector_ops.h @@ -73,6 +73,23 @@ float16 __attribute__((overloadable)) native_vexp2(float16 x) { return native_exp2(x); } +float1 __attribute__((overloadable)) native_vlog2(float1 x) { + x[0] = native_log2(x[0]); + return x; +} +float2 __attribute__((overloadable)) native_vlog2(float2 x) { + return native_log2(x); +} +float4 __attribute__((overloadable)) native_vlog2(float4 x) { + return native_log2(x); +} +float8 __attribute__((overloadable)) native_vlog2(float8 x) { + return native_log2(x); +} +float16 __attribute__((overloadable)) native_vlog2(float16 x) { + return native_log2(x); +} + float1 __attribute__((overloadable)) vselect(float1 x, float1 y, int1 c) { x[0] = select(x[0], y[0], c[0]); return x; diff --git a/src/gpu/intel/include/tile_ops.h b/src/gpu/intel/include/tile_ops.h index d9b10de7277..0a65b05b746 100644 --- a/src/gpu/intel/include/tile_ops.h +++ b/src/gpu/intel/include/tile_ops.h @@ -21,20 +21,18 @@ #include "gpu/intel/include/types.h" float __builtin_IB_atomic_max_local_f32(__local float *, float); +float __builtin_IB_atomic_add_local_f32(__local float *, float); +float __builtin_IB_atomic_add_global_f32(__global float *, float); +half __builtin_IB_atomic_add_global_f16(__global half *, half); __attribute__((overloadable)) float local_atomic_max(local float *p, float v) { return __builtin_IB_atomic_max_local_f32(p, v); } -__attribute__((overloadable)) half local_atomic_max( - local half *p, half v) { /* not implemented */ - return v; -} - +/* not implemented */ +__attribute__((overloadable)) half local_atomic_max(local half *p, half v); __attribute__((overloadable)) ushort local_atomic_max( - local ushort *p, ushort v) { /* not implemented */ - return v; -} + local ushort *p, ushort v); __attribute__((overloadable)) uint local_atomic_max(local uint *p, uint v) { return atomic_max(p, v); @@ -44,6 +42,44 @@ __attribute__((overloadable)) int local_atomic_max(local int *p, int v) { return atomic_max(p, v); } +__attribute__((overloadable)) float local_atomic_add(local float *p, float v) { + return __builtin_IB_atomic_add_local_f32(p, v); +} + +/* not implemented */ +__attribute__((overloadable)) half local_atomic_add(local half *p, half v); +__attribute__((overloadable)) ushort local_atomic_add( + local ushort *p, ushort v); + +__attribute__((overloadable)) uint local_atomic_add(local uint *p, uint v) { + return atomic_add(p, v); +} + +__attribute__((overloadable)) int local_atomic_add(local int *p, int v) { + return atomic_add(p, v); +} + +__attribute__((overloadable)) float global_atomic_add( + global float *p, float v) { + return __builtin_IB_atomic_add_global_f32(p, v); +} + +__attribute__((overloadable)) half global_atomic_add(global half *p, half v) { + return __builtin_IB_atomic_add_global_f16(p, v); +} + +/* not implemented */ +__attribute__((overloadable)) ushort global_atomic_add( + global ushort *p, ushort v); + +__attribute__((overloadable)) uint global_atomic_add(global uint *p, uint v) { + return atomic_add(p, v); +} + +__attribute__((overloadable)) int global_atomic_add(global int *p, int v) { + return atomic_add(p, v); +} + #define DEF_BLOCK_LOAD_STORE(type, itype, suffix, n) \ __attribute__((overloadable)) type##n block_load( \ const global type *p, int vlen) \ @@ -322,6 +358,21 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) } \ } while (0) +#define tile_predicated_assignment( \ + t, sg_offset_r, sg_offset_c, predicate, value, sg, br, bc, nbr, nbc) \ + do { \ + for (int j = 0; j < (bc * nbc); j++) { \ + for (int i0 = 0; i0 < (br * nbr); i0 += sg) { \ + int i = i0 + get_sub_group_local_id(); \ + int offset_r = sg_offset_r + i; \ + int offset_c = sg_offset_c + j; \ + if (predicate(offset_r, offset_c)) { \ + tile_access(t, i0, j, sg, br, bc, nbr) = value; \ + } \ + } \ + } \ + } while (0) + #define DECLARE_2D_TILE_OPS(tile_type, element_type, sg, br, bc, nbr, nbc) \ __attribute__((overloadable)) void tile_load_full(tile_type *t, \ const global element_type *ptr, int ld, int offset_r, \ @@ -345,6 +396,24 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) } \ } \ } \ + __attribute__((overloadable)) void tile_load(tile_type *t, \ + const local element_type *ptr, int m, int n, int ld, int offset_r, \ + int offset_c) { \ + if (m >= offset_r + br * nbr && n >= offset_c + bc * nbc) { \ + tile_load_full(t, ptr, ld, offset_r, offset_c); \ + return; \ + } \ + ptr += ld * offset_c + offset_r; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += ld) { \ + if (offset_c + j < n) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = i0 + get_sub_group_local_id(); \ + if (offset_r + i < m) \ + tile_access(*t, i0, j, sg, br, bc, nbr) = ptr[i]; \ + } \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_load(tile_type *t, \ const global element_type *ptr, int m, int n, int ld, \ int offset_r, int offset_c) { \ @@ -368,6 +437,38 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) int offset_c) { \ tile_load(t, ptr, m, n, m, offset_r, offset_c); \ } \ + __attribute__((overloadable)) void tile_load_t_full(tile_type *t, \ + const local element_type *ptr, int ld, int offset_r, \ + int offset_c) { \ + ptr += ld * offset_r + offset_c; \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; \ + i0 += sg, ptr += ld * sg) { \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + tile_access(*t, i0, j, sg, br, bc, nbr) \ + = ptr[get_sub_group_local_id() * ld + j]; \ + } \ + } \ + } \ + __attribute__((overloadable)) void tile_load_t(tile_type *t, \ + const local element_type *ptr, int m, int n, int ld, int offset_r, \ + int offset_c) { \ + if (m >= offset_r + br * nbr && n >= offset_c + bc * nbc) { \ + tile_load_t_full(t, ptr, ld, offset_r, offset_c); \ + return; \ + } \ + ptr += ld * offset_r + offset_c; \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; \ + i0 += sg, ptr += ld * sg) { \ + int i = i0 + get_sub_group_local_id(); \ + if (offset_r + i < m) \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + if (offset_c + j < n) { \ + tile_access(*t, i0, j, sg, br, bc, nbr) \ + = ptr[get_sub_group_local_id() * ld + j]; \ + } \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_load_t_full(tile_type *t, \ const global element_type *ptr, int ld, int offset_r, \ int offset_c) { \ @@ -405,6 +506,16 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) int offset_c) { \ tile_load_t(t, ptr, m, n, n, offset_r, offset_c); \ } \ + __attribute__((overloadable)) void tile_store_t_full(tile_type t, \ + local element_type *ptr, int ld, int offset_r, int offset_c) { \ + ptr += ld * offset_r + offset_c; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = ld * (i0 + get_sub_group_local_id()); \ + ptr[i] = tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_store_full(tile_type t, \ local element_type *ptr, int ld, int offset_r, int offset_c) { \ ptr += ld * offset_c + offset_r; \ @@ -415,6 +526,24 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) } \ } \ } \ + __attribute__((overloadable)) void tile_store(tile_type t, \ + local element_type *ptr, int m, int n, int ld, int offset_r, \ + int offset_c) { \ + if (m >= offset_r + br * nbr && n >= offset_c + bc * nbc) { \ + tile_store_full(t, ptr, ld, offset_r, offset_c); \ + return; \ + } \ + ptr += ld * offset_c + offset_r; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += ld) { \ + if (offset_c + j < n) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = i0 + get_sub_group_local_id(); \ + if (offset_r + i < m) \ + ptr[i] = tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_store_full(tile_type t, \ global element_type *ptr, int ld, int offset_r, int offset_c) { \ ptr += ld * offset_c + offset_r; \ @@ -448,6 +577,46 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) int offset_c) { \ tile_store(t, ptr, m, n, m, offset_r, offset_c); \ } \ + __attribute__((overloadable)) void tile_load_t_packed_src1(tile_type *t, \ + local element_type *ptr, int panel, int ld, int offset_r, \ + int offset_c) { \ + offset_c += get_sub_group_local_id(); \ + int offset_r0 = offset_r % panel; \ + int offset_r1 = offset_r - offset_r0; \ + ptr += offset_r0 + panel * offset_c + ld * offset_r1; \ + _Pragma("unroll") for (int j0 = 0; j0 < br * nbr; \ + j0 += sg, ptr += sg * panel) { \ + _Pragma("unroll") for (int i = 0; i < bc * nbc; i++) \ + tile_access(*(t), j0, i, sg, br, bc, nbr) \ + = ptr[i]; \ + } \ + } \ + __attribute__((overloadable)) void tile_load_packed_src1(tile_type *t, \ + local element_type *ptr, int panel, int ld, int offset_r, \ + int offset_c) { \ + ptr += offset_c * panel; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += panel) \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = i0 + get_sub_group_local_id(); \ + int offset_r0 = (offset_r + i) % panel; \ + int offset_r1 = (offset_r + i) - offset_r0; \ + tile_access(*(t), i0, j, sg, br, bc, nbr) \ + = ptr[offset_r0 + offset_r1 * ld]; \ + } \ + } \ + __attribute__((overloadable)) void tile_store_packed_src1(tile_type t, \ + local element_type *ptr, int panel, int ld, int offset_r, \ + int offset_c) { \ + ptr += offset_c * panel; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += panel) \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = i0 + get_sub_group_local_id(); \ + int offset_r0 = (offset_r + i) % panel; \ + int offset_r1 = (offset_r + i) - offset_r0; \ + ptr[offset_r0 + offset_r1 * ld] \ + = tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ __attribute__((overloadable)) void tile_store_t_packed_src1(tile_type t, \ local element_type *ptr, int panel, int ld, int offset_r, \ int offset_c) { \ @@ -462,6 +631,82 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) = tile_access(t, j0, i, sg, br, bc, nbr); \ } \ } \ + __attribute__((overloadable)) void tile_store_sys_src1(tile_type t, \ + local element_type *ptr, int tileR, int tileC, int wg_tile_m, \ + int wg_tile_n, int offset_r, int offset_c) { \ + const int crosspack = 2; \ + const int tile_panel_size = tileR * tileC; \ + const int num_row_panels = wg_tile_m / tileR; \ + const int num_col_panels = wg_tile_n / tileC; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + const int in_r = offset_r + i0 + get_sub_group_local_id(); \ + const int in_c = offset_c + j; \ + /* Compute 2D panel grid position: */ \ + const int row_panel = in_r \ + / tileR; /* Which vertical panel (every sg rows) */ \ + const int col_panel = in_c \ + / tileC; /* Which horizontal panel (every tile_n columns) */ \ + const int panel_base \ + = (col_panel * num_row_panels + row_panel) \ + * tile_panel_size; \ + /*const int panel_base = (row_panel * num_col_panels + col_panel) * tile_panel_size;*/ \ + /* Within-panel offsets using crosspack layout: */ \ + const int in_panel_row = in_r \ + & (tileR - 1); /* Row within panel (in_r % sg) */ \ + const int in_panel_col = in_c \ + & (tileC \ + - 1); /* Column within panel (in_c % tile_n) */ \ + const int col_group_offset = (in_panel_col >> 1) \ + * (crosspack * tileR); /* Column pair group */ \ + const int sg_lane_offset = in_panel_row \ + * crosspack; /* Subgroup lane position */ \ + const int crosspack_offset = (in_panel_col \ + & 1); /* Position within column pair */ \ + const int out_idx = panel_base + col_group_offset \ + + sg_lane_offset + crosspack_offset; \ + ptr[out_idx] = tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } \ + __attribute__((overloadable)) void tile_load_sys_src1(tile_type *t, \ + local element_type *ptr, int tileR, int tileC, int wg_tile_m, \ + int wg_tile_n, int offset_r, int offset_c) { \ + const int crosspack = 2; \ + const int tile_panel_size = tileR * tileC; \ + const int num_row_panels = wg_tile_m / tileR; \ + const int num_col_panels = wg_tile_n / tileC; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + const int in_r = offset_r + i0 + get_sub_group_local_id(); \ + const int in_c = offset_c + j; \ + /* Compute 2D panel grid position: */ \ + const int row_panel = in_r \ + / tileR; /* Which vertical panel (every sg rows) */ \ + const int col_panel = in_c \ + / tileC; /* Which horizontal panel (every tile_n columns) */ \ + const int panel_base \ + = (col_panel * num_row_panels + row_panel) \ + * tile_panel_size; \ + /*const int panel_base = (row_panel * num_col_panels + col_panel) * tile_panel_size;*/ \ + /* Within-panel offsets using crosspack layout: */ \ + const int in_panel_row = in_r \ + & (tileR - 1); /* Row within panel (in_r % sg) */ \ + const int in_panel_col = in_c \ + & (tileC \ + - 1); /* Column within panel (in_c % tile_n) */ \ + const int col_group_offset = (in_panel_col >> 1) \ + * (crosspack * tileR); /* Column pair group */ \ + const int sg_lane_offset = in_panel_row \ + * crosspack; /* Subgroup lane position */ \ + const int crosspack_offset = (in_panel_col \ + & 1); /* Position within column pair */ \ + const int out_idx = panel_base + col_group_offset \ + + sg_lane_offset + crosspack_offset; \ + tile_access(*t, i0, j, sg, br, bc, nbr) = ptr[out_idx]; \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_store_t_sys_src1(tile_type t, \ local element_type *ptr, int ld, int offset_r, int offset_c) { \ offset_c += get_sub_group_local_id(); \ @@ -474,6 +719,91 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) = tile_access(t, j0, i, sg, br, bc, nbr); \ } \ } \ + __attribute__((overloadable)) void tile_store_t_sys_src11(tile_type t, \ + local element_type *ptr, int tileR, int tileC, int wg_tile_m, \ + int wg_tile_n, int offset_r, int offset_c) { \ + const int crosspack = 2; \ + const int tile_panel_size = tileR * tileC; \ + const int num_row_panels \ + = wg_tile_m / tileR; /* is correct when _t? */ \ + const int num_col_panels = wg_tile_n / tileC; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + const int in_r = offset_r + i0 + get_sub_group_local_id(); \ + const int in_c = offset_c + j; \ + /* Compute 2D panel grid position: */ \ + const int row_panel = in_c / tileR; \ + const int col_panel = in_r / tileC; \ + const int panel_base \ + = (col_panel * num_row_panels + row_panel) \ + * tile_panel_size; \ + /* Within-panel offsets using crosspack layout: */ \ + const int in_panel_row = in_c \ + & (tileR - 1); /* Row within panel (in_c % tileR) */ \ + const int in_panel_col = in_r \ + & (tileC \ + - 1); /* Column within panel (in_r % tileC) */ \ + const int col_group_offset = (in_panel_col >> 1) \ + * (crosspack * tileR); /* Column pair group */ \ + const int sg_lane_offset = in_panel_row \ + * crosspack; /* Subgroup lane position */ \ + const int crosspack_offset = (in_panel_col \ + & 1); /* Position within column pair */ \ + const int out_idx = panel_base + col_group_offset \ + + sg_lane_offset + crosspack_offset; \ + ptr[out_idx] = tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } \ + __attribute__((overloadable)) void tile_store_sys_src22(tile_type t, \ + local element_type *ptr, int panel_n, int wg_tile_m, \ + int wg_tile_n, int offset_r, int offset_c) { \ + const int crosspack = 16; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + const int in_r = offset_r + i0 + get_sub_group_local_id(); \ + const int in_c = offset_c + j; \ + /* Panel-based addressing: panel_n cols per panel */ \ + const int col_panel = in_c / panel_n; \ + const int in_panel_c = in_c & (panel_n - 1); \ + /* Within-panel offsets using crosspack layout: */ \ + const int col_group_offset = (in_r / crosspack) \ + * (crosspack * panel_n); /* Column pair group */ \ + const int sg_lane_offset \ + = in_panel_c * crosspack; /* Subgroup lane position */ \ + const int crosspack_offset = (in_r & (crosspack - 1)); \ + const int out_idx = col_panel * (panel_n * wg_tile_m) \ + + col_group_offset + sg_lane_offset \ + + crosspack_offset; \ + ptr[out_idx] = tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } \ + __attribute__((overloadable)) void tile_store_t_sys_src22(tile_type t, \ + local element_type *ptr, int panel_n, int wg_tile_m, \ + int wg_tile_n, int offset_r, int offset_c) { \ + const int crosspack = 16; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + const int in_r = offset_r + i0 + get_sub_group_local_id(); \ + const int in_c = offset_c + j; \ + /* Panel-based addressing: panel_n cols per panel */ \ + const int col_panel = in_r / panel_n; \ + const int in_panel_c = in_r & (panel_n - 1); \ + /* Within-panel offsets using crosspack layout: */ \ + const int col_group_offset = (in_c / crosspack) \ + * (crosspack * panel_n); /* Column pair group */ \ + const int sg_lane_offset \ + = in_panel_c * crosspack; /* Subgroup lane position */ \ + const int crosspack_offset = (in_c \ + & (crosspack - 1)); /* Position within column pair */ \ + const int out_idx = col_panel * (panel_n * wg_tile_n) \ + + col_group_offset + sg_lane_offset \ + + crosspack_offset; \ + ptr[out_idx] = tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_store_t_sys_src2(tile_type t, \ local element_type *ptr, int tile_n, int ld, int offset_r, \ int offset_c) { \ @@ -494,6 +824,26 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) } \ } \ } \ + __attribute__((overloadable)) void tile_load_t_sys_src2(tile_type *t, \ + local element_type *ptr, int tile_n, int ld, int offset_r, \ + int offset_c) { \ + const int cp = 32 / sizeof(element_type); \ + offset_c += get_sub_group_local_id(); \ + int offset_r0 = offset_r & (cp - 1); \ + int offset_r1 = offset_r & ~(cp - 1); \ + ptr += offset_r0 + tile_n * offset_r1; \ + _Pragma("unroll") for (int j0 = 0; j0 < br * nbr; \ + j0 += sg, offset_c += sg) { \ + int offset_c0 = offset_c & (tile_n - 1); \ + int offset_c1 = offset_c & ~(tile_n - 1); \ + local element_type *ptr_j = ptr + cp * offset_c0 + ld * offset_c1; \ + _Pragma("unroll") for (int i = 0; i < bc * nbc; i++) { \ + tile_access(*t, j0, i, sg, br, bc, nbr) = *ptr_j; \ + ptr_j++; \ + if ((~i & (cp - 1)) == 0) ptr_j += cp * (tile_n - 1); \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_atomic_max_full(tile_type t, \ local element_type *ptr, int ld, int offset_r, int offset_c) { \ ptr += ld * offset_c + offset_r; \ @@ -504,6 +854,47 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) ptr + i, tile_access(t, i0, j, sg, br, bc, nbr)); \ } \ } \ + } \ + __attribute__((overloadable)) void tile_atomic_add_full(tile_type t, \ + local element_type *ptr, int ld, int offset_r, int offset_c) { \ + ptr += ld * offset_c + offset_r; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += ld) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = i0 + get_sub_group_local_id(); \ + (void)local_atomic_add( \ + ptr + i, tile_access(t, i0, j, sg, br, bc, nbr)); \ + } \ + } \ + } \ + __attribute__((overloadable)) void tile_atomic_add_full(tile_type t, \ + global element_type *ptr, int ld, int offset_r, int offset_c) { \ + ptr += ld * offset_c + offset_r; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += ld) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = i0 + get_sub_group_local_id(); \ + (void)global_atomic_add( \ + ptr + i, tile_access(t, i0, j, sg, br, bc, nbr)); \ + } \ + } \ + } \ + __attribute__((overloadable)) void tile_atomic_add(tile_type t, \ + global element_type *ptr, int m, int n, int ld, int offset_r, \ + int offset_c) { \ + if (m >= (offset_r + (br * nbr)) && n >= (offset_c + (bc * nbc))) { \ + tile_atomic_add_full(t, ptr, ld, offset_r, offset_c); \ + return; \ + } \ + ptr += ld * offset_c + offset_r; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += ld) { \ + if (offset_c + j < n) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = i0 + get_sub_group_local_id(); \ + if (offset_r + i < m) \ + (void)global_atomic_add(ptr + i, \ + tile_access(t, i0, j, sg, br, bc, nbr)); \ + } \ + } \ + } \ } #define DECLARE_2D_TILE_VREDUCE(tile_type, sg, br, bc, nbr, nbc, rtile_type, \ @@ -553,10 +944,29 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) *= tile_access(tr, i0, 0, rsg, rbr, rbc, rnbr); \ } \ } \ + } \ + __attribute__((overloadable)) void tile_vbroadcast_min( \ + tile_type *t, rtile_type tr) { \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + tile_access(*t, i0, j, sg, br, bc, nbr) \ + = min(tile_access(*t, i0, j, sg, br, bc, nbr), \ + tile_access(tr, i0, 0, rsg, rbr, rbc, rnbr)); \ + } \ + } \ } #define DECLARE_2D_TILE_HREDUCE(tile_type, sg, br, bc, nbr, nbc, rtile_type, \ rsg, rbr, rbc, rnbr, rnbc) \ + __attribute__((overloadable)) void tile_hreduce_add( \ + tile_type t, rtile_type *tr) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + tile_access(*tr, i0, j, rsg, rbr, rbc, rnbr) \ + += tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_hbroadcast_add( \ tile_type *t, rtile_type tr) { \ _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ @@ -566,6 +976,15 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) } \ } \ } \ + __attribute__((overloadable)) void tile_hbroadcast_sub( \ + tile_type *t, rtile_type tr) { \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + tile_access(*t, i0, j, sg, br, bc, nbr) \ + -= xlane_tile_access(tr, j, 0, rsg, rbr, rbc, rnbr); \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_hbroadcast_mul( \ tile_type *t, rtile_type tr) { \ _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ diff --git a/src/gpu/intel/sdpa/micro.cpp b/src/gpu/intel/sdpa/micro.cpp index ca5d550b0c3..8530314ba83 100644 --- a/src/gpu/intel/sdpa/micro.cpp +++ b/src/gpu/intel/sdpa/micro.cpp @@ -1632,14 +1632,13 @@ status_t micro_bwd_t::execute_backward(const exec_ctx_t &ctx) const { auto wg_tile_k = config.unroll_m_BcBr * config.wg_m_BcBr; auto wg_tile_q = config.unroll_n_BcBr * config.wg_n_BcBr; - auto sg_per_wg = config.wg_m_BcBr * config.wg_n_BcBr; auto sg_per_wg_BcBr = config.wg_m_BcBr * config.wg_n_BcBr; auto sg_per_wg_DBc = config.wg_m_DBc * config.wg_n_DBc; auto sg_per_wg_DBr = config.wg_m_DBr * config.wg_n_DBr; using std::max; - sg_per_wg = max(max(sg_per_wg_BcBr, sg_per_wg_DBc), sg_per_wg_DBr); + auto sg_per_wg = max(max(sg_per_wg_BcBr, sg_per_wg_DBc), sg_per_wg_DBr); const memory_desc_wrapper qry_mdw(pd()->qry_md()); const memory_desc_wrapper key_mdw(pd()->key_md()); diff --git a/src/gpu/intel/sdpa/utils.h b/src/gpu/intel/sdpa/utils.h index d5b41ab295e..4514f30691f 100644 --- a/src/gpu/intel/sdpa/utils.h +++ b/src/gpu/intel/sdpa/utils.h @@ -56,6 +56,8 @@ #define QRY_S2 QRY_S.array[2] #define VAL_S2 VAL_S.array[2] #define DST_S2 DST_S.array[2] +#define DST_D0 DST_D.array[0] +#define DST_D1 DST_D.array[1] #define MSK_D0 MSK_D.array[0] #define MSK_D1 MSK_D.array[1] #define MSK_S2 MSK_S.array[2] From 92994388a9b86bcd37f9bf6e13e5455e5ddf7802 Mon Sep 17 00:00:00 2001 From: syurkevi Date: Wed, 25 Feb 2026 00:45:03 -0800 Subject: [PATCH 11/23] xe: sdpa: adds bwd kernel implementation --- src/gpu/intel/sdpa/configs.cpp | 26 +- src/gpu/intel/sdpa/micro_bwd.cl | 876 +++++++++++++++++++++++++++ tests/gtests/internals/test_sdpa.cpp | 65 +- 3 files changed, 919 insertions(+), 48 deletions(-) create mode 100644 src/gpu/intel/sdpa/micro_bwd.cl diff --git a/src/gpu/intel/sdpa/configs.cpp b/src/gpu/intel/sdpa/configs.cpp index 9e0571997aa..4639f5392ce 100644 --- a/src/gpu/intel/sdpa/configs.cpp +++ b/src/gpu/intel/sdpa/configs.cpp @@ -696,15 +696,27 @@ static std::vector sorted_bwd_configs = []() { // clang-format off std::vector configs = { // xe_hpc - {{compute::gpu_arch_t::xe_hpc, 32}, { 16, 32, 16, 16, 16, 16, 4, 8, 2, 4, 2, 16 }}, - {{compute::gpu_arch_t::xe_hpc, 64}, { 16, 32, 16, 16, 32, 32, 4, 8, 4, 4, 2, 8 }}, + {{compute::gpu_arch_t::xe_hpc, 32}, { 64, 32, 16, 16, 16, 32, 2, 8, 2, 8, 2, 8 }}, + {{compute::gpu_arch_t::xe_hpc, 64}, { 32, 16, 16, 16, 32, 32, 1, 8, 4, 2, 2, 4 }}, {{compute::gpu_arch_t::xe_hpc, 128}, { 16, 32, 16, 16, 32, 32, 4, 8, 8, 4, 4, 8 }}, + //{{compute::gpu_arch_t::xe_hpc, 256}, { 16, 32, 16, 16, 32, 32, 4, 8, 8, 4, 4, 8 }}, - /* xe2 todo: - {{compute::gpu_arch_t::xe2, 64}, {16, 64, 64, 16, 64, 16, 4, 1, 1, 4, 1, 4}}, - {{compute::gpu_arch_t::xe2, 128}, {16, 64, 64, 16, 64, 16, 4, 2, 2, 4, 2, 4}}, - {{compute::gpu_arch_t::xe2, 256}, {16, 64, 64, 16, 64, 16, 4, 2, 4, 4, 4, 4}}, - */ + {{compute::gpu_arch_t::xe_hpc, 32, second_token}, { 16, 16, 16, 16, 32, 16, 1, 2, 2, 1, 1, 2 }}, + {{compute::gpu_arch_t::xe_hpc, 64, second_token}, { 32, 16, 16, 32, 32, 32, 1, 4, 4, 1, 2, 2 }}, + {{compute::gpu_arch_t::xe_hpc, 128, second_token}, { 16, 16, 16, 16, 32, 32, 2, 8, 8, 2, 4, 4 }}, + //{{compute::gpu_arch_t::xe_hpc, 256, second_token}, { 16, 16, 16, 16, 32, 32, 2, 8, 8, 2, 4, 4 }}, + + {{compute::gpu_arch_t::xe_hpc, 32, f32 | fma}, { 32, 32, 16, 16, 32, 32, 1, 4, 2, 2, 1, 4 }}, + {{compute::gpu_arch_t::xe_hpc, 64, f32 | fma}, { 16, 32, 16, 16, 16, 32, 4, 4, 4, 4, 4, 4 }}, + {{compute::gpu_arch_t::xe_hpc, 128, f32 | fma}, { 16, 16, 16, 32, 32, 32, 2, 4, 8, 1, 4, 2 }}, + + {{compute::gpu_arch_t::xe2, 32, integrated}, { 16, 64, 16, 16, 32, 32, 2, 2, 2, 2, 1, 4 }}, + {{compute::gpu_arch_t::xe2, 64, integrated}, { 16, 32, 16, 16, 32, 32, 2, 4, 4, 2, 2, 4 }}, + {{compute::gpu_arch_t::xe2, 128, integrated}, { 16, 32, 16, 16, 32, 32, 4, 8, 8, 4, 4, 8 }}, + + {{compute::gpu_arch_t::xe2, 32}, { 16, 64, 16, 16, 32, 32, 2, 2, 2, 2, 1, 4 }}, + {{compute::gpu_arch_t::xe2, 64}, { 16, 32, 16, 16, 32, 32, 2, 4, 4, 2, 2, 4 }}, + {{compute::gpu_arch_t::xe2, 128}, { 16, 32, 16, 16, 32, 32, 4, 8, 8, 4, 4, 8 }}, }; // clang-format on diff --git a/src/gpu/intel/sdpa/micro_bwd.cl b/src/gpu/intel/sdpa/micro_bwd.cl new file mode 100644 index 00000000000..fbf019d0c7c --- /dev/null +++ b/src/gpu/intel/sdpa/micro_bwd.cl @@ -0,0 +1,876 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "gpu/intel/include/tile_ops.h" +#include "gpu/intel/include/types_interop.h" +#include "gpu/intel/sdpa/utils.h" + +/* Microkernel headers -- generated at runtime */ +#include "gemm_kq.h" +#include "gemm_ktq.h" +#include "gemm_qdSt.h" +#include "gemm_vs.h" +#include "gemm_vtdA.h" + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define DIV_UP(x, y) (((x) + (y) - 1) / (y)) + +#define sg_per_wg_BcBr \ + (ugemm_kq_sg_per_wg_m * ugemm_kq_sg_per_wg_n) // same for kq, vtdA +#define sg_per_wg_BcD \ + (ugemm_vs_sg_per_wg_m * ugemm_vs_sg_per_wg_n) // same for qdSt and vs +#define sg_per_wg_BrD (ugemm_ktq_sg_per_wg_m * ugemm_ktq_sg_per_wg_n) +#define sg_per_wg MAX(sg_per_wg_BcBr, MAX(sg_per_wg_BcD, sg_per_wg_BrD)) + +#define q_tile_sg_n DIV_UP(ugemm_kq_wg_tile_n, sg_per_wg) +#define dmax_tile_sg_n DIV_UP(D_MAX, sg_per_wg) + +/* Instantiate tile types and operations */ +typedef ugemm_kq_c_type s_tile_type; // Bc*Br tile +typedef ugemm_qdSt_c_type a_tile_type; // Bc*D tile +typedef ugemm_vtdA_c_type p_tile_type; // Br*Bc tile (.T) +typedef ugemm_vs_c_type dv_tile_type; // D*Bc tile +typedef ugemm_ktq_c_type ktq_tile_type; // D*Br tile + +#ifdef QRY_DT_F32 +#define FMA_TYPE float +#elif QRY_DT_F16 +#define VEC_TYPE2 half2 +#define FMA_TYPE half +#elif defined(QRY_DT_BF16) +#define VEC_TYPE2 ushort2 +#define FMA_TYPE ushort +#else +#error "Data type not supported for VEC_TYPE2" +#endif + +#ifdef SCALE_DT_BF16 +#define SCALES_TO_FLOAT cvt_bf16_to_f32 +#else +#define SCALES_TO_FLOAT convert_float +#endif + +DECLARE_2D_TILE(q_tile_type, FMA_TYPE, SUBGROUP_SIZE, D_MAX, 1, 1, q_tile_sg_n) + +DECLARE_2D_TILE(dq_tile_type, float, SUBGROUP_SIZE, D_MAX, 1, 1, q_tile_sg_n) +DECLARE_2D_TILE_BLOCK_OPS( + dq_tile_type, float, SUBGROUP_SIZE, D_MAX, 1, 1, q_tile_sg_n) +DECLARE_2D_TILE_COPY_REBLOCK(q_tile_type, SUBGROUP_SIZE, D_MAX, 1, 1, + q_tile_sg_n, dq_tile_type, SUBGROUP_SIZE, D_MAX, 1, 1, q_tile_sg_n, + CONVERT_FLOAT_T) + +DECLARE_2D_TILE(k_tile_type, FMA_TYPE, SUBGROUP_SIZE, ugemm_kq_wg_tile_m, 1, 1, + dmax_tile_sg_n) +#if BLOCK_K +DECLARE_2D_TILE_BLOCK_OPS(k_tile_type, FMA_TYPE, SUBGROUP_SIZE, + ugemm_kq_wg_tile_m, 1, 1, dmax_tile_sg_n) +#endif + +DECLARE_2D_TILE(s_tile_type_packed, uint, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1 / 2, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1) +DECLARE_2D_TILE(s_tile_type_packed_t, uint, SUBGROUP_SIZE, + ugemm_kq_c_type_block1, ugemm_kq_c_type_block0 / 2, + ugemm_kq_c_type_nblock1, ugemm_kq_c_type_nblock0) + +DECLARE_2D_TILE(s_tile_type_reblock, FMA_TYPE, SUBGROUP_SIZE, + ugemm_kq_sg_tile_m, 1, 1, ugemm_kq_sg_tile_n) +DECLARE_2D_TILE_BLOCK_OPS(s_tile_type_reblock, FMA_TYPE, SUBGROUP_SIZE, + ugemm_kq_sg_tile_m, 1, 1, ugemm_kq_sg_tile_n) + +DECLARE_2D_TILE(p_tile_type_reblock, FMA_TYPE, SUBGROUP_SIZE, + ugemm_vtdA_c_type_block0, 1, ugemm_vtdA_c_type_nblock0, + ugemm_vtdA_c_type_block1 *ugemm_vtdA_c_type_nblock1) +DECLARE_2D_TILE_BLOCK_OPS(p_tile_type_reblock, FMA_TYPE, SUBGROUP_SIZE, + ugemm_vtdA_c_type_block0, 1, ugemm_vtdA_c_type_nblock0, + ugemm_vtdA_c_type_block1 *ugemm_vtdA_c_type_nblock1) + +DECLARE_2D_TILE( + s_sum_tile_type, float, SUBGROUP_SIZE, ugemm_kq_sg_tile_n, 1, 1, 1) +DECLARE_2D_TILE( + p_sum_tile_type, float, SUBGROUP_SIZE, ugemm_vtdA_sg_tile_n, 1, 1, 1) + +#if BROADCAST_MASK_Q +#define mask_br ugemm_kq_sg_tile_m +#define mask_bc 1 +#define mask_nbr 1 +#define mask_nbc 1 +#else +#define mask_br ugemm_kq_c_type_block0 +#define mask_bc ugemm_kq_c_type_block1 +#define mask_nbr ugemm_kq_c_type_nblock0 +#define mask_nbc ugemm_kq_c_type_nblock1 +#endif + +DECLARE_2D_TILE(qmask_tile_type_float, float, SUBGROUP_SIZE, ugemm_kq_sg_tile_n, + 1, 1, 1) +DECLARE_2D_TILE(kmask_tile_type_float, float, SUBGROUP_SIZE, ugemm_kq_sg_tile_m, + 1, 1, 1) + +#if WITH_ATTN_MASK +DECLARE_2D_TILE(mask_tile_type, MSK_DATA_T, SUBGROUP_SIZE, mask_br, mask_bc, + mask_nbr, mask_nbc) + +#if BROADCAST_MASK_Q +DECLARE_2D_TILE_BLOCK_OPS(mask_tile_type, MSK_DATA_T, SUBGROUP_SIZE, mask_br, + mask_bc, mask_nbr, mask_nbc) +#endif +DECLARE_2D_TILE(mask_tile_type_float, float, SUBGROUP_SIZE, mask_br, mask_bc, + mask_nbr, mask_nbc) +DECLARE_2D_TILE_COPY_REBLOCK(mask_tile_type, SUBGROUP_SIZE, mask_br, mask_bc, + mask_nbr, mask_nbc, mask_tile_type_float, SUBGROUP_SIZE, mask_br, + mask_bc, mask_nbr, mask_nbc, CONVERT_FLOAT_T) +#endif + +DECLARE_2D_TILE(a_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, + ugemm_qdSt_sg_tile_m, 1, 1, ugemm_qdSt_sg_tile_n) +DECLARE_2D_TILE_BLOCK_OPS(a_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, + ugemm_qdSt_sg_tile_m, 1, 1, ugemm_qdSt_sg_tile_n) +DECLARE_2D_TILE_COPY_REBLOCK(a_tile_type, SUBGROUP_SIZE, + ugemm_qdSt_c_type_block0, ugemm_qdSt_c_type_block1, + ugemm_qdSt_c_type_nblock0, ugemm_qdSt_c_type_nblock1, a_tile_type_dst, + SUBGROUP_SIZE, ugemm_qdSt_sg_tile_m, 1, 1, ugemm_qdSt_sg_tile_n, + CONVERT_DATA_T) + +DECLARE_2D_TILE(dv_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, ugemm_vs_sg_tile_m, + 1, 1, ugemm_vs_sg_tile_n) +DECLARE_2D_TILE_BLOCK_OPS(dv_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, + ugemm_vs_sg_tile_m, 1, 1, ugemm_vs_sg_tile_n) +DECLARE_2D_TILE_COPY_REBLOCK(dv_tile_type, SUBGROUP_SIZE, + ugemm_vs_c_type_block0, ugemm_vs_c_type_block1, ugemm_vs_c_type_nblock0, + ugemm_vs_c_type_nblock1, dv_tile_type_dst, SUBGROUP_SIZE, + ugemm_vs_sg_tile_m, 1, 1, ugemm_vs_sg_tile_n, CONVERT_DATA_T) + +DECLARE_2D_TILE(dq_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, + ugemm_ktq_sg_tile_m, 1, 1, ugemm_ktq_sg_tile_n) +DECLARE_2D_TILE_BLOCK_OPS(dq_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, + ugemm_ktq_sg_tile_m, 1, 1, ugemm_ktq_sg_tile_n) +DECLARE_2D_TILE_COPY_REBLOCK(ktq_tile_type, SUBGROUP_SIZE, + ugemm_ktq_c_type_block0, ugemm_ktq_c_type_block1, + ugemm_ktq_c_type_nblock0, ugemm_ktq_c_type_nblock1, dq_tile_type_dst, + SUBGROUP_SIZE, ugemm_ktq_sg_tile_m, 1, 1, ugemm_ktq_sg_tile_n, + CONVERT_DATA_T) + +DECLARE_2D_TILE_COPY_REBLOCK(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1, s_tile_type_reblock, SUBGROUP_SIZE, + ugemm_kq_sg_tile_m, 1, 1, ugemm_kq_sg_tile_n, CONVERT_DATA_T) +DECLARE_2D_TILE_COPY_REBLOCK(p_tile_type, SUBGROUP_SIZE, + ugemm_vtdA_c_type_block0, ugemm_vtdA_c_type_block1, + ugemm_vtdA_c_type_nblock0, ugemm_vtdA_c_type_nblock1, + p_tile_type_reblock, SUBGROUP_SIZE, ugemm_vtdA_c_type_block0, 1, + ugemm_vtdA_c_type_nblock0, + ugemm_vtdA_c_type_block1 *ugemm_vtdA_c_type_nblock1, CONVERT_DATA_T) +DECLARE_2D_TILE_COPY_REBLOCK(p_tile_type_reblock, SUBGROUP_SIZE, + ugemm_vtdA_c_type_block0, 1, ugemm_vtdA_c_type_nblock0, + ugemm_vtdA_c_type_block1 *ugemm_vtdA_c_type_nblock1, p_tile_type, + SUBGROUP_SIZE, ugemm_vtdA_c_type_block0, ugemm_vtdA_c_type_block1, + ugemm_vtdA_c_type_nblock0, ugemm_vtdA_c_type_nblock1, CONVERT_FLOAT_T) + +DECLARE_2D_TILE_VREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1, s_sum_tile_type, SUBGROUP_SIZE, + ugemm_kq_sg_tile_n, 1, 1, 1) +DECLARE_2D_TILE_VREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1, qmask_tile_type_float, SUBGROUP_SIZE, + ugemm_kq_sg_tile_n, 1, 1, 1) +DECLARE_2D_TILE_HREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1, qmask_tile_type_float, SUBGROUP_SIZE, + ugemm_kq_sg_tile_n, 1, 1, 1) +DECLARE_2D_TILE_HREDUCE(p_tile_type, SUBGROUP_SIZE, ugemm_vtdA_c_type_block0, + ugemm_vtdA_c_type_block1, ugemm_vtdA_c_type_nblock0, + ugemm_vtdA_c_type_nblock1, kmask_tile_type_float, SUBGROUP_SIZE, + ugemm_vtdA_sg_tile_m, 1, 1, 1) +DECLARE_2D_TILE_VREDUCE(p_tile_type, SUBGROUP_SIZE, ugemm_vtdA_c_type_block0, + ugemm_vtdA_c_type_block1, ugemm_vtdA_c_type_nblock0, + ugemm_vtdA_c_type_nblock1, kmask_tile_type_float, SUBGROUP_SIZE, + ugemm_vtdA_sg_tile_m, 1, 1, 1) + +DECLARE_2D_TILE_HREDUCE(p_tile_type, SUBGROUP_SIZE, ugemm_vtdA_c_type_block0, + ugemm_vtdA_c_type_block1, ugemm_vtdA_c_type_nblock0, + ugemm_vtdA_c_type_nblock1, p_sum_tile_type, SUBGROUP_SIZE, + ugemm_vtdA_sg_tile_n, 1, 1, 1) + +DECLARE_2D_TILE_HREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1, s_sum_tile_type, SUBGROUP_SIZE, + ugemm_kq_sg_tile_n, 1, 1, 1) +#if WITH_ATTN_MASK +DECLARE_2D_TILE_VREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1, mask_tile_type_float, SUBGROUP_SIZE, mask_br, + mask_bc, mask_nbr, mask_nbc) +#endif + +#define tile_load_block_rem_q(t, ptr, n, ld, off_r, off_c, load_rem) \ + if (load_rem) { \ + tile_load_block(t, ptr, n, ld, off_r, off_c); \ + } else { \ + tile_load_block(t, ptr, ld, off_r, off_c); \ + } + +#define tile_store_block_rem_q(t, ptr, n, ld, off_r, off_c, store_rem) \ + if (store_rem) { \ + tile_store_block(t, ptr, n, ld, off_r, off_c); \ + } else { \ + tile_store_block(t, ptr, ld, off_r, off_c); \ + } + +#define binary_add(x, y) ((x) + (y)) + +inline void tile_load_k(k_tile_type *K_tile, const global KEY_DATA_T *K, int m, + int n, int ldk, int offset_r, int offset_c, int load_rem) { +#if BLOCK_K + // can ignore load_rem due to d_full requirement + tile_load_block(K_tile, K, ldk, offset_r, offset_c); +#else + tile_load(K_tile, K, m, n, ldk, offset_r, offset_c); +#endif +} + +#if KV_GROUP_SIZE > 1 +#define DST_DATA_T_DKDV float +#else +#define DST_DATA_T_DKDV DST_DATA_T +#endif + +inline void tile_store_dV(dv_tile_type *dV_tile_slm, global DST_DATA_T_DKDV *dV, + int m, int n, int ld, int offset_r, int offset_c, int rem) { + +#if KV_GROUP_SIZE > 1 // GQA update + tile_atomic_add(*dV_tile_slm, dV, m, n, ld, offset_r, offset_c); +#else // MHA update + + dv_tile_type_dst dV_tile_dst; // convert to half + tile_copy_reblock(*dV_tile_slm, &dV_tile_dst); +#if BLOCK_DV + tile_store_block_rem_q(dV_tile_dst, dV, n, ld, offset_r, offset_c, rem) +#else + tile_store(dV_tile_dst, dV, m, n, ld, offset_r, offset_c); +#endif + +#endif +} + +inline void tile_store_dK(a_tile_type *dK_tile_slm, global DST_DATA_T_DKDV *dK, + int m, int n, int ld, int offset_r, int offset_c) { + +#if KV_GROUP_SIZE > 1 // GQA update + tile_atomic_add(*dK_tile_slm, dK, m, n, ld, offset_r, offset_c); +#else // MHA update + + a_tile_type_dst dK_tile_dst; // convert to half + tile_copy_reblock(*dK_tile_slm, &dK_tile_dst); +#if BLOCK_DK + tile_store_block(dK_tile_dst, dK, ld, offset_r, offset_c); +#else + tile_store(dK_tile_dst, dK, m, n, ld, offset_r, offset_c); +#endif + +#endif +} + +#define DO_MM 1 + +__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void +micro_sdpa_bwd(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, + const global VAL_DATA_T *V, const global float *ws, + const global float *Di, const global DST_DATA_T *A, + const global DST_DATA_T *dA, +#if WITH_DS + global DST_DATA_T *dS, // expensive, optional intermediate +#endif + global DST_DATA_T_DKDV *dK, global float *dQ, + global DST_DATA_T_DKDV *dV, +#if WITH_HOST_SCALE + float scalar_scale, float inv_scalar_scale, +#else + const global SCALE_DATA_T *scale_ptr, +#endif + int d, int k, int q, const int attn_mask_type +#if WITH_ATTN_MASK + , + const global MSK_DATA_T *msk +#endif + , + KEY_OFFSETS, QRY_OFFSETS, VAL_OFFSETS, DST_OFFSETS +#if WITH_ATTN_MASK + , + MSK_OFFSETS +#endif + , + const int remainder_k, const int remainder_q) { + + uint wg_k = get_group_id(0); + + uint sg_ij = sub_group_broadcast(get_local_id(1), 0); + + uint b1 = get_group_id(2); + + // TODO: batch q=1 cases to KV_GROUP_SIZE + uint b0, b0_kv; + b0 = get_group_id(1); + b0_kv = b0 / KV_GROUP_SIZE; + + uint wg_i0 = wg_k * ugemm_kq_wg_tile_m; + + const uint preprocess_batch = b1 * (DST_D1 * q) + b0 * q; + const global float *ws_logsumexp = ws + preprocess_batch; + Di += preprocess_batch; + + /* Calculate the number of keys to process */ + int q0end = q; + int qdiag0 = 0; // potentially offset starting idx in causal mask cases +#if WITH_CAUSAL_MASK + if (attn_mask_type == ATTN_MASK_TOP_LEFT) { + qdiag0 = max(0, (int)(wg_i0)); + } else { + qdiag0 = max(0, (int)(wg_i0 + (q - k))); + } +#endif + + /* Leading dimension for matrices */ + uint ldk = TRANSPOSE_K ? KEY_S3 : KEY_S2; + uint ldq = QRY_S2; + uint ldv = VAL_S2; + uint lda = DST_S2; + + /* Subgroup IDs for each GEMM, although total number of + * sg per wg may be shared + * ordering may differ due to transposes */ + uint sg_i_kq = sg_ij % ugemm_kq_sg_per_wg_m; + uint sg_j_kq = sg_ij / ugemm_kq_sg_per_wg_m; + + uint sg_i_vtdA = sg_ij % ugemm_vtdA_sg_per_wg_m; + uint sg_j_vtdA = sg_ij / ugemm_vtdA_sg_per_wg_m; + + uint sg_i_vs = sg_ij % ugemm_vs_sg_per_wg_m; + uint sg_j_vs = sg_ij / ugemm_vs_sg_per_wg_m; + + uint sg_i_qdSt = sg_ij % ugemm_qdSt_sg_per_wg_m; + uint sg_j_qdSt = sg_ij / ugemm_qdSt_sg_per_wg_m; + + uint sg_i_ktq = sg_ij % ugemm_ktq_sg_per_wg_m; + uint sg_j_ktq = sg_ij / ugemm_ktq_sg_per_wg_m; + + /* SLM allocations -- place in one array to work around compiler bug */ +#define K_slm_size (ugemm_kq_wg_tile_m * D_MAX * sizeof(KEY_DATA_T)) +#define S_slm_size (ugemm_kq_wg_tile_m * ugemm_kq_wg_tile_n * sizeof(FMA_TYPE)) + +#define dK_slm_size (ugemm_kq_wg_tile_m * D_MAX * sizeof(float)) +#define dV_slm_size (ugemm_kq_wg_tile_m * D_MAX * sizeof(float)) + +#define ugemm_slm_size \ + MAX(MAX(MAX(MAX(ugemm_kq_slm_size, ugemm_vs_slm_size), \ + ugemm_vtdA_slm_size), \ + ugemm_qdSt_slm_size), \ + ugemm_ktq_slm_size) + + local char slm[K_slm_size + S_slm_size + ugemm_slm_size + dK_slm_size + + dV_slm_size]; + + local KEY_DATA_T *K_slm = (local KEY_DATA_T *)&slm[0]; + + // used for caching various A,B gemm tiles + local FMA_TYPE *S_slm = (local FMA_TYPE *)&slm[K_slm_size]; + + // ugemm scratch space + local uint *ugemm_slm = (local uint *)&slm[K_slm_size + S_slm_size]; + + // used for accumulation of dV, dK across q-loop + local float *dK_slm + = (local float *)&slm[K_slm_size + S_slm_size + ugemm_slm_size]; + local float *dV_slm = (local float *)&slm[K_slm_size + S_slm_size + + ugemm_slm_size + dK_slm_size]; + + const size_t k_offset = KEY_BATCH(b1, b0_kv); + const size_t v_offset = VAL_BATCH(b1, b0_kv); + const size_t q_offset = QRY_BATCH(b1, b0); + const size_t a_offset = DST_BATCH(b1, b0); + + /* Locate K/Q/V/A matrices within batch */ + K += k_offset; + Q += q_offset; + V += v_offset; + A += a_offset; + + dK += k_offset; + dQ += q_offset; + dV += v_offset; + dA += a_offset; + +#if WITH_DS + dS += b1 * (DST_D1 * q * k) + b0 * (q * k); +#endif + +#if WITH_ATTN_MASK + msk += MSK_BATCH(b1 % MSK_D0, b0 % MSK_D1); + int mask_aligned = (((size_t)msk) % 4) == 0; + bool block_msk = (b1 < MSK_D0 - ceil((float)ugemm_kq_wg_tile_m / MSK_S2)) + && mask_aligned; +#endif + + if (qdiag0 < q0end) { + /* Load K tile, destined for SLM */ + + k_tile_type K_tile; + tile_fill(K_tile, TO_DATA_T(0.f)); + + uint k0_copy = dmax_tile_sg_n + * sg_ij; //each sg will be responsible for dmax_tile_sg_n columns + tile_load_k(&K_tile, K, k, d, ldk, wg_i0, k0_copy, remainder_k); + ///* Store K tile to SLM */ +#if USE_SYSTOLIC_UKERNEL + tile_store_sys_src1(K_tile, &K_slm[0], SUBGROUP_SIZE, D_MAX, + ugemm_kq_wg_tile_m, D_MAX, 0, k0_copy); +#else + tile_store_packed_src1( + K_tile, K_slm, ugemm_kq_sg_tile_m, D_MAX, 0, k0_copy); +#endif + } + + /* Load scale */ + float scale = 1.f; + float iscale = 1.f; + if (qdiag0 < q0end) { +#if WITH_ATTN_SCALE +#if WITH_HOST_SCALE +#if INVERT_SCALE + iscale = scalar_scale; + scale = inv_scalar_scale; +#else + scale = scalar_scale; + iscale = inv_scalar_scale; +#endif +#else +#if INVERT_SCALE + iscale = SCALES_TO_FLOAT(*scale_ptr); + scale = native_recip(iscale); +#else + scale = SCALES_TO_FLOAT(*scale_ptr); + iscale = native_recip(scale); +#endif +#endif +#endif + } + + /* Initialize dV, dK to zero */ +#pragma unroll + for (int i = get_local_id(0); i < ugemm_kq_wg_tile_m * D_MAX; + i += get_local_size(0)) { + dK_slm[i] = 0.f; + dV_slm[i] = 0.f; + } + + uint sg_i0_kq = sg_i_kq * ugemm_kq_sg_tile_m; + uint sg_j0_kq = sg_j_kq * ugemm_kq_sg_tile_n; + + const int k0 = wg_i0; + + // make sure K_tile in SLM + barrier(CLK_LOCAL_MEM_FENCE); + + /* Main loop over k blocks */ + for (int q0 = qdiag0; q0 < q0end; q0 += ugemm_kq_wg_tile_n) { + + const bool first = (q0 == qdiag0); + const int qnext = q0 + ugemm_kq_wg_tile_n; + const bool last = (qnext >= q0end); + + qmask_tile_type_float q_mask; + kmask_tile_type_float k_mask; + + int k_chunk = min(k - k0, ugemm_kq_wg_tile_m); + int q_nchunk = min(q0end - q0, ugemm_kq_wg_tile_n); + /* Calculate S = (K^T) * Q */ +#if DO_MM + s_tile_type S_tile + = ugemm_kq(K_slm, D_MAX, Q + q0 * ldq, ldq, k_chunk, q_nchunk, + d, 0, 0, 0, sg_i_kq, sg_j_kq, (local char *)ugemm_slm); +#else + s_tile_type S_tile; +#endif + + uint sg_i0_s2 = sg_i_kq * ugemm_kq_sg_tile_m + k0; + uint sg_j0_s2 = sg_j_kq * ugemm_kq_sg_tile_n + q0; + + /* Apply attention mask */ +#if WITH_ATTN_MASK + mask_tile_type mask_tile; +#if BROADCAST_MASK_Q + if (block_msk) { + tile_load_block(&mask_tile, msk, MSK_S2, 0, k0 + sg_i0_kq, 0); + } else { + tile_load(&mask_tile, msk, k, 1, MSK_S2, k0 + sg_i0_kq, 0); + } +#else + tile_load(&mask_tile, msk, k, q, MSK_S2, k0 + sg_i0_kq, q0 + sg_j0_kq); +#endif + +#define unscale(x) ((x) * iscale) + mask_tile_type_float mask_tile_float; + tile_copy_reblock(mask_tile, &mask_tile_float); +#if WITH_ATTN_SCALE + tile_elementwise(mask_tile_float, unscale); +#endif +#undef unscale +#if BROADCAST_MASK_Q + tile_vbroadcast_add(&S_tile, mask_tile_float); +#else + tile_binary(S_tile, mask_tile_float, binary_add); +#endif +#endif + + /* Apply q mask */ + if (remainder_q) { +#pragma unroll + for (int jj = get_sub_group_local_id(); jj < ugemm_kq_sg_tile_n; + jj += SUBGROUP_SIZE) { + q_mask.x[0][jj / SUBGROUP_SIZE] + = ((q0 + sg_j0_kq + jj) < q0end) ? nan(0u) : -INFINITY; + } + tile_hbroadcast_min(&S_tile, q_mask); + } + +#if WITH_CAUSAL_MASK +#define less_than(offset_k, offset_q) (offset_q < offset_k) + + int col_offset = q0 + sg_j0_kq; + if (q == 1) col_offset = 1; + if (attn_mask_type == ATTN_MASK_BOTTOM_RIGHT) col_offset += k - q; + + /* Apply causal mask */ + const bool is_diag = (q0 + == qdiag0); // first iteration will be on diagonal, requiring partial masking + if (is_diag) { + tile_predicated_assignment(S_tile, k0 + sg_i0_kq, col_offset, + less_than, -INFINITY, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1); + } +#undef less_than +#endif + + s_sum_tile_type S_logsumexp_tile; + tile_fill(S_logsumexp_tile, 0.f); + tile_load(&S_logsumexp_tile, ws_logsumexp, q, 1, ugemm_kq_wg_tile_n, + sg_j0_kq + q0, 0); +#define mulscale(x) (x * scale) + tile_elementwise(S_tile, mulscale); +#undef mulscale + tile_hbroadcast_sub(&S_tile, S_logsumexp_tile); //layout.N + //tile_vbroadcast_sub(&S_tile, S_logsumexp_tile); //layout.T + +/* Scale + exponentiate */ +#define scaled_exp(x) native_vexp2(x * 1.44269504089f) + tile_elementwise(S_tile, scaled_exp); +#undef scaled_exp + + s_tile_type_reblock S_tile_reblock; + tile_copy_reblock(S_tile, &S_tile_reblock); + uint sg_i0_ds = sg_i_kq * ugemm_kq_sg_tile_m; + uint sg_j0_ds = sg_j_kq * ugemm_kq_sg_tile_n; + + barrier(CLK_LOCAL_MEM_FENCE); +#if USE_SYSTOLIC_UKERNEL + tile_store_t_sys_src22(S_tile_reblock, (local FMA_TYPE *)S_slm, + ugemm_vs_sg_tile_n, ugemm_kq_wg_tile_m, ugemm_kq_wg_tile_n, + sg_i0_kq, sg_j0_kq); +#else + tile_store_packed_src1(S_tile_reblock, S_slm, ugemm_vs_sg_tile_n, + ugemm_kq_wg_tile_n, sg_i0_kq, sg_j0_kq); +#endif + barrier(CLK_LOCAL_MEM_FENCE); + +#if DO_MM + dv_tile_type dV_tile1; + dV_tile1 = ugemm_vs(dA + q0 * lda, lda, (local FMA_TYPE *)S_slm, + ugemm_kq_wg_tile_n, d, k_chunk, q_nchunk, 0, 0, 0, sg_i_vs, + sg_j_vs, (local char *)ugemm_slm); +#else + dv_tile_type dV_tile1; +#endif + uint sg_i0_vs = sg_i_vs * ugemm_vs_sg_tile_m; + uint sg_j0_vs = sg_j_vs * ugemm_vs_sg_tile_n; + + //slm dv tile + dv_tile_type dV_tile_slm; + tile_load(&dV_tile_slm, dV_slm, D_MAX, ugemm_kq_wg_tile_m, D_MAX, + sg_i0_vs, sg_j0_vs); + tile_binary(dV_tile_slm, dV_tile1, binary_add); + tile_store(dV_tile_slm, dV_slm, D_MAX, ugemm_kq_wg_tile_m, D_MAX, + sg_i0_vs, sg_j0_vs); + +#if DO_MM + p_tile_type dP_tile = ugemm_vtdA(V + k0 * ldv, ldv, dA + q0 * lda, lda, + k_chunk, q_nchunk, d, 0, 0, 0, sg_i_kq, sg_j_kq, + (local char *)ugemm_slm); +#else + p_tile_type dP_tile; +#endif + + // get D_i tile + p_sum_tile_type D_i; + tile_fill(D_i, 0.0f); + tile_load(&D_i, Di, q0end, 1, q0end, q0 + sg_j0_kq, 0); + + tile_hbroadcast_sub(&dP_tile, + D_i); // needs output to be transposed from vtdA layout.C = N + + // reload softmax from SLM since ugemm_vtdA() will clobber registers + p_tile_type S2_tile; + p_tile_type_reblock S2_tile_reblock; + +#if USE_SYSTOLIC_UKERNEL + tile_load_t_sys_src2(&S2_tile_reblock, (local FMA_TYPE *)S_slm, + ugemm_vs_sg_tile_n, ugemm_kq_wg_tile_n, sg_j0_kq, sg_i0_kq); +#else + tile_load_packed_src1(&S2_tile_reblock, S_slm, ugemm_vs_sg_tile_n, + ugemm_kq_wg_tile_n, sg_i0_kq, sg_j0_kq); +#endif + tile_copy_reblock(S2_tile_reblock, &S2_tile); + +#define binary_mul_scale(x, y) ((x) * (y) * scale) + tile_binary(dP_tile, S2_tile, binary_mul_scale); + + if (remainder_k) { +#pragma unroll + for (int ii = 0; ii < ugemm_kq_sg_tile_m / SUBGROUP_SIZE; ii++) { + k_mask.x[0][ii] = (k0 + sg_i0_kq + ii * SUBGROUP_SIZE + + get_sub_group_local_id() + < k) + ? 1 + : 0; + } + tile_vbroadcast_mul(&dP_tile, k_mask); + } + + p_tile_type_reblock P_tile_reblock; + tile_copy_reblock(dP_tile, &P_tile_reblock); +#if WITH_DS + tile_store(P_tile_reblock, dS, k_chunk, q_nchunk, k, k0 + sg_i0_kq, + q0 + sg_j0_kq); +#endif + + // SLM for dK = dS^t * Q + local FMA_TYPE *dS_slm = (local FMA_TYPE *)S_slm; + barrier(CLK_LOCAL_MEM_FENCE); +#if USE_SYSTOLIC_UKERNEL + tile_store_sys_src1(P_tile_reblock, dS_slm, SUBGROUP_SIZE, + ugemm_kq_wg_tile_n, ugemm_kq_wg_tile_m, ugemm_kq_wg_tile_n, + sg_i0_kq, sg_j0_kq); +#else + tile_store_packed_src1(P_tile_reblock, dS_slm, ugemm_qdSt_sg_tile_m, + ugemm_kq_wg_tile_n, sg_i0_kq, sg_j0_kq); +#endif + barrier(CLK_LOCAL_MEM_FENCE); + +#if DO_MM + a_tile_type dK_tile1; + dK_tile1 = ugemm_qdSt(dS_slm, ugemm_kq_wg_tile_n, Q + q0 * ldq, ldq, + k_chunk, d, q_nchunk, 0, 0, 0, sg_i_qdSt, sg_j_qdSt, + (local char *)ugemm_slm); // dS^t * Q -> Bc x d +#else + a_tile_type dK_tile1; +#endif + uint sg_i0_dk = sg_i_qdSt * ugemm_qdSt_sg_tile_m; + uint sg_j0_dk = sg_j_qdSt * ugemm_qdSt_sg_tile_n; + + //// dk slm tile + a_tile_type dK_tile_slm; + tile_load(&dK_tile_slm, dK_slm, ugemm_kq_wg_tile_m, D_MAX, + ugemm_kq_wg_tile_m, sg_i0_dk, sg_j0_dk); + tile_binary(dK_tile_slm, dK_tile1, binary_add); + tile_store(dK_tile_slm, dK_slm, ugemm_kq_wg_tile_m, D_MAX, + ugemm_kq_wg_tile_m, sg_i0_dk, sg_j0_dk); + + p_tile_type_reblock dS_transpose_tile; +#if USE_SYSTOLIC_UKERNEL + tile_load_sys_src1(&dS_transpose_tile, dS_slm, SUBGROUP_SIZE, + ugemm_kq_wg_tile_n, ugemm_kq_wg_tile_m, ugemm_kq_wg_tile_n, + sg_i0_kq, sg_j0_kq); + barrier(CLK_LOCAL_MEM_FENCE); + tile_store_sys_src22(dS_transpose_tile, dS_slm, ugemm_ktq_sg_tile_n, + ugemm_kq_wg_tile_m, ugemm_kq_wg_tile_n, sg_i0_kq, sg_j0_kq); +#else + tile_load_packed_src1(&dS_transpose_tile, dS_slm, ugemm_qdSt_sg_tile_m, + ugemm_kq_wg_tile_n, sg_i0_kq, sg_j0_kq); + barrier(CLK_LOCAL_MEM_FENCE); + tile_store_t_packed_src1(dS_transpose_tile, dS_slm, ugemm_ktq_sg_tile_n, + ugemm_kq_wg_tile_m, sg_j0_kq, sg_i0_kq); +#endif + barrier(CLK_LOCAL_MEM_FENCE); + + // dQ = dS * K +#if DO_MM + ktq_tile_type dQ_tile; + dQ_tile = ugemm_ktq(K + k0, ldk, dS_slm, ugemm_kq_wg_tile_m, d, + q_nchunk, k_chunk, 0, 0, 0, sg_i_ktq, sg_j_ktq, + (local char *)ugemm_slm); +#else + ktq_tile_type dQ_tile; +#endif + uint sg_i0_dq = sg_i_ktq * ugemm_ktq_sg_tile_m; + uint sg_j0_dq = sg_j_ktq * ugemm_ktq_sg_tile_n + q0; + + tile_atomic_add(dQ_tile, dQ, d, q, ldq, sg_i0_dq, sg_j0_dq); + } + + //////// update dV + uint sg_i0_vs = sg_i_vs * ugemm_vs_sg_tile_m; + uint sg_j0_vs = sg_j_vs * ugemm_vs_sg_tile_n; + + // ensure all loops done writing to SLM + barrier(CLK_LOCAL_MEM_FENCE); + + dv_tile_type dV_tile_slm; + tile_load(&dV_tile_slm, dV_slm, D_MAX, ugemm_kq_wg_tile_m, D_MAX, sg_i0_vs, + sg_j0_vs); + + tile_store_dV(&dV_tile_slm, dV, d, k, ldv, sg_i0_vs, wg_i0 + sg_j0_vs, + remainder_k); + // /update dV + + //////// update dK + uint sg_i0_dk = sg_i_qdSt * ugemm_qdSt_sg_tile_m; + uint sg_j0_dk = sg_j_qdSt * ugemm_qdSt_sg_tile_n; + + a_tile_type dK_tile_slm; + tile_load(&dK_tile_slm, dK_slm, ugemm_kq_wg_tile_m, D_MAX, + ugemm_kq_wg_tile_m, sg_i0_dk, sg_j0_dk); + + int wg_k_chunk = min(k - k0, ugemm_kq_wg_tile_m); + tile_store_dK( + &dK_tile_slm, dK + wg_i0, wg_k_chunk, d, ldk, sg_i0_dk, sg_j0_dk); + // /update dK +} + +__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void +preprocess_Di(global float *Di, const global DST_DATA_T *A, + const global DST_DATA_T *dA, int d, int k, int q, QRY_OFFSETS, + DST_OFFSETS) { + + uint lda = DST_S2; + uint ldq = QRY_S2; + + uint sg_ij = sub_group_broadcast(get_local_id(1), 0); + uint sg_i_kq = sg_ij % ugemm_kq_sg_per_wg_m; + uint sg_j_kq = sg_ij / ugemm_kq_sg_per_wg_m; + + uint b0, b1; + b0 = get_group_id(1); + b1 = get_group_id(2); + + const uint preprocess_batch = b1 * (DST_D1 * q) + b0 * q; + + const size_t q_offset = QRY_BATCH(b1, b0); + const size_t a_offset = DST_BATCH(b1, b0); + + /* Locate A/dA matrices within batch */ + A += a_offset; + dA += a_offset; + + Di += preprocess_batch; + + uint wg_q = get_group_id(0); + uint wg_j0 = wg_q * ugemm_kq_wg_tile_n; + +#define Di_slm_size (ugemm_kq_wg_tile_n * sizeof(float)) + local char slm[Di_slm_size]; + + local float *Di_slm = (local float *)&slm[0]; + + uint sg_i0_kq = sg_i_kq * ugemm_kq_sg_tile_m; + uint sg_j0_kq = sg_j_kq * ugemm_kq_sg_tile_n; + + uint q0_copy = q_tile_sg_n * sg_ij; + + if (q > 0) { + // D_i calculation +#if QRY_DT_F32 + dq_tile_type dA_tile, A_tile; + tile_fill(A_tile, 0.f); + tile_fill(dA_tile, 0.f); + tile_load( + &dA_tile, (global FMA_TYPE *)dA, d, q, lda, 0, wg_j0 + q0_copy); + tile_load(&A_tile, (global FMA_TYPE *)A, d, q, lda, 0, wg_j0 + q0_copy); +#else + dq_tile_type dA_tile, A_tile; + q_tile_type dA_tile_reblock, A_tile_reblock; // load native type + tile_fill(A_tile_reblock, TO_DATA_T(0.f)); + tile_fill(dA_tile_reblock, TO_DATA_T(0.f)); + + tile_load(&dA_tile_reblock, (global FMA_TYPE *)dA, d, q, lda, 0, + wg_j0 + q0_copy); + tile_load(&A_tile_reblock, (global FMA_TYPE *)A, d, q, lda, 0, + wg_j0 + q0_copy); + + // convert to float for calculation + tile_copy_reblock(dA_tile_reblock, &dA_tile); + tile_copy_reblock(A_tile_reblock, &A_tile); +#endif + +#define binary_mul(x, y) ((x) * (y)) + tile_binary(A_tile, dA_tile, binary_mul); + + // reduce tile across D_MAX + for (int j = 0; j < q_tile_sg_n; j++) { + float r = 0.f; + for (int i0 = 0; i0 < D_MAX; i0 += SUBGROUP_SIZE) { + r += sub_group_reduce_add( + tile_access(A_tile, i0, j, SUBGROUP_SIZE, D_MAX, 1, 1)); + } + Di_slm[j + q0_copy] = r; + } + barrier(CLK_LOCAL_MEM_FENCE); + + for (int i = get_local_id(0); i < ugemm_kq_wg_tile_n; + i += get_local_size(0)) { + if (get_local_id(1) == 0 && (wg_j0 + i) < q) { + Di[wg_j0 + i] = Di_slm[i]; + } + } + } +} + +__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void zero_dQ( + global float *dst, int nelems, QRY_OFFSETS) { + uint b0 = get_group_id(1); + uint b1 = get_group_id(2); + + const size_t offset = QRY_BATCH(b1, b0); + + dst += offset; + size_t idx = get_global_id(0); + if (idx < nelems) { dst[idx] = 0.f; } +} + +__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void +postprocess_dQ(global DST_DATA_T *dst, global const float *src, int nelems, + QRY_OFFSETS) { + uint b0 = get_group_id(1); + uint b1 = get_group_id(2); + + const size_t offset = QRY_BATCH(b1, b0); + + /* Locate dQ/dV/dK matrices within batch */ + src += offset; + dst += offset; + size_t idx = get_global_id(0); + if (idx < (nelems)) { dst[idx] = TO_DATA_T(src[idx]); } +} diff --git a/tests/gtests/internals/test_sdpa.cpp b/tests/gtests/internals/test_sdpa.cpp index d8e491ac27e..05741e0dfbf 100644 --- a/tests/gtests/internals/test_sdpa.cpp +++ b/tests/gtests/internals/test_sdpa.cpp @@ -1532,7 +1532,7 @@ std::chrono::nanoseconds prim_sdpa_quant_bwd(const sdpa_dims_t &p, print_mem(grouped_query, "FWD grouped_query"); print_mem(key_dequantized, "FWD keq_deq"); print_mem(scale, "FWD scale"); - print_mem(mask, "FWD mask"); + if (p.mask.type != mask_type::no_mask) { print_mem(mask, "FWD mask"); } print_mem(score, "FWD intermediate score"); print_mem(score2, "FWD intermediate score2"); #endif @@ -1812,7 +1812,7 @@ void check_memory(dnnl::stream &strm, memory &gold, memory &test, float max_diff = std::numeric_limits::min(); std::map> hist; - const bool verbose = true; + const bool verbose = false; for_(int l = 0; l < dims[0]; l++) for_(int k = 0; k < dims[1]; k++) for_(int j = 0; j < dims[2]; j++) @@ -2116,7 +2116,7 @@ class sdpa_test_t : public ::testing::TestWithParam { auto dS_desc = t.m_dS.get_desc(); memory::desc *dS_ptr = nullptr; - //dS_ptr = &dS_desc; // uncomment for optional dS output (expensive) + //dS_ptr = &dS_desc; // uncomment for optional dS output // fwd sdpa primitive to populate dst, col_maxes sdpa::primitive_desc sdpa_fwd_pd; @@ -2174,10 +2174,6 @@ class sdpa_test_t : public ::testing::TestWithParam { sdpa_fwd.execute(strm, sdpa_fwd_args); strm.wait(); -#if DEBUG_PRINT_MEM - print_mem(sdpa_fwd_workspace_memory, "sharedworkspace"); -#endif - std::unordered_map sdpa_bwd_args = {{DNNL_ARG_QUERIES, t.m_query}, {DNNL_ARG_KEYS, t.m_key_quantized}, @@ -2208,10 +2204,9 @@ class sdpa_test_t : public ::testing::TestWithParam { #if DEBUG_PRINT_MEM print_mem(t.m_dS, "computed dS"); print_mem(t.m_diff_value_quantized, "dV bwd out"); -#endif - printf("-------------- Primitives based implementation -------------- " "\n"); +#endif // perform primitives based backwards sdpa pass to generate "gold" gradient outputs prim_sdpa_quant_bwd(p, t, eng, strm, t.m_query, t.m_key_quantized, @@ -2525,6 +2520,7 @@ class sdpa_test_t : public ::testing::TestWithParam { }; auto qtime = min_time(bwd_time) / iterations; + printf("qtimebwd %f\n", (float)qtime.count() / 1e6); // Backward reads: Q, K, V, O, dO, workspace(logsumexp) // Backward writes: dQ, dK, dV @@ -2903,26 +2899,15 @@ INSTANTIATE_TEST_SUITE_P(bwd_perf, sdpa_bwd_test, // mb,hd_num,kv_hd_num,seq_len,qry_num,hd_size, kg_sz, vgrp_sz, dt, kdt, ksdt, kzpdt, vdt, vsdt, vzpdt, mskdt, qtype testing::Values( - sdpa_dims_t{ 1, 1, 1, 32, 32, 32, 32, 32, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 1, 1, 1, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 1, 1, 1, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 1, 2, 2, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 1, 2, 2, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 1, 4, 4, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 1, 4, 4, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 2, 1, 1, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 2, 1, 1, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 2, 2, 2, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 2, 2, 2, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 2, 4, 4, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 2, 4, 4, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 4, 1, 1, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 4, 1, 1, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 4, 2, 2, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 4, 2, 2, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, - sdpa_dims_t{ 4, 4, 4, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::causal_tl }, - sdpa_dims_t{ 4, 4, 2, 8192, 8192, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::causal_tl }, - sdpa_dims_t{ 4, 12, 2, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::causal_tl } + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 32, 32, 32, mdt::f32, mdt::f32, mdt::undef, mdt::undef, mdt::f32, mdt::undef, mdt::undef, mdt::f32, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 64, 64, 64, mdt::f32, mdt::f32, mdt::undef, mdt::undef, mdt::f32, mdt::undef, mdt::undef, mdt::f32, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 128, 128, 128, mdt::f32, mdt::f32, mdt::undef, mdt::undef, mdt::f32, mdt::undef, mdt::undef, mdt::f32, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 32, 32, 32, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 64, 64, 64, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 32, 32, 32, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::causal_tl }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 64, 64, 64, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::causal_tl }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::causal_tl } ), &print_to_string); // clang-format on @@ -2935,24 +2920,24 @@ GPU_TEST_P(sdpa_test_datatypes, compare) { compare(); } -/* GPU_TEST_P(sdpa_bwd_test, compare_bwd) { compare_bwd(); } -*/ + +GPU_TEST_P(sdpa_bwd_test_datatypes, compare_bwd) { + compare_bwd(); +} GPU_TEST_P(sdpa_test, perf) { perf(); } +/* GPU_TEST_P(sdpa_bwd_test, perf_bwd) { const bool time_reference = true; perf_bwd(time_reference); } - -GPU_TEST_P(sdpa_bwd_test_datatypes, compare_bwd) { - compare_bwd(); -} +*/ // clang-format off @@ -3032,10 +3017,8 @@ INSTANTIATE_TEST_SUITE_P(bwd_nonuniform_seq, sdpa_bwd_test_datatypes, testing::Combine(testing::Values(1), // mb testing::Values(num_heads_t {1, 1}, num_heads_t {2, 2}), // heads - testing::Values(seq_len_size_t {33, 65}, - seq_len_size_t {65, 4097}, - seq_len_size_t {4096, 64}, - seq_len_size_t {1025, 15}), // seq_len (q, kv) + testing::Values(seq_len_size_t {64, 513}, + seq_len_size_t {513, 64}), testing::Values(head_group_size_t {32, 32, 32}, head_group_size_t {64, 64, 64}), // head_size testing::Values(tensor_type_t("Q", mdt::f16)), // dt @@ -3053,7 +3036,7 @@ INSTANTIATE_TEST_SUITE_P(bwd_nonuniform_seq, sdpa_bwd_test_datatypes, // backward pass: f32 INSTANTIATE_TEST_SUITE_P(bwd_f32, sdpa_bwd_test_datatypes, - testing::Combine(testing::Values(1, 4), // mb + testing::Combine(testing::Values(1, 2), // mb testing::Values(num_heads_t {1, 1}, num_heads_t {2, 2}, num_heads_t {12, 12}), // heads testing::Values(seq_len_size_t {32, 32}, seq_len_size_t {64, 64}, @@ -3081,7 +3064,7 @@ INSTANTIATE_TEST_SUITE_P(bwd_f32, sdpa_bwd_test_datatypes, // backward pass: large batch and head counts INSTANTIATE_TEST_SUITE_P(bwd_large_batch, sdpa_bwd_test_datatypes, testing::Combine(testing::Values(4), // mb - testing::Values(num_heads_t {12, 12}), // heads + testing::Values(num_heads_t {8, 8}), // heads testing::Values(seq_len_size_t {4096, 4096}), // seq_len testing::Values(head_group_size_t {32, 32, 32}), // head_size testing::Values(tensor_type_t("Q", mdt::f16)), // dt From 2fa06da4bc40b16fc169d4e452bfffb9653bfc28 Mon Sep 17 00:00:00 2001 From: syurkevi Date: Thu, 26 Feb 2026 12:49:13 -0800 Subject: [PATCH 12/23] xe: sdpa: adds create_sdpa_pd for backwards pass --- src/common/sdpa_pd.hpp | 10 +++++++--- src/common/sdpa_utils.hpp | 32 ++++++++++++++++++++++++++++++++ src/gpu/intel/sdpa/micro.hpp | 2 +- 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/src/common/sdpa_pd.hpp b/src/common/sdpa_pd.hpp index 9eef733b0e6..162966bbc0a 100644 --- a/src/common/sdpa_pd.hpp +++ b/src/common/sdpa_pd.hpp @@ -169,12 +169,13 @@ struct sdpa_pd_t : public primitive_desc_t { , desc_(*op_desc_t::to_desc(adesc)) , hint_fwd_pd_(hint_fwd_pd) {} - void init_default_ws() { + status_t init_default_ws() { dims_t d; d[0] = desc()->batch_size() * desc()->queries(); // (logsumexp) per query - memory_desc_init_by_tag(ws_md_, 1, d, data_type::f32, format_tag::a); + return memory_desc_init_by_tag( + ws_md_, 1, d, data_type::f32, format_tag::a); } bool set_default_format(memory_desc_t *md) { @@ -373,7 +374,10 @@ struct sdpa_bwd_pd_t : public sdpa_pd_t { const memory_desc_t *diff_qry_md() const { return &desc_.diff_q_desc; } const memory_desc_t *diff_key_md() const { return &desc_.diff_k_desc; } const memory_desc_t *diff_val_md() const { return &desc_.diff_v_desc; } - const memory_desc_t *diff_dst_md() const { return &desc_.diff_dst_desc; } + const memory_desc_t *diff_dst_md( + int index = 0, bool user_input = false) const override { + return index == 0 ? &desc_.diff_dst_desc : &glob_zero_md; + } protected: sdpa_bwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, diff --git a/src/common/sdpa_utils.hpp b/src/common/sdpa_utils.hpp index 8cc015288d5..2266721e35c 100644 --- a/src/common/sdpa_utils.hpp +++ b/src/common/sdpa_utils.hpp @@ -249,6 +249,38 @@ static inline status_t create_sdpa_pd( return status::success; } +static inline status_t create_sdpa_pd( + std::shared_ptr &sdpa_pd_, engine_t *engine, + const memory_desc_t *q_md, const memory_desc_t *k_md, + const memory_desc_t *v_md, const memory_desc_t *dst_md, + const memory_desc_t *diff_q_md, const memory_desc_t *diff_k_md, + const memory_desc_t *diff_v_md, const memory_desc_t *diff_dst_md, + const memory_desc_t *dS_md, const memory_desc_t *attn_mask_md, + const memory_desc_t *scale_md, bool invert_scale, dim_t kv_head_number, + attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg, + const primitive_attr_t *attr, const primitive_desc_t *hint_fwd_pd, + const primitive_attr_t *kq_attr = nullptr, + const primitive_attr_t *vs_attr = nullptr) { + CHECK(sdpa_attr_check(q_md, k_md, v_md, engine, attr, kq_attr, vs_attr)); + CHECK(sdpa_desc_check(q_md, k_md, v_md, dst_md, attn_mask_md, engine, attr, + kq_attr, vs_attr)); + + auto sdpa_desc = create_sdpa_desc(q_md, k_md, v_md, dst_md, diff_q_md, + diff_k_md, diff_v_md, diff_dst_md, dS_md, attn_mask_md, scale_md, + invert_scale, kv_head_number, attn_mask_type, softmax_alg, kq_attr, + vs_attr); + + primitive_attr_t sdpa_attr = attr ? *attr : default_attr(); + + primitive_desc_iterator_t it( + engine, (op_desc_t *)&sdpa_desc, &sdpa_attr, hint_fwd_pd); + + sdpa_pd_ = *(++it); + VCHECK_SDPA_COND(sdpa_pd_, "failed to create the backward SDPA primitive"); + + return status::success; +} + } // namespace impl } // namespace dnnl diff --git a/src/gpu/intel/sdpa/micro.hpp b/src/gpu/intel/sdpa/micro.hpp index 3c5936a6609..19ba7b1923c 100644 --- a/src/gpu/intel/sdpa/micro.hpp +++ b/src/gpu/intel/sdpa/micro.hpp @@ -482,7 +482,7 @@ struct micro_bwd_t : public primitive_t { dnnl_dt2str(diff_val_md()->data_type), dnnl_dt2str(diff_dst_md()->data_type)); - init_default_ws(); + CHECK(init_default_ws()); VCHECK_SDPA_COND(compare_ws(hint_fwd_pd_), VERBOSE_WS_MISMATCH); CHECK(init_conf_microkernels(engine)); From 124cf6c907f282e0f23390fd6cd075ea1614adab Mon Sep 17 00:00:00 2001 From: syurkevi Date: Fri, 6 Mar 2026 00:08:00 -0800 Subject: [PATCH 13/23] xe: sdpa: rename fwd/bwd primitives --- src/gpu/gpu_sdpa_list.cpp | 4 ++-- src/gpu/intel/sdpa/config.hpp | 3 ++- src/gpu/intel/sdpa/configs.cpp | 16 ++++++++-------- src/gpu/intel/sdpa/configs.hpp | 18 +++++++++--------- src/gpu/intel/sdpa/micro.cpp | 21 +++++++++++---------- src/gpu/intel/sdpa/micro.hpp | 10 +++++----- src/gpu/intel/sdpa/micro_bwd.cl | 2 +- src/gpu/intel/sdpa/ref.cpp | 2 +- src/gpu/intel/sdpa/ref.hpp | 4 ++-- tests/gtests/internals/test_sdpa.cpp | 5 +---- 10 files changed, 42 insertions(+), 43 deletions(-) diff --git a/src/gpu/gpu_sdpa_list.cpp b/src/gpu/gpu_sdpa_list.cpp index a99336ab66d..13a4b8f310d 100644 --- a/src/gpu/gpu_sdpa_list.cpp +++ b/src/gpu/gpu_sdpa_list.cpp @@ -34,8 +34,8 @@ using namespace dnnl::impl::prop_kind; const std::map> impl_list_map REG_SDPA_P({ {{forward}, { - GPU_INSTANCE_INTEL(intel::sdpa::micro_t) - GPU_INSTANCE_INTEL_DEVMODE(intel::sdpa::ref_t) + GPU_INSTANCE_INTEL(intel::sdpa::micro_fwd_t) + GPU_INSTANCE_INTEL_DEVMODE(intel::sdpa::ref_fwd_t) nullptr, }}, {{backward}, REG_BWD_PK({ diff --git a/src/gpu/intel/sdpa/config.hpp b/src/gpu/intel/sdpa/config.hpp index 9ddf3cf6944..fe9d3bb7a50 100644 --- a/src/gpu/intel/sdpa/config.hpp +++ b/src/gpu/intel/sdpa/config.hpp @@ -25,7 +25,8 @@ namespace gpu { namespace intel { namespace sdpa { -using pd_t = sdpa_pd_t; +using fwd_pd_t = sdpa_fwd_pd_t; +using bwd_pd_t = sdpa_bwd_pd_t; } // namespace sdpa } // namespace intel diff --git a/src/gpu/intel/sdpa/configs.cpp b/src/gpu/intel/sdpa/configs.cpp index 4639f5392ce..1053610cfcf 100644 --- a/src/gpu/intel/sdpa/configs.cpp +++ b/src/gpu/intel/sdpa/configs.cpp @@ -71,7 +71,7 @@ std::string to_string(const config_criteria_t &c) { << ((bool)(c.prop & property::f16_accumulate) ? " f16_acc" : ""); return s.str(); } -std::string to_string(const config_t &c) { +std::string to_string(const fwd_config_t &c) { std::stringstream s; s << c.unroll_m_kq << "," << c.unroll_n_kq << "," << c.unroll_m_vs << "," << c.unroll_n_vs << "," << c.wg_m_kq << "," << c.wg_n_kq << "," @@ -117,7 +117,7 @@ bool criteria_matches( == (key.prop & property::integrated))))); } -bool operator==(const config_record_t &key, const config_query_t &query) { +bool operator==(const fwd_config_record_t &key, const config_query_t &query) { return criteria_matches(key.criteria, query); } @@ -192,7 +192,7 @@ bool operator<(const config_criteria_t &lhs, const config_criteria_t &rhs) { return false; } -bool operator<(const config_record_t &lhs, const config_record_t &rhs) { +bool operator<(const fwd_config_record_t &lhs, const fwd_config_record_t &rhs) { return lhs.criteria < rhs.criteria; } @@ -208,9 +208,9 @@ static auto constexpr f32 = property::f32; static auto constexpr f16_accumulate = property::f16_accumulate; // Kernel configurations: [ arch, head_size, {sequence length}, {properties} ] -> config -static std::vector sorted_configs = []() { +static std::vector sorted_configs = []() { // clang-format off - std::vector configs = { + std::vector configs = { // xe_hpg {{compute::gpu_arch_t::xe_hpg, 16, fma | f32}, {8, 16, 8, 16, 2, 4, 2, 4}}, @@ -624,9 +624,9 @@ property set_properties(bool is_thin_q, bool is_quantized, bool is_integrated, return properties; } -config_t *choose_config(compute::gpu_arch_t arch, dim_t head_size, dim_t seq, - bool is_thin_q, bool is_quantized, bool is_integrated, bool is_fma, - bool is_f32, bool is_f16_accumulate) { +fwd_config_t *choose_config(compute::gpu_arch_t arch, dim_t head_size, + dim_t seq, bool is_thin_q, bool is_quantized, bool is_integrated, + bool is_fma, bool is_f32, bool is_f16_accumulate) { // quantized FMA for f16 on MTL not implemented in gemmstone if (arch == compute::gpu_arch_t::xe_hpg && is_fma && !is_f32 && is_quantized) diff --git a/src/gpu/intel/sdpa/configs.hpp b/src/gpu/intel/sdpa/configs.hpp index 6b15fcd6570..394725a70eb 100644 --- a/src/gpu/intel/sdpa/configs.hpp +++ b/src/gpu/intel/sdpa/configs.hpp @@ -31,7 +31,7 @@ namespace sdpa { namespace micro = gemmstone::microkernel; -struct config_t { +struct fwd_config_t { int unroll_m_kq, unroll_n_kq; // Subgroup tile sizes for K*Q GEMM int unroll_m_vs, unroll_n_vs; // Subgroup tile sizes for V*S GEMM int wg_m_kq, wg_n_kq; // Workgroup configuration for K*Q GEMM @@ -95,9 +95,9 @@ struct config_criteria_t { : arch(a), head_size(hs), seq_len(sq), prop(prop) {} }; -struct config_record_t { +struct fwd_config_record_t { config_criteria_t criteria; - config_t config; + fwd_config_t config; }; struct bwd_config_record_t { @@ -111,17 +111,17 @@ bool criteria_matches( std::ostream &operator<<(std::ostream &s, const config_query_t &q); std::ostream &operator<<(std::ostream &s, const config_criteria_t &c); -std::ostream &operator<<(std::ostream &s, const config_t &c); +std::ostream &operator<<(std::ostream &s, const fwd_config_t &c); -bool operator==(const config_record_t &key, const config_query_t &query); +bool operator==(const fwd_config_record_t &key, const config_query_t &query); bool operator==(const bwd_config_record_t &key, const config_query_t &query); bool operator<(const config_criteria_t &lhs, const config_criteria_t &rhs); -bool operator<(const config_record_t &lhs, const config_record_t &rhs); +bool operator<(const fwd_config_record_t &lhs, const fwd_config_record_t &rhs); bool operator<(const bwd_config_record_t &lhs, const bwd_config_record_t &rhs); -config_t *choose_config(compute::gpu_arch_t arch, dim_t head_size, dim_t seq, - bool is_thin_q, bool is_quantized, bool is_integrated, bool is_fma, - bool is_f32, bool is_f16_accumulate); +fwd_config_t *choose_config(compute::gpu_arch_t arch, dim_t head_size, + dim_t seq, bool is_thin_q, bool is_quantized, bool is_integrated, + bool is_fma, bool is_f32, bool is_f16_accumulate); bwd_config_t *choose_bwd_config(compute::gpu_arch_t arch, dim_t head_size, dim_t seq, bool is_thin_q, bool is_quantized, bool is_integrated, bool is_fma, bool is_f32); diff --git a/src/gpu/intel/sdpa/micro.cpp b/src/gpu/intel/sdpa/micro.cpp index 8530314ba83..cfcd7a16a4a 100644 --- a/src/gpu/intel/sdpa/micro.cpp +++ b/src/gpu/intel/sdpa/micro.cpp @@ -62,7 +62,8 @@ bool with_quantize_common(const quant_entry_t &entry) { } /* anonymous namespace */ -status_t update_config_from_devenv_values(config_t *config, bool quantized) { +status_t update_config_from_devenv_values( + fwd_config_t *config, bool quantized) { std::string q_config_str = gpu_utils::dev_getenv("QUANTIZED_SDPA_CONFIG", std::string("")); std::string config_str @@ -141,7 +142,7 @@ status_t update_config_from_devenv_values(bwd_config_t *config) { return status::success; } -status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { +status_t micro_fwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { using namespace jit; using gemm::jit::convert_dnnl_to_kernel_type; @@ -155,7 +156,7 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { "Microkernels not supported by the OpenCL driver."); /* Retrieve pre-tuned kernel configuration */ - config_t *config = nullptr; + fwd_config_t *config = nullptr; const dim_t thin_q_threshold = 16; auto queries = d->queries(); if (queries == 1) { queries = (d->q_desc.dims[1] / d->kv_head_number); } @@ -746,7 +747,7 @@ status_t micro_bwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { return status::success; } -status_t micro_t::init(impl::engine_t *engine) { +status_t micro_fwd_t::init(impl::engine_t *engine) { CHECK(create_kernel( engine, kernel_, pd()->conf.get_kernel_names()[0], pd()->conf)); @@ -834,7 +835,7 @@ static void init_conf_common(conf_t &conf, pd_type *pd) { conf.use_systolic_ukernel = pd->use_systolic_ukernel(); } -status_t micro_t::pd_t::init_conf(impl::engine_t *engine) { +status_t micro_fwd_t::pd_t::init_conf(impl::engine_t *engine) { using namespace micro; auto *pd = this; @@ -896,7 +897,7 @@ status_t micro_t::pd_t::init_conf(impl::engine_t *engine) { conf.val_group_size = pd->value_group_size(); /* Set up microkernel strategy */ - const config_t config = {conf.ukernel_config.unroll_m_kq, + const fwd_config_t config = {conf.ukernel_config.unroll_m_kq, conf.ukernel_config.unroll_n_kq, conf.ukernel_config.unroll_m_vs, conf.ukernel_config.unroll_n_vs, conf.ukernel_config.wg_m_kq, conf.ukernel_config.wg_n_kq, conf.ukernel_config.wg_m_vs, @@ -1044,7 +1045,7 @@ status_t micro_bwd_t::pd_t::init_scratchpad(impl::engine_t *engine) { return status::success; } -status_t micro_params_t::get_kernel_ctx( +status_t micro_fwd_params_t::get_kernel_ctx( compute::kernel_ctx_t &kernel_ctx) const { using namespace micro; @@ -1132,7 +1133,7 @@ status_t micro_params_t::get_kernel_ctx( micro::Package gemm_kq, gemm_vs; /* Set up microkernel strategy */ - const config_t config + const fwd_config_t config = {ukernel_config.unroll_m_kq, ukernel_config.unroll_n_kq, ukernel_config.unroll_m_vs, ukernel_config.unroll_n_vs, ukernel_config.wg_m_kq, ukernel_config.wg_n_kq, @@ -1455,7 +1456,7 @@ status_t micro_bwd_params_t::get_kernel_ctx( return status::success; } -status_t micro_t::execute_forward(const exec_ctx_t &ctx) const { +status_t micro_fwd_t::execute_forward(const exec_ctx_t &ctx) const { const auto &conf = pd()->conf; const auto &qry = CTX_IN_STORAGE(DNNL_ARG_QUERIES); @@ -1481,7 +1482,7 @@ status_t micro_t::execute_forward(const exec_ctx_t &ctx) const { const dim_t D = pd()->desc()->head_size(); const dim_t Q_per_kv_group = (Q == 1 ? Q * kv_group_size : Q); - const config_t config = {conf.ukernel_config.unroll_m_kq, + const fwd_config_t config = {conf.ukernel_config.unroll_m_kq, conf.ukernel_config.unroll_n_kq, conf.ukernel_config.unroll_m_vs, conf.ukernel_config.unroll_n_vs, conf.ukernel_config.wg_m_kq, conf.ukernel_config.wg_n_kq, conf.ukernel_config.wg_m_vs, diff --git a/src/gpu/intel/sdpa/micro.hpp b/src/gpu/intel/sdpa/micro.hpp index 19ba7b1923c..61dd39fa7a5 100644 --- a/src/gpu/intel/sdpa/micro.hpp +++ b/src/gpu/intel/sdpa/micro.hpp @@ -36,7 +36,7 @@ namespace gpu { namespace intel { namespace sdpa { -struct micro_params_t : trivially_serializable_t { +struct micro_fwd_params_t : trivially_serializable_t { const std::vector &get_kernel_names() const { static const std::vector kernel_names_fwd @@ -97,7 +97,7 @@ struct micro_params_t : trivially_serializable_t { micro_fwd_ukernel_params_t ukernel_config; }; -DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(micro_params_t); +DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(micro_fwd_params_t); struct micro_bwd_params_t : trivially_serializable_t { @@ -153,12 +153,12 @@ struct micro_bwd_params_t : trivially_serializable_t { }; DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(micro_bwd_params_t); -struct micro_t : public primitive_t { +struct micro_fwd_t : public primitive_t { using primitive_t::primitive_t; struct pd_t : public sdpa_fwd_pd_t { using sdpa_fwd_pd_t::sdpa_fwd_pd_t; - DECLARE_COMMON_PD_T("ocl:micro:reusable", micro_t); + DECLARE_COMMON_PD_T("ocl:micro:reusable", micro_fwd_t); status_t init(impl::engine_t *engine) { using namespace data_type; @@ -350,7 +350,7 @@ struct micro_t : public primitive_t { } compute::gpu_arch_t arch() const { return arch_; } - micro_params_t conf; + micro_fwd_params_t conf; private: int sg_size_ = 0; diff --git a/src/gpu/intel/sdpa/micro_bwd.cl b/src/gpu/intel/sdpa/micro_bwd.cl index fbf019d0c7c..3b13e479c1d 100644 --- a/src/gpu/intel/sdpa/micro_bwd.cl +++ b/src/gpu/intel/sdpa/micro_bwd.cl @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2024 Intel Corporation +* Copyright 2026 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/gpu/intel/sdpa/ref.cpp b/src/gpu/intel/sdpa/ref.cpp index d86a2dce5c0..e979cf8d1d8 100644 --- a/src/gpu/intel/sdpa/ref.cpp +++ b/src/gpu/intel/sdpa/ref.cpp @@ -25,7 +25,7 @@ namespace gpu { namespace intel { namespace sdpa { -status_t ref_t::execute_ref(const exec_ctx_t &ctx) const { +status_t ref_fwd_t::execute_ref(const exec_ctx_t &ctx) const { const auto &qry = CTX_IN_STORAGE(DNNL_ARG_QUERIES); const auto &key = CTX_IN_STORAGE(DNNL_ARG_KEYS); const auto &val = CTX_IN_STORAGE(DNNL_ARG_VALUES); diff --git a/src/gpu/intel/sdpa/ref.hpp b/src/gpu/intel/sdpa/ref.hpp index 3856dc2dd24..bf686117e3c 100644 --- a/src/gpu/intel/sdpa/ref.hpp +++ b/src/gpu/intel/sdpa/ref.hpp @@ -27,12 +27,12 @@ namespace gpu { namespace intel { namespace sdpa { -struct ref_t : public primitive_t { +struct ref_fwd_t : public primitive_t { using primitive_t::primitive_t; struct pd_t : public sdpa_fwd_pd_t { using sdpa_fwd_pd_t::sdpa_fwd_pd_t; - DECLARE_COMMON_PD_T("ocl:ref:any", ref_t); + DECLARE_COMMON_PD_T("ocl:ref:any", ref_fwd_t); status_t init(impl::engine_t *engine) { using namespace data_type; diff --git a/tests/gtests/internals/test_sdpa.cpp b/tests/gtests/internals/test_sdpa.cpp index 05741e0dfbf..8304b1f703f 100644 --- a/tests/gtests/internals/test_sdpa.cpp +++ b/tests/gtests/internals/test_sdpa.cpp @@ -1367,9 +1367,6 @@ std::vector timeit( } str.wait(); e = steady_clock::now(); - printf("timeit: %f \n", - (float)std::chrono::duration_cast(e - s).count() - / 1e6 / iterations); times.push_back(std::chrono::duration_cast(e - s)); } return times; @@ -2218,7 +2215,7 @@ class sdpa_test_t : public ::testing::TestWithParam { float fthreshold = 0.f; if (p.dt.dt == mdt::bf16 || p.dt.dt == mdt::f16) { //fthreshold = 0.0079f; //todo: correct threshold or better values - fthreshold = 0.1; + fthreshold = 0.1f; } else { fthreshold = 0.001466f; } From 7b2daeb6c369fe05f93e51bb72e4f094acf1dc9c Mon Sep 17 00:00:00 2001 From: syurkevi Date: Fri, 6 Mar 2026 17:32:29 -0800 Subject: [PATCH 14/23] gtests: internals: separate sdpa from internals --- tests/gtests/internals/CMakeLists.txt | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/gtests/internals/CMakeLists.txt b/tests/gtests/internals/CMakeLists.txt index a05d7d9b05f..c2b01c37791 100644 --- a/tests/gtests/internals/CMakeLists.txt +++ b/tests/gtests/internals/CMakeLists.txt @@ -57,4 +57,10 @@ register_exe(${TEST_EXE}_gmlp list(REMOVE_ITEM TEST_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/test_gated_mlp.cpp) +# Register SDPA tests as a separate executable due to their runtime +register_exe(${TEST_EXE}_sdpa + "${MAIN_SRC_GTEST};${CMAKE_CURRENT_SOURCE_DIR}/test_sdpa.cpp;${CMAKE_CURRENT_SOURCE_DIR}/test_utils.cpp" + "test" "dnnl_gtest") +list(REMOVE_ITEM TEST_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/test_sdpa.cpp) + register_exe(${TEST_EXE} "${TEST_SOURCES}" "test" "dnnl_gtest") From 8d8ab57a1361cd4c92c0cc24b2954b99921c7b74 Mon Sep 17 00:00:00 2001 From: syurkevi Date: Mon, 9 Mar 2026 12:43:34 -0700 Subject: [PATCH 15/23] common: sdpa: move prop_kind param before attrs --- src/common/sdpa_test_iface.cpp | 8 ++++---- src/common/sdpa_utils.hpp | 12 ++++++------ src/graph/backend/dnnl/executables/sdpa.cpp | 3 ++- tests/gtests/internals/sdpa_internal.hpp | 14 +++++++------- tests/gtests/internals/test_sdpa.cpp | 16 ++++++++-------- 5 files changed, 27 insertions(+), 26 deletions(-) diff --git a/src/common/sdpa_test_iface.cpp b/src/common/sdpa_test_iface.cpp index 759d4066804..fb98303ff81 100644 --- a/src/common/sdpa_test_iface.cpp +++ b/src/common/sdpa_test_iface.cpp @@ -30,9 +30,9 @@ dnnl_status_t DNNL_API sdpa_primitive_desc_create( const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc, const_dnnl_memory_desc_t mask_desc, const_dnnl_memory_desc_t scale_desc, bool invert_scale, dnnl_dim_t kv_head_number, int attn_mask_type, - dnnl_alg_kind_t softmax_alg, const_dnnl_primitive_attr_t attr, - const_dnnl_primitive_attr_t kq_attr, - const_dnnl_primitive_attr_t vs_attr, prop_kind_t prop) { + dnnl_alg_kind_t softmax_alg, prop_kind_t prop, + const_dnnl_primitive_attr_t attr, const_dnnl_primitive_attr_t kq_attr, + const_dnnl_primitive_attr_t vs_attr) { CHECK(sdpa_desc_check(query_desc, key_desc, value_desc, dst_desc, mask_desc, engine, attr, kq_attr, vs_attr)); CHECK(sdpa_attr_check( @@ -41,7 +41,7 @@ dnnl_status_t DNNL_API sdpa_primitive_desc_create( dnnl::impl::sdpa_desc_t sdpa_desc = dnnl::impl::create_sdpa_desc(query_desc, key_desc, value_desc, dst_desc, mask_desc, scale_desc, invert_scale, kv_head_number, static_cast(attn_mask_type), - softmax_alg, kq_attr, vs_attr, prop); + softmax_alg, prop, kq_attr, vs_attr); return dnnl::impl::primitive_desc_create(primitive_desc_iface, engine, (const dnnl::impl::op_desc_t *)&sdpa_desc, nullptr, attr); } diff --git a/src/common/sdpa_utils.hpp b/src/common/sdpa_utils.hpp index 2266721e35c..0f74a5a49c3 100644 --- a/src/common/sdpa_utils.hpp +++ b/src/common/sdpa_utils.hpp @@ -153,8 +153,8 @@ static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md, const memory_desc_t *dst_md, const memory_desc_t *attn_mask_md, const memory_desc_t *scale_md, bool invert_scale, dim_t kv_head_number, attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg, - const primitive_attr_t *kq_attr, const primitive_attr_t *vs_attr, - prop_kind_t prop) { + prop_kind_t prop, const primitive_attr_t *kq_attr, + const primitive_attr_t *vs_attr) { auto sdpa_desc = sdpa_desc_t(); sdpa_desc.primitive_kind = primitive_kind::sdpa; sdpa_desc.q_desc = *q_md; @@ -227,16 +227,16 @@ static inline status_t create_sdpa_pd( const memory_desc_t *attn_mask_md, const memory_desc_t *scale_md, bool invert_scale, dim_t kv_head_number, attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg, - const primitive_attr_t *attr, const primitive_attr_t *kq_attr = nullptr, - const primitive_attr_t *vs_attr = nullptr, - prop_kind_t prop = prop_kind::forward_inference) { + prop_kind_t prop, const primitive_attr_t *attr, + const primitive_attr_t *kq_attr = nullptr, + const primitive_attr_t *vs_attr = nullptr) { CHECK(sdpa_attr_check(q_md, k_md, v_md, engine, attr, kq_attr, vs_attr)); CHECK(sdpa_desc_check(q_md, k_md, v_md, dst_md, attn_mask_md, engine, attr, kq_attr, vs_attr)); auto sdpa_desc = create_sdpa_desc(q_md, k_md, v_md, dst_md, attn_mask_md, scale_md, invert_scale, kv_head_number, attn_mask_type, softmax_alg, - kq_attr, vs_attr, prop); + prop, kq_attr, vs_attr); primitive_attr_t sdpa_attr = attr ? *attr : default_attr(); diff --git a/src/graph/backend/dnnl/executables/sdpa.cpp b/src/graph/backend/dnnl/executables/sdpa.cpp index 3ec9ce405e9..7b84f7feca2 100644 --- a/src/graph/backend/dnnl/executables/sdpa.cpp +++ b/src/graph/backend/dnnl/executables/sdpa.cpp @@ -76,7 +76,8 @@ sdpa_executable_t::sdpa_executable_t(std::shared_ptr &op, status_t s = create_sdpa_pd(sdpa_pd_, p_engine.get(), md_q.get(), md_k.get(), md_v.get(), md_dst.get(), md_mask.get(), md_scale.get(), is_invert_scale_, kv_head_number, mask_type_, softmax_alg, - attr.get(), qk_attr.get(), vs_attr.get()); + impl::prop_kind::forward_inference, attr.get(), qk_attr.get(), + vs_attr.get()); if (s != dnnl::impl::status::success) { is_initialized_ = false; } else { diff --git a/tests/gtests/internals/sdpa_internal.hpp b/tests/gtests/internals/sdpa_internal.hpp index 1f3a23957c9..1ccef3d647f 100644 --- a/tests/gtests/internals/sdpa_internal.hpp +++ b/tests/gtests/internals/sdpa_internal.hpp @@ -42,9 +42,9 @@ dnnl_status_t DNNL_API sdpa_primitive_desc_create( const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc, const_dnnl_memory_desc_t mask_desc, const_dnnl_memory_desc_t scale_desc, bool invert_scale, dnnl_dim_t kv_head_number, int attn_mask_type, - dnnl_alg_kind_t softmax_alg, const_dnnl_primitive_attr_t attr, - const_dnnl_primitive_attr_t kq_attr, - const_dnnl_primitive_attr_t vs_attr, dnnl_prop_kind_t prop); + dnnl_alg_kind_t softmax_alg, dnnl_prop_kind_t prop, + const_dnnl_primitive_attr_t attr, const_dnnl_primitive_attr_t kq_attr, + const_dnnl_primitive_attr_t vs_attr); dnnl_status_t DNNL_API sdpa_primitive_desc_create( dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine, @@ -79,10 +79,10 @@ struct sdpa : public dnnl::primitive { const memory::desc &scale_desc, const memory::desc &output_desc, bool invert_scale, memory::dim kv_head_number, int attn_mask_type, int softmax_alg, + prop_kind_t prop_kind = prop_kind::forward_inference, const primitive_attr &attr = default_attr(), const primitive_attr &kq_attr = default_attr(), - const primitive_attr &vs_attr = default_attr(), - prop_kind_t prop_kind = prop_kind::forward_inference) { + const primitive_attr &vs_attr = default_attr()) { dnnl_primitive_desc_t pd = nullptr; dnnl_status_t status = sdpa_primitive_desc_create(&pd, @@ -90,8 +90,8 @@ struct sdpa : public dnnl::primitive { value_desc.get(), output_desc.get(), optional_arg(attn_mask_desc), scale_desc.get(), invert_scale, kv_head_number, attn_mask_type, - (dnnl_alg_kind_t)softmax_alg, attr.get(), kq_attr.get(), - vs_attr.get(), (prop_kind_t)prop_kind); + (dnnl_alg_kind_t)softmax_alg, (prop_kind_t)prop_kind, + attr.get(), kq_attr.get(), vs_attr.get()); dnnl::error::wrap_c_api(status, "could not create a primitive descriptor for a sdpa " diff --git a/tests/gtests/internals/test_sdpa.cpp b/tests/gtests/internals/test_sdpa.cpp index 8304b1f703f..cafdea11d76 100644 --- a/tests/gtests/internals/test_sdpa.cpp +++ b/tests/gtests/internals/test_sdpa.cpp @@ -2008,8 +2008,8 @@ class sdpa_test_t : public ::testing::TestWithParam { t.m_scale.get_desc(), t.m_output_quantized.get_desc(), invert_scale, p.heads.kv, to_attn_mask_type(p.mask.type), dnnl::impl::alg_kind::softmax_accurate_inf_as_zero, - t.sdpa_attr_quantized, t.sdpa_kq_attr_quantized, - t.sdpa_vs_attr_quantized); + prop_kind::forward_inference, t.sdpa_attr_quantized, + t.sdpa_kq_attr_quantized, t.sdpa_vs_attr_quantized); sdpa_quantized_p = sdpa(sdpa_quantized_pd); } catch (const dnnl::error &e) { if (e.status == dnnl_unimplemented) @@ -2128,8 +2128,8 @@ class sdpa_test_t : public ::testing::TestWithParam { t.m_scale.get_desc(), t.m_output_quantized.get_desc(), invert_scale, p.heads.kv, to_attn_mask_type(p.mask.type), dnnl::impl::alg_kind::softmax_accurate_inf_as_zero, - t.sdpa_attr_quantized, t.sdpa_kq_attr_quantized, - t.sdpa_vs_attr_quantized, prop_kind::forward_training); + prop_kind::forward_training, t.sdpa_attr_quantized, + t.sdpa_kq_attr_quantized, t.sdpa_vs_attr_quantized); sdpa_fwd = sdpa(sdpa_fwd_pd); sdpa_bwd_pd = sdpa_backward::primitive_desc(eng, @@ -2295,8 +2295,8 @@ class sdpa_test_t : public ::testing::TestWithParam { t.m_scale.get_desc(), t.m_output_quantized.get_desc(), invert_scale, p.heads.kv, to_attn_mask_type(p.mask.type), alg_kind::softmax_accurate_inf_as_zero, - t.sdpa_attr_quantized, t.sdpa_kq_attr_quantized, - t.sdpa_vs_attr_quantized); + prop_kind::forward_inference, t.sdpa_attr_quantized, + t.sdpa_kq_attr_quantized, t.sdpa_vs_attr_quantized); sdpa_quantized_p = sdpa(sdpa_quantized_pd); } catch (const dnnl::error &e) { if (e.status == dnnl_unimplemented) @@ -2445,8 +2445,8 @@ class sdpa_test_t : public ::testing::TestWithParam { t.m_scale.get_desc(), t.m_output_quantized.get_desc(), invert_scale, p.heads.kv, to_attn_mask_type(p.mask.type), dnnl::impl::alg_kind::softmax_accurate_inf_as_zero, - t.sdpa_attr_quantized, t.sdpa_kq_attr_quantized, - t.sdpa_vs_attr_quantized, prop_kind::forward_training); + prop_kind::forward_training, t.sdpa_attr_quantized, + t.sdpa_kq_attr_quantized, t.sdpa_vs_attr_quantized); sdpa_fwd = sdpa(sdpa_fwd_pd); sdpa_bwd_pd = sdpa_backward::primitive_desc(eng, From 55684e0c8f76b0481fc72e5306b53061190dfa2b Mon Sep 17 00:00:00 2001 From: syurkevi Date: Tue, 10 Mar 2026 21:36:03 -0700 Subject: [PATCH 16/23] xe: sdpa: enable transpose_k for training --- src/gpu/intel/include/tile_ops.h | 18 ++++ src/gpu/intel/sdpa/micro.cpp | 5 +- src/gpu/intel/sdpa/micro.hpp | 2 +- src/gpu/intel/sdpa/micro_bwd.cl | 146 ++++++++++++++++++++++----- tests/gtests/internals/test_sdpa.cpp | 24 ++--- 5 files changed, 154 insertions(+), 41 deletions(-) diff --git a/src/gpu/intel/include/tile_ops.h b/src/gpu/intel/include/tile_ops.h index 0a65b05b746..d82d7aa8dad 100644 --- a/src/gpu/intel/include/tile_ops.h +++ b/src/gpu/intel/include/tile_ops.h @@ -516,6 +516,24 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) } \ } \ } \ + __attribute__((overloadable)) void tile_store_t(tile_type t, \ + local element_type *ptr, int m, int n, int ld, int offset_r, \ + int offset_c) { \ + if (m >= offset_r + br * nbr && n >= offset_c + bc * nbc) { \ + tile_store_t_full(t, ptr, ld, offset_r, offset_c); \ + return; \ + } \ + ptr += ld * offset_r + offset_c; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr++) { \ + if (offset_c + j < n) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = ld * (i0 + get_sub_group_local_id()); \ + if ((offset_r + i0 + get_sub_group_local_id()) < m) \ + ptr[i] = tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_store_full(tile_type t, \ local element_type *ptr, int ld, int offset_r, int offset_c) { \ ptr += ld * offset_c + offset_r; \ diff --git a/src/gpu/intel/sdpa/micro.cpp b/src/gpu/intel/sdpa/micro.cpp index cfcd7a16a4a..80d22da3d24 100644 --- a/src/gpu/intel/sdpa/micro.cpp +++ b/src/gpu/intel/sdpa/micro.cpp @@ -988,8 +988,9 @@ status_t micro_bwd_t::pd_t::init_conf(impl::engine_t *engine) { conf.block_k = conf.block_dK = conf.block_dV = false; if (d_full) { - conf.block_dK = conf.block_k - = (ldk % 4 == 0) && (d->keys() % tile_k == 0); + bool can_block_load_k = (ldk % 4 == 0) && (d->keys() % tile_k == 0); + conf.block_k = can_block_load_k; + conf.block_dK = can_block_load_k && !conf.transpose_k; conf.block_dV = (ldv % 4 == 0) && (dv_full); } diff --git a/src/gpu/intel/sdpa/micro.hpp b/src/gpu/intel/sdpa/micro.hpp index 61dd39fa7a5..31637ca344f 100644 --- a/src/gpu/intel/sdpa/micro.hpp +++ b/src/gpu/intel/sdpa/micro.hpp @@ -510,7 +510,7 @@ struct micro_bwd_t : public primitive_t { CHECK(set_default_format(desc_.dst_desc, false)); CHECK(set_default_format(desc_.diff_dst_desc, false)); CHECK(set_default_format(desc_.diff_q_desc, false)); - CHECK(set_default_format(desc_.diff_k_desc, false)); + CHECK(set_default_format(desc_.diff_k_desc, true)); CHECK(set_default_format(desc_.diff_v_desc, false)); return status::success; } diff --git a/src/gpu/intel/sdpa/micro_bwd.cl b/src/gpu/intel/sdpa/micro_bwd.cl index 3b13e479c1d..812aee4e14e 100644 --- a/src/gpu/intel/sdpa/micro_bwd.cl +++ b/src/gpu/intel/sdpa/micro_bwd.cl @@ -36,7 +36,6 @@ #define sg_per_wg MAX(sg_per_wg_BcBr, MAX(sg_per_wg_BcD, sg_per_wg_BrD)) #define q_tile_sg_n DIV_UP(ugemm_kq_wg_tile_n, sg_per_wg) -#define dmax_tile_sg_n DIV_UP(D_MAX, sg_per_wg) /* Instantiate tile types and operations */ typedef ugemm_kq_c_type s_tile_type; // Bc*Br tile @@ -72,12 +71,26 @@ DECLARE_2D_TILE_COPY_REBLOCK(q_tile_type, SUBGROUP_SIZE, D_MAX, 1, 1, q_tile_sg_n, dq_tile_type, SUBGROUP_SIZE, D_MAX, 1, 1, q_tile_sg_n, CONVERT_FLOAT_T) +#if TRANSPOSE_K + +#define k_tile_t_sg_n DIV_UP(ugemm_kq_wg_tile_m, sg_per_wg) +DECLARE_2D_TILE( + k_tile_type, FMA_TYPE, SUBGROUP_SIZE, D_MAX, 1, 1, k_tile_t_sg_n) +#if BLOCK_K +DECLARE_2D_TILE_BLOCK_OPS( + k_tile_type, FMA_TYPE, SUBGROUP_SIZE, D_MAX, 1, 1, k_tile_t_sg_n) +#endif + +#else + +#define dmax_tile_sg_n DIV_UP(D_MAX, sg_per_wg) DECLARE_2D_TILE(k_tile_type, FMA_TYPE, SUBGROUP_SIZE, ugemm_kq_wg_tile_m, 1, 1, dmax_tile_sg_n) #if BLOCK_K DECLARE_2D_TILE_BLOCK_OPS(k_tile_type, FMA_TYPE, SUBGROUP_SIZE, ugemm_kq_wg_tile_m, 1, 1, dmax_tile_sg_n) #endif +#endif DECLARE_2D_TILE(s_tile_type_packed, uint, SUBGROUP_SIZE, ugemm_kq_c_type_block0, ugemm_kq_c_type_block1 / 2, ugemm_kq_c_type_nblock0, @@ -233,13 +246,57 @@ DECLARE_2D_TILE_VREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0, #define binary_add(x, y) ((x) + (y)) -inline void tile_load_k(k_tile_type *K_tile, const global KEY_DATA_T *K, int m, - int n, int ldk, int offset_r, int offset_c, int load_rem) { +inline void tile_load_k(k_tile_type *K_tile, const global KEY_DATA_T *K, + int seq_len, int ldk, int seq_off, int sg_ij, int load_rem) { + +#if TRANSPOSE_K + // Bc / n_sg -- each sg loads k_tile_t_sg_n k-columns + uint k0_copy = k_tile_t_sg_n * sg_ij; + // Coalesced load from d×k column-major memory (d contiguous, k strided) +#if BLOCK_K + tile_load_block(K_tile, K, ldk, 0, seq_off + k0_copy); +#else + tile_load(K_tile, K, D_MAX, seq_len, ldk, 0, seq_off + k0_copy); +#endif + +#else + // D_MAX / n_sg + uint k0_copy = dmax_tile_sg_n * sg_ij; #if BLOCK_K // can ignore load_rem due to d_full requirement - tile_load_block(K_tile, K, ldk, offset_r, offset_c); + tile_load_block(K_tile, K, ldk, seq_off, k0_copy); #else - tile_load(K_tile, K, m, n, ldk, offset_r, offset_c); + tile_load(K_tile, K, seq_len, D_MAX, ldk, seq_off, k0_copy); +#endif + +#endif +} + +inline void tile_store_k_slm( + k_tile_type *K_tile, local KEY_DATA_T *K_slm, int sg_ij) { + +#if TRANSPOSE_K + // Bc / n_sg -- tile is D*Bc, write transposed to SLM (Bc*D) + uint k0_copy = k_tile_t_sg_n * sg_ij; +#if USE_SYSTOLIC_UKERNEL + tile_store_t_sys_src11(*K_tile, K_slm, SUBGROUP_SIZE, D_MAX, D_MAX, + ugemm_kq_wg_tile_m, 0, k0_copy); +#else + tile_store_t_packed_src1( + *K_tile, K_slm, ugemm_kq_sg_tile_m, D_MAX, k0_copy, 0); +#endif + +#else + + uint k0_copy = dmax_tile_sg_n * sg_ij; +#if USE_SYSTOLIC_UKERNEL + tile_store_sys_src1(*K_tile, K_slm, SUBGROUP_SIZE, D_MAX, + ugemm_kq_wg_tile_m, D_MAX, 0, k0_copy); +#else + tile_store_packed_src1( + *K_tile, K_slm, ugemm_kq_sg_tile_m, D_MAX, 0, k0_copy); +#endif + #endif } @@ -267,15 +324,32 @@ inline void tile_store_dV(dv_tile_type *dV_tile_slm, global DST_DATA_T_DKDV *dV, #endif } -inline void tile_store_dK(a_tile_type *dK_tile_slm, global DST_DATA_T_DKDV *dK, +#if TRANSPOSE_K +// uses transposed dv_tile_type (D*Bc) for dK update +inline void tile_store_dK_t(dv_tile_type *dK_tile, global DST_DATA_T_DKDV *dK, + int m, int n, int ld, int offset_r, int offset_c, int rem) { + +#if KV_GROUP_SIZE > 1 // GQA update + tile_atomic_add(*dK_tile, dK, m, n, ld, offset_r, offset_c); +#else // MHA update + dv_tile_type_dst dK_tile_dst; + tile_copy_reblock(*dK_tile, &dK_tile_dst); + tile_store(dK_tile_dst, dK, m, n, ld, offset_r, offset_c); +#endif +} + +#else + +// uses qdSt tile (Bc*D) for dK update +inline void tile_store_dK(a_tile_type *dK_tile, global DST_DATA_T_DKDV *dK, int m, int n, int ld, int offset_r, int offset_c) { #if KV_GROUP_SIZE > 1 // GQA update - tile_atomic_add(*dK_tile_slm, dK, m, n, ld, offset_r, offset_c); + tile_atomic_add(*dK_tile, dK, m, n, ld, offset_r, offset_c); #else // MHA update - a_tile_type_dst dK_tile_dst; // convert to half - tile_copy_reblock(*dK_tile_slm, &dK_tile_dst); + a_tile_type_dst dK_tile_dst; + tile_copy_reblock(*dK_tile, &dK_tile_dst); #if BLOCK_DK tile_store_block(dK_tile_dst, dK, ld, offset_r, offset_c); #else @@ -285,6 +359,8 @@ inline void tile_store_dK(a_tile_type *dK_tile_slm, global DST_DATA_T_DKDV *dK, #endif } +#endif + #define DO_MM 1 __attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void @@ -427,21 +503,13 @@ micro_sdpa_bwd(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, if (qdiag0 < q0end) { /* Load K tile, destined for SLM */ - k_tile_type K_tile; tile_fill(K_tile, TO_DATA_T(0.f)); - uint k0_copy = dmax_tile_sg_n - * sg_ij; //each sg will be responsible for dmax_tile_sg_n columns - tile_load_k(&K_tile, K, k, d, ldk, wg_i0, k0_copy, remainder_k); - ///* Store K tile to SLM */ -#if USE_SYSTOLIC_UKERNEL - tile_store_sys_src1(K_tile, &K_slm[0], SUBGROUP_SIZE, D_MAX, - ugemm_kq_wg_tile_m, D_MAX, 0, k0_copy); -#else - tile_store_packed_src1( - K_tile, K_slm, ugemm_kq_sg_tile_m, D_MAX, 0, k0_copy); -#endif + tile_load_k(&K_tile, K, k, ldk, wg_i0, sg_ij, remainder_k); + + /* Store K tile to SLM */ + tile_store_k_slm(&K_tile, K_slm, sg_ij); } /* Load scale */ @@ -608,7 +676,7 @@ micro_sdpa_bwd(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, uint sg_i0_vs = sg_i_vs * ugemm_vs_sg_tile_m; uint sg_j0_vs = sg_j_vs * ugemm_vs_sg_tile_n; - //slm dv tile + // accumulate dv tile to slm dv_tile_type dV_tile_slm; tile_load(&dV_tile_slm, dV_slm, D_MAX, ugemm_kq_wg_tile_m, D_MAX, sg_i0_vs, sg_j0_vs); @@ -691,13 +759,22 @@ micro_sdpa_bwd(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, uint sg_i0_dk = sg_i_qdSt * ugemm_qdSt_sg_tile_m; uint sg_j0_dk = sg_j_qdSt * ugemm_qdSt_sg_tile_n; - //// dk slm tile + // dk slm tile a_tile_type dK_tile_slm; +#if TRANSPOSE_K + // accumulate transposed to slm + tile_load_t(&dK_tile_slm, dK_slm, ugemm_kq_wg_tile_m, D_MAX, D_MAX, + sg_i0_dk, sg_j0_dk); + tile_binary(dK_tile_slm, dK_tile1, binary_add); + tile_store_t(dK_tile_slm, dK_slm, ugemm_kq_wg_tile_m, D_MAX, D_MAX, + sg_i0_dk, sg_j0_dk); +#else tile_load(&dK_tile_slm, dK_slm, ugemm_kq_wg_tile_m, D_MAX, ugemm_kq_wg_tile_m, sg_i0_dk, sg_j0_dk); tile_binary(dK_tile_slm, dK_tile1, binary_add); tile_store(dK_tile_slm, dK_slm, ugemm_kq_wg_tile_m, D_MAX, ugemm_kq_wg_tile_m, sg_i0_dk, sg_j0_dk); +#endif p_tile_type_reblock dS_transpose_tile; #if USE_SYSTOLIC_UKERNEL @@ -719,9 +796,15 @@ micro_sdpa_bwd(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, // dQ = dS * K #if DO_MM ktq_tile_type dQ_tile; - dQ_tile = ugemm_ktq(K + k0, ldk, dS_slm, ugemm_kq_wg_tile_m, d, - q_nchunk, k_chunk, 0, 0, 0, sg_i_ktq, sg_j_ktq, - (local char *)ugemm_slm); + + dQ_tile = ugemm_ktq( +#if TRANSPOSE_K + K + k0 * ldk, +#else + K + k0, +#endif + ldk, dS_slm, ugemm_kq_wg_tile_m, d, q_nchunk, k_chunk, 0, 0, 0, + sg_i_ktq, sg_j_ktq, (local char *)ugemm_slm); #else ktq_tile_type dQ_tile; #endif @@ -747,6 +830,16 @@ micro_sdpa_bwd(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, // /update dV //////// update dK +#if TRANSPOSE_K + // transposed dK_slm (D*Bc) matches dV tile layout + dv_tile_type dK_tile_t; + tile_load(&dK_tile_t, dK_slm, D_MAX, ugemm_kq_wg_tile_m, D_MAX, sg_i0_vs, + sg_j0_vs); + + tile_store_dK_t( + &dK_tile_t, dK, d, k, ldk, sg_i0_vs, wg_i0 + sg_j0_vs, remainder_k); +#else + // non-transposed dK_slm uses qdSt layout (Bc*D) and indexing uint sg_i0_dk = sg_i_qdSt * ugemm_qdSt_sg_tile_m; uint sg_j0_dk = sg_j_qdSt * ugemm_qdSt_sg_tile_n; @@ -757,6 +850,7 @@ micro_sdpa_bwd(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, int wg_k_chunk = min(k - k0, ugemm_kq_wg_tile_m); tile_store_dK( &dK_tile_slm, dK + wg_i0, wg_k_chunk, d, ldk, sg_i0_dk, sg_j0_dk); +#endif // /update dK } diff --git a/tests/gtests/internals/test_sdpa.cpp b/tests/gtests/internals/test_sdpa.cpp index cafdea11d76..98f169a0482 100644 --- a/tests/gtests/internals/test_sdpa.cpp +++ b/tests/gtests/internals/test_sdpa.cpp @@ -1659,15 +1659,12 @@ std::chrono::nanoseconds prim_sdpa_quant_bwd(const sdpa_dims_t &p, memory dK_mem(key_dequantized.get_desc(), eng); // backwards pass gradient of q (dQ = dS * k^t) - // TODO: handle transposed K test case - // memory::desc k_t_md = p.with_key_transposed // k^t requires transposed format and dims - // ? memory::desc({k_sz[0], k_sz[1], k_sz[2], k_sz[4], k_sz[3]}, - // memory::data_type::f32, memory::format_tag::abcde) - // : memory::desc({k_sz[0], k_sz[1], k_sz[2], k_sz[4], k_sz[3]}, - // memory::data_type::f32, memory::format_tag::abced); - memory::desc k_t_md - = memory::desc({k_sz[0], k_sz[1], k_sz[2], k_sz[4], k_sz[3]}, - p.dt.dt, memory::format_tag::abced); + // k^t requires transposed format and dims + auto k_t_fmt = p.key_format_tag == memory::format_tag::abcd + ? memory::format_tag::abced + : memory::format_tag::abcde; + memory::desc k_t_md = memory::desc( + {k_sz[0], k_sz[1], k_sz[2], k_sz[4], k_sz[3]}, p.dt.dt, k_t_fmt); matmul::primitive_desc mm_bwd_dq_pd( eng, diff_score_md, k_t_md, grouped_query_md); matmul mm_bwd_dq(mm_bwd_dq_pd); @@ -2952,7 +2949,8 @@ INSTANTIATE_TEST_SUITE_P(bwd_f16, sdpa_bwd_test_datatypes, testing::Values(tensor_type_t("K", mdt::f16)), // kdt testing::Values(tensor_type_t("V", mdt::f16)), // vdt testing::Values(quantize_type::no_quantization), // qtype - testing::Values(dnnl::memory::format_tag::abcd), // key_format_tag + testing::Values(dnnl::memory::format_tag::abcd, + dnnl::memory::format_tag::abdc), // key_format_tag testing::Values(mask_config_t {mask_type::no_mask}, mask_config_t {mask_type::causal_tl}, mask_config_t {mask_type::causal_br}, mask_config_t {mask_type::twoD}, mask_config_t {mask_type::oneD} @@ -2977,7 +2975,8 @@ INSTANTIATE_TEST_SUITE_P(bwd_bf16, sdpa_bwd_test_datatypes, testing::Values(tensor_type_t("K", mdt::bf16)), // kdt testing::Values(tensor_type_t("V", mdt::bf16)), // vdt testing::Values(quantize_type::no_quantization), // qtype - testing::Values(dnnl::memory::format_tag::abcd), // key_format_tag + testing::Values(dnnl::memory::format_tag::abcd, + dnnl::memory::format_tag::abdc), // key_format_tag testing::Values(mask_config_t {mask_type::causal_tl}), // mask_type testing::Values(default_scale_type), // scale_type testing::Values( @@ -2999,7 +2998,8 @@ INSTANTIATE_TEST_SUITE_P(bwd_gqa, sdpa_bwd_test_datatypes, testing::Values(tensor_type_t("K", mdt::f16)), // kdt testing::Values(tensor_type_t("V", mdt::f16)), // vdt testing::Values(quantize_type::no_quantization), // qtype - testing::Values(dnnl::memory::format_tag::abcd), // key_format_tag + testing::Values(dnnl::memory::format_tag::abcd, + dnnl::memory::format_tag::abdc), // key_format_tag testing::Values(mask_config_t {mask_type::no_mask}, mask_config_t {mask_type::causal_tl}), // mask_type testing::Values(default_scale_type), // scale_type From 59dd3460582e2c57f2ab41c1672463ac9db4679d Mon Sep 17 00:00:00 2001 From: syurkevi Date: Thu, 12 Mar 2026 00:03:25 -0700 Subject: [PATCH 17/23] common: sdpa: refactors pd accessors, misc cleanup --- src/common/sdpa_pd.hpp | 49 +--- src/common/sdpa_test_iface.cpp | 21 +- src/common/sdpa_types.hpp | 19 +- src/common/sdpa_utils.hpp | 84 +++++- src/common/type_helpers.hpp | 6 + src/common/verbose.cpp | 18 +- src/gpu/intel/sdpa/micro.cpp | 344 +++++++++++------------ src/gpu/intel/sdpa/micro.hpp | 243 ++++++++-------- src/gpu/intel/sdpa/ref.hpp | 29 +- tests/gtests/internals/sdpa_internal.hpp | 19 +- tests/gtests/internals/test_sdpa.cpp | 25 +- 11 files changed, 461 insertions(+), 396 deletions(-) diff --git a/src/common/sdpa_pd.hpp b/src/common/sdpa_pd.hpp index 162966bbc0a..c5bc1a06b48 100644 --- a/src/common/sdpa_pd.hpp +++ b/src/common/sdpa_pd.hpp @@ -43,14 +43,10 @@ struct sdpa_fwd_pd_t; struct sdpa_pd_t : public primitive_desc_t { static constexpr auto base_pkind = primitive_kind::sdpa; - static constexpr int mask_mb_index = 0; static constexpr int mask_q_index = 2; static constexpr int mask_k_index = 3; static constexpr int ndims = 4; - using base_class = sdpa_pd_t; - using hint_class = sdpa_fwd_pd_t; - const sdpa_desc_t *desc() const { return &desc_; } const op_desc_t *op_desc() const override { return reinterpret_cast(this->desc()); @@ -62,19 +58,15 @@ struct sdpa_pd_t : public primitive_desc_t { } bool with_attn_scale() const { - return (scale_md()->data_type != data_type::undef); + return (desc()->scale_md()->data_type != data_type::undef); } bool with_host_scale() const { - return (scale_md()->format_kind == format_kind::host_scalar); + return (desc()->scale_md()->format_kind == format_kind::host_scalar); } bool with_attn_mask() const { - return (attn_mask_md()->data_type != data_type::undef); - } - - bool with_dS() const { - return (desc_.dS_desc.data_type != data_type::undef); + return (desc()->attn_mask_md()->data_type != data_type::undef); } /// Returns the accumulation data type of the KQ matmul @@ -133,9 +125,9 @@ struct sdpa_pd_t : public primitive_desc_t { int key_group_size() const { int out = 0; if (with_key_scales()) - out = group_size(desc()->kq_scales, *key_md()); + out = group_size(desc()->kq_scales, *desc()->key_md()); else if (with_key_zp()) { - out = group_size(desc()->kq_zero_points, *key_md()); + out = group_size(desc()->kq_zero_points, *desc()->key_md()); } return out; } @@ -144,19 +136,13 @@ struct sdpa_pd_t : public primitive_desc_t { int value_group_size() const { int out = 0; if (with_value_scales()) - out = group_size(desc()->vs_scales, *val_md()); + out = group_size(desc()->vs_scales, *desc()->val_md()); else if (with_value_zp()) { - out = group_size(desc()->vs_zero_points, *val_md()); + out = group_size(desc()->vs_zero_points, *desc()->val_md()); } return out; } - const memory_desc_t *qry_md() const { return &desc_.q_desc; } - const memory_desc_t *key_md() const { return &desc_.k_desc; } - const memory_desc_t *val_md() const { return &desc_.v_desc; } - const memory_desc_t *attn_mask_md() const { return &desc_.attn_mask_desc; } - const memory_desc_t *scale_md() const { return &desc_.scale_desc; } - protected: sdpa_desc_t desc_; const sdpa_fwd_pd_t *hint_fwd_pd_; @@ -164,14 +150,14 @@ struct sdpa_pd_t : public primitive_desc_t { memory_desc_t ws_md_; sdpa_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, - const hint_class *hint_fwd_pd) + const sdpa_fwd_pd_t *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) , desc_(*op_desc_t::to_desc(adesc)) , hint_fwd_pd_(hint_fwd_pd) {} status_t init_default_ws() { dims_t d; - d[0] = desc()->batch_size() + d[0] = desc()->batch() * desc()->num_q_heads() * desc()->queries(); // (logsumexp) per query return memory_desc_init_by_tag( @@ -300,12 +286,6 @@ struct sdpa_bwd_pd_t : public sdpa_pd_t { DNNL_ARG_SCALE)) return arg_usage_t::input; - if (utils::one_of(arg, DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS, - DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES, - DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS, - DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES)) - return arg_usage_t::unused; - if (utils::one_of(arg, DNNL_ARG_DIFF_QUERIES, DNNL_ARG_DIFF_KEYS, DNNL_ARG_DIFF_VALUES)) return arg_usage_t::output; @@ -313,9 +293,7 @@ struct sdpa_bwd_pd_t : public sdpa_pd_t { if (arg == DNNL_ARG_DS) return with_dS() ? arg_usage_t::output : arg_usage_t::unused; - if (arg == DNNL_ARG_WORKSPACE) - return !types::is_zero_md(workspace_md()) ? arg_usage_t::input - : arg_usage_t::unused; + if (arg == DNNL_ARG_WORKSPACE) return arg_usage_t::input; return primitive_desc_t::arg_usage(arg); } @@ -364,6 +342,10 @@ struct sdpa_bwd_pd_t : public sdpa_pd_t { : &glob_zero_md; } + bool with_dS() const { + return (desc_.dS_desc.data_type != data_type::undef); + } + int n_inputs() const override { // Q, K, V, O, dO return 5 + int(with_attn_mask()) + int(with_attn_scale()) @@ -371,9 +353,6 @@ struct sdpa_bwd_pd_t : public sdpa_pd_t { } int n_outputs() const override { return 3 + int(with_dS()); } - const memory_desc_t *diff_qry_md() const { return &desc_.diff_q_desc; } - const memory_desc_t *diff_key_md() const { return &desc_.diff_k_desc; } - const memory_desc_t *diff_val_md() const { return &desc_.diff_v_desc; } const memory_desc_t *diff_dst_md( int index = 0, bool user_input = false) const override { return index == 0 ? &desc_.diff_dst_desc : &glob_zero_md; diff --git a/src/common/sdpa_test_iface.cpp b/src/common/sdpa_test_iface.cpp index fb98303ff81..6b6b33e1359 100644 --- a/src/common/sdpa_test_iface.cpp +++ b/src/common/sdpa_test_iface.cpp @@ -50,28 +50,25 @@ dnnl_status_t DNNL_API sdpa_primitive_desc_create( dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine, const_dnnl_memory_desc_t query_desc, const_dnnl_memory_desc_t key_desc, const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc, + const_dnnl_memory_desc_t mask_desc, const_dnnl_memory_desc_t scale_desc, const_dnnl_memory_desc_t diff_query_desc, const_dnnl_memory_desc_t diff_key_desc, const_dnnl_memory_desc_t diff_value_desc, const_dnnl_memory_desc_t diff_dst_desc, - const_dnnl_memory_desc_t dS_desc, const_dnnl_memory_desc_t mask_desc, - const_dnnl_memory_desc_t scale_desc, bool invert_scale, + const_dnnl_memory_desc_t dS_desc, bool invert_scale, dnnl_dim_t kv_head_number, int attn_mask_type, dnnl_alg_kind_t softmax_alg, const_dnnl_primitive_attr_t attr, - const_dnnl_primitive_attr_t kq_attr, - const_dnnl_primitive_attr_t vs_attr, const_dnnl_primitive_desc_t hint_fwd_pd = nullptr) { CHECK(sdpa_desc_check(query_desc, key_desc, value_desc, dst_desc, mask_desc, - engine, attr, kq_attr, vs_attr)); - CHECK(sdpa_attr_check( - query_desc, key_desc, value_desc, engine, attr, kq_attr, vs_attr)); + diff_query_desc, diff_key_desc, diff_value_desc, diff_dst_desc, + engine, attr)); + CHECK(sdpa_attr_check(engine, attr)); dnnl::impl::sdpa_desc_t sdpa_desc = dnnl::impl::create_sdpa_desc(query_desc, - key_desc, value_desc, dst_desc, diff_query_desc, diff_key_desc, - diff_value_desc, diff_dst_desc, dS_desc, mask_desc, scale_desc, - invert_scale, kv_head_number, - static_cast(attn_mask_type), softmax_alg, kq_attr, - vs_attr); + key_desc, value_desc, dst_desc, mask_desc, scale_desc, + diff_query_desc, diff_key_desc, diff_value_desc, diff_dst_desc, + dS_desc, invert_scale, kv_head_number, + static_cast(attn_mask_type), softmax_alg); return dnnl::impl::primitive_desc_create(primitive_desc_iface, engine, (const dnnl::impl::op_desc_t *)&sdpa_desc, hint_fwd_pd, attr); } diff --git a/src/common/sdpa_types.hpp b/src/common/sdpa_types.hpp index 95898979ad6..4c70c8c6903 100644 --- a/src/common/sdpa_types.hpp +++ b/src/common/sdpa_types.hpp @@ -113,13 +113,18 @@ struct sdpa_desc_t : public op_desc_t { dnnl_dim_t values() const { return v_desc.dims[v_desc.ndims - 1]; } dim_t num_q_heads() const { return q_desc.dims[1]; } dim_t num_kv_heads() const { return kv_head_number; } - // Total batch size. - dnnl_dim_t batch_size() const { - dnnl_dim_t batch = 1; - for (int i = 0; i < dst_desc.ndims - 2; i++) - batch *= dst_desc.dims[i]; - return batch; - } + // Batch size (outer batch dimension, excluding heads). + dnnl_dim_t batch() const { return dst_desc.dims[0]; } + + // Memory descriptors + const memory_desc_t *qry_md() const { return &q_desc; } + const memory_desc_t *key_md() const { return &k_desc; } + const memory_desc_t *val_md() const { return &v_desc; } + const memory_desc_t *attn_mask_md() const { return &attn_mask_desc; } + const memory_desc_t *scale_md() const { return &scale_desc; } + const memory_desc_t *diff_qry_md() const { return &diff_q_desc; } + const memory_desc_t *diff_key_md() const { return &diff_k_desc; } + const memory_desc_t *diff_val_md() const { return &diff_v_desc; } }; } // namespace impl diff --git a/src/common/sdpa_utils.hpp b/src/common/sdpa_utils.hpp index 0f74a5a49c3..089d0a2a7ae 100644 --- a/src/common/sdpa_utils.hpp +++ b/src/common/sdpa_utils.hpp @@ -81,6 +81,61 @@ static inline status_t sdpa_desc_check(const memory_desc_t *q_desc, return status::success; } +static inline status_t sdpa_desc_check(const memory_desc_t *q_desc, + const memory_desc_t *k_desc, const memory_desc_t *v_desc, + const memory_desc_t *dst_desc, const memory_desc_t *attn_mask_md, + const memory_desc_t *diff_q_desc, const memory_desc_t *diff_k_desc, + const memory_desc_t *diff_v_desc, const memory_desc_t *diff_dst_desc, + const engine_t *engine, const primitive_attr_t *attr) { + int ndims = dst_desc->ndims; + int r = ndims - 2, c = ndims - 1; + VCHECK_SDPA_COND( + utils::everyone_is(ndims, q_desc->ndims, k_desc->ndims, + v_desc->ndims, diff_q_desc->ndims, diff_k_desc->ndims, + diff_v_desc->ndims, diff_dst_desc->ndims), + "number of dimensions have to match. expected: %d q: %d k: %d v: " + "%d dQ: %d dK: %d dV: %d dO: %d", + ndims, q_desc->ndims, k_desc->ndims, v_desc->ndims, + diff_q_desc->ndims, diff_k_desc->ndims, diff_v_desc->ndims, + diff_dst_desc->ndims); + + VCHECK_SDPA_COND(q_desc->dims[c] == k_desc->dims[r], + "q_desc->dims[%d](%s) must match k_desc->dims[%d](%s)", c, + md2dim_str(q_desc).c_str(), r, md2dim_str(k_desc).c_str()); + VCHECK_SDPA_COND(k_desc->dims[c] == v_desc->dims[r], + "k_desc->dims[%d](%s) must match v_desc->dims[%d](%s)", c, + md2dim_str(k_desc).c_str(), r, md2dim_str(v_desc).c_str()); + VCHECK_SDPA_COND(dst_desc->dims[r] == q_desc->dims[r], + "dst_desc->dims[%d](%s) == q_desc->dims[%d](%s)", r, + md2dim_str(dst_desc).c_str(), r, md2dim_str(q_desc).c_str()); + VCHECK_SDPA_COND(dst_desc->dims[c] == v_desc->dims[c], + "dst_desc->dims[%d](%s) == v_desc->dims[%d](%s)", c, + md2dim_str(dst_desc).c_str(), c, md2dim_str(v_desc).c_str()); + + for (int i = 0; i < ndims; i++) { + VCHECK_SDPA_COND(diff_q_desc->dims[i] == q_desc->dims[i], + "diff_q_desc->dims[%d](%s) must match q_desc->dims[%d](%s)", i, + md2dim_str(diff_q_desc).c_str(), i, md2dim_str(q_desc).c_str()); + VCHECK_SDPA_COND(diff_k_desc->dims[i] == k_desc->dims[i], + "diff_k_desc->dims[%d](%s) must match k_desc->dims[%d](%s)", i, + md2dim_str(diff_k_desc).c_str(), i, md2dim_str(k_desc).c_str()); + VCHECK_SDPA_COND(diff_v_desc->dims[i] == v_desc->dims[i], + "diff_v_desc->dims[%d](%s) must match v_desc->dims[%d](%s)", i, + md2dim_str(diff_v_desc).c_str(), i, md2dim_str(v_desc).c_str()); + VCHECK_SDPA_COND(diff_dst_desc->dims[i] == dst_desc->dims[i], + "diff_dst_desc->dims[%d](%s) must match dst_desc->dims[%d](%s)", + i, md2dim_str(diff_dst_desc).c_str(), i, + md2dim_str(dst_desc).c_str()); + } + + VCHECK_SDPA_COND(!any_memory_desc_host_scalar(q_desc, k_desc, v_desc, + dst_desc, attn_mask_md, diff_q_desc, diff_k_desc, + diff_v_desc, diff_dst_desc), + VERBOSE_UNSUPPORTED_FORMAT_KIND); + + return status::success; +} + static inline status_t sdpa_attr_check(const memory_desc_t *q_desc, const memory_desc_t *k_desc, const memory_desc_t *v_desc, const engine_t *engine, const primitive_attr_t *attr, @@ -148,6 +203,20 @@ static inline status_t sdpa_attr_check(const memory_desc_t *q_desc, return status::success; } +static inline status_t sdpa_attr_check( + const engine_t *engine, const primitive_attr_t *attr) { + using smask_t = primitive_attr_t::skip_mask_t; + + if (attr == nullptr) return status::success; + if (attr && attr->has_default_values()) { return status::success; } + + if (attr) { + smask_t attr_mask = smask_t::none; + VCHECK_SDPA_UNIMPL( + attr->has_default_values(attr_mask), VERBOSE_UNSUPPORTED_ATTR); + } + return status::success; +} static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md, const memory_desc_t *k_md, const memory_desc_t *v_md, const memory_desc_t *dst_md, const memory_desc_t *attn_mask_md, @@ -189,13 +258,12 @@ static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md, static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md, const memory_desc_t *k_md, const memory_desc_t *v_md, - const memory_desc_t *dst_md, const memory_desc_t *diff_q_md, + const memory_desc_t *dst_md, const memory_desc_t *attn_mask_md, + const memory_desc_t *scale_md, const memory_desc_t *diff_q_md, const memory_desc_t *diff_k_md, const memory_desc_t *diff_v_md, const memory_desc_t *diff_dst_md, const memory_desc_t *dS_md, - const memory_desc_t *attn_mask_md, const memory_desc_t *scale_md, bool invert_scale, dim_t kv_head_number, - attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg, - const primitive_attr_t *kq_attr, const primitive_attr_t *vs_attr) { + attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg) { auto sdpa_desc = sdpa_desc_t(); sdpa_desc.primitive_kind = primitive_kind::sdpa; sdpa_desc.q_desc = *q_md; @@ -210,6 +278,7 @@ static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md, sdpa_desc.diff_q_desc = *diff_q_md; sdpa_desc.diff_k_desc = *diff_k_md; sdpa_desc.diff_v_desc = *diff_v_md; + if (attn_mask_md) sdpa_desc.attn_mask_desc = *attn_mask_md; sdpa_desc.scale_desc = *scale_md; sdpa_desc.invert_scale = invert_scale; @@ -265,10 +334,9 @@ static inline status_t create_sdpa_pd( CHECK(sdpa_desc_check(q_md, k_md, v_md, dst_md, attn_mask_md, engine, attr, kq_attr, vs_attr)); - auto sdpa_desc = create_sdpa_desc(q_md, k_md, v_md, dst_md, diff_q_md, - diff_k_md, diff_v_md, diff_dst_md, dS_md, attn_mask_md, scale_md, - invert_scale, kv_head_number, attn_mask_type, softmax_alg, kq_attr, - vs_attr); + auto sdpa_desc = create_sdpa_desc(q_md, k_md, v_md, dst_md, attn_mask_md, + scale_md, diff_q_md, diff_k_md, diff_v_md, diff_dst_md, dS_md, + invert_scale, kv_head_number, attn_mask_type, softmax_alg); primitive_attr_t sdpa_attr = attr ? *attr : default_attr(); diff --git a/src/common/type_helpers.hpp b/src/common/type_helpers.hpp index 27d35537c10..281d6e16166 100644 --- a/src/common/type_helpers.hpp +++ b/src/common/type_helpers.hpp @@ -1022,6 +1022,7 @@ inline bool operator==(const zero_pad_desc_t &lhs, const zero_pad_desc_t &rhs) { inline bool operator==(const sdpa_desc_t &lhs, const sdpa_desc_t &rhs) { bool ret = COMPARE_DESC_MEMBERS(primitive_kind) + && COMPARE_DESC_MEMBERS(prop_kind) && COMPARE_DESC_MEMBERS(q_desc) && COMPARE_DESC_MEMBERS(k_desc) && COMPARE_DESC_MEMBERS(v_desc) @@ -1029,7 +1030,12 @@ inline bool operator==(const sdpa_desc_t &lhs, const sdpa_desc_t &rhs) { && COMPARE_DESC_MEMBERS(kq_zero_points) && COMPARE_DESC_MEMBERS(vs_scales) && COMPARE_DESC_MEMBERS(vs_zero_points) + && COMPARE_DESC_MEMBERS(dS_desc) && COMPARE_DESC_MEMBERS(dst_desc) + && COMPARE_DESC_MEMBERS(diff_dst_desc) + && COMPARE_DESC_MEMBERS(diff_q_desc) + && COMPARE_DESC_MEMBERS(diff_k_desc) + && COMPARE_DESC_MEMBERS(diff_v_desc) && COMPARE_DESC_MEMBERS(attn_mask_desc) && COMPARE_DESC_MEMBERS(scale_desc) && COMPARE_DESC_MEMBERS(kq_acc_dt) diff --git a/src/common/verbose.cpp b/src/common/verbose.cpp index 7c572e74919..0276d63cd5f 100644 --- a/src/common/verbose.cpp +++ b/src/common/verbose.cpp @@ -1618,14 +1618,16 @@ std::string init_info_sdpa(const engine_t *e, const pd_t *pd) { const sdpa_desc_t *desc = pd->desc(); ss << md2fmt_str( - "query", pd->qry_md(), pd->invariant_src_user_format_kind(0)) + "query", desc->qry_md(), pd->invariant_src_user_format_kind(0)) << " "; - ss << md2fmt_str("key", pd->key_md(), pd->invariant_src_user_format_kind(1)) + ss << md2fmt_str( + "key", desc->key_md(), pd->invariant_src_user_format_kind(1)) << " "; - ss << md2fmt_str("val", pd->val_md(), pd->invariant_src_user_format_kind(2)) + ss << md2fmt_str( + "val", desc->val_md(), pd->invariant_src_user_format_kind(2)) << " "; if (pd->with_attn_mask()) - ss << md2fmt_str("msk", pd->attn_mask_md(), + ss << md2fmt_str("msk", desc->attn_mask_md(), pd->invariant_src_user_format_kind(3)) << " "; ss << md2fmt_str("dst", pd->dst_md(), pd->invariant_dst_user_format_kind()) @@ -1663,7 +1665,7 @@ std::string init_info_sdpa(const engine_t *e, const pd_t *pd) { delimiter = " "; ss << ",alg:" << desc->softmax_alg; if (pd->with_attn_mask()) { - auto *md = pd->attn_mask_md(); + auto *md = desc->attn_mask_md(); ss << delimiter << "msk:" << (md->dims[2] == 1 ? 1 : 2) << 'd'; } else if (pd->with_causal_mask()) { ss << delimiter; @@ -1678,15 +1680,15 @@ std::string init_info_sdpa(const engine_t *e, const pd_t *pd) { ss << "div:"; else ss << "mul:"; - ss << dnnl_dt2str(pd->scale_md()->data_type) << ":"; + ss << dnnl_dt2str(desc->scale_md()->data_type) << ":"; if (pd->with_host_scale()) ss << "host"; else ss << "device"; } - ss << "," << md2dim_str(pd->qry_md()) << ":" << md2dim_str(pd->key_md()) - << ":" << md2dim_str(pd->val_md()); + ss << "," << md2dim_str(desc->qry_md()) << ":" << md2dim_str(desc->key_md()) + << ":" << md2dim_str(desc->val_md()); return ss.str(); } diff --git a/src/gpu/intel/sdpa/micro.cpp b/src/gpu/intel/sdpa/micro.cpp index 80d22da3d24..692abb5a5dd 100644 --- a/src/gpu/intel/sdpa/micro.cpp +++ b/src/gpu/intel/sdpa/micro.cpp @@ -159,13 +159,13 @@ status_t micro_fwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { fwd_config_t *config = nullptr; const dim_t thin_q_threshold = 16; auto queries = d->queries(); - if (queries == 1) { queries = (d->q_desc.dims[1] / d->kv_head_number); } + if (queries == 1) { queries = (d->q_desc.dims[1] / d->num_kv_heads()); } bool thin_q = (queries <= thin_q_threshold); bool quantized = with_key_scales() || with_key_zp() || with_value_scales() || with_value_zp(); bool is_integrated = intel_engine->device_info()->is_integrated(); - bool is_f32 = (qry_md()->data_type == data_type::f32); + bool is_f32 = (desc()->qry_md()->data_type == data_type::f32); use_systolic_ukernel_ = intel_engine->mayiuse(compute::device_ext_t:: intel_subgroup_matrix_multiply_accumulate) @@ -236,7 +236,8 @@ status_t micro_fwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { hw_info.gmdid = dev_info->ip_version(); hw_info.systolicAvailable = use_systolic_ukernel_; - if (hw_info.gmdid == 0) return status::unimplemented; + VDISPATCH_SDPA( + hw_info.gmdid != 0, "gmdid is 0, microkernels not supported."); ukernel_params.hwinfo = {hw_info}; @@ -252,17 +253,17 @@ status_t micro_fwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { /* Set up GEMMProblem structure for first GEMM: K^T * Q */ GEMMProblem problem; - problem.Ta_ext = convert_dnnl_to_kernel_type(key_md()->data_type); - problem.Tb_ext = convert_dnnl_to_kernel_type(qry_md()->data_type); - if (qry_md()->data_type == data_type::f16) { + problem.Ta_ext = convert_dnnl_to_kernel_type(desc()->key_md()->data_type); + problem.Tb_ext = convert_dnnl_to_kernel_type(desc()->qry_md()->data_type); + if (desc()->qry_md()->data_type == data_type::f16) { problem.Ta = problem.Tb = Type::f16; - } else if (qry_md()->data_type == data_type::bf16) { + } else if (desc()->qry_md()->data_type == data_type::bf16) { problem.Ta = problem.Tb = Type::bf16; - } else if (qry_md()->data_type == data_type::f32) { + } else if (desc()->qry_md()->data_type == data_type::f32) { problem.Ta = problem.Tb = Type::f32; } else { - VCHECK_SDPA_COND(utils::one_of(qry_md()->data_type, data_type::f16, - data_type::bf16), + VCHECK_SDPA_COND(utils::one_of(desc()->qry_md()->data_type, + data_type::f16, data_type::bf16), "Q tensor's data type must be bf16 or f16"); } problem.Tc = problem.Tc_ext = Type::f32; @@ -273,7 +274,7 @@ status_t micro_fwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { problem_kq.Tc = problem_kq.Ts = (kq_acc_dt() == data_type::f16) ? Type::f16 : Type::f32; - problem_kq.A.layout = convert_dnnl_to_kernel_layout(key_md()); + problem_kq.A.layout = convert_dnnl_to_kernel_layout(desc()->key_md()); if (with_key_scales() && !kq_common_scales) { auto scale_dt = key_scales_dt(); @@ -302,9 +303,9 @@ status_t micro_fwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { problem_kq.B.layout = MatrixLayout::Pr; problem_kq.C.layout = MatrixLayout::T; - const memory_desc_wrapper key_mdw(key_md()); + const memory_desc_wrapper key_mdw(desc()->key_md()); auto ldk = static_cast( - gemm_desc_t::get_ld(*key_md()) * key_mdw.data_type_size()); + gemm_desc_t::get_ld(*desc()->key_md()) * key_mdw.data_type_size()); problem_kq.A.setAlignment(micro::alignmentForLD(int(ldk))); problem_kq.B.setAlignment(64); // Q is packed in VNNI format in SLM if (use_systolic_ukernel()) { @@ -337,7 +338,7 @@ status_t micro_fwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { : utils::rnd_up_pow2(queries); heuristic_sizes.k = d->head_size(); // baked into kernel regardless, no quantization - heuristic_sizes.batch = utils::rnd_up_pow2(d->batch_size()); + heuristic_sizes.batch = utils::rnd_up_pow2(d->batch() * d->num_q_heads()); ukernel_params.sizes_kq = {heuristic_sizes}; @@ -349,8 +350,9 @@ status_t micro_fwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { bool vs_common_scales = with_quantize_common(d->vs_scales); bool vs_common_zp = with_quantize_common(d->vs_zero_points); - problem_vs.Ta_ext = convert_dnnl_to_kernel_type(val_md()->data_type); - problem_vs.A.layout = convert_dnnl_to_kernel_layout(val_md()); + problem_vs.Ta_ext + = convert_dnnl_to_kernel_type(desc()->val_md()->data_type); + problem_vs.A.layout = convert_dnnl_to_kernel_layout(desc()->val_md()); if (with_value_scales() && !vs_common_scales) { auto scale_dt = value_scales_dt(); problem_vs.Ta_scale = convert_dnnl_to_kernel_type(scale_dt); @@ -378,9 +380,9 @@ status_t micro_fwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { problem_vs.B.layout = MatrixLayout::Pr; problem_vs.C.layout = MatrixLayout::N; - const memory_desc_wrapper val_mdw(val_md()); + const memory_desc_wrapper val_mdw(desc()->val_md()); auto ldv = static_cast( - gemm_desc_t::get_ld(*val_md()) * val_mdw.data_type_size()); + gemm_desc_t::get_ld(*desc()->val_md()) * val_mdw.data_type_size()); problem_vs.A.setAlignment(micro::alignmentForLD(int(ldv))); problem_vs.B.setAlignment(64); // S is packed in SLM if (use_systolic_ukernel()) { problem_vs.B.crosspack = 16; } @@ -429,12 +431,12 @@ status_t micro_bwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { const dim_t thin_q_threshold = 16; auto queries = d->queries(); // TODO: q=1 batch group optimizations - // if (queries == 1) { queries = (d->q_desc.dims[1] / d->kv_head_number); } + // if (queries == 1) { queries = (d->q_desc.dims[1] / d->num_kv_heads()); } bool thin_q = (queries <= thin_q_threshold); bool quantized = false; bool is_integrated = intel_engine->device_info()->is_integrated(); - bool is_f32 = (qry_md()->data_type == data_type::f32); + bool is_f32 = (desc()->qry_md()->data_type == data_type::f32); use_systolic_ukernel_ = intel_engine->mayiuse(compute::device_ext_t:: intel_subgroup_matrix_multiply_accumulate) @@ -539,7 +541,8 @@ status_t micro_bwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { hw_info.gmdid = dev_info->ip_version(); hw_info.systolicAvailable = use_systolic_ukernel_; - if (hw_info.gmdid == 0) return status::unimplemented; + VDISPATCH_SDPA( + hw_info.gmdid != 0, "gmdid is 0, microkernels not supported."); ukernel_params.hwinfo = {hw_info}; @@ -561,17 +564,18 @@ status_t micro_bwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { /* Set up GEMMProblem structure for first GEMM: K^T * Q */ GEMMProblem problem; - problem.Ta_ext = convert_dnnl_to_kernel_type(key_md()->data_type); - problem.Tb_ext = convert_dnnl_to_kernel_type(qry_md()->data_type); - if (qry_md()->data_type == data_type::f16) { + problem.Ta_ext = convert_dnnl_to_kernel_type(desc()->key_md()->data_type); + problem.Tb_ext = convert_dnnl_to_kernel_type(desc()->qry_md()->data_type); + if (desc()->qry_md()->data_type == data_type::f16) { problem.Ta = problem.Tb = Type::f16; - } else if (qry_md()->data_type == data_type::bf16) { + } else if (desc()->qry_md()->data_type == data_type::bf16) { problem.Ta = problem.Tb = Type::bf16; - } else if (qry_md()->data_type == data_type::f32) { + } else if (desc()->qry_md()->data_type == data_type::f32) { problem.Ta = problem.Tb = Type::f32; } else { - VCHECK_SDPA_COND(utils::one_of(qry_md()->data_type, data_type::f16, - data_type::bf16, data_type::f32), + VCHECK_SDPA_COND( + utils::one_of(desc()->qry_md()->data_type, data_type::f16, + data_type::bf16, data_type::f32), "Q tensor's data type must be bf16, f16, or f32"); } problem.Tc = problem.Tc_ext = Type::f32; @@ -585,12 +589,12 @@ status_t micro_bwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { problem_kq.A.layout = MatrixLayout::Pc; problem_kq.B.layout = MatrixLayout::N; problem_kq.C.layout = MatrixLayout::N; - const memory_desc_wrapper key_mdw(key_md()); - const memory_desc_wrapper qry_mdw(qry_md()); + const memory_desc_wrapper key_mdw(desc()->key_md()); + const memory_desc_wrapper qry_mdw(desc()->qry_md()); auto ldk = static_cast( - gemm_desc_t::get_ld(*key_md()) * key_mdw.data_type_size()); + gemm_desc_t::get_ld(*desc()->key_md()) * key_mdw.data_type_size()); auto ldq = static_cast( - gemm_desc_t::get_ld(*qry_md()) * qry_mdw.data_type_size()); + gemm_desc_t::get_ld(*desc()->qry_md()) * qry_mdw.data_type_size()); problem_kq.A.setAlignment(64); // Q is packed in VNNI format in SLM if (use_systolic_ukernel()) { problem_kq.A.crosspack = 2; @@ -624,7 +628,8 @@ status_t micro_bwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { problem_vs.Tc = problem_vs.Ts = (vs_acc_dt() == data_type::f16) ? Type::f16 : Type::f32; - problem_vs.Ta_ext = convert_dnnl_to_kernel_type(val_md()->data_type); + problem_vs.Ta_ext + = convert_dnnl_to_kernel_type(desc()->val_md()->data_type); problem_vs.A.layout = convert_dnnl_to_kernel_layout(diff_dst_md()); problem_vs.B.layout = MatrixLayout::Pr; problem_vs.C.layout = MatrixLayout::N; @@ -655,15 +660,16 @@ status_t micro_bwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { //////// Vt * dA auto problem_vtdA = problem; - problem_vtdA.Ta_ext = convert_dnnl_to_kernel_type(val_md()->data_type); + problem_vtdA.Ta_ext + = convert_dnnl_to_kernel_type(desc()->val_md()->data_type); - problem_vtdA.A.layout = transpose_layout( - convert_dnnl_to_kernel_layout(val_md())); //TODO hardcode? - problem_vtdA.B.layout - = convert_dnnl_to_kernel_layout(diff_dst_md()); //TODO hardcode? + problem_vtdA.A.layout + = transpose_layout(convert_dnnl_to_kernel_layout(desc()->val_md())); + problem_vtdA.B.layout = convert_dnnl_to_kernel_layout(diff_dst_md()); problem_vtdA.C.layout = MatrixLayout::N; - const memory_desc_wrapper val_mdw(val_md()); - auto ldv = gemm_desc_t::get_ld(*val_md()) * val_mdw.data_type_size(); + const memory_desc_wrapper val_mdw(desc()->val_md()); + auto ldv + = gemm_desc_t::get_ld(*desc()->val_md()) * val_mdw.data_type_size(); problem_vtdA.A.setAlignment(micro::alignmentForLD(int(ldv))); problem_vtdA.B.setAlignment(micro::alignmentForLD(int(lda))); @@ -684,10 +690,11 @@ status_t micro_bwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { //////// Q * dS^t auto problem_qdSt = problem; - problem_qdSt.Ta_ext = convert_dnnl_to_kernel_type(qry_md()->data_type); + problem_qdSt.Ta_ext + = convert_dnnl_to_kernel_type(desc()->qry_md()->data_type); problem_qdSt.A.layout = MatrixLayout::Pc; problem_qdSt.B.layout - = transpose_layout(convert_dnnl_to_kernel_layout(qry_md())); + = transpose_layout(convert_dnnl_to_kernel_layout(desc()->qry_md())); problem_qdSt.C.layout = MatrixLayout::N; problem_qdSt.A.setAlignment(64); @@ -716,10 +723,11 @@ status_t micro_bwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { // dS * K auto problem_ktq = problem; - problem_ktq.Ta_ext = convert_dnnl_to_kernel_type(key_md()->data_type); + problem_ktq.Ta_ext + = convert_dnnl_to_kernel_type(desc()->key_md()->data_type); problem_ktq.A.layout - = transpose_layout(convert_dnnl_to_kernel_layout(key_md())); + = transpose_layout(convert_dnnl_to_kernel_layout(desc()->key_md())); problem_ktq.B.layout = MatrixLayout::Pr; problem_ktq.C.layout = MatrixLayout::N; @@ -782,11 +790,11 @@ static void init_conf_common(conf_t &conf, pd_type *pd) { conf.data_t = data_t; conf.ndims = pd_t::ndims; - const memory_desc_wrapper qry_mdw(pd->qry_md()); - const memory_desc_wrapper key_mdw(pd->key_md()); - const memory_desc_wrapper val_mdw(pd->val_md()); + const memory_desc_wrapper qry_mdw(pd->desc()->qry_md()); + const memory_desc_wrapper key_mdw(pd->desc()->key_md()); + const memory_desc_wrapper val_mdw(pd->desc()->val_md()); const memory_desc_wrapper dst_mdw(pd->dst_md()); - const memory_desc_wrapper msk_mdw(pd->attn_mask_md()); + const memory_desc_wrapper msk_mdw(pd->desc()->attn_mask_md()); conf.key_data_t = key_mdw.data_type(); conf.qry_data_t = qry_mdw.data_type(); @@ -797,11 +805,14 @@ static void init_conf_common(conf_t &conf, pd_type *pd) { if (pd->with_attn_mask()) { conf.msk_data_t = msk_mdw.data_type(); } auto Q_num_heads_dim = qry_mdw.dims()[1]; - conf.kv_group_size = static_cast(Q_num_heads_dim / d->kv_head_number); - - auto ldq = gemm_desc_t::get_ld(*pd->qry_md()) * qry_mdw.data_type_size(); - auto ldk = gemm_desc_t::get_ld(*pd->key_md()) * key_mdw.data_type_size(); - auto ldv = gemm_desc_t::get_ld(*pd->val_md()) * val_mdw.data_type_size(); + conf.kv_group_size = static_cast(Q_num_heads_dim / d->num_kv_heads()); + + auto ldq = gemm_desc_t::get_ld(*pd->desc()->qry_md()) + * qry_mdw.data_type_size(); + auto ldk = gemm_desc_t::get_ld(*pd->desc()->key_md()) + * key_mdw.data_type_size(); + auto ldv = gemm_desc_t::get_ld(*pd->desc()->val_md()) + * val_mdw.data_type_size(); auto lda = gemm_desc_t::get_ld(*pd->dst_md()) * dst_mdw.data_type_size(); conf.q_align = micro::alignmentForLD(int(ldq)); @@ -809,9 +820,10 @@ static void init_conf_common(conf_t &conf, pd_type *pd) { conf.v_align = micro::alignmentForLD(int(ldv)); conf.a_align = micro::alignmentForLD(int(lda)); - conf.transpose_k = gemm_desc_t::get_trans(*pd->key_md()) == dnnl_trans; + conf.transpose_k + = gemm_desc_t::get_trans(*pd->desc()->key_md()) == dnnl_trans; - conf.scale_data_t = pd->scale_md()->data_type; + conf.scale_data_t = pd->desc()->scale_md()->data_type; conf.attn_mask_undef = attn_mask_type::undef; conf.attn_mask_buffer = attn_mask_type::buffer; @@ -837,42 +849,39 @@ static void init_conf_common(conf_t &conf, pd_type *pd) { status_t micro_fwd_t::pd_t::init_conf(impl::engine_t *engine) { using namespace micro; - - auto *pd = this; - auto *d = pd->desc(); - - init_conf_common(conf, pd); + init_conf_common(conf, this); conf.require_stateless_addressing = has_large_buffers(); - const memory_desc_wrapper qry_mdw(pd->qry_md()); - const memory_desc_wrapper key_mdw(pd->key_md()); - const memory_desc_wrapper val_mdw(pd->val_md()); - const memory_desc_wrapper dst_mdw(pd->dst_md()); + const memory_desc_wrapper qry_mdw(desc()->qry_md()); + const memory_desc_wrapper key_mdw(desc()->key_md()); + const memory_desc_wrapper val_mdw(desc()->val_md()); + const memory_desc_wrapper dst_mdw(dst_md()); - conf.key_scales_data_t = pd->key_scales_dt(); - conf.value_scales_data_t = pd->value_scales_dt(); + conf.key_scales_data_t = key_scales_dt(); + conf.value_scales_data_t = value_scales_dt(); - conf.key_zp_data_t = pd->key_zp_dt(); - conf.value_zp_data_t = pd->value_zp_dt(); + conf.key_zp_data_t = key_zp_dt(); + conf.value_zp_data_t = value_zp_dt(); - auto ldq = gemm_desc_t::get_ld(*pd->qry_md()) * qry_mdw.data_type_size(); - auto lda = gemm_desc_t::get_ld(*pd->dst_md()) * dst_mdw.data_type_size(); + auto ldq + = gemm_desc_t::get_ld(*desc()->qry_md()) * qry_mdw.data_type_size(); + auto lda = gemm_desc_t::get_ld(*dst_md()) * dst_mdw.data_type_size(); - int kq_scale_mask = (static_cast(pd->with_key_scales()) << 1) - | static_cast(with_quantize_common(d->kq_scales)); + int kq_scale_mask = (static_cast(with_key_scales()) << 1) + | static_cast(with_quantize_common(desc()->kq_scales)); conf.kq_scale_mask = kq_scale_mask; - int vs_scale_mask = (static_cast(pd->with_value_scales()) << 1) - | static_cast(with_quantize_common(d->vs_scales)); + int vs_scale_mask = (static_cast(with_value_scales()) << 1) + | static_cast(with_quantize_common(desc()->vs_scales)); conf.vs_scale_mask = vs_scale_mask; - int kq_zp_mask = (static_cast(pd->with_key_zp()) << 1) - | static_cast(with_quantize_common(d->kq_zero_points)); + int kq_zp_mask = (static_cast(with_key_zp()) << 1) + | static_cast(with_quantize_common(desc()->kq_zero_points)); conf.kq_zp_mask = kq_zp_mask; - int vs_zp_mask = (static_cast(pd->with_value_zp()) << 1) - | static_cast(with_quantize_common(d->vs_zero_points)); + int vs_zp_mask = (static_cast(with_value_zp()) << 1) + | static_cast(with_quantize_common(desc()->vs_zero_points)); conf.vs_zp_mask = vs_zp_mask; using namespace data_type; @@ -885,16 +894,16 @@ status_t micro_fwd_t::pd_t::init_conf(impl::engine_t *engine) { }; conf.key_elements_per_byte = elems_per_byte(key_mdw.data_type()); - conf.key_zp_elements_per_byte = elems_per_byte(pd->key_zp_dt()); + conf.key_zp_elements_per_byte = elems_per_byte(key_zp_dt()); conf.val_elements_per_byte = elems_per_byte(val_mdw.data_type()); - conf.val_zp_elements_per_byte = elems_per_byte(pd->value_zp_dt()); + conf.val_zp_elements_per_byte = elems_per_byte(value_zp_dt()); conf.key_group_size = 1; conf.val_group_size = 1; - if (pd->with_key_scales() || pd->with_key_zp()) - conf.key_group_size = pd->key_group_size(); - if (pd->with_value_scales() || pd->with_value_zp()) - conf.val_group_size = pd->value_group_size(); + if (with_key_scales() || with_key_zp()) + conf.key_group_size = key_group_size(); + if (with_value_scales() || with_value_zp()) + conf.val_group_size = value_group_size(); /* Set up microkernel strategy */ const fwd_config_t config = {conf.ukernel_config.unroll_m_kq, @@ -910,9 +919,9 @@ status_t micro_fwd_t::pd_t::init_conf(impl::engine_t *engine) { int tile_v = vs_wg_tile_m; bool d_full = conf.d_full; - bool v_full = (d->head_size() == tile_v); + bool v_full = (desc()->head_size() == tile_v); - auto Q = d->queries(); + auto Q = desc()->queries(); const dim_t Q_per_kv_group = (Q == 1 ? Q * conf.kv_group_size : Q); bool q_full = ((Q_per_kv_group % kq_wg_tile_n) != 0); conf.remainder_q = d_full && q_full; @@ -921,19 +930,19 @@ status_t micro_fwd_t::pd_t::init_conf(impl::engine_t *engine) { if (d_full) { conf.block_q = (ldq % 4 == 0); conf.block_a = (lda % 4 == 0 && v_full); - } else if (pd->arch() >= compute::gpu_arch_t::xe_hpc + } else if (arch() >= compute::gpu_arch_t::xe_hpc && config.unroll_m_vs < 64) { - auto vbytes = d->values() * val_mdw.data_type_size(); + auto vbytes = desc()->values() * val_mdw.data_type_size(); if (lda % 16 == 0 && vbytes % 4 == 0) conf.block_2d_a = true; } - if (pd->arch() >= compute::gpu_arch_t::xe_hpc) { + if (arch() >= compute::gpu_arch_t::xe_hpc) { conf.prefetch_mask = true; conf.prefetch_k0 = true; conf.prefetch_k = true; conf.prefetch_v = true; - conf.prefetch_d_max = nstl::min(pd->d_max(), 64); - bool no_rem = d_full && v_full && (d->keys() % tile_k == 0); + conf.prefetch_d_max = nstl::min(d_max(), 64); + bool no_rem = d_full && v_full && (desc()->keys() % tile_k == 0); conf.prefetch_remainder = !no_rem; } else { conf.prefetch_mask = conf.prefetch_k0 = conf.prefetch_k @@ -943,31 +952,30 @@ status_t micro_fwd_t::pd_t::init_conf(impl::engine_t *engine) { conf.q_arrive_await_barrier = (Q > 1); conf.softmax_inf_as_zero - = (d->softmax_alg == alg_kind::softmax_accurate_inf_as_zero); + = (desc()->softmax_alg == alg_kind::softmax_accurate_inf_as_zero); conf.kq_f16_accumulate = (kq_acc_dt() == data_type::f16); conf.vs_f16_accumulate = (vs_acc_dt() == data_type::f16); bool is_training = desc()->prop_kind == prop_kind::forward_training; conf.is_training = is_training; - if (is_training) { pd->init_default_ws(); } + if (is_training) { init_default_ws(); } return status::success; } status_t micro_bwd_t::pd_t::init_conf(impl::engine_t *engine) { - auto *pd = this; - auto *d = pd->desc(); - - init_conf_common(conf, pd); + init_conf_common(conf, this); conf.require_stateless_addressing = has_large_buffers(); - conf.with_dS = pd->with_dS(); + conf.with_dS = with_dS(); - const memory_desc_wrapper key_mdw(pd->key_md()); - const memory_desc_wrapper val_mdw(pd->val_md()); + const memory_desc_wrapper key_mdw(desc()->key_md()); + const memory_desc_wrapper val_mdw(desc()->val_md()); - auto ldk = gemm_desc_t::get_ld(*pd->key_md()) * key_mdw.data_type_size(); - auto ldv = gemm_desc_t::get_ld(*pd->val_md()) * val_mdw.data_type_size(); + auto ldk + = gemm_desc_t::get_ld(*desc()->key_md()) * key_mdw.data_type_size(); + auto ldv + = gemm_desc_t::get_ld(*desc()->val_md()) * val_mdw.data_type_size(); /* Set up microkernel strategy */ const bwd_config_t config = {conf.ukernel_config.unroll_m_BcBr, @@ -984,22 +992,21 @@ status_t micro_bwd_t::pd_t::init_conf(impl::engine_t *engine) { const int tile_dv = config.wg_n_DBc * config.unroll_n_DBc; bool d_full = conf.d_full; - bool dv_full = (d->head_size() == tile_dv); + bool dv_full = (desc()->head_size() == tile_dv); conf.block_k = conf.block_dK = conf.block_dV = false; if (d_full) { - bool can_block_load_k = (ldk % 4 == 0) && (d->keys() % tile_k == 0); + bool can_block_load_k + = (ldk % 4 == 0) && (desc()->keys() % tile_k == 0); conf.block_k = can_block_load_k; conf.block_dK = can_block_load_k && !conf.transpose_k; conf.block_dV = (ldv % 4 == 0) && (dv_full); } - /* - * TODO: prefetching for bwd - * conf.prefetch_mask = conf.prefetch_k0 = conf.prefetch_k - * = conf.prefetch_v = conf.prefetch_remainder = false; - * conf.prefetch_d_max = 0; - */ + //TODO: remove or add prefetching to BWD + conf.prefetch_mask = conf.prefetch_k0 = conf.prefetch_k = conf.prefetch_v + = conf.prefetch_remainder = false; + conf.prefetch_d_max = 0; return status::success; } @@ -1008,9 +1015,7 @@ status_t micro_bwd_t::pd_t::init_scratchpad(impl::engine_t *engine) { auto scratchpad = scratchpad_registry().registrar(); auto gpu_align = utils::downcast(engine)->get_buffer_alignment(); - memory_desc_wrapper dQ_wspace(diff_qry_md()); - size_t wspace_size = dQ_wspace.dims()[0] * dQ_wspace.dims()[1] - * dQ_wspace.dims()[2] * dQ_wspace.dims()[3]; + size_t wspace_size = memory_desc_wrapper(desc()->diff_qry_md()).nelems(); // f32 can directly atomic add to output // others need intermediate scratchpad before conversion if (conf.data_t != data_type::f32) { @@ -1022,24 +1027,18 @@ status_t micro_bwd_t::pd_t::init_scratchpad(impl::engine_t *engine) { const bool needs_intermediate_dKV = (conf.kv_group_size > 1 && conf.data_t != data_type::f32); if (needs_intermediate_dKV) { - memory_desc_wrapper dK_wspace(diff_key_md()); - size_t dK_size = dK_wspace.dims()[0] * dK_wspace.dims()[1] - * dK_wspace.dims()[2] * dK_wspace.dims()[3]; + size_t dK_size = memory_desc_wrapper(desc()->diff_key_md()).nelems(); scratchpad.book(memory_tracking::names::key_sdpa_dK_reduction, dK_size, sizeof(float), gpu_align); - memory_desc_wrapper dV_wspace(diff_val_md()); - size_t dV_size = dV_wspace.dims()[0] * dV_wspace.dims()[1] - * dV_wspace.dims()[2] * dV_wspace.dims()[3]; + size_t dV_size = memory_desc_wrapper(desc()->diff_val_md()).nelems(); scratchpad.book(memory_tracking::names::key_sdpa_dV_reduction, dV_size, sizeof(float), gpu_align); } // space for D_i preprocess result - dim_t batch = qry_md()->dims[0]; - dim_t num_q_heads = qry_md()->dims[1]; - dim_t Q = desc()->queries(); - size_t Di_size = batch * num_q_heads * Q; + size_t Di_size + = desc()->batch() * desc()->num_q_heads() * desc()->queries(); scratchpad.book(memory_tracking::names::key_sdpa_Di, Di_size, sizeof(float), gpu_align); @@ -1049,6 +1048,7 @@ status_t micro_bwd_t::pd_t::init_scratchpad(impl::engine_t *engine) { status_t micro_fwd_params_t::get_kernel_ctx( compute::kernel_ctx_t &kernel_ctx) const { using namespace micro; + kernel_ctx.require_stateless_addressing(require_stateless_addressing); kernel_ctx.define_int("NDIMS", ndims); kernel_ctx.set_data_type(data_t); @@ -1245,6 +1245,7 @@ status_t micro_fwd_params_t::get_kernel_ctx( status_t micro_bwd_params_t::get_kernel_ctx( compute::kernel_ctx_t &kernel_ctx) const { using namespace micro; + kernel_ctx.require_stateless_addressing(require_stateless_addressing); kernel_ctx.define_int("NDIMS", ndims); kernel_ctx.set_data_type(data_t); @@ -1286,7 +1287,6 @@ status_t micro_bwd_params_t::get_kernel_ctx( kernel_ctx.define_int("BLOCK_DK", block_dK); kernel_ctx.define_int("BLOCK_DV", block_dV); - //TODO: remove or add prefetching to BWD kernel_ctx.define_int("PREFETCH_MASK", prefetch_mask); kernel_ctx.define_int("PREFETCH_K0", prefetch_k0); kernel_ctx.define_int("PREFETCH_K", prefetch_k); @@ -1493,11 +1493,11 @@ status_t micro_fwd_t::execute_forward(const exec_ctx_t &ctx) const { auto wg_tile_q = kq_wg_tile_n; auto sg_per_wg = config.wg_m_kq * config.wg_n_kq; - const memory_desc_wrapper qry_mdw(pd()->qry_md()); - const memory_desc_wrapper key_mdw(pd()->key_md()); - const memory_desc_wrapper val_mdw(pd()->val_md()); + const memory_desc_wrapper qry_mdw(pd()->desc()->qry_md()); + const memory_desc_wrapper key_mdw(pd()->desc()->key_md()); + const memory_desc_wrapper val_mdw(pd()->desc()->val_md()); const memory_desc_wrapper dst_mdw(pd()->dst_md()); - const memory_desc_wrapper msk_mdw(pd()->attn_mask_md()); + const memory_desc_wrapper msk_mdw(pd()->desc()->attn_mask_md()); using offset_t = decltype(offsets_t().src_off); offset_t key_off, qry_off, val_off, dst_off, msk_off; @@ -1521,7 +1521,7 @@ status_t micro_fwd_t::execute_forward(const exec_ctx_t &ctx) const { arg_list.append(strides4); }; - const memory_desc_wrapper scale_mdw(pd()->scale_md()); + const memory_desc_wrapper scale_mdw(pd()->desc()->scale_md()); float scalar_scale = 1.f; float inv_scalar_scale = 1.f; if (pd()->with_host_scale()) { @@ -1532,7 +1532,7 @@ status_t micro_fwd_t::execute_forward(const exec_ctx_t &ctx) const { assert(status == status::success); if (status != status::success) return status; scalar_scale = dnnl::impl::cpu::io::load_float_value( - pd()->scale_md()->data_type, &scalar_scale, 0); + pd()->desc()->scale_md()->data_type, &scalar_scale, 0); inv_scalar_scale = 1. / scalar_scale; } @@ -1582,7 +1582,7 @@ status_t micro_fwd_t::execute_forward(const exec_ctx_t &ctx) const { gws[0] *= utils::div_up(Q, wg_tile_q); gws[1] *= pd()->dst_md()->dims[1]; } - gws[2] *= pd()->dst_md()->dims[0]; + gws[2] *= pd()->desc()->batch(); auto nd_range = compute::nd_range_t(gws, lws); return parallel_for(ctx, nd_range, kernel_, arg_list); @@ -1592,7 +1592,7 @@ status_t micro_bwd_t::execute_backward(const exec_ctx_t &ctx) const { const auto &qry = CTX_IN_STORAGE(DNNL_ARG_QUERIES); const auto &key = CTX_IN_STORAGE(DNNL_ARG_KEYS); const auto &val = CTX_IN_STORAGE(DNNL_ARG_VALUES); - auto &ws = CTX_IN_STORAGE(DNNL_ARG_WORKSPACE); + const auto &ws = CTX_IN_STORAGE(DNNL_ARG_WORKSPACE); const auto &dst = CTX_IN_STORAGE(DNNL_ARG_DST); const auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); auto &diff_q = CTX_OUT_STORAGE(DNNL_ARG_DIFF_QUERIES); @@ -1639,18 +1639,18 @@ status_t micro_bwd_t::execute_backward(const exec_ctx_t &ctx) const { auto sg_per_wg_DBc = config.wg_m_DBc * config.wg_n_DBc; auto sg_per_wg_DBr = config.wg_m_DBr * config.wg_n_DBr; - using std::max; - auto sg_per_wg = max(max(sg_per_wg_BcBr, sg_per_wg_DBc), sg_per_wg_DBr); + auto sg_per_wg + = std::max(std::max(sg_per_wg_BcBr, sg_per_wg_DBc), sg_per_wg_DBr); - const memory_desc_wrapper qry_mdw(pd()->qry_md()); - const memory_desc_wrapper key_mdw(pd()->key_md()); - const memory_desc_wrapper val_mdw(pd()->val_md()); + const memory_desc_wrapper qry_mdw(pd()->desc()->qry_md()); + const memory_desc_wrapper key_mdw(pd()->desc()->key_md()); + const memory_desc_wrapper val_mdw(pd()->desc()->val_md()); const memory_desc_wrapper dst_mdw(pd()->dst_md()); - const memory_desc_wrapper msk_mdw(pd()->attn_mask_md()); + const memory_desc_wrapper msk_mdw(pd()->desc()->attn_mask_md()); const memory_desc_wrapper diff_dst_mdw(pd()->diff_dst_md()); - const memory_desc_wrapper diff_qry_mdw(pd()->diff_qry_md()); - const memory_desc_wrapper diff_key_mdw(pd()->diff_key_md()); - const memory_desc_wrapper diff_val_mdw(pd()->diff_val_md()); + const memory_desc_wrapper diff_qry_mdw(pd()->desc()->diff_qry_md()); + const memory_desc_wrapper diff_key_mdw(pd()->desc()->diff_key_md()); + const memory_desc_wrapper diff_val_mdw(pd()->desc()->diff_val_md()); using offset_t = decltype(offsets_t().src_off); offset_t qry_off, key_off, val_off, dst_off, msk_off; @@ -1673,7 +1673,7 @@ status_t micro_bwd_t::execute_backward(const exec_ctx_t &ctx) const { int mask_type = static_cast(pd()->desc()->mask_type); - const memory_desc_wrapper scale_mdw(pd()->scale_md()); + const memory_desc_wrapper scale_mdw(pd()->desc()->scale_md()); float scalar_scale = 1.f; float inv_scalar_scale = 1.f; if (pd()->with_host_scale()) { @@ -1684,7 +1684,7 @@ status_t micro_bwd_t::execute_backward(const exec_ctx_t &ctx) const { assert(status == status::success); if (status != status::success) return status; scalar_scale = dnnl::impl::cpu::io::load_float_value( - pd()->scale_md()->data_type, &scalar_scale, 0); + pd()->desc()->scale_md()->data_type, &scalar_scale, 0); inv_scalar_scale = 1. / scalar_scale; } @@ -1695,7 +1695,7 @@ status_t micro_bwd_t::execute_backward(const exec_ctx_t &ctx) const { gws_preprocess[0] *= utils::div_up(Q, wg_tile_q); gws_preprocess[1] *= pd()->dst_md()->dims[1]; - gws_preprocess[2] *= pd()->dst_md()->dims[0]; + gws_preprocess[2] *= pd()->desc()->batch(); auto nd_range_preprocess = compute::nd_range_t(gws_preprocess, lws); @@ -1710,16 +1710,16 @@ status_t micro_bwd_t::execute_backward(const exec_ctx_t &ctx) const { append_offs(preprocess_arg_list, qry_off); append_offs(preprocess_arg_list, dst_off); - status_t s = parallel_for( - ctx, nd_range_preprocess, preprocess_, preprocess_arg_list); - if (s != status::success) return s; + CHECK(parallel_for( + ctx, nd_range_preprocess, preprocess_, preprocess_arg_list)); + auto *d = pd()->desc(); // zero f32 intermediates before atomic adds in the main kernel // dQ always needs atomics, dK/dV only for GQA cases { - const dim_t num_kv_heads = pd()->dst_md()->dims[1] / kv_group_size; - const dim_t num_q_heads = pd()->dst_md()->dims[1]; - const int lws_zero = 256; + const dim_t num_kv_heads = d->num_kv_heads(); + const dim_t num_q_heads = d->num_q_heads(); + static constexpr size_t lws_zero = 256; auto dispatch_zero = [&](const memory_storage_t &buf, dim_t count, @@ -1733,25 +1733,22 @@ status_t micro_bwd_t::execute_backward(const exec_ctx_t &ctx) const { compute::range_t gws_z = lws_z; gws_z[0] *= utils::div_up(count, lws_zero); gws_z[1] *= num_heads; - gws_z[2] *= pd()->dst_md()->dims[0]; + gws_z[2] *= pd()->desc()->batch(); return parallel_for( ctx, compute::nd_range_t(gws_z, lws_z), zero_, args); }; // always zero dQ auto &dQ_buf = needs_intermediate_dQ ? *diff_q_scratch : diff_q; - s = dispatch_zero(dQ_buf, Q * D, qry_off, num_q_heads); - if (s != status::success) return s; + CHECK(dispatch_zero(dQ_buf, Q * D, qry_off, num_q_heads)); // zero dK/dV for GQA cases if (needs_zero_dKV) { auto &dK_buf = needs_intermediate_dKV ? *diff_k_scratch : diff_k; auto &dV_buf = needs_intermediate_dKV ? *diff_v_scratch : diff_v; - s = dispatch_zero(dK_buf, K * D, key_off, num_kv_heads); - if (s != status::success) return s; - s = dispatch_zero(dV_buf, K * D, val_off, num_kv_heads); - if (s != status::success) return s; + CHECK(dispatch_zero(dK_buf, K * D, key_off, num_kv_heads)); + CHECK(dispatch_zero(dV_buf, K * D, val_off, num_kv_heads)); } } @@ -1788,7 +1785,6 @@ status_t micro_bwd_t::execute_backward(const exec_ctx_t &ctx) const { if (pd()->with_attn_mask()) { append_offs(arg_list, msk_off); } const int remainder_k = (K % wg_tile_k) != 0; - auto *d = pd()->desc(); const bool d_full = (d->head_size() == pd()->d_max()); const int remainder_q = d_full && ((Q % wg_tile_q) != 0); @@ -1799,35 +1795,33 @@ status_t micro_bwd_t::execute_backward(const exec_ctx_t &ctx) const { gws[0] *= utils::div_up(K, wg_tile_k); gws[1] *= pd()->dst_md()->dims[1]; - gws[2] *= pd()->dst_md()->dims[0]; + gws[2] *= pd()->desc()->batch(); auto nd_range = compute::nd_range_t(gws, lws); - s = parallel_for(ctx, nd_range, kernel_, arg_list); - if (s != status::success) return s; + CHECK(parallel_for(ctx, nd_range, kernel_, arg_list)); /// postprocessing kernels // will cast dQ/dK/dV to lower precision outputs if needed if (needs_intermediate_dQ) { - const int lws_pp = 256; + static constexpr size_t lws_pp = 256; compute::range_t lws_p = {(size_t)lws_pp, 1, 1}; compute::range_t gws_p = lws_p; gws_p[0] *= utils::div_up(Q * D, lws_pp); gws_p[1] *= pd()->dst_md()->dims[1]; // Q heads - gws_p[2] *= pd()->dst_md()->dims[0]; + gws_p[2] *= pd()->desc()->batch(); compute::kernel_arg_list_t pp; pp.append(diff_q); pp.append(*diff_q_scratch); pp.append((int)(Q * D)); append_offs(pp, qry_off); - s = parallel_for( - ctx, compute::nd_range_t(gws_p, lws_p), postprocess_, pp); - if (s != status::success) return s; + CHECK(parallel_for( + ctx, compute::nd_range_t(gws_p, lws_p), postprocess_, pp)); } if (needs_intermediate_dKV) { - const dim_t num_kv_heads = pd()->dst_md()->dims[1] / kv_group_size; - const int lws_pp = 256; + const dim_t num_kv_heads = d->num_kv_heads(); + static constexpr size_t lws_pp = 256; compute::range_t lws_p = {(size_t)lws_pp, 1, 1}; // dK @@ -1835,32 +1829,30 @@ status_t micro_bwd_t::execute_backward(const exec_ctx_t &ctx) const { compute::range_t gws_p = lws_p; gws_p[0] *= utils::div_up(K * D, lws_pp); gws_p[1] *= num_kv_heads; // KV heads - gws_p[2] *= pd()->dst_md()->dims[0]; + gws_p[2] *= pd()->desc()->batch(); compute::kernel_arg_list_t pp; pp.append(diff_k); pp.append(*diff_k_scratch); pp.append((int)(K * D)); append_offs(pp, key_off); - s = parallel_for( - ctx, compute::nd_range_t(gws_p, lws_p), postprocess_, pp); - if (s != status::success) return s; + CHECK(parallel_for( + ctx, compute::nd_range_t(gws_p, lws_p), postprocess_, pp)); } // dV { compute::range_t gws_p = lws_p; gws_p[0] *= utils::div_up(K * D, lws_pp); gws_p[1] *= num_kv_heads; - gws_p[2] *= pd()->dst_md()->dims[0]; + gws_p[2] *= pd()->desc()->batch(); compute::kernel_arg_list_t pp; pp.append(diff_v); pp.append(*diff_v_scratch); pp.append((int)(K * D)); append_offs(pp, val_off); - s = parallel_for( - ctx, compute::nd_range_t(gws_p, lws_p), postprocess_, pp); - if (s != status::success) return s; + CHECK(parallel_for( + ctx, compute::nd_range_t(gws_p, lws_p), postprocess_, pp)); } } diff --git a/src/gpu/intel/sdpa/micro.hpp b/src/gpu/intel/sdpa/micro.hpp index 31637ca344f..e0d2042b177 100644 --- a/src/gpu/intel/sdpa/micro.hpp +++ b/src/gpu/intel/sdpa/micro.hpp @@ -164,14 +164,14 @@ struct micro_fwd_t : public primitive_t { using namespace data_type; VCHECK_SDPA_COND(is_fwd(), VERBOSE_BAD_PROPKIND); - VCHECK_SDPA_COND( - utils::everyone_is(4, qry_md()->ndims, key_md()->ndims, - val_md()->ndims, dst_md()->ndims), + VCHECK_SDPA_COND(utils::everyone_is(4, desc()->qry_md()->ndims, + desc()->key_md()->ndims, + desc()->val_md()->ndims, dst_md()->ndims), VERBOSE_UNSUPPORTED_TAG); - memory_desc_wrapper qry_mdw(qry_md()); - memory_desc_wrapper key_mdw(key_md()); - memory_desc_wrapper val_mdw(val_md()); + memory_desc_wrapper qry_mdw(desc()->qry_md()); + memory_desc_wrapper key_mdw(desc()->key_md()); + memory_desc_wrapper val_mdw(desc()->val_md()); memory_desc_wrapper dst_mdw(dst_md()); VCHECK_SDPA_COND(utils::everyone_is(true, qry_mdw.is_plain(), key_mdw.is_plain(), val_mdw.is_plain(), @@ -179,72 +179,77 @@ struct micro_fwd_t : public primitive_t { VERBOSE_UNSUPPORTED_TAG); if (with_attn_mask()) { + VCHECK_SDPA_COND(desc()->attn_mask_md()->ndims == 4, + VERBOSE_UNSUPPORTED_TAG); VCHECK_SDPA_COND( - attn_mask_md()->ndims == 4, VERBOSE_UNSUPPORTED_TAG); - VCHECK_SDPA_COND( - utils::one_of(attn_mask_md()->dims[mask_q_index], + utils::one_of( + desc()->attn_mask_md()->dims[mask_q_index], desc()->queries(), 1), VERBOSE_INVALID_BROADCAST, "attn_mask", mask_q_index); - VCHECK_SDPA_COND( - attn_mask_md()->dims[mask_k_index] == desc()->keys(), + VCHECK_SDPA_COND(desc()->attn_mask_md()->dims[mask_k_index] + == desc()->keys(), VERBOSE_INVALID_BROADCAST, "attn_mask", mask_k_index); - if (qry_md()->data_type == data_type::f32) { - VCHECK_SDPA_COND( - attn_mask_md()->data_type == qry_md()->data_type, + if (desc()->qry_md()->data_type == data_type::f32) { + VCHECK_SDPA_COND(desc()->attn_mask_md()->data_type + == desc()->qry_md()->data_type, "Mask data type(%s) should match Qry/Dst data " "type(%s).", - dnnl_dt2str(attn_mask_md()->data_type), - dnnl_dt2str(qry_md()->data_type)); + dnnl_dt2str(desc()->attn_mask_md()->data_type), + dnnl_dt2str(desc()->qry_md()->data_type)); } else { - VCHECK_SDPA_COND( - (attn_mask_md()->data_type == qry_md()->data_type) - || (attn_mask_md()->data_type + VCHECK_SDPA_COND((desc()->attn_mask_md()->data_type + == desc()->qry_md()->data_type) + || (desc()->attn_mask_md()->data_type == data_type::f32), "Mask data type(%s) should be xf16 or f32 when " "Qry/Dst(%s) is xf16.", - dnnl_dt2str(attn_mask_md()->data_type), - dnnl_dt2str(qry_md()->data_type)); + dnnl_dt2str(desc()->attn_mask_md()->data_type), + dnnl_dt2str(desc()->qry_md()->data_type)); } } VCHECK_SDPA_COND( - (utils::everyone_is(data_type::f16, qry_md()->data_type, - dst_md()->data_type) + (utils::everyone_is(data_type::f16, + desc()->qry_md()->data_type, dst_md()->data_type) || utils::everyone_is(data_type::bf16, - qry_md()->data_type, dst_md()->data_type) + desc()->qry_md()->data_type, + dst_md()->data_type) || utils::everyone_is(data_type::f32, - qry_md()->data_type, dst_md()->data_type)), + desc()->qry_md()->data_type, + dst_md()->data_type)), VERBOSE_UNSUPPORTED_DT); - VCHECK_SDPA_COND(utils::one_of(key_md()->data_type, f32, bf16, f16, - u8, s8, u4, s4), + VCHECK_SDPA_COND(utils::one_of(desc()->key_md()->data_type, f32, + bf16, f16, u8, s8, u4, s4), VERBOSE_UNSUPPORTED_DT); - VCHECK_SDPA_COND(utils::one_of(val_md()->data_type, f32, bf16, f16, - u8, s8, u4, s4), + VCHECK_SDPA_COND(utils::one_of(desc()->val_md()->data_type, f32, + bf16, f16, u8, s8, u4, s4), VERBOSE_UNSUPPORTED_DT); VCHECK_SDPA_COND(set_default_formats() == status::success, VERBOSE_UNSUPPORTED_TAG); VCHECK_SDPA_COND(desc()->values() == desc()->head_size(), "values does not match head size"); - if (utils::one_of(key_md()->data_type, u4, s4)) { + if (utils::one_of(desc()->key_md()->data_type, u4, s4)) { VCHECK_SDPA_COND(desc()->keys() % 2 == 0, "The number of keys must be an even size with the data " "type is u4 or s4."); } - if (utils::one_of(val_md()->data_type, u4, s4)) { + if (utils::one_of(desc()->val_md()->data_type, u4, s4)) { VCHECK_SDPA_COND(desc()->values() % 2 == 0, "The number of values must be an even size with the " "data type is u4 or s4."); } - VCHECK_SDPA_COND(qry_md()->dims[1] >= key_md()->dims[1] - && qry_md()->dims[1] >= val_md()->dims[1], + VCHECK_SDPA_COND( + desc()->qry_md()->dims[1] >= desc()->key_md()->dims[1] + && desc()->qry_md()->dims[1] + >= desc()->val_md()->dims[1], "number of heads in query tensor(%ld) must be greater " "than the number of heads in the key(%ld) and value(%ld) " "tensors", - static_cast(qry_md()->dims[1]), - static_cast(key_md()->dims[1]), - static_cast(val_md()->dims[1])); + static_cast(desc()->qry_md()->dims[1]), + static_cast(desc()->key_md()->dims[1]), + static_cast(desc()->val_md()->dims[1])); VCHECK_SDPA_COND(utils::one_of(kq_acc_dt(), f16, f32), "KQ accumulation data type should be f16 or f32"); @@ -302,34 +307,23 @@ struct micro_fwd_t : public primitive_t { int vgs = value_group_size(); VCHECK_SDPA_COND(utils::one_of(vs_scales_mask, 0, 1, 3) || (math::is_pow2(vgs) - || vgs == val_md()->dims[3]), + || vgs == desc()->val_md()->dims[3]), "the value group size(%d) must be a power of 2 or " "equal to the number of values(%ld).", - vgs, static_cast(val_md()->dims[3])); + vgs, static_cast(desc()->val_md()->dims[3])); } CHECK(init_conf_microkernels(engine)); CHECK(init_conf(engine)); - VCHECK_SDPA_COND( - IMPLICATION((arch() == compute::gpu_arch_t::xe_hpc) - && (qry_md()->data_type == data_type::f32), - with_causal_mask()), + VCHECK_SDPA_COND(IMPLICATION((arch() == compute::gpu_arch_t::xe_hpc) + && (desc()->qry_md()->data_type + == data_type::f32), + with_causal_mask()), "fused f32 SDPA only optimized for causal mask"); //TODO: update when performance improved return status::success; } - status_t set_default_format(memory_desc_t &md, bool allow_transpose) { - using namespace format_tag; - memory_desc_wrapper mdw(md); - if (mdw.format_any()) return status::unimplemented; - if (!is_md_gemm_compatible_plain_format(&md)) - return status::unimplemented; - if (gemm_desc_t::get_trans(md) == dnnl_trans && !allow_transpose) - return status::unimplemented; - return status::success; - } - status_t set_default_formats() { CHECK(set_default_format(desc_.q_desc, false)); CHECK(set_default_format(desc_.k_desc, true)); @@ -359,6 +353,19 @@ struct micro_fwd_t : public primitive_t { status_t init_conf_microkernels(impl::engine_t *engine); status_t init_conf(impl::engine_t *engine); + + status_t set_default_format(memory_desc_t &md, bool allow_transpose) { + using namespace format_tag; + memory_desc_wrapper mdw(md); + VCHECK_SDPA_UNIMPL(!mdw.format_any(), VERBOSE_UNSUPPORTED_TAG); + VCHECK_SDPA_UNIMPL(is_md_gemm_compatible_plain_format(&md), + VERBOSE_UNSUPPORTED_TAG); + VCHECK_SDPA_UNIMPL( + IMPLICATION(gemm_desc_t::get_trans(md) == dnnl_trans, + allow_transpose), + VERBOSE_UNSUPPORTED_TAG); + return status::success; + } }; status_t init(impl::engine_t *engine) override; @@ -387,57 +394,65 @@ struct micro_bwd_t : public primitive_t { VCHECK_SDPA_COND(!is_fwd(), VERBOSE_BAD_PROPKIND); - VCHECK_SDPA_COND( - utils::everyone_is(4, qry_md()->ndims, key_md()->ndims, - val_md()->ndims, dst_md()->ndims), + VCHECK_SDPA_COND(utils::everyone_is(4, desc()->qry_md()->ndims, + desc()->key_md()->ndims, + desc()->val_md()->ndims, dst_md()->ndims), VERBOSE_UNSUPPORTED_TAG); - VCHECK_SDPA_COND(utils::everyone_is(4, diff_qry_md()->ndims, - diff_key_md()->ndims, diff_val_md()->ndims, - diff_dst_md()->ndims), + VCHECK_SDPA_COND( + utils::everyone_is(4, desc()->diff_qry_md()->ndims, + desc()->diff_key_md()->ndims, + desc()->diff_val_md()->ndims, diff_dst_md()->ndims), VERBOSE_UNSUPPORTED_TAG); if (with_attn_mask()) { + VCHECK_SDPA_COND(desc()->attn_mask_md()->ndims == 4, + VERBOSE_UNSUPPORTED_TAG); VCHECK_SDPA_COND( - attn_mask_md()->ndims == 4, VERBOSE_UNSUPPORTED_TAG); - VCHECK_SDPA_COND( - utils::one_of(attn_mask_md()->dims[mask_q_index], + utils::one_of( + desc()->attn_mask_md()->dims[mask_q_index], desc()->queries(), 1), VERBOSE_INVALID_BROADCAST, "attn_mask", mask_q_index); - VCHECK_SDPA_COND( - attn_mask_md()->dims[mask_k_index] == desc()->keys(), + VCHECK_SDPA_COND(desc()->attn_mask_md()->dims[mask_k_index] + == desc()->keys(), VERBOSE_INVALID_BROADCAST, "attn_mask", mask_k_index); - VCHECK_SDPA_COND( - attn_mask_md()->data_type == qry_md()->data_type, + VCHECK_SDPA_COND(desc()->attn_mask_md()->data_type + == desc()->qry_md()->data_type, "Mask data type should match Qry/Dst data type."); } VCHECK_SDPA_COND( - (utils::everyone_is(data_type::f16, qry_md()->data_type, - dst_md()->data_type) + (utils::everyone_is(data_type::f16, + desc()->qry_md()->data_type, dst_md()->data_type) || utils::everyone_is(data_type::bf16, - qry_md()->data_type, dst_md()->data_type) + desc()->qry_md()->data_type, + dst_md()->data_type) || utils::everyone_is(data_type::f32, - qry_md()->data_type, dst_md()->data_type)), + desc()->qry_md()->data_type, + dst_md()->data_type)), VERBOSE_UNSUPPORTED_DT); - VCHECK_SDPA_COND(utils::one_of(key_md()->data_type, f32, bf16, f16), + VCHECK_SDPA_COND( + utils::one_of(desc()->key_md()->data_type, f32, bf16, f16), VERBOSE_UNSUPPORTED_DT); - VCHECK_SDPA_COND(utils::one_of(val_md()->data_type, f32, bf16, f16), + VCHECK_SDPA_COND( + utils::one_of(desc()->val_md()->data_type, f32, bf16, f16), VERBOSE_UNSUPPORTED_DT); VCHECK_SDPA_COND(set_default_formats() == status::success, VERBOSE_UNSUPPORTED_TAG); VCHECK_SDPA_COND(desc()->values() == desc()->head_size(), "values does not match head size"); - VCHECK_SDPA_COND(qry_md()->dims[1] >= key_md()->dims[1] - && qry_md()->dims[1] >= val_md()->dims[1], + VCHECK_SDPA_COND( + desc()->qry_md()->dims[1] >= desc()->key_md()->dims[1] + && desc()->qry_md()->dims[1] + >= desc()->val_md()->dims[1], "number of heads in query tensor(%ld) must be greater " "than the number of heads in the key(%ld) and value(%ld) " "tensors", - static_cast(qry_md()->dims[1]), - static_cast(key_md()->dims[1]), - static_cast(val_md()->dims[1])); + static_cast(desc()->qry_md()->dims[1]), + static_cast(desc()->key_md()->dims[1]), + static_cast(desc()->val_md()->dims[1])); { - memory_desc_wrapper diff_qry_mdw(diff_qry_md()); - memory_desc_wrapper diff_key_mdw(diff_key_md()); - memory_desc_wrapper diff_val_mdw(diff_val_md()); + memory_desc_wrapper diff_qry_mdw(desc()->diff_qry_md()); + memory_desc_wrapper diff_key_mdw(desc()->diff_key_md()); + memory_desc_wrapper diff_val_mdw(desc()->diff_val_md()); memory_desc_wrapper diff_dst_mdw(diff_dst_md()); VCHECK_SDPA_COND( utils::everyone_is(true, diff_qry_mdw.is_plain(), @@ -448,21 +463,24 @@ struct micro_bwd_t : public primitive_t { } // make sure gradient outputs match input dimensions - for (int i = 0; i < qry_md()->ndims; i++) - VCHECK_SDPA_COND(diff_qry_md()->dims[i] == qry_md()->dims[i], + for (int i = 0; i < desc()->qry_md()->ndims; i++) + VCHECK_SDPA_COND(desc()->diff_qry_md()->dims[i] + == desc()->qry_md()->dims[i], "diff_qry dim[%d](%ld) must match qry dim[%d](%ld)", i, - (long)diff_qry_md()->dims[i], i, - (long)qry_md()->dims[i]); - for (int i = 0; i < key_md()->ndims; i++) - VCHECK_SDPA_COND(diff_key_md()->dims[i] == key_md()->dims[i], + (long)desc()->diff_qry_md()->dims[i], i, + (long)desc()->qry_md()->dims[i]); + for (int i = 0; i < desc()->key_md()->ndims; i++) + VCHECK_SDPA_COND(desc()->diff_key_md()->dims[i] + == desc()->key_md()->dims[i], "diff_key dim[%d](%ld) must match key dim[%d](%ld)", i, - (long)diff_key_md()->dims[i], i, - (long)key_md()->dims[i]); - for (int i = 0; i < val_md()->ndims; i++) - VCHECK_SDPA_COND(diff_val_md()->dims[i] == val_md()->dims[i], + (long)desc()->diff_key_md()->dims[i], i, + (long)desc()->key_md()->dims[i]); + for (int i = 0; i < desc()->val_md()->ndims; i++) + VCHECK_SDPA_COND(desc()->diff_val_md()->dims[i] + == desc()->val_md()->dims[i], "diff_val dim[%d](%ld) must match val dim[%d](%ld)", i, - (long)diff_val_md()->dims[i], i, - (long)val_md()->dims[i]); + (long)desc()->diff_val_md()->dims[i], i, + (long)desc()->val_md()->dims[i]); // dO.dims() == O.dims() for (int i = 0; i < src_md(4)->ndims; i++) VCHECK_SDPA_COND(diff_dst_md()->dims[i] == src_md(4)->dims[i], @@ -470,16 +488,17 @@ struct micro_bwd_t : public primitive_t { (long)diff_dst_md()->dims[i], i, (long)src_md(4)->dims[i]); - VCHECK_SDPA_COND( - utils::everyone_is(qry_md()->data_type, - diff_qry_md()->data_type, diff_key_md()->data_type, - diff_val_md()->data_type, diff_dst_md()->data_type), + VCHECK_SDPA_COND(utils::everyone_is(desc()->qry_md()->data_type, + desc()->diff_qry_md()->data_type, + desc()->diff_key_md()->data_type, + desc()->diff_val_md()->data_type, + diff_dst_md()->data_type), "diff tensor data types must match qry data type(%s) " " ?= dQ(%s), dK(%s), dV(%s), dO(%s)", - dnnl_dt2str(qry_md()->data_type), - dnnl_dt2str(diff_qry_md()->data_type), - dnnl_dt2str(diff_key_md()->data_type), - dnnl_dt2str(diff_val_md()->data_type), + dnnl_dt2str(desc()->qry_md()->data_type), + dnnl_dt2str(desc()->diff_qry_md()->data_type), + dnnl_dt2str(desc()->diff_key_md()->data_type), + dnnl_dt2str(desc()->diff_val_md()->data_type), dnnl_dt2str(diff_dst_md()->data_type)); CHECK(init_default_ws()); @@ -492,17 +511,6 @@ struct micro_bwd_t : public primitive_t { return status::success; } - status_t set_default_format(memory_desc_t &md, bool allow_transpose) { - using namespace format_tag; - memory_desc_wrapper mdw(md); - if (mdw.format_any()) return status::unimplemented; - if (!is_md_gemm_compatible_plain_format(&md)) - return status::unimplemented; - if (gemm_desc_t::get_trans(md) == dnnl_trans && !allow_transpose) - return status::unimplemented; - return status::success; - } - status_t set_default_formats() { CHECK(set_default_format(desc_.q_desc, false)); CHECK(set_default_format(desc_.k_desc, true)); @@ -537,6 +545,19 @@ struct micro_bwd_t : public primitive_t { status_t init_scratchpad(impl::engine_t *engine); status_t init_conf_microkernels(impl::engine_t *engine); status_t init_conf(impl::engine_t *engine); + + status_t set_default_format(memory_desc_t &md, bool allow_transpose) { + using namespace format_tag; + memory_desc_wrapper mdw(md); + VCHECK_SDPA_UNIMPL(!mdw.format_any(), VERBOSE_UNSUPPORTED_TAG); + VCHECK_SDPA_UNIMPL(is_md_gemm_compatible_plain_format(&md), + VERBOSE_UNSUPPORTED_TAG); + VCHECK_SDPA_UNIMPL( + IMPLICATION(gemm_desc_t::get_trans(md) == dnnl_trans, + allow_transpose), + VERBOSE_UNSUPPORTED_TAG); + return status::success; + } }; status_t init(impl::engine_t *engine) override; diff --git a/src/gpu/intel/sdpa/ref.hpp b/src/gpu/intel/sdpa/ref.hpp index bf686117e3c..34ad410738b 100644 --- a/src/gpu/intel/sdpa/ref.hpp +++ b/src/gpu/intel/sdpa/ref.hpp @@ -44,13 +44,13 @@ struct ref_fwd_t : public primitive_t { VDISPATCH_SDPA(attr()->has_default_values(smask_t::scales), VERBOSE_UNSUPPORTED_ATTR); - VDISPATCH_SDPA( - utils::everyone_is(4, qry_md()->ndims, key_md()->ndims, - val_md()->ndims, dst_md()->ndims), + VDISPATCH_SDPA(utils::everyone_is(4, desc()->qry_md()->ndims, + desc()->key_md()->ndims, + desc()->val_md()->ndims, dst_md()->ndims), VERBOSE_UNSUPPORTED_TAG); if (with_attn_mask()) { - VDISPATCH_SDPA( - attn_mask_md()->ndims == 4, VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_SDPA(desc()->attn_mask_md()->ndims == 4, + VERBOSE_UNSUPPORTED_TAG); } VDISPATCH_SDPA(set_default_formats(), VERBOSE_UNSUPPORTED_TAG); @@ -66,11 +66,11 @@ struct ref_fwd_t : public primitive_t { int ndims = 4; - const memory_desc_wrapper qry_mdw(pd()->qry_md()); - const memory_desc_wrapper key_mdw(pd()->key_md()); - const memory_desc_wrapper val_mdw(pd()->val_md()); + const memory_desc_wrapper qry_mdw(pd()->desc()->qry_md()); + const memory_desc_wrapper key_mdw(pd()->desc()->key_md()); + const memory_desc_wrapper val_mdw(pd()->desc()->val_md()); const memory_desc_wrapper dst_mdw(pd()->dst_md()); - const memory_desc_wrapper msk_mdw(pd()->attn_mask_md()); + const memory_desc_wrapper msk_mdw(pd()->desc()->attn_mask_md()); using offset_t = decltype(offsets_t().src_off); offset_t qry_off, key_off, val_off, dst_off, msk_off; set_offsets(qry_mdw, qry_off); @@ -90,12 +90,13 @@ struct ref_fwd_t : public primitive_t { kernel_ctx.define_int("WITH_ATTN_SCALE", pd()->with_attn_scale()); kernel_ctx.define_int("WITH_ATTN_MASK", pd()->with_attn_mask()); - def_data_type(kernel_ctx, pd()->qry_md()->data_type, "QRY"); - def_data_type(kernel_ctx, pd()->key_md()->data_type, "KEY"); - def_data_type(kernel_ctx, pd()->val_md()->data_type, "VAL"); + def_data_type(kernel_ctx, pd()->desc()->qry_md()->data_type, "QRY"); + def_data_type(kernel_ctx, pd()->desc()->key_md()->data_type, "KEY"); + def_data_type(kernel_ctx, pd()->desc()->val_md()->data_type, "VAL"); def_data_type(kernel_ctx, pd()->dst_md()->data_type, "DST"); - def_data_type(kernel_ctx, pd()->attn_mask_md()->data_type, "MSK"); - def_data_type(kernel_ctx, pd()->scale_md()->data_type, "SCALE"); + def_data_type( + kernel_ctx, pd()->desc()->attn_mask_md()->data_type, "MSK"); + def_data_type(kernel_ctx, pd()->desc()->scale_md()->data_type, "SCALE"); CHECK(create_kernel(engine, &kernel_, "ref_sdpa", kernel_ctx)); if (!kernel_) return status::runtime_error; return status::success; diff --git a/tests/gtests/internals/sdpa_internal.hpp b/tests/gtests/internals/sdpa_internal.hpp index 1ccef3d647f..cbd388802a8 100644 --- a/tests/gtests/internals/sdpa_internal.hpp +++ b/tests/gtests/internals/sdpa_internal.hpp @@ -58,8 +58,6 @@ dnnl_status_t DNNL_API sdpa_primitive_desc_create( const_dnnl_memory_desc_t scale_desc, bool invert_scale, dnnl_dim_t kv_head_number, int attn_mask_type, dnnl_alg_kind_t softmax_alg, const_dnnl_primitive_attr_t attr, - const_dnnl_primitive_attr_t kq_attr, - const_dnnl_primitive_attr_t vs_attr, const dnnl_primitive_desc *hint_fwd_pd); namespace dnnl { @@ -127,20 +125,19 @@ struct sdpa_backward : public dnnl::primitive { const memory::desc *dS_desc, bool invert_scale, memory::dim kv_head_number, int attn_mask_type, int softmax_alg, const sdpa::primitive_desc &hint_fwd_pd, - const primitive_attr &attr = default_attr(), - const primitive_attr &kq_attr = default_attr(), - const primitive_attr &vs_attr = default_attr()) { + const primitive_attr &attr = default_attr()) { dnnl_primitive_desc_t pd = nullptr; dnnl_status_t status = sdpa_primitive_desc_create(&pd, aengine.get(), query_desc.get(), key_desc.get(), - value_desc.get(), output_desc.get(), diff_query_desc.get(), - diff_key_desc.get(), diff_value_desc.get(), - diff_output_desc.get(), dS_desc ? dS_desc->get() : nullptr, + value_desc.get(), output_desc.get(), optional_arg(attn_mask_desc), scale_desc.get(), - invert_scale, kv_head_number, attn_mask_type, - (dnnl_alg_kind_t)softmax_alg, attr.get(), kq_attr.get(), - vs_attr.get(), hint_fwd_pd.get()); + diff_query_desc.get(), diff_key_desc.get(), + diff_value_desc.get(), diff_output_desc.get(), + dS_desc ? dS_desc->get() : nullptr, invert_scale, + kv_head_number, attn_mask_type, + (dnnl_alg_kind_t)softmax_alg, attr.get(), + hint_fwd_pd.get()); dnnl::error::wrap_c_api(status, "could not create a primitive descriptor for a sdpa " diff --git a/tests/gtests/internals/test_sdpa.cpp b/tests/gtests/internals/test_sdpa.cpp index 98f169a0482..c9ec35c3654 100644 --- a/tests/gtests/internals/test_sdpa.cpp +++ b/tests/gtests/internals/test_sdpa.cpp @@ -2139,8 +2139,7 @@ class sdpa_test_t : public ::testing::TestWithParam { t.m_diff_output.get_desc(), dS_ptr, invert_scale, p.heads.kv, to_attn_mask_type(p.mask.type), dnnl::impl::alg_kind::softmax_accurate_inf_as_zero, - sdpa_fwd_pd, t.sdpa_attr_quantized, - t.sdpa_kq_attr_quantized, t.sdpa_vs_attr_quantized); + sdpa_fwd_pd, t.sdpa_attr_quantized); sdpa_bwd = sdpa_backward(sdpa_bwd_pd); } catch (const dnnl::error &e) { if (e.status == dnnl_unimplemented) @@ -2456,8 +2455,7 @@ class sdpa_test_t : public ::testing::TestWithParam { t.m_diff_output.get_desc(), nullptr, invert_scale, p.heads.kv, to_attn_mask_type(p.mask.type), dnnl::impl::alg_kind::softmax_accurate_inf_as_zero, - sdpa_fwd_pd, t.sdpa_attr_quantized, - t.sdpa_kq_attr_quantized, t.sdpa_vs_attr_quantized); + sdpa_fwd_pd, t.sdpa_attr_quantized); sdpa_bwd = sdpa_backward(sdpa_bwd_pd); } catch (const dnnl::error &e) { if (e.status == dnnl_unimplemented) @@ -2914,9 +2912,11 @@ GPU_TEST_P(sdpa_test_datatypes, compare) { compare(); } +/* GPU_TEST_P(sdpa_bwd_test, compare_bwd) { compare_bwd(); } +*/ GPU_TEST_P(sdpa_bwd_test_datatypes, compare_bwd) { compare_bwd(); @@ -2938,10 +2938,9 @@ GPU_TEST_P(sdpa_bwd_test, perf_bwd) { // backward pass: f16 INSTANTIATE_TEST_SUITE_P(bwd_f16, sdpa_bwd_test_datatypes, testing::Combine(testing::Values(1, 2), // mb - testing::Values(num_heads_t {1, 1}, num_heads_t {2, 2}, - num_heads_t {8, 8}), // heads - testing::Values(seq_len_size_t {32, 32}, seq_len_size_t {64, 64}, - seq_len_size_t {384, 384}), // seq_len + testing::Values(num_heads_t {1, 1}, num_heads_t {2, 2}), // heads + testing::Values(seq_len_size_t {64, 64}, + seq_len_size_t {1024, 1024}), // seq_len testing::Values(head_group_size_t {32, 32, 32}, head_group_size_t {64, 64, 64}, head_group_size_t {128, 128, 128}), // head_size @@ -3012,8 +3011,7 @@ INSTANTIATE_TEST_SUITE_P(bwd_gqa, sdpa_bwd_test_datatypes, // backward pass: non-uniform sequence lengths (q != kv) INSTANTIATE_TEST_SUITE_P(bwd_nonuniform_seq, sdpa_bwd_test_datatypes, testing::Combine(testing::Values(1), // mb - testing::Values(num_heads_t {1, 1}, - num_heads_t {2, 2}), // heads + testing::Values(num_heads_t {2, 2}), // heads testing::Values(seq_len_size_t {64, 513}, seq_len_size_t {513, 64}), testing::Values(head_group_size_t {32, 32, 32}, @@ -3034,12 +3032,11 @@ INSTANTIATE_TEST_SUITE_P(bwd_nonuniform_seq, sdpa_bwd_test_datatypes, // backward pass: f32 INSTANTIATE_TEST_SUITE_P(bwd_f32, sdpa_bwd_test_datatypes, testing::Combine(testing::Values(1, 2), // mb - testing::Values(num_heads_t {1, 1}, num_heads_t {2, 2}, - num_heads_t {12, 12}), // heads - testing::Values(seq_len_size_t {32, 32}, seq_len_size_t {64, 64}, + testing::Values(num_heads_t {1, 1}, num_heads_t {2, 2}), // heads + testing::Values(seq_len_size_t {32, 32}, seq_len_size_t {384, 384}, seq_len_size_t {4096, 4096}), // seq_len - testing::Values(head_group_size_t {16, 16, 16}, + testing::Values( head_group_size_t {32, 32, 32}, head_group_size_t {64, 64, 64}, head_group_size_t {128, 128, 128}), // head_size From 65748a1c6c8a5c47e715a2906d0b446ea44f848a Mon Sep 17 00:00:00 2001 From: "Bao, Yixin" Date: Wed, 25 Feb 2026 19:31:11 -0800 Subject: [PATCH 18/23] graph: backend: dnnl: move softmax decomposition to a transform pass --- .../backend/dnnl/kernels/large_partition.cpp | 1 + src/graph/backend/dnnl/passes/lower.cpp | 185 +---------------- src/graph/backend/dnnl/passes/transform.cpp | 195 ++++++++++++++++++ src/graph/backend/dnnl/passes/transform.hpp | 6 + src/graph/interface/op_def.hpp | 12 +- src/graph/interface/shape_infer.cpp | 51 +++++ src/graph/interface/shape_infer.hpp | 4 + 7 files changed, 273 insertions(+), 181 deletions(-) diff --git a/src/graph/backend/dnnl/kernels/large_partition.cpp b/src/graph/backend/dnnl/kernels/large_partition.cpp index 38d57938e1d..848406b67d9 100644 --- a/src/graph/backend/dnnl/kernels/large_partition.cpp +++ b/src/graph/backend/dnnl/kernels/large_partition.cpp @@ -45,6 +45,7 @@ void larger_partition_kernel_t::setup_pipeline_stage1( BACKEND_DNNL_ADD_PASS(pipeline, fuse_mul_sigmoid_to_swish); BACKEND_DNNL_ADD_PASS(pipeline, fuse_to_dnnl_sum); BACKEND_DNNL_ADD_PASS(pipeline, fuse_to_shuffle); + BACKEND_DNNL_ADD_PASS(pipeline, decompose_softmax); // TODO(xx) The implementation of these two passes relay on a non-fully // lowered subgraph. We need to improve them. diff --git a/src/graph/backend/dnnl/passes/lower.cpp b/src/graph/backend/dnnl/passes/lower.cpp index 2aaeecc4c4e..09f17e888c9 100644 --- a/src/graph/backend/dnnl/passes/lower.cpp +++ b/src/graph/backend/dnnl/passes/lower.cpp @@ -757,7 +757,6 @@ static status_t softmax_handler( const std::shared_ptr &op, subgraph_rewriter_t &rewriter) { const auto &src = op->get_input_value(0); const auto &dst = op->get_output_value(0); - bool no_stats = op->num_outputs() == 1; auto new_softmax_op = std::make_shared(op_kind::_softmax); new_softmax_op->merge_attributes(op->get_attributes()); @@ -765,185 +764,15 @@ static status_t softmax_handler( src->remove_consumer(*op, 0); src->add_consumer(*new_softmax_op, 0); new_softmax_op->add_input(src); - if (no_stats) { - new_softmax_op->add_output(dst); - insert_empty_scratchpad(new_softmax_op); - rewriter.to_insert(new_softmax_op); - rewriter.to_remove(op); - return status::success; + new_softmax_op->add_output(dst); + insert_empty_scratchpad(new_softmax_op); + if (op->num_outputs() == 2) { + const auto &stats = op->get_output_value(1); + new_softmax_op->add_output(stats); } - auto f32_dst = dst; - if (f32_dst->get_logical_tensor().data_type == impl::data_type::f32) { - // if the dst is already f32, we can just use it as the output - new_softmax_op->add_output(dst); - dst->remove_consumer(*op, 0); - insert_empty_scratchpad(new_softmax_op); - rewriter.to_insert(new_softmax_op); - rewriter.to_remove(op); - } else { - logical_tensor_t softmax_op_out_lt - = empty_logical_tensor_with_default_id(); - f32_dst = std::make_shared( - *new_softmax_op, 0, softmax_op_out_lt, true); - f32_dst->set_data_type(impl::data_type::f32); - new_softmax_op->add_output(f32_dst); - insert_empty_scratchpad(new_softmax_op); - - // create reorder op to convert the output to the original data type - auto reorder_op = std::make_shared(op_kind::_reorder); - reorder_op->set_attr(op_attr::change_layout, false); - reorder_op->add_input(f32_dst); - f32_dst->add_consumer(*reorder_op, 0); - reorder_op->add_output(dst); - dst->remove_consumer(*op, 0); - insert_empty_scratchpad(reorder_op); - rewriter.to_insert(new_softmax_op); - rewriter.to_insert(reorder_op); - rewriter.to_remove(op); - } - - // support stats computation: stats = reducemax(src) - log(reducemax(f32_dst)) - const auto &stats = op->get_output_value(1); - // reduction primitive doesn't support identity operation. - // check if reduce ops are needed before creating them. - // if the dims[axis] = 1, no need to add reduce ops. - bool need_reduction = true; - int64_t axis = new_softmax_op->get_attr(op_attr::axis); - axis = axis < 0 ? axis + src->get_logical_tensor().ndims : axis; - if (src->get_logical_tensor().dims[axis] == 1) { need_reduction = false; } - - auto reduce_src_op_out_val = src; - auto reduce_dst_op_out_val = f32_dst; - if (need_reduction) { - // create reduce_src op - auto reduce_src_op = std::make_shared(op_kind::_reduction); - reduce_src_op->set_attr>(op_attr::axes, - {new_softmax_op->get_attr(op_attr::axis)}); - reduce_src_op->set_attr(op_attr::keep_dims, true); - reduce_src_op->set_attr(op_attr::alg_kind, - static_cast(dnnl::algorithm::reduction_max)); - reduce_src_op->add_input(src); - src->add_consumer(*reduce_src_op, 0); - // add output for reduce_src - logical_tensor_t reduce_src_op_out_lt - = empty_logical_tensor_with_default_id(); - reduce_src_op_out_val = std::make_shared( - *reduce_src_op, 0, reduce_src_op_out_lt, true); - reduce_src_op_out_val->set_data_type(impl::data_type::f32); - reduce_src_op->add_output(reduce_src_op_out_val); - insert_empty_scratchpad(reduce_src_op); - - // create reduce_dst op - auto reduce_dst_op = std::make_shared(op_kind::_reduction); - reduce_dst_op->set_attr>(op_attr::axes, - {new_softmax_op->get_attr(op_attr::axis)}); - reduce_dst_op->set_attr(op_attr::keep_dims, true); - reduce_dst_op->set_attr(op_attr::alg_kind, - static_cast(dnnl::algorithm::reduction_max)); - reduce_dst_op->add_input(f32_dst); - f32_dst->add_consumer(*reduce_dst_op, 0); - // add output for reduce_dst - logical_tensor_t reduce_dst_op_out_lt - = empty_logical_tensor_with_default_id(); - reduce_dst_op_out_val = std::make_shared( - *reduce_dst_op, 0, reduce_dst_op_out_lt, true); - reduce_dst_op_out_val->set_data_type(impl::data_type::f32); - reduce_dst_op->add_output(reduce_dst_op_out_val); - insert_empty_scratchpad(reduce_dst_op); - - rewriter.to_insert(reduce_src_op); - rewriter.to_insert(reduce_dst_op); - } - - // create log op - auto log_op = std::make_shared(op_kind::_eltwise); - log_op->set_attr(op_attr::alg_kind, - static_cast(dnnl::algorithm::eltwise_log)); - log_op->add_input(reduce_dst_op_out_val); - reduce_dst_op_out_val->add_consumer(*log_op, 0); - // add output for log_op - logical_tensor_t log_op_out_lt = empty_logical_tensor_with_default_id(); - auto log_op_out_val - = std::make_shared(*log_op, 0, log_op_out_lt, true); - log_op_out_val->set_data_type(impl::data_type::f32); - log_op->add_output(log_op_out_val); - insert_empty_scratchpad(log_op); - - // create subtract op - auto sub_op = std::make_shared(op_kind::_binary); - sub_op->set_attr(op_attr::alg_kind, - static_cast(dnnl::algorithm::binary_sub)); - sub_op->add_input(reduce_src_op_out_val); - reduce_src_op_out_val->add_consumer(*sub_op, 0); - sub_op->add_input(log_op_out_val); - log_op_out_val->add_consumer(*sub_op, 1); - // add output for sub_op - logical_tensor_t sub_op_out_lt = empty_logical_tensor_with_default_id(); - auto sub_op_out_val - = std::make_shared(*sub_op, 0, sub_op_out_lt, true); - sub_op_out_val->set_data_type(impl::data_type::f32); - sub_op->add_output(sub_op_out_val); - insert_empty_scratchpad(sub_op); - - // special handling for inf_as_zero: - // stats = reducesum(f32_dst) == 0? 0: stats - // create reduce_sum_dst op - auto reduce_or_reorder_op_out_val = f32_dst; - if (need_reduction) { - auto reduce_sum_dst_op = std::make_shared(op_kind::_reduction); - reduce_sum_dst_op->set_attr>(op_attr::axes, - {new_softmax_op->get_attr(op_attr::axis)}); - reduce_sum_dst_op->set_attr(op_attr::keep_dims, true); - reduce_sum_dst_op->set_attr(op_attr::alg_kind, - static_cast(dnnl::algorithm::reduction_sum)); - reduce_sum_dst_op->add_input(f32_dst); - f32_dst->add_consumer(*reduce_sum_dst_op, 0); - // add output for reduce_sum_dst - logical_tensor_t reduce_sum_dst_op_out_lt - = empty_logical_tensor_with_default_id(); - reduce_or_reorder_op_out_val = std::make_shared( - *reduce_sum_dst_op, 0, reduce_sum_dst_op_out_lt, true); - reduce_or_reorder_op_out_val->set_data_type(dnnl::impl::data_type::s8); - reduce_sum_dst_op->add_output(reduce_or_reorder_op_out_val); - insert_empty_scratchpad(reduce_sum_dst_op); - - rewriter.to_insert(reduce_sum_dst_op); - } else { - // create reorder op to convert f32_dst to s8 - auto reorder_s8_op = std::make_shared(op_kind::_reorder); - reorder_s8_op->set_attr(op_attr::change_layout, false); - reorder_s8_op->add_input(f32_dst); - f32_dst->add_consumer(*reorder_s8_op, 0); - // add output for reorder_s8_op - logical_tensor_t reorder_s8_op_out_lt - = empty_logical_tensor_with_default_id(); - reduce_or_reorder_op_out_val = std::make_shared( - *reorder_s8_op, 0, reorder_s8_op_out_lt, true); - reduce_or_reorder_op_out_val->set_data_type(dnnl::impl::data_type::s8); - reorder_s8_op->add_output(reduce_or_reorder_op_out_val); - insert_empty_scratchpad(reorder_s8_op); - rewriter.to_insert(reorder_s8_op); - } - - // create select op - auto select_op = std::make_shared(op_kind::_binary); - select_op->set_attr(op_attr::alg_kind, - static_cast(dnnl::algorithm::binary_select)); - select_op->add_input(sub_op_out_val); - sub_op_out_val->add_consumer(*select_op, 0); - select_op->add_input(reduce_dst_op_out_val); - reduce_dst_op_out_val->add_consumer(*select_op, 1); - // condition - select_op->add_input(reduce_or_reorder_op_out_val); - reduce_or_reorder_op_out_val->add_consumer(*select_op, 2); - select_op->add_output(stats); - insert_empty_scratchpad(select_op); - - rewriter.to_insert(log_op); - rewriter.to_insert(sub_op); - rewriter.to_insert(select_op); - + rewriter.to_insert(new_softmax_op); + rewriter.to_remove(op); return status::success; } diff --git a/src/graph/backend/dnnl/passes/transform.cpp b/src/graph/backend/dnnl/passes/transform.cpp index 5aaf0e4584c..7b9cec17618 100644 --- a/src/graph/backend/dnnl/passes/transform.cpp +++ b/src/graph/backend/dnnl/passes/transform.cpp @@ -3235,6 +3235,201 @@ impl::status_t convert_dynamic_quantize_ops(std::shared_ptr &sg) { return impl::status::success; } +status_t decompose_softmax(std::shared_ptr &sg) { + subgraph_rewriter_t rewriter(sg); + + for (auto &cur_op : sg->get_ops()) { + if (cur_op->get_kind() != op_kind::dnnl_softmax) continue; + if (cur_op->num_outputs() != 3) continue; + const auto &dst = cur_op->get_output_value(0); + + auto f32_dst = dst; + if (f32_dst->get_logical_tensor().data_type != impl::data_type::f32) { + logical_tensor_t softmax_op_out_lt + = empty_logical_tensor_with_default_id(); + f32_dst = std::make_shared( + *cur_op, 0, softmax_op_out_lt, true); + f32_dst->set_data_type(impl::data_type::f32); + cur_op->connect_output(0, f32_dst); + + // create reorder op to convert the output to the original data type + auto reorder_op = std::make_shared(op_kind::dnnl_reorder); + reorder_op->set_attr(op_attr::change_layout, false); + reorder_op->add_input(f32_dst); + f32_dst->add_consumer(*reorder_op, 0); + reorder_op->add_output(dst); + dst->remove_consumer(*cur_op, 0); + insert_empty_scratchpad(reorder_op); + rewriter.to_insert(reorder_op); + } + + // support stats computation: stats = reducemax(src) - log(reducemax(f32_dst)) + const auto &stats = cur_op->get_output_value(2); + const auto &src = cur_op->get_input_value(0); + // reduction primitive doesn't support identity operation. + // check if reduce ops are needed before creating them. + // if the dims[axis] = 1, no need to add reduce ops. + bool need_reduction = true; + int64_t axis = cur_op->get_attr(op_attr::axis); + axis = axis < 0 ? axis + src->get_logical_tensor().ndims : axis; + if (src->get_logical_tensor().dims[axis] == 1) { + need_reduction = false; + } + + auto reduce_src_op_out_val = src; + auto reduce_dst_op_out_val = f32_dst; + if (need_reduction) { + // create reduce_src op + auto reduce_src_op + = std::make_shared(op_kind::dnnl_reduction); + reduce_src_op->set_attr>( + op_attr::axes, {cur_op->get_attr(op_attr::axis)}); + reduce_src_op->set_attr(op_attr::keep_dims, true); + reduce_src_op->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::reduction_max)); + reduce_src_op->add_input(src); + src->add_consumer(*reduce_src_op, 0); + // add output for reduce_src + logical_tensor_t reduce_src_op_out_lt + = empty_logical_tensor_with_default_id(); + reduce_src_op_out_val = std::make_shared( + *reduce_src_op, 0, reduce_src_op_out_lt, true); + reduce_src_op_out_val->set_data_type(impl::data_type::f32); + reduce_src_op->add_output(reduce_src_op_out_val); + insert_empty_scratchpad(reduce_src_op); + + // create reduce_dst op + auto reduce_dst_op + = std::make_shared(op_kind::dnnl_reduction); + reduce_dst_op->set_attr>( + op_attr::axes, {cur_op->get_attr(op_attr::axis)}); + reduce_dst_op->set_attr(op_attr::keep_dims, true); + reduce_dst_op->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::reduction_max)); + reduce_dst_op->add_input(f32_dst); + f32_dst->add_consumer(*reduce_dst_op, 0); + // add output for reduce_dst + logical_tensor_t reduce_dst_op_out_lt + = empty_logical_tensor_with_default_id(); + reduce_dst_op_out_val = std::make_shared( + *reduce_dst_op, 0, reduce_dst_op_out_lt, true); + reduce_dst_op_out_val->set_data_type(impl::data_type::f32); + reduce_dst_op->add_output(reduce_dst_op_out_val); + insert_empty_scratchpad(reduce_dst_op); + + rewriter.to_insert(reduce_src_op); + rewriter.to_insert(reduce_dst_op); + } + + // create log op + auto log_op = std::make_shared(op_kind::dnnl_eltwise); + log_op->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::eltwise_log)); + log_op->add_input(reduce_dst_op_out_val); + reduce_dst_op_out_val->add_consumer(*log_op, 0); + // add output for log_op + logical_tensor_t log_op_out_lt = empty_logical_tensor_with_default_id(); + auto log_op_out_val + = std::make_shared(*log_op, 0, log_op_out_lt, true); + log_op_out_val->set_data_type(impl::data_type::f32); + log_op->add_output(log_op_out_val); + insert_empty_scratchpad(log_op); + + // create subtract op + auto sub_op = std::make_shared(op_kind::dnnl_binary); + sub_op->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::binary_sub)); + sub_op->add_input(reduce_src_op_out_val); + reduce_src_op_out_val->add_consumer(*sub_op, 0); + sub_op->add_input(log_op_out_val); + log_op_out_val->add_consumer(*sub_op, 1); + // add output for sub_op + logical_tensor_t sub_op_out_lt = empty_logical_tensor_with_default_id(); + auto sub_op_out_val + = std::make_shared(*sub_op, 0, sub_op_out_lt, true); + sub_op_out_val->set_data_type(impl::data_type::f32); + sub_op->add_output(sub_op_out_val); + insert_empty_scratchpad(sub_op); + + // special handling for inf_as_zero: + // stats = reducesum(f32_dst) == 0? 0: stats + // create reduce_sum_dst op + auto reduce_or_reorder_op_out_val = f32_dst; + if (need_reduction) { + auto reduce_sum_dst_op + = std::make_shared(op_kind::dnnl_reduction); + reduce_sum_dst_op->set_attr>( + op_attr::axes, {cur_op->get_attr(op_attr::axis)}); + reduce_sum_dst_op->set_attr(op_attr::keep_dims, true); + reduce_sum_dst_op->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::reduction_sum)); + reduce_sum_dst_op->add_input(f32_dst); + f32_dst->add_consumer(*reduce_sum_dst_op, 0); + // add output for reduce_sum_dst + logical_tensor_t reduce_sum_dst_op_out_lt + = empty_logical_tensor_with_default_id(); + reduce_or_reorder_op_out_val = std::make_shared( + *reduce_sum_dst_op, 0, reduce_sum_dst_op_out_lt, true); + reduce_or_reorder_op_out_val->set_data_type( + dnnl::impl::data_type::s8); + reduce_sum_dst_op->add_output(reduce_or_reorder_op_out_val); + insert_empty_scratchpad(reduce_sum_dst_op); + + rewriter.to_insert(reduce_sum_dst_op); + } else { + // create reorder op to convert f32_dst to s8 + auto reorder_s8_op = std::make_shared(op_kind::dnnl_reorder); + reorder_s8_op->set_attr(op_attr::change_layout, false); + reorder_s8_op->add_input(f32_dst); + f32_dst->add_consumer(*reorder_s8_op, 0); + // add output for reorder_s8_op + logical_tensor_t reorder_s8_op_out_lt + = empty_logical_tensor_with_default_id(); + reduce_or_reorder_op_out_val = std::make_shared( + *reorder_s8_op, 0, reorder_s8_op_out_lt, true); + reduce_or_reorder_op_out_val->set_data_type( + dnnl::impl::data_type::s8); + reorder_s8_op->add_output(reduce_or_reorder_op_out_val); + insert_empty_scratchpad(reorder_s8_op); + rewriter.to_insert(reorder_s8_op); + } + + // create select op + auto select_op = std::make_shared(op_kind::dnnl_binary); + select_op->set_attr(op_attr::alg_kind, + static_cast(dnnl::algorithm::binary_select)); + select_op->add_input(sub_op_out_val); + sub_op_out_val->add_consumer(*select_op, 0); + select_op->add_input(reduce_dst_op_out_val); + reduce_dst_op_out_val->add_consumer(*select_op, 1); + // condition + select_op->add_input(reduce_or_reorder_op_out_val); + reduce_or_reorder_op_out_val->add_consumer(*select_op, 2); + select_op->add_output(stats); + insert_empty_scratchpad(select_op); + + rewriter.to_insert(log_op); + rewriter.to_insert(sub_op); + rewriter.to_insert(select_op); + + // recreate dnnl_softmax with 2 outputs: output and scratchpad + auto new_softmax_op = std::make_shared(op_kind::dnnl_softmax); + new_softmax_op->merge_attributes(cur_op->get_attributes()); + + src->remove_consumer(*cur_op, 0); + src->add_consumer(*new_softmax_op, 0); + new_softmax_op->add_input(src); + new_softmax_op->add_output(f32_dst); + f32_dst->set_producer(*new_softmax_op); + insert_empty_scratchpad(new_softmax_op); + rewriter.to_insert(new_softmax_op); + rewriter.to_remove(cur_op); + } + + rewriter.run(); + return infer_shape(sg); +} + status_t reorder_canonicalization(std::shared_ptr &sg) { subgraph_rewriter_t rewriter(sg); diff --git a/src/graph/backend/dnnl/passes/transform.hpp b/src/graph/backend/dnnl/passes/transform.hpp index f2b07754360..b00a3c8f8d8 100644 --- a/src/graph/backend/dnnl/passes/transform.hpp +++ b/src/graph/backend/dnnl/passes/transform.hpp @@ -305,6 +305,12 @@ status_t fuse_sdpa(std::shared_ptr &sg); /// This pass will transform the gated mlp subgraph into a _gated_mlp op. status_t fuse_gated_mlp(std::shared_ptr &sg); +/// This pass will decompose the softmax with stats output into a normal softmax +/// without stats output and some small ops to compute the stats. +/// The main reason for this pass is that the current implementation +/// of softmax primitive doesn't support stats. +status_t decompose_softmax(std::shared_ptr &sg); + } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/interface/op_def.hpp b/src/graph/interface/op_def.hpp index fb069445412..b4dbf3597d0 100644 --- a/src/graph/interface/op_def.hpp +++ b/src/graph/interface/op_def.hpp @@ -2207,11 +2207,13 @@ DNNL_GRAPH_OP_SCHEMA(_matmul, 1, DNNL_GRAPH_OP_SCHEMA(_softmax, 1, op_schema_t() .set_inputs_option(op_schema_t::param_num_option::variadic) + .set_outputs_option(op_schema_t::param_num_option::optional) .set_num_inputs(std::set({1, 32})) - .set_num_outputs(2) + .set_num_outputs(std::set({2, 3})) .set_input(0, "input") .set_output(0, "output") .set_output(1, "scratchpad") + .set_output(2, "stats") // optional // Attributes inherited from SoftMax .set_attr(op_attr::axis, false, attribute_kind::i, (int64_t)1) .set_attr(op_attr::mode, false, attribute_kind::s, "none", @@ -2221,7 +2223,7 @@ DNNL_GRAPH_OP_SCHEMA(_softmax, 1, .set_attr(op_attr::fusion_info, false, attribute_kind::fusion_info) // Analysis rules - .set_shape_inference_function(infer_identity_output_shape)) + .set_shape_inference_function(infer_dnnl_softmax_output_shape)) DNNL_GRAPH_OP_SCHEMA(_logsoftmax, 1, op_schema_t() @@ -2353,8 +2355,9 @@ DNNL_GRAPH_OP_SCHEMA(_mask, 1, DNNL_GRAPH_OP_SCHEMA(_sdpa, 1, op_schema_t() .set_inputs_option(op_schema_t::param_num_option::variadic) + .set_outputs_option(op_schema_t::param_num_option::optional) .set_num_inputs(std::set({3, 32})) - .set_num_outputs(2) + .set_num_outputs(std::set({2, 3})) .set_input(0, "query") .set_input(1, "key") .set_input(2, "value") @@ -2362,11 +2365,14 @@ DNNL_GRAPH_OP_SCHEMA(_sdpa, 1, .set_input(4, "mask") // optional .set_output(0, "output") .set_output(1, "scratchpad") + .set_output(2, + "softmax_stats") // optional, only used for sdpa training .set_attr(op_attr::fusion_info, false, attribute_kind::fusion_info) .set_attr(op_attr::with_scale, true, attribute_kind::b) .set_attr(op_attr::is_invert_scale, false, attribute_kind::b, false) + .set_attr(op_attr::is_training, true, attribute_kind::b) // mask_type attribute indicates existence of explicit mask, // top-left implicit causal mask or bottm-right implicit causal mask .set_attr(op_attr::mask_type, true, attribute_kind::i) diff --git a/src/graph/interface/shape_infer.cpp b/src/graph/interface/shape_infer.cpp index a5c35d93c60..85415b845e5 100644 --- a/src/graph/interface/shape_infer.cpp +++ b/src/graph/interface/shape_infer.cpp @@ -2388,6 +2388,21 @@ status_t infer_dnnl_sdpa_output_shape(op_t *n, } set_shape_and_strides(*outputs[0], inferred_output_shape); + + if (outputs.size() > 2) { + auto out1 = logical_tensor_wrapper_t(outputs[2]); + dims inferred_stats_shape + = {query_dims[0], query_dims[1], query_dims[2], 1}; + + if (out1.ndims() != -1) { + VCHECK_INVALID_SHAPE(validate(inferred_stats_shape, out1.vdims()), + "%s, given stats shape is not compatible with inferred", + op_t::kind2str(n->get_kind()).c_str()); + } + + set_shape_and_strides(*outputs[2], inferred_stats_shape); + } + return status::success; } @@ -2450,6 +2465,42 @@ status_t infer_gated_mlp_output_shape(op_t *n, return status::success; } +status_t infer_dnnl_softmax_output_shape(op_t *n, + std::vector &inputs, + std::vector &outputs) { + auto out0 = logical_tensor_wrapper_t(outputs[0]); + auto in0 = logical_tensor_wrapper_t(inputs[0]); + + // check if partial set shape aligns with inferred shape + if (out0.ndims() != -1) { + VCHECK_INVALID_SHAPE(validate(in0.vdims(), out0.vdims()), + "%s, input and output shapes are not compatible", + op_t::kind2str(n->get_kind()).c_str()); + } + + // We should compute output dense strides instead of directly copying input + // strides to it + set_shape_and_strides(*outputs[0], in0.vdims()); + if (outputs.size() == 2) return status::success; + + // infer stats output shape + auto out1 = logical_tensor_wrapper_t(outputs[2]); + dims out1_dims = in0.vdims(); + int64_t axis = n->get_attr(op_attr::axis); + if (axis < 0) { axis += in0.ndims(); } + out1_dims[axis] = 1; + + if (out1.ndims() != -1) { + VCHECK_INVALID_SHAPE(validate(out1_dims, out1.vdims()), + "%s, given stats shape is not compatible with inferred", + op_t::kind2str(n->get_kind()).c_str()); + } + + set_shape_and_strides(*outputs[2], out1_dims); + + return status::success; +} + } // namespace graph } // namespace impl } // namespace dnnl diff --git a/src/graph/interface/shape_infer.hpp b/src/graph/interface/shape_infer.hpp index c63bc7be914..97cf18be80d 100644 --- a/src/graph/interface/shape_infer.hpp +++ b/src/graph/interface/shape_infer.hpp @@ -315,6 +315,10 @@ status_t infer_gated_mlp_output_shape(op_t *n, std::vector &inputs, std::vector &outputs); +status_t infer_dnnl_softmax_output_shape(op_t *n, + std::vector &inputs, + std::vector &outputs); + } // namespace graph } // namespace impl } // namespace dnnl From bd12f2dbdcb4f84fc03e1781efd2583463bfddb0 Mon Sep 17 00:00:00 2001 From: "Bao, Yixin" Date: Wed, 25 Feb 2026 22:35:33 -0800 Subject: [PATCH 19/23] graph: backend: dnnl: enable sdpa microkernel for training fwd --- src/graph/backend/dnnl/executables/sdpa.cpp | 19 +++++++- src/graph/backend/dnnl/executables/sdpa.hpp | 1 + .../dnnl/kernels/sdp_primitive_config.cpp | 5 ++- src/graph/backend/dnnl/layout_propagator.cpp | 43 ++++++++++++++++++- src/graph/backend/dnnl/passes/insert_ops.cpp | 12 ++++++ src/graph/backend/dnnl/passes/transform.cpp | 15 ++++++- 6 files changed, 87 insertions(+), 8 deletions(-) diff --git a/src/graph/backend/dnnl/executables/sdpa.cpp b/src/graph/backend/dnnl/executables/sdpa.cpp index 7b84f7feca2..b00899a9346 100644 --- a/src/graph/backend/dnnl/executables/sdpa.cpp +++ b/src/graph/backend/dnnl/executables/sdpa.cpp @@ -25,6 +25,7 @@ sdpa_executable_t::sdpa_executable_t(std::shared_ptr &op, const dnnl::engine &p_engine, pd_cache_t &pd_cache, const fpmath_t &fpmath, bool use_block_layout) : with_scale_(op->get_attr(op_attr::with_scale)) + , is_training_(op->get_attr(op_attr::is_training)) , mask_type_(static_cast( op->get_attr(op_attr::mask_type))) { @@ -76,8 +77,9 @@ sdpa_executable_t::sdpa_executable_t(std::shared_ptr &op, status_t s = create_sdpa_pd(sdpa_pd_, p_engine.get(), md_q.get(), md_k.get(), md_v.get(), md_dst.get(), md_mask.get(), md_scale.get(), is_invert_scale_, kv_head_number, mask_type_, softmax_alg, - impl::prop_kind::forward_inference, attr.get(), qk_attr.get(), - vs_attr.get()); + is_training_ ? impl::prop_kind::forward_training + : impl::prop_kind::forward_inference, + attr.get(), qk_attr.get(), vs_attr.get()); if (s != dnnl::impl::status::success) { is_initialized_ = false; } else { @@ -93,6 +95,9 @@ void sdpa_executable_t::execute(const stream &stream, memory_arg_t mem_arg_k = {(args.at(DNNL_ARG_KEYS)).get(), true}; memory_arg_t mem_arg_v = {(args.at(DNNL_ARG_VALUES)).get(), true}; memory_arg_t mem_arg_dst = {(args.at(DNNL_ARG_DST)).get(), false}; + memory_arg_t mem_arg_stats + = {is_training_ ? (args.at(DNNL_ARG_WORKSPACE)).get() : nullptr, + false}; memory_arg_t mem_arg_scale = {with_scale_ ? (args.at(DNNL_ARG_SCALE)).get() : nullptr, true}; memory_arg_t mem_arg_mask = { @@ -125,6 +130,7 @@ void sdpa_executable_t::execute(const stream &stream, exec_args[DNNL_ARG_KEYS] = mem_arg_k; exec_args[DNNL_ARG_VALUES] = mem_arg_v; exec_args[DNNL_ARG_DST] = mem_arg_dst; + exec_args[DNNL_ARG_WORKSPACE] = mem_arg_stats; exec_args[DNNL_ARG_SCALE] = mem_arg_scale; exec_args[DNNL_ARG_ATTN_MASK] = mem_arg_mask; exec_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS] = mem_arg_k_scale; @@ -148,6 +154,9 @@ ::sycl::event sdpa_executable_t::execute_sycl(const stream &stream, memory_arg_t mem_arg_k = {(args.at(DNNL_ARG_KEYS)).get(), true}; memory_arg_t mem_arg_v = {(args.at(DNNL_ARG_VALUES)).get(), true}; memory_arg_t mem_arg_dst = {(args.at(DNNL_ARG_DST)).get(), false}; + memory_arg_t mem_arg_stats + = {is_training_ ? (args.at(DNNL_ARG_WORKSPACE)).get() : nullptr, + false}; memory_arg_t mem_arg_scale = {with_scale_ ? (args.at(DNNL_ARG_SCALE)).get() : nullptr, true}; memory_arg_t mem_arg_mask = { @@ -180,6 +189,7 @@ ::sycl::event sdpa_executable_t::execute_sycl(const stream &stream, exec_args[DNNL_ARG_KEYS] = mem_arg_k; exec_args[DNNL_ARG_VALUES] = mem_arg_v; exec_args[DNNL_ARG_DST] = mem_arg_dst; + exec_args[DNNL_ARG_WORKSPACE] = mem_arg_stats; exec_args[DNNL_ARG_SCALE] = mem_arg_scale; exec_args[DNNL_ARG_ATTN_MASK] = mem_arg_mask; exec_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS] = mem_arg_k_scale; @@ -215,6 +225,9 @@ cl_event sdpa_executable_t::execute_ocl(const stream &stream, memory_arg_t mem_arg_k = {(args.at(DNNL_ARG_KEYS)).get(), true}; memory_arg_t mem_arg_v = {(args.at(DNNL_ARG_VALUES)).get(), true}; memory_arg_t mem_arg_dst = {(args.at(DNNL_ARG_DST)).get(), false}; + memory_arg_t mem_arg_stats + = {is_training_ ? (args.at(DNNL_ARG_WORKSPACE)).get() : nullptr, + false}; memory_arg_t mem_arg_scale = {with_scale_ ? (args.at(DNNL_ARG_SCALE)).get() : nullptr, true}; memory_arg_t mem_arg_mask = { @@ -247,6 +260,7 @@ cl_event sdpa_executable_t::execute_ocl(const stream &stream, exec_args[DNNL_ARG_KEYS] = mem_arg_k; exec_args[DNNL_ARG_VALUES] = mem_arg_v; exec_args[DNNL_ARG_DST] = mem_arg_dst; + exec_args[DNNL_ARG_WORKSPACE] = mem_arg_stats; exec_args[DNNL_ARG_SCALE] = mem_arg_scale; exec_args[DNNL_ARG_ATTN_MASK] = mem_arg_mask; exec_args[DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS] = mem_arg_k_scale; @@ -322,6 +336,7 @@ arg_indices_t sdpa_executable_t::get_arg_indices(const op_t *op) { // outputs args.insert({DNNL_ARG_DST, {indices_t::type_t::output, 0}}); args.insert({DNNL_ARG_SCRATCHPAD, {indices_t::type_t::output, 1}}); + args.insert({DNNL_ARG_WORKSPACE, {indices_t::type_t::output, 2}}); return args; } diff --git a/src/graph/backend/dnnl/executables/sdpa.hpp b/src/graph/backend/dnnl/executables/sdpa.hpp index a3be2657b7e..a21f3004909 100644 --- a/src/graph/backend/dnnl/executables/sdpa.hpp +++ b/src/graph/backend/dnnl/executables/sdpa.hpp @@ -54,6 +54,7 @@ struct sdpa_executable_t : public op_executable_t { std::shared_ptr sdpa_pd_; std::shared_ptr sdpa_prim_; bool with_scale_; + bool is_training_; bool with_explicit_mask_; attn_mask_type_t mask_type_; bool is_invert_scale_; diff --git a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp index 31081d5cc55..ebd9e988ca3 100644 --- a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp @@ -42,8 +42,6 @@ status_t sdp_primitive_config_t::initial_check( // At least 3 inputs: Q, K, V VCHECK_SDP_PRIMITIVE(inputs.size() >= 3, status::invalid_arguments, "At least 3 inputs are required"); - VCHECK_SDP_PRIMITIVE(outputs.size() == 1, status::unimplemented, - "does not support multiple outputs"); const bool is_f32 = inputs[0].data_type == data_type::f32; bool has_genindex = false; @@ -57,6 +55,9 @@ status_t sdp_primitive_config_t::initial_check( VCHECK_SDP_PRIMITIVE(opk != graph::op_kind::Dequantize && opk != graph::op_kind::Quantize, status::unimplemented, "Not support quantized SDPA"); + // SDPA with Dropout is currently unsupported in the ukernel. + VCHECK_SDP_PRIMITIVE(opk != graph::op_kind::Dropout, status::unimplemented, + "Not support SDPA with Dropout"); if (opk == graph::op_kind::GenIndex) { has_genindex = true; } } diff --git a/src/graph/backend/dnnl/layout_propagator.cpp b/src/graph/backend/dnnl/layout_propagator.cpp index d4ac7ebe619..b6be4d70e9d 100644 --- a/src/graph/backend/dnnl/layout_propagator.cpp +++ b/src/graph/backend/dnnl/layout_propagator.cpp @@ -1738,7 +1738,8 @@ status_t layout_propagator_for_sdpa(std::shared_ptr &op, UNUSED(use_block_layout); UNUSED(rewriter); - value_ptr dst_val = op->get_output_value(0); + size_t output_idx = 0; + value_ptr dst_val = op->get_output_value(output_idx++); const logical_tensor_t &out_lt = dst_val->get_logical_tensor(); dnnl::memory::desc expected_md; @@ -1779,9 +1780,47 @@ status_t layout_propagator_for_sdpa(std::shared_ptr &op, status_t status = fill_layout_info(dst_val, expected_md); // fill scratchpads dimensions and data type to scratchpad value_t - value_ptr scratchpad_val = op->get_output_value(1); + value_ptr scratchpad_val = op->get_output_value(output_idx++); const memory::desc scratchpad_desc; status = fill_layout_info(scratchpad_val, scratchpad_desc); + + if (op->get_attr(op_attr::is_training)) { + value_ptr stats_val = op->get_output_value(output_idx); + const logical_tensor_t &stats_lt = stats_val->get_logical_tensor(); + dnnl::memory::desc stats_md; + // For GQA, we need to check the layout of the dnnl_reshape output + // following dnnl_sdpa, which is given by the user. + if (!stats_val->get_consumers().empty()) { + const auto &consumer_op = stats_val->get_consumers()[0].get_op(); + const logical_tensor_t &consumer_out + = consumer_op.get_output_logical_tensor(0); + if (consumer_op.get_kind() == op_kind::dnnl_reshape + && ltw(consumer_out).ndims() == 5 + && ltw(consumer_out).is_strided()) { + const auto &ori_strides = ltw(consumer_out).vstrides(); + std::vector strides = {ori_strides[0], ori_strides[2], + ori_strides[3], ori_strides[4]}; + stats_md = {ltw(stats_lt).vdims(), + static_cast( + ltw(stats_lt).data_type()), + strides}; + } else { + // Set default output layout format for sdpa as acbd if user + // doesn't specify the layout since no reorder will be required. + stats_md = {ltw(stats_lt).vdims(), + static_cast( + ltw(out_lt).data_type()), + dnnl::memory::format_tag::acbd}; + } + } else { + expected_md = {ltw(out_lt).vdims(), + static_cast( + ltw(out_lt).data_type()), + dnnl::memory::format_tag::acbd}; + } + + status = fill_layout_info(stats_val, stats_md); + } return status; } diff --git a/src/graph/backend/dnnl/passes/insert_ops.cpp b/src/graph/backend/dnnl/passes/insert_ops.cpp index 093561fe7ab..54853ceb7a8 100644 --- a/src/graph/backend/dnnl/passes/insert_ops.cpp +++ b/src/graph/backend/dnnl/passes/insert_ops.cpp @@ -633,6 +633,18 @@ status_t insert_reshape_for_sdpa(std::shared_ptr &sg) { reshape_output->set_attr>( op_attr::shape, expected_output_dims); rewriter.insert_op_after(reshape_output, cur_op, 0); + + // Insert reshape for optional stats output (output 2) + if (cur_op->get_attr(op_attr::is_training)) { + auto stats_dims = ltw(cur_op->get_output_logical_tensor(2)).vdims(); + dims expected_stats_dims = stats_dims; + op_ptr reshape_stats + = std::make_shared(op_kind::dnnl_reshape); + reshape_stats->set_attr(op_attr::special_zero, false); + reshape_stats->set_attr>( + op_attr::shape, expected_stats_dims); + rewriter.insert_op_after(reshape_stats, cur_op, 2); + } } rewriter.run(); return infer_shape(sg); diff --git a/src/graph/backend/dnnl/passes/transform.cpp b/src/graph/backend/dnnl/passes/transform.cpp index 7b9cec17618..27b0dd24039 100644 --- a/src/graph/backend/dnnl/passes/transform.cpp +++ b/src/graph/backend/dnnl/passes/transform.cpp @@ -4705,6 +4705,14 @@ status_t fuse_sdpa(std::shared_ptr &sg) { } else if (op->get_kind() == op_kind::_softmax) { sdpa_op->set_attr( op_attr::mode, op->get_attr(op_attr::mode)); + if (op->num_outputs() < 3) { + sdpa_op->set_attr(op_attr::is_training, false); + } else { + sdpa_op->set_attr(op_attr::is_training, true); + auto stats_output = op->get_output_value(2); + stats_output->set_producer(*sdpa_op); + sdpa_op->connect_output(2, stats_output); + } } } @@ -4767,9 +4775,12 @@ status_t fuse_sdpa(std::shared_ptr &sg) { auto final_output = vs->get_output_value(0); final_output->set_producer(*sdpa_op); - sdpa_op->add_output(final_output); + sdpa_op->connect_output(0, final_output); - insert_empty_scratchpad(sdpa_op); + logical_tensor_t lt = empty_logical_tensor_with_default_id(); + auto scratchpad_val = std::make_shared(*sdpa_op, 1, lt); + sdpa_op->connect_output(1, scratchpad_val); + scratchpad_val->set_data_type(graph::data_type::u8); for (auto &op : candidates) { rewriter.to_remove(op); From 218f0e9f795681107a6129277cf2cc0e45f8cb94 Mon Sep 17 00:00:00 2001 From: "Bao, Yixin" Date: Thu, 5 Mar 2026 22:56:51 -0800 Subject: [PATCH 20/23] graph: backend: dnnl: enable sdpa microkernel for training bwd --- src/graph/backend/dnnl/executables/sdpa.cpp | 285 ++++++++++++ src/graph/backend/dnnl/executables/sdpa.hpp | 34 ++ src/graph/backend/dnnl/kernels/sdp_bwd.hpp | 121 +++++ .../dnnl/kernels/sdp_bwd_primitive.cpp | 277 +++++++++++ .../dnnl/kernels/sdp_bwd_primitive.hpp | 103 +++++ src/graph/backend/dnnl/layout_propagator.cpp | 143 +++++- src/graph/backend/dnnl/layout_propagator.hpp | 1 + src/graph/backend/dnnl/op_executable.cpp | 3 + src/graph/backend/dnnl/passes/compile_ops.cpp | 8 + src/graph/backend/dnnl/passes/insert_ops.cpp | 151 +++++- src/graph/backend/dnnl/passes/insert_ops.hpp | 8 + src/graph/backend/dnnl/passes/transform.cpp | 430 +++++++++++++++++- src/graph/backend/dnnl/passes/transform.hpp | 2 + src/graph/backend/dnnl/patterns/sdp.cpp | 210 +++++++-- src/graph/interface/c_types_map.hpp | 1 + src/graph/interface/op.hpp | 1 + src/graph/interface/op_def.hpp | 32 ++ src/graph/interface/opset.hpp | 1 + src/graph/interface/shape_infer.cpp | 55 +++ src/graph/interface/shape_infer.hpp | 4 + 20 files changed, 1819 insertions(+), 51 deletions(-) create mode 100644 src/graph/backend/dnnl/kernels/sdp_bwd.hpp create mode 100644 src/graph/backend/dnnl/kernels/sdp_bwd_primitive.cpp create mode 100644 src/graph/backend/dnnl/kernels/sdp_bwd_primitive.hpp diff --git a/src/graph/backend/dnnl/executables/sdpa.cpp b/src/graph/backend/dnnl/executables/sdpa.cpp index b00899a9346..7a6a5e1188f 100644 --- a/src/graph/backend/dnnl/executables/sdpa.cpp +++ b/src/graph/backend/dnnl/executables/sdpa.cpp @@ -340,6 +340,291 @@ arg_indices_t sdpa_executable_t::get_arg_indices(const op_t *op) { return args; } +sdpa_bwd_executable_t::sdpa_bwd_executable_t(std::shared_ptr &op, + const dnnl::engine &p_engine, pd_cache_t &pd_cache, + const fpmath_t &fpmath, bool use_block_layout) + : with_scale_(op->get_attr(op_attr::with_scale)) + , mask_type_(static_cast( + op->get_attr(op_attr::mask_type))) + , is_invert_scale_(op->has_attr(op_attr::is_invert_scale) + ? op->get_attr(op_attr::is_invert_scale) + : false) + , with_explicit_mask_(mask_type_ == attn_mask_type::buffer) { + // Op inputs: Q(0) K(1) V(2) dst/O(3) stats(4) diff_dst/dO(5) [scale(6)] [mask(7)] + // Op outputs: diff_q(0) diff_k(1) diff_v(2) scratchpad(3) [diff_mask/dS(4)] + auto md_q = make_dnnl_memory_desc(op->get_input_logical_tensor(0)); + auto md_k = make_dnnl_memory_desc(op->get_input_logical_tensor(1)); + auto md_v = make_dnnl_memory_desc(op->get_input_logical_tensor(2)); + auto md_dst = make_dnnl_memory_desc(op->get_input_logical_tensor(3)); + // stats at input 4 is consumed as DNNL_ARG_WORKSPACE at execute time + auto md_diff_dst = make_dnnl_memory_desc(op->get_input_logical_tensor(5)); + auto md_diff_q = make_dnnl_memory_desc(op->get_output_logical_tensor(0)); + auto md_diff_k = make_dnnl_memory_desc(op->get_output_logical_tensor(1)); + auto md_diff_v = make_dnnl_memory_desc(op->get_output_logical_tensor(2)); + + // Optional scale, attn_mask, dS (diff_mask) + dnnl::memory::desc md_scale, md_attn_mask, md_dS; + size_t idx = 6; + if (with_scale_) { + md_scale = make_dnnl_memory_desc(op->get_input_logical_tensor(idx++)); + } + if (with_explicit_mask_) { + md_attn_mask + = make_dnnl_memory_desc(op->get_input_logical_tensor(idx++)); + if (op->num_outputs() > 4) { + md_dS = make_dnnl_memory_desc(op->get_output_logical_tensor(4)); + } + } + + // Fusion info and attributes (if any) + const auto &sdpa_fusion_info = op->has_attr(op_attr::fusion_info) + ? op->get_attr(op_attr::fusion_info) + : fusion_info_t(); + dnnl::primitive_attr attr, qk_attr, vs_attr; + if (op->has_attr(op_attr::fusion_info)) { + qk_attr = make_dnnl_sdpa_primitive_attr( + op, sdpa_fusion_info, attr_type_t::QK); + vs_attr = make_dnnl_sdpa_primitive_attr( + op, sdpa_fusion_info, attr_type_t::VS); + } + // Set accumulation mode: the two attributes are requested for + // dnnl_sdpa, so we can get them directly without calling has_attr(). + qk_attr.set_accumulation_mode(str2accumulation_mode( + op->get_attr(op_attr::qk_acc_mode))); + vs_attr.set_accumulation_mode(str2accumulation_mode( + op->get_attr(op_attr::vs_acc_mode))); + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + attr.set_fpmath_mode(static_cast(fpmath.mode_)); + + dim_t kv_head_number = op->get_input_logical_tensor(1).dims[1]; + const alg_kind_t softmax_alg = alg_kind::softmax_accurate_inf_as_zero; + + // create hint_fwd pd + std::shared_ptr hint_fwd_pd; + status_t s = create_sdpa_pd(hint_fwd_pd, p_engine.get(), md_q.get(), + md_k.get(), md_v.get(), md_dst.get(), md_attn_mask.get(), + md_scale.get(), is_invert_scale_, kv_head_number, mask_type_, + softmax_alg, impl::prop_kind::forward_training, attr.get(), qk_attr.get(), vs_attr.get()); + if (s != dnnl::impl::status::success) { + is_initialized_ = false; + return; + } + + s = create_sdpa_pd(sdpa_bwd_pd_, p_engine.get(), md_q.get(), md_k.get(), + md_v.get(), md_dst.get(), md_diff_q.get(), md_diff_k.get(), + md_diff_v.get(), md_diff_dst.get(), md_dS.get(), md_attn_mask.get(), + md_scale.get(), is_invert_scale_, kv_head_number, mask_type_, + softmax_alg, attr.get(), hint_fwd_pd.get(), qk_attr.get(), + vs_attr.get()); + + if (s != dnnl::impl::status::success) { + is_initialized_ = false; + } else { + s = sdpa_bwd_pd_->create_primitive(sdpa_bwd_prim_, p_engine.get()); + is_initialized_ = s == status::success; + } + +} + +void sdpa_bwd_executable_t::execute(const stream &stream, + const std::unordered_map &args) const { + exec_args_t exec_args; + memory_arg_t mem_arg_q = {args.at(DNNL_ARG_QUERIES).get(), true}; + memory_arg_t mem_arg_k = {args.at(DNNL_ARG_KEYS).get(), true}; + memory_arg_t mem_arg_v = {args.at(DNNL_ARG_VALUES).get(), true}; + memory_arg_t mem_arg_dst = {args.at(DNNL_ARG_DST).get(), true}; + memory_arg_t mem_arg_diff_dst = {args.at(DNNL_ARG_DIFF_DST).get(), true}; + memory_arg_t mem_arg_workspace = {args.at(DNNL_ARG_WORKSPACE).get(), true}; + memory_arg_t mem_arg_scale + = {with_scale_ ? args.at(DNNL_ARG_SCALE).get() : nullptr, true}; + memory_arg_t mem_arg_mask = { + with_explicit_mask_ ? args.at(DNNL_ARG_ATTN_MASK).get() : nullptr, + true}; + memory_arg_t mem_arg_diff_q = {args.at(DNNL_ARG_DIFF_QUERIES).get(), false}; + memory_arg_t mem_arg_diff_k = {args.at(DNNL_ARG_DIFF_KEYS).get(), false}; + memory_arg_t mem_arg_diff_v = {args.at(DNNL_ARG_DIFF_VALUES).get(), false}; + memory_arg_t mem_arg_scratchpad + = {args.at(DNNL_ARG_SCRATCHPAD).get(), false}; + + exec_args[DNNL_ARG_QUERIES] = mem_arg_q; + exec_args[DNNL_ARG_KEYS] = mem_arg_k; + exec_args[DNNL_ARG_VALUES] = mem_arg_v; + exec_args[DNNL_ARG_DST] = mem_arg_dst; + exec_args[DNNL_ARG_DIFF_DST] = mem_arg_diff_dst; + exec_args[DNNL_ARG_WORKSPACE] = mem_arg_workspace; + exec_args[DNNL_ARG_SCALE] = mem_arg_scale; + exec_args[DNNL_ARG_ATTN_MASK] = mem_arg_mask; + exec_args[DNNL_ARG_DIFF_QUERIES] = mem_arg_diff_q; + exec_args[DNNL_ARG_DIFF_KEYS] = mem_arg_diff_k; + exec_args[DNNL_ARG_DIFF_VALUES] = mem_arg_diff_v; + exec_args[DNNL_ARG_SCRATCHPAD] = mem_arg_scratchpad; + if (args.count(DNNL_ARG_DS)) + exec_args[DNNL_ARG_DS] = {args.at(DNNL_ARG_DS).get(), false}; + + exec_ctx_t ctx(stream.get(), std::move(exec_args)); + + // Set up scratchpad grantor required by the primitive's execute + const memory_storage_t *mem_storage = nullptr; + memory_t *scratchpad_memory = ctx.output(DNNL_ARG_SCRATCHPAD); + if (scratchpad_memory) + mem_storage = scratchpad_memory->memory_storage(); + const void *host_ptr + = ctx.host_ptr(mem_storage, /* require_host_ptr = */ true); + auto *scratchpad_grantor + = sdpa_bwd_pd_->scratchpad_registry().create_grantor( + mem_storage, host_ptr); + ctx.set_scratchpad_grantor(scratchpad_grantor); + + sdpa_bwd_prim_->execute(ctx); +} + +#ifdef DNNL_WITH_SYCL +::sycl::event sdpa_bwd_executable_t::execute_sycl(const stream &stream, + const std::unordered_map &args, + const std::vector<::sycl::event> &deps) const { + exec_args_t exec_args; + exec_args[DNNL_ARG_QUERIES] = {args.at(DNNL_ARG_QUERIES).get(), true}; + exec_args[DNNL_ARG_KEYS] = {args.at(DNNL_ARG_KEYS).get(), true}; + exec_args[DNNL_ARG_VALUES] = {args.at(DNNL_ARG_VALUES).get(), true}; + exec_args[DNNL_ARG_DST] = {args.at(DNNL_ARG_DST).get(), true}; + exec_args[DNNL_ARG_DIFF_DST] = {args.at(DNNL_ARG_DIFF_DST).get(), true}; + exec_args[DNNL_ARG_WORKSPACE] = {args.at(DNNL_ARG_WORKSPACE).get(), true}; + exec_args[DNNL_ARG_SCALE] + = {with_scale_ ? args.at(DNNL_ARG_SCALE).get() : nullptr, true}; + exec_args[DNNL_ARG_ATTN_MASK] = { + with_explicit_mask_ ? args.at(DNNL_ARG_ATTN_MASK).get() : nullptr, + true}; + exec_args[DNNL_ARG_DIFF_QUERIES] + = {args.at(DNNL_ARG_DIFF_QUERIES).get(), false}; + exec_args[DNNL_ARG_DIFF_KEYS] = {args.at(DNNL_ARG_DIFF_KEYS).get(), false}; + exec_args[DNNL_ARG_DIFF_VALUES] + = {args.at(DNNL_ARG_DIFF_VALUES).get(), false}; + exec_args[DNNL_ARG_SCRATCHPAD] + = {args.at(DNNL_ARG_SCRATCHPAD).get(), false}; + if (args.count(DNNL_ARG_DS)) + exec_args[DNNL_ARG_DS] = {args.at(DNNL_ARG_DS).get(), false}; + + auto strm_t = stream.get(); + exec_ctx_t ctx(strm_t, std::move(exec_args)); + auto *sycl_stream_impl = dnnl::impl::utils::downcast< + dnnl::impl::xpu::sycl::stream_impl_t *>(strm_t->impl()); + + // Set up scratchpad grantor required by the primitive's execute + const memory_storage_t *mem_storage_sycl = nullptr; + memory_t *scratchpad_memory_sycl = ctx.output(DNNL_ARG_SCRATCHPAD); + if (scratchpad_memory_sycl) + mem_storage_sycl = scratchpad_memory_sycl->memory_storage(); + const void *host_ptr_sycl + = ctx.host_ptr(mem_storage_sycl, /* require_host_ptr = */ true); + auto *scratchpad_grantor_sycl + = sdpa_bwd_pd_->scratchpad_registry().create_grantor( + mem_storage_sycl, host_ptr_sycl); + ctx.set_scratchpad_grantor(scratchpad_grantor_sycl); + + strm_t->before_exec_hook(); + if (!deps.empty()) sycl_stream_impl->sycl_ctx().set_deps(deps); + + sdpa_bwd_prim_->execute(ctx); + + ::sycl::event return_event = sycl_stream_impl->get_output_event(); + strm_t->after_exec_hook(); + return return_event; +} +#endif + +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL +cl_event sdpa_bwd_executable_t::execute_ocl(const stream &stream, + const std::unordered_map &args, + const std::vector &deps) const { + exec_args_t exec_args; + exec_args[DNNL_ARG_QUERIES] = {args.at(DNNL_ARG_QUERIES).get(), true}; + exec_args[DNNL_ARG_KEYS] = {args.at(DNNL_ARG_KEYS).get(), true}; + exec_args[DNNL_ARG_VALUES] = {args.at(DNNL_ARG_VALUES).get(), true}; + exec_args[DNNL_ARG_DST] = {args.at(DNNL_ARG_DST).get(), true}; + exec_args[DNNL_ARG_DIFF_DST] = {args.at(DNNL_ARG_DIFF_DST).get(), true}; + exec_args[DNNL_ARG_WORKSPACE] = {args.at(DNNL_ARG_WORKSPACE).get(), true}; + exec_args[DNNL_ARG_SCALE] + = {with_scale_ ? args.at(DNNL_ARG_SCALE).get() : nullptr, true}; + exec_args[DNNL_ARG_ATTN_MASK] = { + with_explicit_mask_ ? args.at(DNNL_ARG_ATTN_MASK).get() : nullptr, + true}; + exec_args[DNNL_ARG_DIFF_QUERIES] + = {args.at(DNNL_ARG_DIFF_QUERIES).get(), false}; + exec_args[DNNL_ARG_DIFF_KEYS] = {args.at(DNNL_ARG_DIFF_KEYS).get(), false}; + exec_args[DNNL_ARG_DIFF_VALUES] + = {args.at(DNNL_ARG_DIFF_VALUES).get(), false}; + exec_args[DNNL_ARG_SCRATCHPAD] + = {args.at(DNNL_ARG_SCRATCHPAD).get(), false}; + if (args.count(DNNL_ARG_DS)) + exec_args[DNNL_ARG_DS] = {args.at(DNNL_ARG_DS).get(), false}; + + exec_ctx_t ctx(stream.get(), std::move(exec_args)); + + // Set up scratchpad grantor required by the primitive's execute + const memory_storage_t *mem_storage_ocl = nullptr; + memory_t *scratchpad_memory_ocl = ctx.output(DNNL_ARG_SCRATCHPAD); + if (scratchpad_memory_ocl) + mem_storage_ocl = scratchpad_memory_ocl->memory_storage(); + const void *host_ptr_ocl + = ctx.host_ptr(mem_storage_ocl, /* require_host_ptr = */ true); + auto *scratchpad_grantor_ocl + = sdpa_bwd_pd_->scratchpad_registry().create_grantor( + mem_storage_ocl, host_ptr_ocl); + ctx.set_scratchpad_grantor(scratchpad_grantor_ocl); + + auto *ocl_stream = dnnl::impl::utils::downcast( + stream.get()); + ocl_stream->before_exec_hook(); + + if (!deps.empty()) { + std::vector> events(deps.size()); + for (size_t i = 0; i < deps.size(); i++) + events[i] = xpu::ocl::wrapper_t(deps[i], true); + ocl_stream->ocl_ctx().set_deps(events); + } + + sdpa_bwd_prim_->execute(ctx); + + cl_event return_event = nullptr; + if ((ocl_stream->flags() & stream_flags::in_order) == 0) { + auto last = ocl_stream->get_output_event(); + return_event = last.release(); + } + + ocl_stream->after_exec_hook(); + return return_event; +} +#endif + +arg_indices_t sdpa_bwd_executable_t::get_arg_indices(const op_t *op) { + arg_indices_t args; + // inputs: Q, K, V, dst(O), stats(workspace), diff_dst(dO) + size_t idx = 0; + args.insert({DNNL_ARG_QUERIES, {indices_t::type_t::input, idx++}}); + args.insert({DNNL_ARG_KEYS, {indices_t::type_t::input, idx++}}); + args.insert({DNNL_ARG_VALUES, {indices_t::type_t::input, idx++}}); + args.insert({DNNL_ARG_DST, {indices_t::type_t::input, idx++}}); + args.insert({DNNL_ARG_WORKSPACE, {indices_t::type_t::input, idx++}}); + args.insert({DNNL_ARG_DIFF_DST, {indices_t::type_t::input, idx++}}); + // optional: scale, mask + if (op->get_attr(op_attr::with_scale)) { + args.insert({DNNL_ARG_SCALE, {indices_t::type_t::input, idx++}}); + } + if (op->get_attr(op_attr::mask_type) + == static_cast(attn_mask_type::buffer)) { + args.insert({DNNL_ARG_ATTN_MASK, {indices_t::type_t::input, idx++}}); + } + // outputs: diff_q, diff_k, diff_v, scratchpad, [dS] + args.insert({DNNL_ARG_DIFF_QUERIES, {indices_t::type_t::output, 0}}); + args.insert({DNNL_ARG_DIFF_KEYS, {indices_t::type_t::output, 1}}); + args.insert({DNNL_ARG_DIFF_VALUES, {indices_t::type_t::output, 2}}); + args.insert({DNNL_ARG_SCRATCHPAD, {indices_t::type_t::output, 3}}); + if (op->num_outputs() > 4) { + args.insert({DNNL_ARG_DS, {indices_t::type_t::output, 4}}); + } + return args; +} + } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/backend/dnnl/executables/sdpa.hpp b/src/graph/backend/dnnl/executables/sdpa.hpp index a21f3004909..eb740f8f458 100644 --- a/src/graph/backend/dnnl/executables/sdpa.hpp +++ b/src/graph/backend/dnnl/executables/sdpa.hpp @@ -61,6 +61,40 @@ struct sdpa_executable_t : public op_executable_t { bool is_initialized_; }; +struct sdpa_bwd_executable_t : public op_executable_t { + DECLARE_ARG_INDICES_GETTER; + + sdpa_bwd_executable_t(std::shared_ptr &op, + const dnnl::engine &p_engine, pd_cache_t &pd_cache, + const fpmath_t &fpmath, bool use_block_layout); + + bool is_initialized() const { return is_initialized_; } + + void execute(const stream &stream, + const std::unordered_map &args) const override; + +#ifdef DNNL_WITH_SYCL + ::sycl::event execute_sycl(const stream &stream, + const std::unordered_map &args, + const std::vector<::sycl::event> &deps) const override; +#endif + +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL + cl_event execute_ocl(const stream &stream, + const std::unordered_map &args, + const std::vector &deps) const override; +#endif + +private: + std::shared_ptr sdpa_bwd_pd_; + std::shared_ptr sdpa_bwd_prim_; + bool with_scale_; + attn_mask_type_t mask_type_; + bool is_invert_scale_; + bool with_explicit_mask_; + bool is_initialized_; +}; + } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/backend/dnnl/kernels/sdp_bwd.hpp b/src/graph/backend/dnnl/kernels/sdp_bwd.hpp new file mode 100644 index 00000000000..c9827c4b157 --- /dev/null +++ b/src/graph/backend/dnnl/kernels/sdp_bwd.hpp @@ -0,0 +1,121 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GRAPH_BACKEND_DNNL_KERNELS_SDP_BWD_HPP +#define GRAPH_BACKEND_DNNL_KERNELS_SDP_BWD_HPP + +#include +#include +#include +#include +#include + +#include "graph/backend/dnnl/kernels/kernel_base.hpp" +#include "graph/backend/dnnl/kernels/large_partition.hpp" +#include "graph/backend/dnnl/kernels/sdp_bwd_primitive.hpp" + +#include "graph/backend/dnnl/dnnl_partition_impl.hpp" + +#define VDISPATCH_GRAPH_SDP_BWD(msg, ...) \ + VINFO(graph, create, dispatch, compile, msg, ##__VA_ARGS__) + +namespace dnnl { +namespace impl { +namespace graph { +namespace dnnl_impl { + +struct sdp_bwd_base_t : public kernel_base_t { +private: + std::shared_ptr kernel; + +public: + status_t compile_impl(const dnnl_partition_impl_t *part, + const engine_t *g_engine, + const std::vector &inputs, + const std::vector &outputs) override { + const engine_kind_t ekind = g_engine->kind(); + bool enable_ukernel = false; + + if (ekind == engine_kind::gpu) { + enable_ukernel = !force_primitive(); + } else if (ekind != engine_kind::cpu) { + assert(!"unknown engine kind"); + return status::invalid_arguments; + } + + status_t ret = status::unimplemented; + + if (enable_ukernel) { + kernel = std::make_shared(); + ret = kernel->compile_impl(part, g_engine, inputs, outputs); + } + + if (ret != status::success) { + kernel = std::make_shared(); + ret = kernel->compile_impl(part, g_engine, inputs, outputs); + } + if (ret == status::success) + VDISPATCH_GRAPH_SDP_BWD( + "sdpa_bwd is dispatched to (%s)", kernel->str().c_str()); + else + VDISPATCH_GRAPH_SDP_BWD("sdpa_bwd is failed to dispatch"); + return ret; + } + + // An internal env var is provided to force using primitive based SDPA + // backward implementation and skipping ukernel based optimization on GPU. + // Currently it's for oneDNN debug and testing only. + bool force_primitive() const { + const int force = graph::utils::getenv_int_internal( + "GRAPH_SDPA_FORCE_PRIMITIVE", 0); + return force > 0; + } + + status_t execute_impl(const stream_t *g_stream, + const std::vector &inputs, + const std::vector &outputs) override { + return kernel->execute_impl(g_stream, inputs, outputs); + } + +#ifdef DNNL_WITH_SYCL + status_t sycl_execute_impl(const stream_t *g_stream, + const std::vector &inputs, + const std::vector &outputs, + const std::vector<::sycl::event> &sycl_deps, + ::sycl::event *sycl_event) override { + return kernel->sycl_execute_impl( + g_stream, inputs, outputs, sycl_deps, sycl_event); + } +#endif + +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL + status_t ocl_execute_impl(const stream_t *g_stream, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &deps, cl_event *event) override { + return kernel->ocl_execute_impl(g_stream, inputs, outputs, deps, event); + } +#endif + + std::string str() const override { return kernel->str(); } +}; + +} // namespace dnnl_impl +} // namespace graph +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/graph/backend/dnnl/kernels/sdp_bwd_primitive.cpp b/src/graph/backend/dnnl/kernels/sdp_bwd_primitive.cpp new file mode 100644 index 00000000000..822c2b2285a --- /dev/null +++ b/src/graph/backend/dnnl/kernels/sdp_bwd_primitive.cpp @@ -0,0 +1,277 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "graph/backend/dnnl/kernels/sdp_bwd_primitive.hpp" + +#include "common/sdpa_pd.hpp" + +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL +#include "gpu/intel/ocl/stream.hpp" +#elif DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL +#include "gpu/intel/sycl/stream.hpp" +#endif + +#include "graph/backend/dnnl/passes/compile_ops.hpp" +#include "graph/backend/dnnl/passes/constant_propagation.hpp" +#include "graph/backend/dnnl/passes/insert_ops.hpp" +#include "graph/backend/dnnl/passes/layout_propagation.hpp" +#include "graph/backend/dnnl/passes/lower.hpp" +#include "graph/backend/dnnl/passes/memory_planning.hpp" +#include "graph/backend/dnnl/passes/transform.hpp" +#include "graph/backend/dnnl/passes/utils.hpp" + +#include "graph/backend/dnnl/op_executable.hpp" + +#include "common/verbose.hpp" + +#define VCHECK_SDP_BWD_PRIMITIVE(cond, status, msg, ...) \ + VCONDCHECK(graph, create, check, sdp_bwd_primitive_kernel_t, (cond), \ + status, msg, ##__VA_ARGS__); + +namespace dnnl { +namespace impl { +namespace graph { +namespace dnnl_impl { + +status_t sdp_bwd_primitive_kernel_t::initial_check( + const std::shared_ptr &sg, + const std::vector &inputs, + const std::vector &outputs) { + const bool is_f32 = inputs[0].data_type == data_type::f32; + VCHECK_SDP_BWD_PRIMITIVE(!is_f32, + status::unimplemented, + "SDPA bwd primitive doesn't support f32 because of performance"); + + bool has_dropout = false; + for (const auto &cur_op : sg->get_ops()) { + const auto opk = cur_op->get_kind(); + if (opk == graph::op_kind::Dropout) { + has_dropout = true; + break; + } + } + VCHECK_SDP_BWD_PRIMITIVE(!has_dropout, + status::unimplemented, + "SDPA bwd primitive doesn't support Dropout for now"); + + bool has_host_scalar = false; + for (const auto < : inputs) { + if (logical_tensor_wrapper_t(lt).is_host_scalar()) { + has_host_scalar = true; + break; + } + } + VCHECK_SDP_BWD_PRIMITIVE(!has_host_scalar, + status::unimplemented, + "SDPA bwd primitive doesn't support host scalar inputs for now"); + + return status::success; +} + +status_t sdp_bwd_primitive_kernel_t::compile_impl( + const dnnl_partition_impl_t *part, const engine_t *g_engine, + const std::vector &inputs, + const std::vector &outputs) { +// sdp_bwd_primitive_kernel_t only supports Intel GPU. +#if defined(DNNL_WITH_SYCL) && DNNL_GPU_VENDOR != DNNL_VENDOR_INTEL + return status::unimplemented; +#endif + + p_engine_ = make_dnnl_engine(*g_engine); + g_alloc_ + = reinterpret_cast(g_engine->get_allocator()); + + // First, dry run on a deep copy + subgraph_ + = std::make_shared(graph_t::deep_copy(part->get_ops()), + p_engine_, part->get_fpmath_mode(), false, true); + CHECK(set_given_inputs_outputs(subgraph_, inputs, outputs)); + CHECK(initial_check(subgraph_, inputs, outputs)); + + subgraph_visualizer_t vis(part->id(), [this](const value_t *val) { + return this->memory_planner_.get_memory_info(val); + }); + pass_pipeline_t pipeline = pass_pipeline_t(vis); + + BACKEND_DNNL_ADD_PASS(pipeline, lower_down); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_implicit_causal_mask); + BACKEND_DNNL_ADD_PASS(pipeline, insert_permute_for_matmul); + + pipeline.reset_visualize_arg(true, false); + BACKEND_DNNL_ADD_PASS(pipeline, infer_shape); + BACKEND_DNNL_ADD_PASS(pipeline, fuse_sdpa_bwd); + BACKEND_DNNL_ADD_PASS(pipeline, insert_permute_for_sdpa_bwd); + BACKEND_DNNL_ADD_PASS(pipeline, insert_reshape_for_sdpa_bwd); + BACKEND_DNNL_ADD_PASS(pipeline, layout_propagation); + + // bind the memory for each op + auto memory_plan = [&](std::shared_ptr &sg) { + return memory_planner_.run(sg); + }; + pipeline.reset_visualize_arg(true, true); + BACKEND_DNNL_ADD_PASS(pipeline, memory_plan); + BACKEND_DNNL_ADD_PASS(pipeline, compile_ops); + + // Run the added passes + BACKEND_DNNL_CHECK(pipeline.run(subgraph_)); + + // fill information for inputs logical tensors + for (size_t i = 0; i < inputs.size(); i++) { + auto &in = const_cast(inputs[i]); + in = subgraph_->ins_[i]; + } + + // fill information for outputs logical tensors + for (size_t i = 0; i < outputs.size(); i++) { + auto &out = const_cast(outputs[i]); + out = subgraph_->outs_[i]; + } + + resource_ctor_ = [this]() { + return this->memory_planner_.get_exec_args_set().clone(); + }; + + return status::success; +} + +void sdp_bwd_primitive_kernel_t::prepare_args_set( + const execution_args_set_t *res, const std::vector &inputs, + const std::vector &outputs, const scratchpad_t &scratchpad) { + // update the data of partition in/outputs args + for (const auto &mem_idx : res->get_mems_use_external_inputs()) { + const dnnl::memory &mem = mem_idx.first; + const tensor_t &ts = inputs[mem_idx.second]; + const logical_tensor_t lt = ts.get_logical_tensor(); + const logical_tensor_wrapper_t ltw(lt); + if (ltw.is_host_scalar()) { + DNNL_HOST_SCALAR_TYPE_SWITCH(ltw.data_type(), DType, { + mem.set_host_scalar_value( + *static_cast(ts.get_data_handle())); + }); + } else { + mem.set_data_handle(ts.get_data_handle()); + } + } + + for (const auto &mem_idx : res->get_mems_use_external_outputs()) { + mem_idx.first.set_data_handle( + outputs[mem_idx.second].get_data_handle()); + } + + grantor_t var_grantor = memory_planner_.internal_temporary_grantor( + scratchpad.get_buffer()); + + for (auto &mem_offkey : res->get_mems_use_internal_temporary()) { + mem_offkey.first.set_data_handle(var_grantor.get(mem_offkey.second)); + } +} + +status_t sdp_bwd_primitive_kernel_t::execute_impl(const stream_t *g_stream, + const std::vector &inputs, + const std::vector &outputs) { + dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream); + + thread_local_cache_t res_cache; + execution_args_set_t *res = res_cache.get_or_add( + reinterpret_cast(this), resource_ctor_); + + temporary_scratchpad_t scratchpad( + memory_planner_.total_internal_temporary_size(), p_engine_, + *g_alloc_); + prepare_args_set(res, inputs, outputs, scratchpad); + + for (size_t i = 0; i < subgraph_->execs_.size(); i++) { + subgraph_->execs_[i]->execute(p_stream, res->get_exec_args()[i]); + } + + return status::success; +} + +#ifdef DNNL_WITH_SYCL +status_t sdp_bwd_primitive_kernel_t::sycl_execute_impl(const stream_t *g_stream, + const std::vector &inputs, + const std::vector &outputs, + const std::vector<::sycl::event> &sycl_deps, + ::sycl::event *sycl_event) { +// sdp_bwd_primitive_kernel_t only supports Intel GPU. +#if DNNL_GPU_VENDOR != DNNL_VENDOR_INTEL + return status::unimplemented; +#endif + auto deps = sycl_deps; + ::sycl::event returned_event; + + dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream); + + thread_local_cache_t res_cache; + execution_args_set_t *res = res_cache.get_or_add( + reinterpret_cast(this), resource_ctor_); + + temporary_scratchpad_t scratchpad( + memory_planner_.total_internal_temporary_size(), p_engine_, + *g_alloc_); + prepare_args_set(res, inputs, outputs, scratchpad); + + for (size_t i = 0; i < subgraph_->execs_.size(); i++) { + if (subgraph_->is_constant_[i]) continue; + returned_event = subgraph_->execs_[i]->execute_sycl( + p_stream, res->get_exec_args()[i], deps); + deps = {returned_event}; + } + + scratchpad.set_deps(returned_event); + if (sycl_event) *sycl_event = returned_event; + + return status::success; +} +#endif + +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL +status_t sdp_bwd_primitive_kernel_t::ocl_execute_impl(const stream_t *g_stream, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &cl_deps, cl_event *ret_event) { + auto deps = cl_deps; + cl_event returned_event {}; + + dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream); + + thread_local_cache_t res_cache; + execution_args_set_t *res = res_cache.get_or_add( + reinterpret_cast(this), resource_ctor_); + + temporary_scratchpad_t scratchpad( + memory_planner_.total_internal_temporary_size(), p_engine_, + *g_alloc_); + prepare_args_set(res, inputs, outputs, scratchpad); + + for (size_t i = 0; i < subgraph_->execs_.size(); i++) { + if (subgraph_->is_constant_[i]) continue; + returned_event = subgraph_->execs_[i]->execute_ocl( + p_stream, res->get_exec_args()[i], deps); + deps = {returned_event}; + } + + scratchpad.set_deps(returned_event); + if (ret_event) *ret_event = returned_event; + + return status::success; +} +#endif + +} // namespace dnnl_impl +} // namespace graph +} // namespace impl +} // namespace dnnl diff --git a/src/graph/backend/dnnl/kernels/sdp_bwd_primitive.hpp b/src/graph/backend/dnnl/kernels/sdp_bwd_primitive.hpp new file mode 100644 index 00000000000..3fb72967eb2 --- /dev/null +++ b/src/graph/backend/dnnl/kernels/sdp_bwd_primitive.hpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef GRAPH_BACKEND_DNNL_KERNELS_SDP_BWD_PRIMITIVE_HPP +#define GRAPH_BACKEND_DNNL_KERNELS_SDP_BWD_PRIMITIVE_HPP + +#include +#include +#include +#include +#include + +#include "graph/backend/dnnl/common.hpp" +#include "graph/backend/dnnl/dnnl_constant_tensor_cache.hpp" +#include "graph/backend/dnnl/dnnl_partition_impl.hpp" +#include "graph/backend/dnnl/op_executable.hpp" +#include "graph/backend/dnnl/scratchpad.hpp" +#include "graph/backend/dnnl/thread_local_cache.hpp" +#include "graph/backend/dnnl/utils.hpp" + +#include "graph/backend/dnnl/passes/memory_planning.hpp" + +namespace dnnl { +namespace impl { +namespace graph { +namespace dnnl_impl { + +struct sdp_bwd_primitive_kernel_t : public kernel_base_t { +private: + allocator_t *g_alloc_ = nullptr; + + std::shared_ptr subgraph_; + memory_planner_t memory_planner_; + std::function()> resource_ctor_; + +public: + sdp_bwd_primitive_kernel_t() { + thread_local_cache_t res_cache; + res_cache.retain(); + } + + ~sdp_bwd_primitive_kernel_t() override { + thread_local_cache_t res_cache; + res_cache.remove_if_exist(reinterpret_cast(this)); + res_cache.release(); + } + + status_t compile_impl(const dnnl_partition_impl_t *part, + const engine_t *g_engine, + const std::vector &inputs, + const std::vector &outputs) override; + + void prepare_args_set(const execution_args_set_t *res, + const std::vector &inputs, + const std::vector &outputs, + const scratchpad_t &scratchpad); + + status_t execute_impl(const stream_t *g_stream, + const std::vector &inputs, + const std::vector &outputs) override; + +#ifdef DNNL_WITH_SYCL + status_t sycl_execute_impl(const stream_t *g_stream, + const std::vector &inputs, + const std::vector &outputs, + const std::vector<::sycl::event> &sycl_deps, + ::sycl::event *sycl_event) override; +#endif + +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL + status_t ocl_execute_impl(const stream_t *g_stream, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &cl_deps, cl_event *ret_event) override; +#endif + + static status_t initial_check(const std::shared_ptr &sg, + const std::vector &inputs, + const std::vector &outputs); + + DEF_KERNEL_METHOD_STR(sdp_bwd_primitive_kernel_t) + DNNL_DISALLOW_COPY_AND_ASSIGN(sdp_bwd_primitive_kernel_t) +}; + +} // namespace dnnl_impl +} // namespace graph +} // namespace impl +} // namespace dnnl + +#endif diff --git a/src/graph/backend/dnnl/layout_propagator.cpp b/src/graph/backend/dnnl/layout_propagator.cpp index b6be4d70e9d..62a2fcfdef0 100644 --- a/src/graph/backend/dnnl/layout_propagator.cpp +++ b/src/graph/backend/dnnl/layout_propagator.cpp @@ -1794,7 +1794,7 @@ status_t layout_propagator_for_sdpa(std::shared_ptr &op, const auto &consumer_op = stats_val->get_consumers()[0].get_op(); const logical_tensor_t &consumer_out = consumer_op.get_output_logical_tensor(0); - if (consumer_op.get_kind() == op_kind::dnnl_reshape + if (consumer_op.get_kind() == op_kind::_reshape && ltw(consumer_out).ndims() == 5 && ltw(consumer_out).is_strided()) { const auto &ori_strides = ltw(consumer_out).vstrides(); @@ -1824,6 +1824,147 @@ status_t layout_propagator_for_sdpa(std::shared_ptr &op, return status; } +status_t layout_propagator_for_sdpa_bwd(std::shared_ptr &op, + const dnnl::engine &p_engine, pd_cache_t &pd_cache, + const fpmath_t &fpmath, bool use_block_layout, + subgraph_rewriter_t &rewriter) { + UNUSED(pd_cache); + UNUSED(use_block_layout); + UNUSED(rewriter); + + // Helper: derive a memory desc for a diff output from the corresponding + // forward input logical tensor. If the input layout is already fixed, reuse + // it; otherwise fall back to the canonical acbd format used by sdpa. + auto get_md_for_diff = [](const logical_tensor_t <) { + if (!ltw(lt).is_any()) + return make_dnnl_memory_desc(lt); + return dnnl::memory::desc {ltw(lt).vdims(), + static_cast(ltw(lt).data_type()), + dnnl::memory::format_tag::acbd}; + }; + + status_t status = status::success; + size_t output_idx = 0; + + // diff_query (output 0): propagate layout from query (input 0) + value_ptr diff_query_val = op->get_output_value(output_idx++); + const logical_tensor_t &query_lt = op->get_input_logical_tensor(0); + status = fill_layout_info(diff_query_val, get_md_for_diff(query_lt)); + VCHECK_LAYOUT_PROPAGATOR(status == status::success, status, + "failed to fill layout info for sdpa_bwd diff_query"); + + // diff_key (output 1): propagate layout from key (input 1) + value_ptr diff_key_val = op->get_output_value(output_idx++); + const logical_tensor_t &key_lt = op->get_input_logical_tensor(1); + status = fill_layout_info(diff_key_val, get_md_for_diff(key_lt)); + VCHECK_LAYOUT_PROPAGATOR(status == status::success, status, + "failed to fill layout info for sdpa_bwd diff_key"); + + // diff_value (output 2): propagate layout from value (input 2) + value_ptr diff_value_val = op->get_output_value(output_idx++); + const logical_tensor_t &value_lt = op->get_input_logical_tensor(2); + status = fill_layout_info(diff_value_val, get_md_for_diff(value_lt)); + VCHECK_LAYOUT_PROPAGATOR(status == status::success, status, + "failed to fill layout info for sdpa_bwd diff_value"); + + // scratchpad (output 3): create the pd to get the real scratchpad size + { + const bool with_scale = op->get_attr(op_attr::with_scale); + const auto mask_type = static_cast( + op->get_attr(op_attr::mask_type)); + const bool is_invert_scale + = op->has_attr(op_attr::is_invert_scale) + ? op->get_attr(op_attr::is_invert_scale) + : false; + const bool with_explicit_mask + = mask_type == attn_mask_type::buffer; + + auto md_q = make_dnnl_memory_desc(op->get_input_logical_tensor(0)); + auto md_k = make_dnnl_memory_desc(op->get_input_logical_tensor(1)); + auto md_v = make_dnnl_memory_desc(op->get_input_logical_tensor(2)); + auto md_dst = make_dnnl_memory_desc(op->get_input_logical_tensor(3)); + auto md_diff_dst + = make_dnnl_memory_desc(op->get_input_logical_tensor(5)); + auto md_diff_q = get_md_for_diff(op->get_input_logical_tensor(0)); + auto md_diff_k = get_md_for_diff(op->get_input_logical_tensor(1)); + auto md_diff_v = get_md_for_diff(op->get_input_logical_tensor(2)); + + dnnl::memory::desc md_scale, md_attn_mask, md_dS; + size_t idx = 6; + if (with_scale) + md_scale = make_dnnl_memory_desc( + op->get_input_logical_tensor(idx++)); + if (with_explicit_mask) { + md_attn_mask = make_dnnl_memory_desc( + op->get_input_logical_tensor(idx++)); + if (op->num_outputs() > 4) + md_dS = make_dnnl_memory_desc( + op->get_output_logical_tensor(4)); + } + + const auto &sdpa_fusion_info = op->has_attr(op_attr::fusion_info) + ? op->get_attr(op_attr::fusion_info) + : fusion_info_t(); + dnnl::primitive_attr attr, qk_attr, vs_attr; + if (op->has_attr(op_attr::fusion_info)) { + qk_attr = make_dnnl_sdpa_primitive_attr( + op, sdpa_fusion_info, attr_type_t::QK); + vs_attr = make_dnnl_sdpa_primitive_attr( + op, sdpa_fusion_info, attr_type_t::VS); + } + qk_attr.set_accumulation_mode(str2accumulation_mode( + op->get_attr(op_attr::qk_acc_mode))); + vs_attr.set_accumulation_mode(str2accumulation_mode( + op->get_attr(op_attr::vs_acc_mode))); + attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + attr.set_fpmath_mode( + static_cast(fpmath.mode_)); + + dim_t kv_head_number = op->get_input_logical_tensor(1).dims[1]; + const alg_kind_t softmax_alg + = alg_kind::softmax_accurate_inf_as_zero; + + std::shared_ptr hint_fwd_pd; + status = create_sdpa_pd(hint_fwd_pd, p_engine.get(), md_q.get(), + md_k.get(), md_v.get(), md_dst.get(), md_attn_mask.get(), + md_scale.get(), is_invert_scale, kv_head_number, mask_type, + softmax_alg, impl::prop_kind::forward_training, attr.get(), qk_attr.get(), vs_attr.get()); + VCHECK_LAYOUT_PROPAGATOR(status == status::success, status, + "failed to create hint fwd pd for sdpa_bwd scratchpad"); + + std::shared_ptr sdpa_bwd_pd; + status = create_sdpa_pd(sdpa_bwd_pd, p_engine.get(), md_q.get(), + md_k.get(), md_v.get(), md_dst.get(), md_diff_q.get(), + md_diff_k.get(), md_diff_v.get(), md_diff_dst.get(), + md_dS.get(), md_attn_mask.get(), md_scale.get(), + is_invert_scale, kv_head_number, mask_type, softmax_alg, + attr.get(), hint_fwd_pd.get(), qk_attr.get(), + vs_attr.get()); + VCHECK_LAYOUT_PROPAGATOR(status == status::success, status, + "failed to create pd for sdpa_bwd scratchpad"); + + value_ptr scratchpad_val = op->get_output_value(output_idx++); + dnnl_memory_desc_t cloned_md = nullptr; + dnnl_memory_desc_clone(&cloned_md, sdpa_bwd_pd->scratchpad_md()); + dnnl::memory::desc scratchpad_desc; + scratchpad_desc.reset(cloned_md); + status = fill_layout_info(scratchpad_val, scratchpad_desc); + } + + // diff_mask (output 4, optional) + if (op->num_outputs() > output_idx) { + value_ptr diff_mask_val = op->get_output_value(output_idx); + const logical_tensor_t &diff_mask_lt + = diff_mask_val->get_logical_tensor(); + status = fill_layout_info( + diff_mask_val, make_dnnl_memory_desc(diff_mask_lt)); + VCHECK_LAYOUT_PROPAGATOR(status == status::success, status, + "failed to fill layout info for sdpa_bwd diff_mask"); + } + + return status; +} + status_t layout_propagator_for_host_scalar(std::shared_ptr &op, const dnnl::engine &p_engine, pd_cache_t &pd_cache, const fpmath_t &fpmath, bool use_block_layout, diff --git a/src/graph/backend/dnnl/layout_propagator.hpp b/src/graph/backend/dnnl/layout_propagator.hpp index 8f12fd966ca..6edb1b80d68 100644 --- a/src/graph/backend/dnnl/layout_propagator.hpp +++ b/src/graph/backend/dnnl/layout_propagator.hpp @@ -94,6 +94,7 @@ DECLARE_LAYOUT_PROPAGATOR(groupnorm); DECLARE_LAYOUT_PROPAGATOR(gen_index); DECLARE_LAYOUT_PROPAGATOR(mask); DECLARE_LAYOUT_PROPAGATOR(sdpa); +DECLARE_LAYOUT_PROPAGATOR(sdpa_bwd); DECLARE_LAYOUT_PROPAGATOR(host_scalar); DECLARE_LAYOUT_PROPAGATOR(identity); DECLARE_LAYOUT_PROPAGATOR(gated_mlp); diff --git a/src/graph/backend/dnnl/op_executable.cpp b/src/graph/backend/dnnl/op_executable.cpp index 27b281c49a4..02d7946ac85 100644 --- a/src/graph/backend/dnnl/op_executable.cpp +++ b/src/graph/backend/dnnl/op_executable.cpp @@ -74,6 +74,7 @@ executable_creator_func op_func_t::get_executable_creator(op_kind_t kind) { {_gen_index, executable_creator}, {_mask, executable_creator}, {_sdpa, executable_creator}, + {_sdpa_bwd, executable_creator}, {_host_scalar, executable_creator}, {_identity, executable_creator}, {_dropout, dummy_executable_creator}, @@ -138,6 +139,7 @@ arg_indices_getter_func op_func_t::get_arg_indices_getter(op_kind_t kind) { {_gen_index, genindex_executable_t::get_arg_indices}, {_mask, memory_reparser_t::get_arg_indices}, {_sdpa, sdpa_executable_t::get_arg_indices}, + {_sdpa_bwd, sdpa_bwd_executable_t::get_arg_indices}, {_host_scalar, host_scalar_executable_t::get_arg_indices}, {_identity, memory_reparser_t::get_arg_indices}, {_dropout, dummy_arg_indices_getter}, @@ -201,6 +203,7 @@ layout_propagator_func op_func_t::get_layout_propagator(op_kind_t kind) { {_gen_index, layout_propagator_for_gen_index}, {_mask, layout_propagator_for_mask}, {_sdpa, layout_propagator_for_sdpa}, + {_sdpa_bwd, layout_propagator_for_sdpa_bwd}, {_host_scalar, layout_propagator_for_host_scalar}, {_identity, layout_propagator_for_identity}, {_gated_mlp, layout_propagator_for_gated_mlp}, diff --git a/src/graph/backend/dnnl/passes/compile_ops.cpp b/src/graph/backend/dnnl/passes/compile_ops.cpp index 23bdbe340ea..f5aa642fe58 100644 --- a/src/graph/backend/dnnl/passes/compile_ops.cpp +++ b/src/graph/backend/dnnl/passes/compile_ops.cpp @@ -61,7 +61,15 @@ status_t compile_ops(std::shared_ptr &sg) { status::unimplemented, "failed to create executable for op %s", op->get_name().c_str()); + } else if (cur_op->get_kind() == op_kind::_sdpa_bwd) { + auto sdpa_bwd_exec + = std::dynamic_pointer_cast(exec); + VCHECK_COMPILE_OPS(sdpa_bwd_exec->is_initialized(), + status::unimplemented, + "failed to create executable for op %s", + op->get_name().c_str()); } + sg->execs_.emplace_back(exec); sg->is_constant_.push_back(op->has_attr(op_attr::is_constant) diff --git a/src/graph/backend/dnnl/passes/insert_ops.cpp b/src/graph/backend/dnnl/passes/insert_ops.cpp index 54853ceb7a8..641dc3c7ec4 100644 --- a/src/graph/backend/dnnl/passes/insert_ops.cpp +++ b/src/graph/backend/dnnl/passes/insert_ops.cpp @@ -482,6 +482,35 @@ status_t insert_permute_for_matmul(std::shared_ptr &sg) { return infer_shape(sg); } +status_t insert_permute_for_sdpa_bwd(std::shared_ptr &sg) { + subgraph_rewriter_t rewriter(sg); + + for (auto &cur_op : sg->get_ops()) { + if (cur_op->get_kind() != op_kind::_sdpa_bwd) continue; + + for (size_t i = 0; i <=2; ++i) { + // check if Q,K,V has permute op in front of sdpa_bwd, if yes, + // inserting permute for corresponding dQ, dK, dV to keep the + // consistency of data layout + if (!cur_op->get_input_value(i)->has_producer() + || cur_op->get_input_value(i)->get_producer().get_kind() + != op_kind::_permute) + continue; + op_t &input_permute_op = + cur_op->get_input_value(i)->get_producer(); + auto perm = input_permute_op.get_attr>( + op_attr::permutation); + op_ptr output_permute_op = std::make_shared(op_kind::_permute); + output_permute_op->set_attr>( + op_attr::permutation, perm); + rewriter.insert_op_after(output_permute_op, cur_op, i); + } + } + + rewriter.run(); + return infer_shape(sg); +} + status_t insert_reshape_for_ndx2d_matmul(std::shared_ptr &sg) { using ltw = logical_tensor_wrapper_t; subgraph_rewriter_t rewriter(sg); @@ -639,7 +668,7 @@ status_t insert_reshape_for_sdpa(std::shared_ptr &sg) { auto stats_dims = ltw(cur_op->get_output_logical_tensor(2)).vdims(); dims expected_stats_dims = stats_dims; op_ptr reshape_stats - = std::make_shared(op_kind::dnnl_reshape); + = std::make_shared(op_kind::_reshape); reshape_stats->set_attr(op_attr::special_zero, false); reshape_stats->set_attr>( op_attr::shape, expected_stats_dims); @@ -650,6 +679,126 @@ status_t insert_reshape_for_sdpa(std::shared_ptr &sg) { return infer_shape(sg); } +status_t insert_reshape_for_sdpa_bwd(std::shared_ptr &sg) { + using ltw = logical_tensor_wrapper_t; + + subgraph_rewriter_t rewriter(sg); + + for (auto &cur_op : sg->get_ops()) { + if (cur_op->get_kind() != op_kind::_sdpa_bwd) continue; + + int32_t query_ndims = cur_op->get_input_logical_tensor(0).ndims; + if (query_ndims != 5) continue; + + // Helper lambda: create a reshape op collapsing dims[1]*dims[2] -> -1 + // for a 5D tensor [batch, g, h, seq, d] -> [batch, -1, seq, d] + auto make_reshape_5d_to_4d = [](const dims &in_dims) { + dims out_dims {in_dims[0], -1, in_dims[3], in_dims[4]}; + op_ptr reshape = std::make_shared(op_kind::_reshape); + reshape->set_attr(op_attr::special_zero, false); + reshape->set_attr>(op_attr::shape, out_dims); + return reshape; + }; + + // Insert reshape for Query (input 0) + auto query_dims = ltw(cur_op->get_input_logical_tensor(0)).vdims(); + rewriter.insert_op_before(make_reshape_5d_to_4d(query_dims), cur_op, 0); + + // Insert reshape for Key (input 1) + auto key_dims = ltw(cur_op->get_input_logical_tensor(1)).vdims(); + rewriter.insert_op_before(make_reshape_5d_to_4d(key_dims), cur_op, 1); + + // Insert reshape for Value (input 2) + auto value_dims = ltw(cur_op->get_input_logical_tensor(2)).vdims(); + rewriter.insert_op_before(make_reshape_5d_to_4d(value_dims), cur_op, 2); + + // Insert reshape for dst (input 3, forward output, 5D) + auto dst_dims = ltw(cur_op->get_input_logical_tensor(3)).vdims(); + rewriter.insert_op_before(make_reshape_5d_to_4d(dst_dims), cur_op, 3); + + // Insert reshape for stats (input 4) + auto stats_dims = ltw(cur_op->get_input_logical_tensor(4)).vdims(); + rewriter.insert_op_before(make_reshape_5d_to_4d(stats_dims), cur_op, 4); + + // Insert reshape for diff_dst (input 5, 5D same as dst) + auto diff_dst_dims = ltw(cur_op->get_input_logical_tensor(5)).vdims(); + rewriter.insert_op_before( + make_reshape_5d_to_4d(diff_dst_dims), cur_op, 5); + + size_t index = 6; + // Insert reshape for scale (optional) + if (cur_op->get_attr(op_attr::with_scale)) { + int32_t scale_ndims + = cur_op->get_input_logical_tensor(index).ndims; + if (scale_ndims == 5) { + auto scale_dims + = ltw(cur_op->get_input_logical_tensor(index)).vdims(); + rewriter.insert_op_before( + make_reshape_5d_to_4d(scale_dims), cur_op, index); + } + index += 1; + } + // Insert reshape for mask (optional) + if (cur_op->get_attr(op_attr::mask_type) + == static_cast(attn_mask_type::buffer)) { + int32_t mask_ndims + = cur_op->get_input_logical_tensor(index).ndims; + + if (mask_ndims == 5) { + auto mask_dims + = ltw(cur_op->get_input_logical_tensor(index)).vdims(); + rewriter.insert_op_before( + make_reshape_5d_to_4d(mask_dims), cur_op, index); + } + } + + // Insert reshape for diff_query output (output 0) -> 4D to 5D + auto diff_query_dims = ltw(cur_op->get_output_logical_tensor(0)).vdims(); + const dims &expected_diff_query_dims = diff_query_dims; + op_ptr reshape_diff_query + = std::make_shared(op_kind::_reshape); + reshape_diff_query->set_attr(op_attr::special_zero, false); + reshape_diff_query->set_attr>( + op_attr::shape, expected_diff_query_dims); + rewriter.insert_op_after(reshape_diff_query, cur_op, 0); + + // Insert reshape for diff_key output (output 1) -> 4D to 5D + auto diff_key_dims = ltw(cur_op->get_output_logical_tensor(1)).vdims(); + const dims &expected_diff_key_dims = diff_key_dims; + op_ptr reshape_diff_key = std::make_shared(op_kind::_reshape); + reshape_diff_key->set_attr(op_attr::special_zero, false); + reshape_diff_key->set_attr>( + op_attr::shape, expected_diff_key_dims); + rewriter.insert_op_after(reshape_diff_key, cur_op, 1); + + // Insert reshape for diff_value output (output 2) -> 4D to 5D + auto diff_value_dims = ltw(cur_op->get_output_logical_tensor(2)).vdims(); + const dims &expected_diff_value_dims = diff_value_dims; + op_ptr reshape_diff_value + = std::make_shared(op_kind::_reshape); + reshape_diff_value->set_attr(op_attr::special_zero, false); + reshape_diff_value->set_attr>( + op_attr::shape, expected_diff_value_dims); + rewriter.insert_op_after(reshape_diff_value, cur_op, 2); + + // Insert reshape for diff_mask output (output 4) -> 4D to 5D + if (cur_op->num_outputs() > 4) { + auto diff_mask_dims + = ltw(cur_op->get_output_logical_tensor(4)).vdims(); + const dims &expected_diff_mask_dims = diff_mask_dims; + op_ptr reshape_diff_mask + = std::make_shared(op_kind::_reshape); + reshape_diff_mask->set_attr(op_attr::special_zero, false); + reshape_diff_mask->set_attr>( + op_attr::shape, expected_diff_mask_dims); + rewriter.insert_op_after(reshape_diff_mask, cur_op, 4); + } + } + + rewriter.run(); + return infer_shape(sg); +} + status_t insert_unsqueeze_and_squeeze_for_matmul( std::shared_ptr &sg) { subgraph_rewriter_t rewriter(sg); diff --git a/src/graph/backend/dnnl/passes/insert_ops.hpp b/src/graph/backend/dnnl/passes/insert_ops.hpp index 8724cc3daf3..bfee98bb60e 100644 --- a/src/graph/backend/dnnl/passes/insert_ops.hpp +++ b/src/graph/backend/dnnl/passes/insert_ops.hpp @@ -63,6 +63,14 @@ status_t insert_reshape_for_ndx2d_matmul(std::shared_ptr &sg); /// 2) reshape output from 4D to 5D status_t insert_reshape_for_sdpa(std::shared_ptr &sg); +/// Insert reshape for 5D sdpa_bwd. sdpa_bwd only supports 4D input/output +/// 1) reshape Q/K/V/dst/diff_dst/scale/mask from 5D to 4D +/// 2) reshape diff_query/diff_key/diff_value from 4D to 5D +status_t insert_reshape_for_sdpa_bwd(std::shared_ptr &sg); + +/// Insert permute for dQ,dK,dV based on whether Q,K,V are permuted or not. +status_t insert_permute_for_sdpa_bwd(std::shared_ptr &sg); + // Insert an unsqueeze-squeeze pair for matmul // // The usage of unsqueeze op: diff --git a/src/graph/backend/dnnl/passes/transform.cpp b/src/graph/backend/dnnl/passes/transform.cpp index 27b0dd24039..109bc0b4ac1 100644 --- a/src/graph/backend/dnnl/passes/transform.cpp +++ b/src/graph/backend/dnnl/passes/transform.cpp @@ -3239,7 +3239,7 @@ status_t decompose_softmax(std::shared_ptr &sg) { subgraph_rewriter_t rewriter(sg); for (auto &cur_op : sg->get_ops()) { - if (cur_op->get_kind() != op_kind::dnnl_softmax) continue; + if (cur_op->get_kind() != op_kind::_softmax) continue; if (cur_op->num_outputs() != 3) continue; const auto &dst = cur_op->get_output_value(0); @@ -3253,7 +3253,7 @@ status_t decompose_softmax(std::shared_ptr &sg) { cur_op->connect_output(0, f32_dst); // create reorder op to convert the output to the original data type - auto reorder_op = std::make_shared(op_kind::dnnl_reorder); + auto reorder_op = std::make_shared(op_kind::_reorder); reorder_op->set_attr(op_attr::change_layout, false); reorder_op->add_input(f32_dst); f32_dst->add_consumer(*reorder_op, 0); @@ -3281,7 +3281,7 @@ status_t decompose_softmax(std::shared_ptr &sg) { if (need_reduction) { // create reduce_src op auto reduce_src_op - = std::make_shared(op_kind::dnnl_reduction); + = std::make_shared(op_kind::_reduction); reduce_src_op->set_attr>( op_attr::axes, {cur_op->get_attr(op_attr::axis)}); reduce_src_op->set_attr(op_attr::keep_dims, true); @@ -3300,7 +3300,7 @@ status_t decompose_softmax(std::shared_ptr &sg) { // create reduce_dst op auto reduce_dst_op - = std::make_shared(op_kind::dnnl_reduction); + = std::make_shared(op_kind::_reduction); reduce_dst_op->set_attr>( op_attr::axes, {cur_op->get_attr(op_attr::axis)}); reduce_dst_op->set_attr(op_attr::keep_dims, true); @@ -3322,7 +3322,7 @@ status_t decompose_softmax(std::shared_ptr &sg) { } // create log op - auto log_op = std::make_shared(op_kind::dnnl_eltwise); + auto log_op = std::make_shared(op_kind::_eltwise); log_op->set_attr(op_attr::alg_kind, static_cast(dnnl::algorithm::eltwise_log)); log_op->add_input(reduce_dst_op_out_val); @@ -3336,7 +3336,7 @@ status_t decompose_softmax(std::shared_ptr &sg) { insert_empty_scratchpad(log_op); // create subtract op - auto sub_op = std::make_shared(op_kind::dnnl_binary); + auto sub_op = std::make_shared(op_kind::_binary); sub_op->set_attr(op_attr::alg_kind, static_cast(dnnl::algorithm::binary_sub)); sub_op->add_input(reduce_src_op_out_val); @@ -3357,7 +3357,7 @@ status_t decompose_softmax(std::shared_ptr &sg) { auto reduce_or_reorder_op_out_val = f32_dst; if (need_reduction) { auto reduce_sum_dst_op - = std::make_shared(op_kind::dnnl_reduction); + = std::make_shared(op_kind::_reduction); reduce_sum_dst_op->set_attr>( op_attr::axes, {cur_op->get_attr(op_attr::axis)}); reduce_sum_dst_op->set_attr(op_attr::keep_dims, true); @@ -3378,7 +3378,7 @@ status_t decompose_softmax(std::shared_ptr &sg) { rewriter.to_insert(reduce_sum_dst_op); } else { // create reorder op to convert f32_dst to s8 - auto reorder_s8_op = std::make_shared(op_kind::dnnl_reorder); + auto reorder_s8_op = std::make_shared(op_kind::_reorder); reorder_s8_op->set_attr(op_attr::change_layout, false); reorder_s8_op->add_input(f32_dst); f32_dst->add_consumer(*reorder_s8_op, 0); @@ -3395,7 +3395,7 @@ status_t decompose_softmax(std::shared_ptr &sg) { } // create select op - auto select_op = std::make_shared(op_kind::dnnl_binary); + auto select_op = std::make_shared(op_kind::_binary); select_op->set_attr(op_attr::alg_kind, static_cast(dnnl::algorithm::binary_select)); select_op->add_input(sub_op_out_val); @@ -3413,7 +3413,7 @@ status_t decompose_softmax(std::shared_ptr &sg) { rewriter.to_insert(select_op); // recreate dnnl_softmax with 2 outputs: output and scratchpad - auto new_softmax_op = std::make_shared(op_kind::dnnl_softmax); + auto new_softmax_op = std::make_shared(op_kind::_softmax); new_softmax_op->merge_attributes(cur_op->get_attributes()); src->remove_consumer(*cur_op, 0); @@ -4895,6 +4895,416 @@ status_t fuse_gated_mlp(std::shared_ptr &sg) { return status::success; } +// Fuses a backward SDPA subgraph into a single dnnl_sdpa_bwd op. +// +// Pattern (all optional nodes indicated with []): +// +// [Q,K] → matmul_qk → [scale_pre] → [mask] → sub(stats) → exp (=P) +// | +// [dropout_fwd]→[tc_fwd]→matmul_dv → [dV] +// +// Compute softmax_bwd = Mul(P, dp_corrected), dp_corrected = Sub(dP, correction), +// dP = matmul_do_vt, correction = ReduceSum(Mul(O, dO)) +// matmul_do_vt → [dropout_bwd] → Sub → Mul → softmax_bwd +// / / +// [O,dO] → Mul → ReduceSum ────── P +// +// From softmax_bwd: +// → [scale_post] → [End (dMask)] → [tc_bwd] → matmul_dq +// → matmul_dk → [reducesum] +// +// Resulting dnnl_sdpa_bwd inputs: +// 0:Q 1:K 2:V 3:O(dst) 4:stats 5:dO [6:scale] [7:mask] +// Outputs: +// 0:dQ 1:dK 2:dV 3:scratchpad [4:dMask] +status_t fuse_sdpa_bwd(std::shared_ptr &sg) { + if (sg->get_ops().size() < 13) return status::success; + + // ── Helpers ─────────────────────────────────────────────────────────── + auto is_binary = [](const op_ptr &op, dnnl::algorithm alg) -> bool { + if (op->get_kind() != op_kind::_binary) return false; + return static_cast( + op->get_attr(op_attr::alg_kind)) + == alg; + }; + + auto is_exp = [](const op_ptr &op) -> bool { + if (op->get_kind() != op_kind::_eltwise) return false; + return static_cast( + op->get_attr(op_attr::alg_kind)) + == dnnl::algorithm::eltwise_exp; + }; + + // Walk the sole consumer chain one step; return nullptr if none or >1. + auto sole_consumer = [](const op_ptr &op, size_t out_idx = 0) -> op_ptr { + auto out_val = op->get_output_value(out_idx); + if (!out_val || out_val->get_consumers().size() != 1) return nullptr; + return out_val->get_consumers()[0].get_op().shared_from_this(); + }; + + auto consumers_of + = [](const op_ptr &op, size_t out_idx = 0) -> std::vector { + auto out_val = op->get_output_value(out_idx); + if (!out_val) return {}; + std::vector res; + for (auto &c : out_val->get_consumers()) + res.push_back(c.get_op().shared_from_this()); + return res; + }; + + // ── Main search loop ───────────────────────────────────────────────── + for (auto &cur_op : sg->get_ops()) { + if (cur_op->get_kind() != op_kind::_matmul) continue; + + // Step 1 – walk matmul_qk → [scale_pre] → [mask] → sub → exp + op_ptr matmul_qk = cur_op; + op_ptr scale_pre = nullptr, mask_op = nullptr; + op_ptr sub_op = nullptr, exp_op = nullptr; + + { + op_ptr w = sole_consumer(matmul_qk); + while (w) { + if (is_exp(w)) { + exp_op = w; + break; + } else if (is_binary(w, dnnl::algorithm::binary_mul) + || is_binary(w, dnnl::algorithm::binary_div)) { + if (!scale_pre) scale_pre = w; + } else if (w->get_kind() == op_kind::_mask + || is_binary(w, dnnl::algorithm::binary_add)) { + mask_op = w; + } else if (is_binary(w, dnnl::algorithm::binary_sub)) { + sub_op = w; + } else { + break; + } + w = sole_consumer(w); + } + } + if (!exp_op || !sub_op) + continue; + + // stats tensor feeds Subtract at input 1 + value_ptr stats_val = sub_op->get_input_value(1); + + // Step 2 – classify consumers of exp: matmul_dv branch and + // softmax_bwd (P * dp_corrected). + op_ptr dropout_fwd = nullptr, tc_fwd = nullptr; + op_ptr matmul_dv = nullptr, softmax_bwd = nullptr; + // permute is introduced because matmul_dv = matmul(P_t, dO) + op_ptr permute_p = nullptr; + + for (auto &c : consumers_of(exp_op)) { + if (c->get_kind() == op_kind::_permute) + permute_p = c; + else if (is_binary(c, dnnl::algorithm::binary_mul)) + softmax_bwd = c; + else if (c->get_kind() == op_kind::_reorder) + tc_fwd = c; + else if (c->get_kind() == op_kind::_dropout) + dropout_fwd = c; + } + + // Resolve matmul_dv through optional dropout / typecast + permute chain + if (!permute_p) { + if (dropout_fwd) { + tc_fwd = sole_consumer(dropout_fwd); + if (tc_fwd && tc_fwd->get_kind() == op_kind::_reorder) { + permute_p = sole_consumer(tc_fwd); + } else { + permute_p = dropout_fwd->get_output_value(0) + ->get_consumers()[0] + .get_op() + .shared_from_this(); + } + } else if (tc_fwd) { + permute_p = sole_consumer(tc_fwd); + } else if (permute_p) { + permute_p = sole_consumer(permute_p); + } + } + if (!permute_p || permute_p->get_kind() != op_kind::_permute) + continue; + + matmul_dv = sole_consumer(permute_p); + + if (!matmul_dv || !softmax_bwd) + continue; + + // Optional reduce after matmul_dv (e.g., GQA dV accumulation) + op_ptr reduce_dv = nullptr; + { + auto next = sole_consumer(matmul_dv); + if (next && next->get_kind() == op_kind::_reduction) + reduce_dv = next; + } + + // Step 3 – decode softmax_bwd = Mul(P=exp, dp_corrected) + // dp_corrected = Sub(dP_maybe_dropouted, correction) + // dP_maybe_dropouted = matmul_do_vt, optionally via Dropout + // correction = ReduceSum(o_do) + // o_do = Mul(O, dO) + value_ptr dp_corr_val = softmax_bwd->get_input_value(1); + if (!dp_corr_val->has_producer()) + continue; + + op_ptr dp_corrected_op = dp_corr_val->get_producer().shared_from_this(); + if (!is_binary(dp_corrected_op, dnnl::algorithm::binary_sub)) + continue; + + // dP side (input 0): matmul_v_do, optionally via Dropout + value_ptr dP_val = dp_corrected_op->get_input_value(0); + if (!dP_val->has_producer()) + continue; + + op_ptr dP_prod = dP_val->get_producer().shared_from_this(); + + op_ptr dropout_bwd = nullptr, matmul_vt_do = nullptr; + if (dP_prod->get_kind() == op_kind::_matmul) { + matmul_vt_do = dP_prod; + } else if (dP_prod->get_kind() == op_kind::_dropout) { + dropout_bwd = dP_prod; + value_ptr mm_out = dropout_bwd->get_input_value(0); + if (!mm_out->has_producer()) + continue; + + auto mm_prod = mm_out->get_producer().shared_from_this(); + if (mm_prod->get_kind() != op_kind::_matmul) + continue; + + matmul_vt_do = mm_prod; + } else { + continue; + } + + // get permute before matmul_vt_do + op_ptr permute_v = nullptr; + if (matmul_vt_do->get_input_value(1)->has_producer()) { + permute_v = matmul_vt_do->get_input_value(1)->get_producer().shared_from_this(); + if (permute_v->get_kind() != op_kind::_permute) + continue; + } + + // correction side (input 1): ReduceSum → o_do = Mul(O, dO) + value_ptr corr_val = dp_corrected_op->get_input_value(1); + if (!corr_val->has_producer()) + continue; + op_ptr correction_op = corr_val->get_producer().shared_from_this(); + if (correction_op->get_kind() != op_kind::_reduction) + continue; + + value_ptr o_do_out = correction_op->get_input_value(0); + if (!o_do_out->has_producer()) + continue; + + op_ptr o_do_op = o_do_out->get_producer().shared_from_this(); + if (!is_binary(o_do_op, dnnl::algorithm::binary_mul)) + continue; + + value_ptr O_val = o_do_op->get_input_value(0); // forward output O + value_ptr dO_val = o_do_op->get_input_value(1); // diff_dst dO + value_ptr V_val = permute_v->get_input_value(0); // value tensor V + + // Step 4 – walk forward from softmax_bwd to find: + // [scale_post], [End/dMask], [tc_bwd], matmul_dq, permute + matmul_dk + op_ptr scale_post = nullptr, end_op = nullptr, tc_bwd = nullptr; + op_ptr matmul_dq = nullptr, matmul_dk = nullptr; + // permute_ds is introduced because matmul_dk = matmul(ds_t, Q) + op_ptr permute_ds = nullptr; + + auto classify_sbwd_consumers = [&](const std::vector &cs) { + for (auto &c : cs) { + if (is_binary(c, dnnl::algorithm::binary_mul) + || is_binary(c, dnnl::algorithm::binary_div)) + scale_post = c; + else if (c->get_kind() == op_kind::_identity) + end_op = c; + else if (c->get_kind() == op_kind::_reorder) + tc_bwd = c; + else if (c->get_kind() == op_kind::_matmul) + matmul_dq = c; + else if (c->get_kind() == op_kind::_permute) + permute_ds = c; + } + }; + + classify_sbwd_consumers(consumers_of(softmax_bwd)); + + // If there's a scale_post, also classify its consumers + if (scale_post) classify_sbwd_consumers(consumers_of(scale_post)); + + // If there's a typecast, its consumers are matmul_dq + permute_ds + if (tc_bwd && (!matmul_dq || !permute_ds)) + classify_sbwd_consumers(consumers_of(tc_bwd)); + + if (permute_ds && !matmul_dk) { + auto next = sole_consumer(permute_ds); + if (next && next->get_kind() == op_kind::_matmul) + matmul_dk = next; + } + + if (!matmul_dq || !matmul_dk) + continue; + + // Detect and handle the permute of K that feeds matmul_dq input 1 + // (matmul_dq computes dS * permute(K), where permute transposes K) + op_ptr permute_k = nullptr; + { + auto dq_in1 = matmul_dq->get_input_value(1); + if (dq_in1->has_producer()) { + auto prod = dq_in1->get_producer().shared_from_this(); + if (prod->get_kind() == op_kind::_permute) + permute_k = prod; + } + } + + // Optional transpose_dk before matmul_dk + op_ptr transpose_dk = nullptr; + { + auto next = sole_consumer(matmul_dk); + if (next && next->get_kind() == op_kind::_transpose) + transpose_dk = next; + } + + // Optional reduce after matmul_dk (e.g., GQA dK accumulation) + op_ptr reduce_dk = nullptr; + { + auto next = transpose_dk? sole_consumer(transpose_dk): sole_consumer(matmul_dk); + if (next && next->get_kind() == op_kind::_reduction) + reduce_dk = next; + } + + // ── Step 5: Build dnnl_sdpa_bwd ────────────────────────────────── + subgraph_rewriter_t rewriter(sg); + op_ptr bwd_op = std::make_shared(op_kind::_sdpa_bwd); + + // Attributes + const bool with_scale = (scale_post != nullptr); + bwd_op->set_attr(op_attr::with_scale, with_scale); + if (with_scale) { + auto alg = static_cast( + scale_post->get_attr(op_attr::alg_kind)); + bwd_op->set_attr(op_attr::is_invert_scale, + alg == dnnl::algorithm::binary_div); + } + + int64_t mtype = static_cast(attn_mask_type::undef); + if (mask_op) { + mtype = (mask_op->get_kind() == op_kind::_mask) + ? mask_op->get_attr(op_attr::mask_type) + : static_cast(attn_mask_type::buffer); + } + bwd_op->set_attr(op_attr::mask_type, mtype); + + const std::string qk_acc + = matmul_qk->has_attr(op_attr::accumulation_mode) + ? matmul_qk->get_attr(op_attr::accumulation_mode) + : "strict"; + const std::string vs_acc + = matmul_dv->has_attr(op_attr::accumulation_mode) + ? matmul_dv->get_attr(op_attr::accumulation_mode) + : "strict"; + bwd_op->set_attr(op_attr::qk_acc_mode, qk_acc); + bwd_op->set_attr(op_attr::vs_acc_mode, vs_acc); + + // Connect inputs + // 0: Q + auto Qv = matmul_qk->get_input_value(0); + Qv->remove_consumer(*matmul_qk, 0); + Qv->remove_consumer(*matmul_dk, 1); + bwd_op->connect_input(0, Qv); + + // 1: K (original transposed key from matmul_qk) + auto Kv = matmul_qk->get_input_value(1); + Kv->remove_consumer(*matmul_qk, 1); + if (permute_k) { + Kv->remove_consumer(*permute_k, 0); + } + bwd_op->connect_input(1, Kv); + + + // 2: V + V_val->remove_consumer(*permute_v, 0); + bwd_op->connect_input(2, V_val); + + // 3: O (forward output / dst) + O_val->remove_consumer(*o_do_op, 0); + bwd_op->connect_input(3, O_val); + + // 4: stats + stats_val->remove_consumer(*sub_op, 1); + bwd_op->connect_input(4, stats_val); + + // 5: dO (diff_dst); shared across matmul_dv, matmul_v_do, o_do + dO_val->remove_consumer(*o_do_op, 1); + bwd_op->connect_input(5, dO_val); + + size_t in_idx = 6; + + // 6: scale (optional) + if (with_scale) { + auto sv = scale_post->get_input_value(1); + sv->remove_consumer(*scale_post, 1); + bwd_op->connect_input(in_idx++, sv); + } + + // 7: explicit mask (optional, only for buffer-type masks) + if (mask_op && mask_op->get_kind() != op_kind::_mask) { + auto mv = mask_op->get_input_value(1); + mv->remove_consumer(*mask_op, 1); + bwd_op->connect_input(in_idx++, mv); + } + + // Connect outputs + // 0: dQ + auto dQ_val = matmul_dq->get_output_value(0); + dQ_val->set_producer(*bwd_op); + bwd_op->connect_output(0, dQ_val); + + // 1: dK (possibly through optional reduce) + auto dK_val = reduce_dk ? reduce_dk->get_output_value(0) + : transpose_dk ? transpose_dk->get_output_value(0) : matmul_dk->get_output_value(0); + dK_val->set_producer(*bwd_op); + bwd_op->connect_output(1, dK_val); + + // 2: dV (possibly through optional reduce) + auto dV_val = reduce_dv ? reduce_dv->get_output_value(0) + : matmul_dv->get_output_value(0); + dV_val->set_producer(*bwd_op); + bwd_op->connect_output(2, dV_val); + + // 3: scratchpad + logical_tensor_t lt = empty_logical_tensor_with_default_id(); + auto scratch = std::make_shared(*bwd_op, 3, lt); + scratch->set_data_type(graph::data_type::u8); + bwd_op->connect_output(3, scratch); + + // 4: diff_mask (optional) + if (end_op) { + auto dM_val = end_op->get_output_value(0); + dM_val->set_producer(*bwd_op); + bwd_op->connect_output(4, dM_val); + } + + // Remove all pattern ops + std::vector to_remove = {matmul_qk, sub_op, exp_op, matmul_dv, + matmul_vt_do, o_do_op, correction_op, dp_corrected_op, + softmax_bwd, matmul_dq, matmul_dk, permute_p, permute_ds, permute_v}; + for (auto *opt : {&scale_pre, &mask_op, &dropout_fwd, &tc_fwd, + &dropout_bwd, &scale_post, &end_op, &tc_bwd, + &reduce_dv, &reduce_dk, &permute_k, &transpose_dk}) + if (*opt) to_remove.push_back(*opt); + + for (auto &op : to_remove) + rewriter.to_remove(op); + rewriter.to_insert(bwd_op); + rewriter.run(); + return status::success; + } + + return status::success; +} + } // namespace dnnl_impl } // namespace graph } // namespace impl diff --git a/src/graph/backend/dnnl/passes/transform.hpp b/src/graph/backend/dnnl/passes/transform.hpp index b00a3c8f8d8..629630bdc30 100644 --- a/src/graph/backend/dnnl/passes/transform.hpp +++ b/src/graph/backend/dnnl/passes/transform.hpp @@ -301,6 +301,8 @@ status_t fuse_implicit_causal_mask(std::shared_ptr &sg); /// This pass will transform the sdpa subgraph into a dnnl_sdpa op. status_t fuse_sdpa(std::shared_ptr &sg); +/// This pass will transform the sdpa bwd subgraph into a dnnl_sdpa_bwd op. +status_t fuse_sdpa_bwd(std::shared_ptr &sg); /// This pass will transform the gated mlp subgraph into a _gated_mlp op. status_t fuse_gated_mlp(std::shared_ptr &sg); diff --git a/src/graph/backend/dnnl/patterns/sdp.cpp b/src/graph/backend/dnnl/patterns/sdp.cpp index 4e0a9a5cce2..531dd0f1604 100644 --- a/src/graph/backend/dnnl/patterns/sdp.cpp +++ b/src/graph/backend/dnnl/patterns/sdp.cpp @@ -18,6 +18,7 @@ #include "graph/backend/dnnl/kernels/large_partition.hpp" #include "graph/backend/dnnl/kernels/matmul.hpp" #include "graph/backend/dnnl/kernels/mqa.hpp" +#include "graph/backend/dnnl/kernels/sdp_bwd.hpp" #include "graph/backend/dnnl/patterns/fusions.hpp" #include "graph/backend/dnnl/patterns/pattern_matcher_pass.hpp" @@ -197,16 +198,26 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_backward_fusion) graph::op_kind::MatMul, {in_edge(0, exp, 0)}); auto matmul_v_do = pgraph->append_op(graph::op_kind::MatMul); + // Decompose softmax_bwd: dS = P * (dP - ReduceSum(O * dO)) + // where P = exp, dP = matmul_v_do output + auto o_do = pgraph->append_op(graph::op_kind::Multiply); + auto correction = pgraph->append_op( + graph::op_kind::ReduceSum, {in_edge(0, o_do, 0)}); + auto dp_corrected + = pgraph->append_op(graph::op_kind::Subtract, + {in_edge(0, matmul_v_do, 0), + in_edge(1, correction, 0)}); auto softmax_bwd = pgraph->append_op( - graph::op_kind::SoftMaxBackward, - {in_edge(0, matmul_v_do, 0), in_edge(1, exp, 0)}); + graph::op_kind::Multiply, + {in_edge(0, exp, 0), in_edge(1, dp_corrected, 0)}); auto ds = pgraph->append_alternation( {graph::op_kind::Multiply, graph::op_kind::Divide}, {in_edge(0, softmax_bwd, 0)}); - auto matmul_dq = pgraph->append_op( - graph::op_kind::MatMul, {in_edge(0, ds, 0)}); auto matmul_dk = pgraph->append_op( graph::op_kind::MatMul, {in_edge(0, ds, 0)}); + auto matmul_dq = pgraph->append_op( + graph::op_kind::MatMul, {in_edge(0, ds, 0)}); + // Q is a shared input for matmul_qk and matmul_dk pgraph->create_input_port(0, matmul_qk, 0); pgraph->create_input_port(0, matmul_dk, 1); @@ -216,6 +227,8 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_backward_fusion) // dO is a shared input for matmul_dv and matmul_v_do pgraph->create_input_port(2, matmul_dv, 1); pgraph->create_input_port(2, matmul_v_do, 0); + // dO is also shared input for o_do (O * dO correction) + pgraph->create_input_port(2, o_do, 1); }) .set_attr("FCreatePattern", [](const std::shared_ptr &pgraph) -> void { @@ -233,18 +246,28 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_backward_fusion) graph::op_kind::MatMul, {in_edge(0, tc, 0)}); auto matmul_v_do = pgraph->append_op(graph::op_kind::MatMul); + // Decompose softmax_bwd: dS = P * (dP - ReduceSum(O * dO)) + // where P = exp, dP = matmul_v_do output + auto o_do = pgraph->append_op(graph::op_kind::Multiply); + auto correction = pgraph->append_op( + graph::op_kind::ReduceSum, {in_edge(0, o_do, 0)}); + auto dp_corrected + = pgraph->append_op(graph::op_kind::Subtract, + {in_edge(0, matmul_v_do, 0), + in_edge(1, correction, 0)}); auto softmax_bwd = pgraph->append_op( - graph::op_kind::SoftMaxBackward, - {in_edge(0, matmul_v_do, 0), in_edge(1, exp, 0)}); + graph::op_kind::Multiply, + {in_edge(0, exp, 0), in_edge(1, dp_corrected, 0)}); auto ds = pgraph->append_alternation( {graph::op_kind::Multiply, graph::op_kind::Divide}, {in_edge(0, softmax_bwd, 0)}); auto tc2 = pgraph->append_op( graph::op_kind::TypeCast, {in_edge(0, ds, 0)}); - auto matmul_dq = pgraph->append_op( - graph::op_kind::MatMul, {in_edge(0, tc2, 0)}); + auto matmul_dk = pgraph->append_op( graph::op_kind::MatMul, {in_edge(0, tc2, 0)}); + auto matmul_dq = pgraph->append_op( + graph::op_kind::MatMul, {in_edge(0, tc2, 0)}); // Q is a shared input for matmul_qk and matmul_dk pgraph->create_input_port(0, matmul_qk, 0); pgraph->create_input_port(0, matmul_dk, 1); @@ -254,6 +277,8 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_backward_fusion) // dO is a shared input for matmul_dv and matmul_v_do pgraph->create_input_port(2, matmul_dv, 1); pgraph->create_input_port(2, matmul_v_do, 0); + // dO is also shared input for o_do (O * dO correction) + pgraph->create_input_port(2, o_do, 1); }) .set_attr("FCreatePattern", [](const std::shared_ptr &pgraph) -> void { @@ -275,9 +300,18 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_backward_fusion) = pgraph->append_op(graph::op_kind::MatMul); auto dropout2 = pgraph->append_op(graph::op_kind::Dropout, {in_edge(0, matmul_v_do, 0)}); + // Decompose softmax_bwd: dS = P * (dP - ReduceSum(O * dO)) + // where P = exp, dP = dropout2 output + auto o_do = pgraph->append_op(graph::op_kind::Multiply); + auto correction = pgraph->append_op( + graph::op_kind::ReduceSum, {in_edge(0, o_do, 0)}); + auto dp_corrected + = pgraph->append_op(graph::op_kind::Subtract, + {in_edge(0, dropout2, 0), + in_edge(1, correction, 0)}); auto softmax_bwd = pgraph->append_op( - graph::op_kind::SoftMaxBackward, - {in_edge(0, dropout2, 0), in_edge(1, exp, 0)}); + graph::op_kind::Multiply, + {in_edge(0, exp, 0), in_edge(1, dp_corrected, 0)}); auto ds = pgraph->append_alternation( {graph::op_kind::Multiply, graph::op_kind::Divide}, {in_edge(0, softmax_bwd, 0)}); @@ -294,6 +328,8 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_backward_fusion) // dO is a shared input for matmul_dv and matmul_v_do pgraph->create_input_port(2, matmul_dv, 1); pgraph->create_input_port(2, matmul_v_do, 0); + // dO is also shared input for o_do (O * dO correction) + pgraph->create_input_port(2, o_do, 1); }) .set_attr("FCreatePattern", [](const std::shared_ptr &pgraph) -> void { @@ -316,9 +352,18 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_backward_fusion) = pgraph->append_op(graph::op_kind::MatMul); auto dropout2 = pgraph->append_op(graph::op_kind::Dropout, {in_edge(0, matmul_v_do, 0)}); + // Decompose softmax_bwd: dS = P * (dP - ReduceSum(O * dO)) + // where P = exp, dP = dropout2 output + auto o_do = pgraph->append_op(graph::op_kind::Multiply); + auto correction = pgraph->append_op( + graph::op_kind::ReduceSum, {in_edge(0, o_do, 0)}); + auto dp_corrected + = pgraph->append_op(graph::op_kind::Subtract, + {in_edge(0, dropout2, 0), + in_edge(1, correction, 0)}); auto softmax_bwd = pgraph->append_op( - graph::op_kind::SoftMaxBackward, - {in_edge(0, dropout2, 0), in_edge(1, exp, 0)}); + graph::op_kind::Multiply, + {in_edge(0, exp, 0), in_edge(1, dp_corrected, 0)}); auto ds = pgraph->append_alternation( {graph::op_kind::Multiply, graph::op_kind::Divide}, {in_edge(0, softmax_bwd, 0)}); @@ -337,6 +382,8 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_backward_fusion) // dO is a shared input for matmul_dv and matmul_v_do pgraph->create_input_port(2, matmul_dv, 1); pgraph->create_input_port(2, matmul_v_do, 0); + // dO is also shared input for o_do (O * dO correction) + pgraph->create_input_port(2, o_do, 1); }) .set_attr("FCreatePattern", [](const std::shared_ptr &pgraph) -> void { @@ -354,14 +401,18 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_backward_fusion) graph::op_kind::MatMul, {in_edge(0, exp, 0)}); auto matmul_v_do = pgraph->append_op(graph::op_kind::MatMul); + // Decompose softmax_bwd: dS = P * (dP - ReduceSum(O * dO)) + // where P = exp, dP = matmul_v_do output + auto o_do = pgraph->append_op(graph::op_kind::Multiply); + auto correction = pgraph->append_op( + graph::op_kind::ReduceSum, {in_edge(0, o_do, 0)}); + auto dp_corrected + = pgraph->append_op(graph::op_kind::Subtract, + {in_edge(0, matmul_v_do, 0), + in_edge(1, correction, 0)}); auto softmax_bwd = pgraph->append_op( - graph::op_kind::SoftMaxBackward, - {in_edge(0, matmul_v_do, 0), in_edge(1, exp, 0)}); - - // End op to mark dMask as output - pgraph->append_op( - graph::op_kind::End, {in_edge(0, softmax_bwd, 0)}); - + graph::op_kind::Multiply, + {in_edge(0, exp, 0), in_edge(1, dp_corrected, 0)}); auto ds = pgraph->append_alternation( {graph::op_kind::Multiply, graph::op_kind::Divide}, {in_edge(0, softmax_bwd, 0)}); @@ -378,6 +429,8 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_backward_fusion) // dO is a shared input for matmul_dv and matmul_v_do pgraph->create_input_port(2, matmul_dv, 1); pgraph->create_input_port(2, matmul_v_do, 0); + // dO is also shared input for o_do (O * dO correction) + pgraph->create_input_port(2, o_do, 1); }) .set_attr("FCreatePattern", [](const std::shared_ptr &pgraph) -> void { @@ -395,9 +448,18 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_backward_fusion) graph::op_kind::MatMul, {in_edge(0, tc, 0)}); auto matmul_v_do = pgraph->append_op(graph::op_kind::MatMul); + // dedS = P * (dP - ReduceSum(O * dO)) + // where P = exp, dP = matmul_v_do output + auto o_do = pgraph->append_op(graph::op_kind::Multiply); + auto correction = pgraph->append_op( + graph::op_kind::ReduceSum, {in_edge(0, o_do, 0)}); + auto dp_corrected + = pgraph->append_op(graph::op_kind::Subtract, + {in_edge(0, matmul_v_do, 0), + in_edge(1, correction, 0)}); auto softmax_bwd = pgraph->append_op( - graph::op_kind::SoftMaxBackward, - {in_edge(0, matmul_v_do, 0), in_edge(1, exp, 0)}); + graph::op_kind::Multiply, + {in_edge(0, exp, 0), in_edge(1, dp_corrected, 0)}); // End op to mark dMask as output pgraph->append_op( @@ -421,9 +483,11 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_backward_fusion) // dO is a shared input for matmul_dv and matmul_v_do pgraph->create_input_port(2, matmul_dv, 1); pgraph->create_input_port(2, matmul_v_do, 0); + // dO is also shared input for o_do (O * dO correction) + pgraph->create_input_port(2, o_do, 1); }) .set_attr("FCreateKernel", []() -> kernel_ptr { - return std::make_shared(); + return std::make_shared(); }); DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_gqa_fusion) @@ -492,16 +556,25 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_gqa_backward_fusion) {in_edge(0, matmul_dv, 0)}); auto matmul_v_do = pgraph->append_op(graph::op_kind::MatMul); + // Decompose softmax_bwd: dS = P * (dP - ReduceSum(O * dO)) + // where P = exp, dP = matmul_v_do output + auto o_do = pgraph->append_op(graph::op_kind::Multiply); + auto correction = pgraph->append_op( + graph::op_kind::ReduceSum, {in_edge(0, o_do, 0)}); + auto dp_corrected + = pgraph->append_op(graph::op_kind::Subtract, + {in_edge(0, matmul_v_do, 0), + in_edge(1, correction, 0)}); auto softmax_bwd = pgraph->append_op( - graph::op_kind::SoftMaxBackward, - {in_edge(0, matmul_v_do, 0), in_edge(1, exp, 0)}); + graph::op_kind::Multiply, + {in_edge(0, exp, 0), in_edge(1, dp_corrected, 0)}); auto ds = pgraph->append_alternation( {graph::op_kind::Multiply, graph::op_kind::Divide}, {in_edge(0, softmax_bwd, 0)}); - auto matmul_dq = pgraph->append_op( - graph::op_kind::MatMul, {in_edge(0, ds, 0)}); auto matmul_dk = pgraph->append_op( graph::op_kind::MatMul, {in_edge(0, ds, 0)}); + auto matmul_dq = pgraph->append_op( + graph::op_kind::MatMul, {in_edge(0, ds, 0)}); // reduction_dk pgraph->append_op(graph::op_kind::ReduceSum, {in_edge(0, matmul_dk, 0)}); @@ -514,6 +587,8 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_gqa_backward_fusion) // dO is a shared input for matmul_dv and matmul_v_do pgraph->create_input_port(2, matmul_dv, 1); pgraph->create_input_port(2, matmul_v_do, 0); + // dO is also shared input for o_do (O * dO correction) + pgraph->create_input_port(2, o_do, 1); }) .set_attr("FCreatePattern", [](const std::shared_ptr &pgraph) -> void { @@ -534,18 +609,28 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_gqa_backward_fusion) {in_edge(0, matmul_dv, 0)}); auto matmul_v_do = pgraph->append_op(graph::op_kind::MatMul); + // dS = P * (dP - ReduceSum(O * dO)) + // where P = exp, dP = matmul_v_do output + auto o_do = pgraph->append_op(graph::op_kind::Multiply); + auto correction = pgraph->append_op( + graph::op_kind::ReduceSum, {in_edge(0, o_do, 0)}); + auto dp_corrected + = pgraph->append_op(graph::op_kind::Subtract, + {in_edge(0, matmul_v_do, 0), + in_edge(1, correction, 0)}); auto softmax_bwd = pgraph->append_op( - graph::op_kind::SoftMaxBackward, - {in_edge(0, matmul_v_do, 0), in_edge(1, exp, 0)}); + graph::op_kind::Multiply, + {in_edge(0, exp, 0), in_edge(1, dp_corrected, 0)}); auto ds = pgraph->append_alternation( {graph::op_kind::Multiply, graph::op_kind::Divide}, {in_edge(0, softmax_bwd, 0)}); auto tc2 = pgraph->append_op( graph::op_kind::TypeCast, {in_edge(0, ds, 0)}); - auto matmul_dq = pgraph->append_op( - graph::op_kind::MatMul, {in_edge(0, tc2, 0)}); auto matmul_dk = pgraph->append_op( graph::op_kind::MatMul, {in_edge(0, tc2, 0)}); + auto matmul_dq = pgraph->append_op( + graph::op_kind::MatMul, {in_edge(0, tc2, 0)}); + // reduction_dk pgraph->append_op(graph::op_kind::ReduceSum, {in_edge(0, matmul_dk, 0)}); @@ -558,6 +643,8 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_gqa_backward_fusion) // dO is a shared input for matmul_dv and matmul_v_do pgraph->create_input_port(2, matmul_dv, 1); pgraph->create_input_port(2, matmul_v_do, 0); + // dO is also shared input for o_do (O * dO correction) + pgraph->create_input_port(2, o_do, 1); }) .set_attr("FCreatePattern", [](const std::shared_ptr &pgraph) -> void { @@ -582,9 +669,18 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_gqa_backward_fusion) = pgraph->append_op(graph::op_kind::MatMul); auto dropout2 = pgraph->append_op(graph::op_kind::Dropout, {in_edge(0, matmul_v_do, 0)}); + // dS = P * (dP - ReduceSum(O * dO)) + // where P = exp, dP = dropout2 output + auto o_do = pgraph->append_op(graph::op_kind::Multiply); + auto correction = pgraph->append_op( + graph::op_kind::ReduceSum, {in_edge(0, o_do, 0)}); + auto dp_corrected + = pgraph->append_op(graph::op_kind::Subtract, + {in_edge(0, dropout2, 0), + in_edge(1, correction, 0)}); auto softmax_bwd = pgraph->append_op( - graph::op_kind::SoftMaxBackward, - {in_edge(0, dropout2, 0), in_edge(1, exp, 0)}); + graph::op_kind::Multiply, + {in_edge(0, exp, 0), in_edge(1, dp_corrected, 0)}); auto ds = pgraph->append_alternation( {graph::op_kind::Multiply, graph::op_kind::Divide}, {in_edge(0, softmax_bwd, 0)}); @@ -604,6 +700,8 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_gqa_backward_fusion) // dO is a shared input for matmul_dv and matmul_v_do pgraph->create_input_port(2, matmul_dv, 1); pgraph->create_input_port(2, matmul_v_do, 0); + // dO is also shared input for o_do (O * dO correction) + pgraph->create_input_port(2, o_do, 1); }) .set_attr("FCreatePattern", [](const std::shared_ptr &pgraph) @@ -629,9 +727,18 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_gqa_backward_fusion) = pgraph->append_op(graph::op_kind::MatMul); auto dropout2 = pgraph->append_op(graph::op_kind::Dropout, {in_edge(0, matmul_v_do, 0)}); + // dS = P * (dP - ReduceSum(O * dO)) + // where P = exp, dP = dropout2 output + auto o_do = pgraph->append_op(graph::op_kind::Multiply); + auto correction = pgraph->append_op( + graph::op_kind::ReduceSum, {in_edge(0, o_do, 0)}); + auto dp_corrected + = pgraph->append_op(graph::op_kind::Subtract, + {in_edge(0, dropout2, 0), + in_edge(1, correction, 0)}); auto softmax_bwd = pgraph->append_op( - graph::op_kind::SoftMaxBackward, - {in_edge(0, dropout2, 0), in_edge(1, exp, 0)}); + graph::op_kind::Multiply, + {in_edge(0, exp, 0), in_edge(1, dp_corrected, 0)}); auto ds = pgraph->append_alternation( {graph::op_kind::Multiply, graph::op_kind::Divide}, {in_edge(0, softmax_bwd, 0)}); @@ -653,6 +760,8 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_gqa_backward_fusion) // dO is a shared input for matmul_dv and matmul_v_do pgraph->create_input_port(2, matmul_dv, 1); pgraph->create_input_port(2, matmul_v_do, 0); + // dO is also shared input for o_do (O * dO correction) + pgraph->create_input_port(2, o_do, 1); }) .set_attr("FCreatePattern", [](const std::shared_ptr &pgraph) -> void { @@ -672,9 +781,19 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_gqa_backward_fusion) {in_edge(0, matmul_dv, 0)}); auto matmul_v_do = pgraph->append_op(graph::op_kind::MatMul); + // dS = P * (dP - ReduceSum(O * dO)) + // where P = exp, dP = matmul_v_do output + auto o_do = pgraph->append_op(graph::op_kind::Multiply); + auto correction = pgraph->append_op( + graph::op_kind::ReduceSum, {in_edge(0, o_do, 0)}); + auto dp_corrected + = pgraph->append_op(graph::op_kind::Subtract, + {in_edge(0, matmul_v_do, 0), + in_edge(1, correction, 0)}); + auto softmax_bwd = pgraph->append_op( - graph::op_kind::SoftMaxBackward, - {in_edge(0, matmul_v_do, 0), in_edge(1, exp, 0)}); + graph::op_kind::Multiply, + {in_edge(0, exp, 0), in_edge(1, dp_corrected, 0)}); // End op to mark dMask as output pgraph->append_op( @@ -699,6 +818,8 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_gqa_backward_fusion) // dO is a shared input for matmul_dv and matmul_v_do pgraph->create_input_port(2, matmul_dv, 1); pgraph->create_input_port(2, matmul_v_do, 0); + // dO is also shared input for o_do (O * dO correction) + pgraph->create_input_port(2, o_do, 1); }) .set_attr("FCreatePattern", [](const std::shared_ptr &pgraph) -> void { @@ -719,9 +840,18 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_gqa_backward_fusion) {in_edge(0, matmul_dv, 0)}); auto matmul_v_do = pgraph->append_op(graph::op_kind::MatMul); + // dS = P * (dP - ReduceSum(O * dO)) + // where P = exp, dP = matmul_v_do output + auto o_do = pgraph->append_op(graph::op_kind::Multiply); + auto correction = pgraph->append_op( + graph::op_kind::ReduceSum, {in_edge(0, o_do, 0)}); + auto dp_corrected + = pgraph->append_op(graph::op_kind::Subtract, + {in_edge(0, matmul_v_do, 0), + in_edge(1, correction, 0)}); auto softmax_bwd = pgraph->append_op( - graph::op_kind::SoftMaxBackward, - {in_edge(0, matmul_v_do, 0), in_edge(1, exp, 0)}); + graph::op_kind::Multiply, + {in_edge(0, exp, 0), in_edge(1, dp_corrected, 0)}); // End op to mark dMask as output pgraph->append_op( @@ -748,9 +878,11 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_gqa_backward_fusion) // dO is a shared input for matmul_dv and matmul_v_do pgraph->create_input_port(2, matmul_dv, 1); pgraph->create_input_port(2, matmul_v_do, 0); + // dO is also shared input for o_do (O * dO correction) + pgraph->create_input_port(2, o_do, 1); }) .set_attr("FCreateKernel", []() -> kernel_ptr { - return std::make_shared(); + return std::make_shared(); }); DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_jax_fusion) diff --git a/src/graph/interface/c_types_map.hpp b/src/graph/interface/c_types_map.hpp index e39eb9fc8bc..17ca062f85f 100644 --- a/src/graph/interface/c_types_map.hpp +++ b/src/graph/interface/c_types_map.hpp @@ -270,6 +270,7 @@ const op_kind_t _host_scalar = 1070; const op_kind_t _identity = 1071; const op_kind_t _dropout = 1072; const op_kind_t _gated_mlp = 1073; +const op_kind_t _sdpa_bwd = 1074; } // namespace op_kind using op_attr_t = typename std::underlying_type::type; diff --git a/src/graph/interface/op.hpp b/src/graph/interface/op.hpp index 59c9df35945..ab61153ac7d 100644 --- a/src/graph/interface/op.hpp +++ b/src/graph/interface/op.hpp @@ -546,6 +546,7 @@ struct dnnl_graph_op : public std::enable_shared_from_this { CASE(_identity); CASE(_dropout); CASE(_gated_mlp); + CASE(_sdpa_bwd); default: return "undefined_op"; } #undef CASE diff --git a/src/graph/interface/op_def.hpp b/src/graph/interface/op_def.hpp index b4dbf3597d0..0737098d49c 100644 --- a/src/graph/interface/op_def.hpp +++ b/src/graph/interface/op_def.hpp @@ -2397,6 +2397,38 @@ DNNL_GRAPH_OP_SCHEMA(_gated_mlp, 1, .set_attr(op_attr::alg_kind, true, attribute_kind::i) .set_shape_inference_function(infer_gated_mlp_output_shape)) +// Backward op for SDPA +DNNL_GRAPH_OP_SCHEMA(_sdpa_bwd, 1, + op_schema_t() + .set_inputs_option(op_schema_t::param_num_option::variadic) + .set_outputs_option(op_schema_t::param_num_option::optional) + // Inputs: query, key, value, dst, diff_dst, [dS], [scale], [mask] + .set_num_inputs(std::set({5, 32})) + .set_num_outputs(std::set({4, 5})) + .set_input(0, "query") + .set_input(1, "key") + .set_input(2, "value") + .set_input(3, "dst") + .set_input(4, "stats") + .set_input(5, "diff_dst") + .set_input(6, "scale") // optional + .set_input(7, "mask") // optional + // Outputs: diff_query, diff_key, diff_value, scratchpad, diff_mask + .set_output(0, "diff_query") + .set_output(1, "diff_key") + .set_output(2, "diff_value") + .set_output(3, "scratchpad") + .set_output(4, "diff_mask") // optional + .set_attr(op_attr::fusion_info, false, + attribute_kind::fusion_info) + .set_attr(op_attr::with_scale, true, attribute_kind::b) + .set_attr(op_attr::is_invert_scale, false, attribute_kind::b, + false) + .set_attr(op_attr::mask_type, true, attribute_kind::i) + .set_attr(op_attr::qk_acc_mode, true, attribute_kind::s) + .set_attr(op_attr::vs_acc_mode, true, attribute_kind::s) + .set_shape_inference_function(infer_dnnl_sdpa_bwd_output_shape)) + } // namespace graph } // namespace impl } // namespace dnnl diff --git a/src/graph/interface/opset.hpp b/src/graph/interface/opset.hpp index 6fea9a203ac..47b348c9186 100644 --- a/src/graph/interface/opset.hpp +++ b/src/graph/interface/opset.hpp @@ -197,6 +197,7 @@ class opset_v1_t { fn(get_op_schema()); fn(get_op_schema()); fn(get_op_schema()); + fn(get_op_schema()); } }; diff --git a/src/graph/interface/shape_infer.cpp b/src/graph/interface/shape_infer.cpp index 85415b845e5..27af66d44c7 100644 --- a/src/graph/interface/shape_infer.cpp +++ b/src/graph/interface/shape_infer.cpp @@ -2501,6 +2501,61 @@ status_t infer_dnnl_softmax_output_shape(op_t *n, return status::success; } +status_t infer_dnnl_sdpa_bwd_output_shape(op_t *n, + std::vector &inputs, + std::vector &outputs) { + // [batch_size, num_heads_q, seq_len_q, head_size_qk] + auto query = ltw(inputs[0]); + // [batch_size, num_heads_q, head_size_qk, seq_len_kv,] + auto key = ltw(inputs[1]); + // [batch_size, num_heads_v, seq_len_kv, head_size_v] + auto value = ltw(inputs[2]); + + auto dquery = ltw(outputs[0]); + auto dkey = ltw(outputs[1]); + auto dvalue = ltw(outputs[2]); + + if (dquery.ndims() != -1) { + VCHECK_INVALID_SHAPE(validate(dquery.vdims(), query.vdims()), + "%s, inferred out shape and output shape are not compatible", + op_t::kind2str(n->get_kind()).c_str()); + } + set_shape_and_strides(*outputs[0], query.vdims()); + + if (dkey.ndims() != -1) { + VCHECK_INVALID_SHAPE(validate(dkey.vdims(), key.vdims()), + "%s, inferred out shape and output shape are not compatible", + op_t::kind2str(n->get_kind()).c_str()); + } + set_shape_and_strides(*outputs[1], key.vdims()); + + if (dvalue.ndims() != -1) { + VCHECK_INVALID_SHAPE(validate(dvalue.vdims(), value.vdims()), + "%s, inferred out shape and output shape are not compatible", + op_t::kind2str(n->get_kind()).c_str()); + } + set_shape_and_strides(*outputs[2], value.vdims()); + + if (outputs.size() > 4) { + // dmask exists + auto dmask = ltw(outputs[4]); + dims inferred_dmask_shape = query.vdims(); + size_t ndims = query.ndims(); + // [batch_size, num_heads_q, seq_len_q, seq_len_kv] + inferred_dmask_shape[ndims - 1] = value.vdims()[ndims - 1]; + + if (dmask.ndims() != -1) { + VCHECK_INVALID_SHAPE(validate(inferred_dmask_shape, dmask.vdims()), + "%s, given dmask shape is not compatible with inferred", + op_t::kind2str(n->get_kind()).c_str()); + } + + set_shape_and_strides(*outputs[4], inferred_dmask_shape); + } + + return status::success; +} + } // namespace graph } // namespace impl } // namespace dnnl diff --git a/src/graph/interface/shape_infer.hpp b/src/graph/interface/shape_infer.hpp index 97cf18be80d..480c06736b3 100644 --- a/src/graph/interface/shape_infer.hpp +++ b/src/graph/interface/shape_infer.hpp @@ -303,6 +303,10 @@ status_t infer_dnnl_sdpa_output_shape(op_t *n, std::vector &inputs, std::vector &outputs); +status_t infer_dnnl_sdpa_bwd_output_shape(op_t *n, + std::vector &inputs, + std::vector &outputs); + status_t infer_dnnl_host_scalar_output_shape(op_t *n, std::vector &inputs, std::vector &outputs); From 4798b551570ec2c9caab46b397b537be069af653 Mon Sep 17 00:00:00 2001 From: "Bao, Yixin" Date: Fri, 13 Mar 2026 04:59:00 +0000 Subject: [PATCH 21/23] tests, examples: update sdpa training bwd cases --- examples/graph/gqa_training.cpp | 37 +- tests/benchdnn/graph/deserialize.cpp | 32 +- .../graph/complex_fusion/harness_mha_all | 4 +- .../graph/complex_fusion/harness_mha_ci | 5 +- .../gqa-plain-training-backward-bf16-f32.json | 1015 ++++++++++------- .../mha/gqa-plain-training-backward-f32.json | 533 ++++++--- ...sdpa-plain-training-backward-bf16-f32.json | 489 +++++--- .../mha/sdpa-plain-training-backward-f32.json | 373 ++++-- 8 files changed, 1705 insertions(+), 783 deletions(-) diff --git a/examples/graph/gqa_training.cpp b/examples/graph/gqa_training.cpp index fb7b76dc570..e687924f554 100644 --- a/examples/graph/gqa_training.cpp +++ b/examples/graph/gqa_training.cpp @@ -317,13 +317,33 @@ bool bench_gqa_backward(engine::kind ekind, logical_tensor::data_type dt, bmm_do_v.add_inputs({doutput, value}); bmm_do_v.add_outputs({dprobs}); - // compute dmasked_score = dsoftmax(dprobs) + // compute dmasked_score = P * (dprobs - ReduceSum(O * dO)) + // decomposed softmax backward: dS = P * (dP - rowsum(O * dO)) + auto o_do_out + = logical_tensor(id++, dt_inter, output_sz, layout_type::strided); + auto o_do_mul = op(id++, op::kind::Multiply, "mul_o_do"); + o_do_mul.add_inputs({output, doutput}); + o_do_mul.add_outputs({o_do_out}); + + auto correction_out + = logical_tensor(id++, dt_inter, stats_sz, layout_type::strided); + auto correction = op(id++, op::kind::ReduceSum, "reducesum_correction"); + correction.set_attr>(op::attr::axes, {4}); + correction.set_attr(op::attr::keep_dims, true); + correction.add_inputs({o_do_out}); + correction.add_outputs({correction_out}); + + auto dp_corrected_out + = logical_tensor(id++, dt_inter, score_sz, layout_type::strided); + auto dp_corrected_op = op(id++, op::kind::Subtract, "sub_dp_corrected"); + dp_corrected_op.add_inputs({dprobs, correction_out}); + dp_corrected_op.add_outputs({dp_corrected_out}); + auto dmasked_score = logical_tensor(id++, dt_inter, score_sz, layout_type::strided); - auto softmax_grad = op(id++, op::kind::SoftMaxBackward, "softmax_bwd"); - softmax_grad.set_attr(op::attr::axis, -1); - softmax_grad.add_inputs({dprobs, probs}); - softmax_grad.add_outputs({dmasked_score}); + auto softmax_bwd_mul = op(id++, op::kind::Multiply, "mul_softmax_bwd"); + softmax_bwd_mul.add_inputs({probs, dp_corrected_out}); + softmax_bwd_mul.add_outputs({dmasked_score}); // compute dscored_score = dmasked_score / scale auto dscaled_score @@ -372,10 +392,13 @@ bool bench_gqa_backward(engine::kind ekind, logical_tensor::data_type dt, gqa_bwd.add_op(exp); gqa_bwd.add_op(bmm_p_do); gqa_bwd.add_op(bmm_do_v); - gqa_bwd.add_op(softmax_grad); + gqa_bwd.add_op(o_do_mul); + gqa_bwd.add_op(correction); + gqa_bwd.add_op(dp_corrected_op); + gqa_bwd.add_op(softmax_bwd_mul); gqa_bwd.add_op(scale_div2); - gqa_bwd.add_op(bmm_dscaled_score_k); gqa_bwd.add_op(bmm_dscaled_score_q); + gqa_bwd.add_op(bmm_dscaled_score_k); gqa_bwd.add_op(reduce_dv); gqa_bwd.add_op(reduce_dk); if (dt != dt_inter) { diff --git a/tests/benchdnn/graph/deserialize.cpp b/tests/benchdnn/graph/deserialize.cpp index 92a615ffac4..fcaafc7d7b4 100644 --- a/tests/benchdnn/graph/deserialize.cpp +++ b/tests/benchdnn/graph/deserialize.cpp @@ -806,17 +806,37 @@ bool deserialized_graph_t::detect_sdpa_bwd_impl() const { continue; } - // find SoftMaxBackward and MatMul for dV + // find softmax bwd decomposed chain and MatMul for dV + // The decomposed softmax bwd is: Multiply(O,dO)->ReduceSum->Subtract->Multiply(P,dp_corr) + // Exp's children: one branch is the final Multiply (P*dp_corr), the other leads to dV MatMul auto cur_op_refs = get_child_ops(cur_op_ref); if (cur_op_refs.size() != 2) continue; + + // Identify which child is the final Multiply of the decomposed softmax bwd + // by verifying the full chain: its input comes from Subtract <- ReduceSum <- Multiply + const auto verify_softmax_bwd_chain + = [&](const deserialized_op_t &mul_op) -> bool { + if (mul_op.kind_ != "Multiply") return false; + for (const auto &mul_in : mul_op.in_lts_) { + const auto &sub_op = get_op_by_out_lt(mul_in.id_); + if (sub_op.kind_ != "Subtract") continue; + const auto &rs_op = get_op_by_out_lt(sub_op.in_lts_[1].id_); + if (rs_op.kind_ != "ReduceSum") continue; + const auto &odo_op = get_op_by_out_lt(rs_op.in_lts_[0].id_); + if (odo_op.kind_ == "Multiply") return true; + } + return false; + }; + size_t softmax_bwd_idx; - if (cur_op_refs[0].kind_ == "SoftMaxBackward") { + if (verify_softmax_bwd_chain(cur_op_refs[0])) { softmax_bwd_idx = 0; - } else if (cur_op_refs[1].kind_ == "SoftMaxBackward") { + } else if (verify_softmax_bwd_chain(cur_op_refs[1])) { softmax_bwd_idx = 1; } else { BENCHDNN_PRINT(8, "%s\n", - "[DETECT_SDPA_BWD]: failed due to no SoftMaxBackward"); + "[DETECT_SDPA_BWD]: failed due to no decomposed softmax " + "bwd chain (Multiply->ReduceSum->Subtract->Multiply)"); continue; } // find MatMul for dV @@ -848,8 +868,8 @@ bool deserialized_graph_t::detect_sdpa_bwd_impl() const { // if we find a path that contains: // ->MatMul->[dV] // MatMul->Subtract->Exp - // ->SoftMaxBackward->MatMul->[dQ / dK] - // ->End (optional) + // ->Multiply->MatMul->[dQ / dK] + // Mul->ReduceSum->Sub ->End (optional) // It will be considered as a SDPA bwd implementation. return true; } diff --git a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all index eb1fab992ac..fed646237b0 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all +++ b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all @@ -25,9 +25,9 @@ --reset --case=complex_fusion/mha/gqa-plain-bottom-right-implicit-causal-mask-f16-f32.json --reset --case=complex_fusion/mha/sdpa-plain-bottom-right-implicit-causal-mask-f16-f32.json --reset --dt=0:f16+1:f16+4:f16+7:f16+10:f16+13:f16+14:f16 --case=complex_fusion/mha/sdpa-plain-training-forward-bf16-f32.json ---reset --dt=16:f16+17:f16+32:f16+33:f16+34:f16+36:f16+44:f16+45:f16+47:f16 --case=complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json +--reset --dt=100:f16+101:f16+102:f16+103:f16+12:f16+105:f16+13:f16+104:f16+10:f16+28:f16+29:f16+31:f16 --case=complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json --reset --dt=0:f16+1:f16+4:f16+10:f16+13:f16+14:f16 --case=complex_fusion/mha/gqa-plain-training-forward-bf16-f32.json ---reset --dt=16:f16+17:f16+20:f16+32:f16+33:f16+34:f16+36:f16+38:f16+46:f16+47:f16+49:f16+51:f16 --case=complex_fusion/mha/gqa-plain-training-backward-bf16-f32.json +--reset --dt=100:f16+101:f16+102:f16+103:f16+12:f16+105:f16+13:f16+15:f16+104:f16+10:f16+30:f16+33:f16+31:f16+35:f16 --case=complex_fusion/mha/gqa-plain-training-backward-bf16-f32.json --reset --dt=3:f16+4:f16+2:f16+1:f16+11:f16+0:f16+12:f16+14:f16+16:f16 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json --reset --dt=0:f16+1:f16+3:f16+7:f16+2:f16+8:f16 --case=complex_fusion/mha/MHA-stable_diffusion-inf-fp32-bs1.json --reset --dt=0:f16+1:f16+7:f16+9:f16+10:f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f32.json diff --git a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci index ed8e03af8e4..50ed39e941e 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci +++ b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci @@ -26,10 +26,9 @@ --reset --case=complex_fusion/mha/gqa-plain-bottom-right-implicit-causal-mask-f16-f32.json --reset --case=complex_fusion/mha/sdpa-plain-bottom-right-implicit-causal-mask-f16-f32.json --reset --dt=0:f16+1:f16+4:f16+7:f16+10:f16+13:f16+14:f16 --case=complex_fusion/mha/sdpa-plain-training-forward-bf16-f32.json ---reset --dt=16:f16+17:f16+32:f16+33:f16+34:f16+36:f16+44:f16+45:f16+47:f16 --case=complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json ---reset --dt=16:f16+17:f16+32:f16+33:f16+34:f16+36:f16+44:f16+45:f16+47:f16 --case=complex_fusion/mha/sdpa-plain-training-backward-w-dmask-bf16-f32.json +--reset --dt=100:f16+101:f16+102:f16+103:f16+12:f16+105:f16+13:f16+104:f16+10:f16+28:f16+29:f16+31:f16 --case=complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json --reset --dt=0:f16+1:f16+4:f16+10:f16+13:f16+14:f16 --case=complex_fusion/mha/gqa-plain-training-forward-bf16-f32.json ---reset --dt=16:f16+17:f16+20:f16+32:f16+33:f16+34:f16+36:f16+38:f16+46:f16+47:f16+49:f16+51:f16 --case=complex_fusion/mha/gqa-plain-training-backward-bf16-f32.json +--reset --dt=100:f16+101:f16+102:f16+103:f16+12:f16+105:f16+13:f16+15:f16+104:f16+10:f16+30:f16+33:f16+31:f16+35:f16 --case=complex_fusion/mha/gqa-plain-training-backward-bf16-f32.json --reset --dt=3:f16+4:f16+2:f16+1:f16+11:f16+0:f16+12:f16+14:f16+16:f16 --case=complex_fusion/mha/MHA-bert_large-inf-fp32-bs1.json --reset --dt=0:f16+1:f16+3:f16+7:f16+2:f16+8:f16 --case=complex_fusion/mha/MHA-stable_diffusion-inf-fp32-bs1.json --reset --dt=0:f16+1:f16+7:f16+9:f16+10:f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f32.json diff --git a/tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-backward-bf16-f32.json b/tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-backward-bf16-f32.json index 64785eb1676..0e2001fb7ba 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-backward-bf16-f32.json +++ b/tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-backward-bf16-f32.json @@ -1,29 +1,31 @@ { - "version": "3.10.0", - "engine_kind": "cpu", + "version": "3.12.0", + "engine_kind": "gpu", "fpmath_mode": "strict", "fpmath_mode_apply_to_int": "false", "input_ports": [ - 16, - 17, - 20, - 23, - 26, - 33, - 33, - 38, - 20, - 17, - 16 + 100, + 101, + 102, + 103, + 8, + 105, + 105, + 104, + 10, + 105, + 102, + 100, + 101 ], "output_ports": [ - 47, - 36, - 51 + 15, + 31, + 35 ], "graph": [ { - "id": 19, + "id": 2, "name": "bmm1", "kind": "MatMul", "attrs": { @@ -31,6 +33,10 @@ "type": "bool", "value": 0 }, + "accumulation_mode": { + "type": "string", + "value": "strict" + }, "transpose_b": { "type": "bool", "value": 1 @@ -38,19 +44,19 @@ }, "inputs": [ { - "id": 16, + "id": 100, "dtype": "bf16", "shape": [ - 1, + 2, 2, 8, - 384, + 128, 64 ], "stride": [ - 393216, - 196608, - 24576, + 131072, + 65536, + 8192, 64, 1 ], @@ -58,19 +64,19 @@ "property_type": "undef" }, { - "id": 17, + "id": 101, "dtype": "bf16", "shape": [ - 1, + 2, 2, 1, - 384, + 128, 64 ], "stride": [ - 49152, - 24576, - 24576, + 16384, + 8192, + 8192, 64, 1 ], @@ -80,20 +86,20 @@ ], "outputs": [ { - "id": 18, + "id": 1, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 128 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", @@ -102,9 +108,9 @@ ] }, { - "id": 22, - "name": "scale_div", - "kind": "Divide", + "id": 4, + "name": "scale_mul", + "kind": "Multiply", "attrs": { "auto_broadcast": { "type": "string", @@ -113,28 +119,28 @@ }, "inputs": [ { - "id": 18, + "id": 1, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 128 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", "property_type": "undef" }, { - "id": 20, - "dtype": "f32", + "id": 102, + "dtype": "bf16", "shape": [ 1 ], @@ -147,20 +153,20 @@ ], "outputs": [ { - "id": 21, + "id": 3, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 128 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", @@ -169,7 +175,7 @@ ] }, { - "id": 25, + "id": 6, "name": "mask_add", "kind": "Add", "attrs": { @@ -180,40 +186,40 @@ }, "inputs": [ { - "id": 21, + "id": 3, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 128 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", "property_type": "undef" }, { - "id": 23, - "dtype": "f32", + "id": 103, + "dtype": "bf16", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 128 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", @@ -222,20 +228,20 @@ ], "outputs": [ { - "id": 24, + "id": 5, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 128 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", @@ -244,7 +250,7 @@ ] }, { - "id": 28, + "id": 8, "name": "subtract", "kind": "Subtract", "attrs": { @@ -255,39 +261,39 @@ }, "inputs": [ { - "id": 24, + "id": 5, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 128 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", "property_type": "undef" }, { - "id": 26, + "id": 8, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, + 128, 1 ], "stride": [ - 6144, - 3072, - 384, + 2048, + 1024, + 128, 1, 1 ], @@ -297,20 +303,20 @@ ], "outputs": [ { - "id": 27, + "id": 7, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 128 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", @@ -319,26 +325,26 @@ ] }, { - "id": 30, + "id": 10, "name": "exp", "kind": "Exp", "attrs": {}, "inputs": [ { - "id": 27, + "id": 7, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 128 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", @@ -347,20 +353,70 @@ ], "outputs": [ { - "id": 29, + "id": 9, "dtype": "f32", "shape": [ - 1, + 2, + 2, + 8, + 128, + 128 + ], + "stride": [ + 262144, + 131072, + 16384, + 128, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 11, + "name": "typecast", + "kind": "TypeCast", + "attrs": {}, + "inputs": [ + { + "id": 9, + "dtype": "f32", + "shape": [ + 2, 2, 8, - 384, - 384 + 128, + 128 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 262144, + 131072, + 16384, + 128, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 12, + "dtype": "bf16", + "shape": [ + 2, + 2, + 8, + 128, + 128 + ], + "stride": [ + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", @@ -369,7 +425,7 @@ ] }, { - "id": 35, + "id": 14, "name": "bmm_dv", "kind": "MatMul", "attrs": { @@ -377,6 +433,10 @@ "type": "bool", "value": 0 }, + "accumulation_mode": { + "type": "string", + "value": "strict" + }, "transpose_a": { "type": "bool", "value": 1 @@ -384,39 +444,39 @@ }, "inputs": [ { - "id": 32, + "id": 12, "dtype": "bf16", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 128 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", "property_type": "undef" }, { - "id": 33, + "id": 105, "dtype": "bf16", "shape": [ - 1, + 2, 2, 8, - 384, + 128, 64 ], "stride": [ - 393216, - 196608, - 24576, + 131072, + 65536, + 8192, 64, 1 ], @@ -426,19 +486,19 @@ ], "outputs": [ { - "id": 34, + "id": 13, "dtype": "bf16", "shape": [ - 1, + 2, 2, 8, - 384, + 128, 64 ], "stride": [ - 393216, - 196608, - 24576, + 131072, + 65536, + 8192, 64, 1 ], @@ -448,7 +508,68 @@ ] }, { - "id": 40, + "id": 16, + "name": "reduce_dv", + "kind": "ReduceSum", + "attrs": { + "keep_dims": { + "type": "bool", + "value": 1 + }, + "axes": { + "type": "s64[]", + "value": [ + 2 + ] + } + }, + "inputs": [ + { + "id": 13, + "dtype": "bf16", + "shape": [ + 2, + 2, + 8, + 128, + 64 + ], + "stride": [ + 131072, + 65536, + 8192, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 15, + "dtype": "bf16", + "shape": [ + 2, + 2, + 1, + 128, + 64 + ], + "stride": [ + 16384, + 8192, + 8192, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 18, "name": "bmm_dprobs", "kind": "MatMul", "attrs": { @@ -456,6 +577,10 @@ "type": "bool", "value": 0 }, + "accumulation_mode": { + "type": "string", + "value": "strict" + }, "transpose_b": { "type": "bool", "value": 1 @@ -463,19 +588,19 @@ }, "inputs": [ { - "id": 33, + "id": 105, "dtype": "bf16", "shape": [ - 1, + 2, 2, 8, - 384, + 128, 64 ], "stride": [ - 393216, - 196608, - 24576, + 131072, + 65536, + 8192, 64, 1 ], @@ -483,19 +608,19 @@ "property_type": "undef" }, { - "id": 38, + "id": 104, "dtype": "bf16", "shape": [ - 1, + 2, 2, 1, - 384, + 128, 64 ], "stride": [ - 49152, - 24576, - 24576, + 16384, + 8192, + 8192, 64, 1 ], @@ -505,20 +630,20 @@ ], "outputs": [ { - "id": 39, + "id": 17, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 128 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", @@ -527,51 +652,112 @@ ] }, { - "id": 42, - "name": "softmax_bwd", - "kind": "SoftMaxBackward", + "id": 20, + "name": "mul_o_do", + "kind": "Multiply", "attrs": { - "axis": { - "type": "s64", - "value": -1 + "auto_broadcast": { + "type": "string", + "value": "numpy" } }, "inputs": [ { - "id": 39, + "id": 10, + "dtype": "bf16", + "shape": [ + 2, + 2, + 8, + 128, + 64 + ], + "stride": [ + 131072, + 65536, + 8192, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 105, + "dtype": "bf16", + "shape": [ + 2, + 2, + 8, + 128, + 64 + ], + "stride": [ + 131072, + 65536, + 8192, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 19, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 64 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 131072, + 65536, + 8192, + 64, 1 ], "layout_type": "strided", "property_type": "undef" + } + ] + }, + { + "id": 22, + "name": "reducesum_correction", + "kind": "ReduceSum", + "attrs": { + "keep_dims": { + "type": "bool", + "value": 1 }, + "axes": { + "type": "s64[]", + "value": [ + 4 + ] + } + }, + "inputs": [ { - "id": 29, + "id": 19, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 64 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 131072, + 65536, + 8192, + 64, 1 ], "layout_type": "strided", @@ -580,20 +766,20 @@ ], "outputs": [ { - "id": 41, + "id": 21, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 1 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 2048, + 1024, + 128, + 1, 1 ], "layout_type": "strided", @@ -602,9 +788,9 @@ ] }, { - "id": 44, - "name": "scale_div", - "kind": "Divide", + "id": 24, + "name": "sub_dp_corrected", + "kind": "Subtract", "attrs": { "auto_broadcast": { "type": "string", @@ -613,32 +799,40 @@ }, "inputs": [ { - "id": 41, + "id": 17, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 128 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", "property_type": "undef" }, { - "id": 20, + "id": 21, "dtype": "f32", "shape": [ + 2, + 2, + 8, + 128, 1 ], "stride": [ + 2048, + 1024, + 128, + 1, 1 ], "layout_type": "strided", @@ -647,20 +841,20 @@ ], "outputs": [ { - "id": 43, + "id": 23, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 128 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", @@ -669,55 +863,51 @@ ] }, { - "id": 48, - "name": "bmm_dq", - "kind": "MatMul", + "id": 26, + "name": "mul_softmax_bwd", + "kind": "Multiply", "attrs": { - "transpose_a": { - "type": "bool", - "value": 0 - }, - "transpose_b": { - "type": "bool", - "value": 0 + "auto_broadcast": { + "type": "string", + "value": "numpy" } }, "inputs": [ { - "id": 46, - "dtype": "bf16", + "id": 9, + "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 128 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", "property_type": "undef" }, { - "id": 17, - "dtype": "bf16", + "id": 23, + "dtype": "f32", "shape": [ - 1, 2, - 1, - 384, - 64 + 2, + 8, + 128, + 128 ], "stride": [ - 49152, - 24576, - 24576, - 64, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", @@ -726,20 +916,20 @@ ], "outputs": [ { - "id": 47, - "dtype": "bf16", + "id": 25, + "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 64 + 128, + 128 ], "stride": [ - 393216, - 196608, - 24576, - 64, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", @@ -748,55 +938,43 @@ ] }, { - "id": 50, - "name": "bmm_dk", - "kind": "MatMul", + "id": 28, + "name": "scale_mul", + "kind": "Multiply", "attrs": { - "transpose_b": { - "type": "bool", - "value": 0 - }, - "transpose_a": { - "type": "bool", - "value": 1 + "auto_broadcast": { + "type": "string", + "value": "numpy" } }, "inputs": [ { - "id": 46, - "dtype": "bf16", + "id": 25, + "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 128 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", "property_type": "undef" }, { - "id": 16, + "id": 102, "dtype": "bf16", "shape": [ - 1, - 2, - 8, - 384, - 64 + 1 ], "stride": [ - 393216, - 196608, - 24576, - 64, 1 ], "layout_type": "strided", @@ -805,20 +983,20 @@ ], "outputs": [ { - "id": 49, - "dtype": "bf16", + "id": 27, + "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 64 + 128, + 128 ], "stride": [ - 393216, - 196608, - 24576, - 64, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", @@ -827,37 +1005,26 @@ ] }, { - "id": 37, - "name": "reduce_dv", - "kind": "ReduceSum", - "attrs": { - "keep_dims": { - "type": "bool", - "value": 1 - }, - "axes": { - "type": "s64[]", - "value": [ - 2 - ] - } - }, + "id": 29, + "name": "typecast", + "kind": "TypeCast", + "attrs": {}, "inputs": [ { - "id": 34, - "dtype": "bf16", + "id": 27, + "dtype": "f32", "shape": [ - 1, + 2, 2, 8, - 384, - 64 + 128, + 128 ], "stride": [ - 393216, - 196608, - 24576, - 64, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", @@ -866,20 +1033,20 @@ ], "outputs": [ { - "id": 36, + "id": 30, "dtype": "bf16", "shape": [ - 1, 2, - 1, - 384, - 64 + 2, + 8, + 128, + 128 ], "stride": [ - 49152, - 24576, - 24576, - 64, + 262144, + 131072, + 16384, + 128, 1 ], "layout_type": "strided", @@ -888,36 +1055,58 @@ ] }, { - "id": 52, - "name": "reduce_dk", - "kind": "ReduceSum", + "id": 34, + "name": "bmm_dk", + "kind": "MatMul", "attrs": { - "keep_dims": { + "transpose_b": { "type": "bool", - "value": 1 + "value": 0 }, - "axes": { - "type": "s64[]", - "value": [ - 2 - ] + "accumulation_mode": { + "type": "string", + "value": "strict" + }, + "transpose_a": { + "type": "bool", + "value": 1 } }, "inputs": [ { - "id": 49, + "id": 30, "dtype": "bf16", "shape": [ - 1, + 2, 2, 8, - 384, + 128, + 128 + ], + "stride": [ + 262144, + 131072, + 16384, + 128, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 100, + "dtype": "bf16", + "shape": [ + 2, + 2, + 8, + 128, 64 ], "stride": [ - 393216, - 196608, - 24576, + 131072, + 65536, + 8192, 64, 1 ], @@ -927,19 +1116,19 @@ ], "outputs": [ { - "id": 51, + "id": 33, "dtype": "bf16", "shape": [ - 1, 2, - 1, - 384, + 2, + 8, + 128, 64 ], "stride": [ - 49152, - 24576, - 24576, + 131072, + 65536, + 8192, 64, 1 ], @@ -949,26 +1138,59 @@ ] }, { - "id": 31, - "name": "typecast", - "kind": "TypeCast", - "attrs": {}, + "id": 32, + "name": "bmm_dq", + "kind": "MatMul", + "attrs": { + "transpose_a": { + "type": "bool", + "value": 0 + }, + "transpose_b": { + "type": "bool", + "value": 0 + }, + "accumulation_mode": { + "type": "string", + "value": "strict" + } + }, "inputs": [ { - "id": 29, - "dtype": "f32", + "id": 30, + "dtype": "bf16", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 128 + ], + "stride": [ + 262144, + 131072, + 16384, + 128, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 101, + "dtype": "bf16", + "shape": [ + 2, + 2, + 1, + 128, + 64 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 16384, + 8192, + 8192, + 64, 1 ], "layout_type": "strided", @@ -977,20 +1199,20 @@ ], "outputs": [ { - "id": 32, + "id": 31, "dtype": "bf16", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 64 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 131072, + 65536, + 8192, + 64, 1 ], "layout_type": "strided", @@ -999,26 +1221,37 @@ ] }, { - "id": 45, - "name": "typecast", - "kind": "TypeCast", - "attrs": {}, + "id": 36, + "name": "reduce_dk", + "kind": "ReduceSum", + "attrs": { + "keep_dims": { + "type": "bool", + "value": 1 + }, + "axes": { + "type": "s64[]", + "value": [ + 2 + ] + } + }, "inputs": [ { - "id": 43, - "dtype": "f32", + "id": 33, + "dtype": "bf16", "shape": [ - 1, + 2, 2, 8, - 384, - 384 + 128, + 64 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 131072, + 65536, + 8192, + 64, 1 ], "layout_type": "strided", @@ -1027,20 +1260,20 @@ ], "outputs": [ { - "id": 46, + "id": 35, "dtype": "bf16", "shape": [ - 1, 2, - 8, - 384, - 384 + 2, + 1, + 128, + 64 ], "stride": [ - 2359296, - 1179648, - 147456, - 384, + 16384, + 8192, + 8192, + 64, 1 ], "layout_type": "strided", diff --git a/tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-backward-f32.json b/tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-backward-f32.json index 0b7c360006e..38ebdf6b358 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-backward-f32.json +++ b/tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-backward-f32.json @@ -1,40 +1,42 @@ { - "version": "3.10.0", - "engine_kind": "cpu", + "version": "3.12.0", + "engine_kind": "gpu", "fpmath_mode": "strict", "fpmath_mode_apply_to_int": "false", "input_ports": [ - 16, - 17, - 20, - 23, - 26, - 32, - 32, - 37, - 20, - 17, - 16 + 100, + 101, + 102, + 103, + 8, + 105, + 105, + 104, + 10, + 105, + 102, + 100, + 101 ], "output_ports": [ - 45, - 35, - 49 + 14, + 29, + 33 ], "graph": [ { - "id": 19, + "id": 2, "name": "bmm1", "kind": "MatMul", "attrs": { - "accumulation_mode": { - "type": "string", - "value": "strict" - }, "transpose_a": { "type": "bool", "value": 0 }, + "accumulation_mode": { + "type": "string", + "value": "strict" + }, "transpose_b": { "type": "bool", "value": 1 @@ -42,10 +44,10 @@ }, "inputs": [ { - "id": 16, + "id": 100, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -62,10 +64,10 @@ "property_type": "undef" }, { - "id": 17, + "id": 101, "dtype": "f32", "shape": [ - 1, + 2, 2, 1, 128, @@ -84,10 +86,10 @@ ], "outputs": [ { - "id": 18, + "id": 1, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -106,7 +108,7 @@ ] }, { - "id": 22, + "id": 4, "name": "scale_mul", "kind": "Multiply", "attrs": { @@ -117,10 +119,10 @@ }, "inputs": [ { - "id": 18, + "id": 1, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -137,7 +139,7 @@ "property_type": "undef" }, { - "id": 20, + "id": 102, "dtype": "f32", "shape": [ 1 @@ -151,10 +153,10 @@ ], "outputs": [ { - "id": 21, + "id": 3, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -173,7 +175,7 @@ ] }, { - "id": 25, + "id": 6, "name": "mask_add", "kind": "Add", "attrs": { @@ -184,10 +186,10 @@ }, "inputs": [ { - "id": 21, + "id": 3, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -204,10 +206,10 @@ "property_type": "undef" }, { - "id": 23, + "id": 103, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -226,10 +228,10 @@ ], "outputs": [ { - "id": 24, + "id": 5, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -248,7 +250,7 @@ ] }, { - "id": 28, + "id": 8, "name": "subtract", "kind": "Subtract", "attrs": { @@ -259,10 +261,10 @@ }, "inputs": [ { - "id": 24, + "id": 5, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -279,10 +281,10 @@ "property_type": "undef" }, { - "id": 26, + "id": 8, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -301,10 +303,10 @@ ], "outputs": [ { - "id": 27, + "id": 7, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -323,16 +325,16 @@ ] }, { - "id": 30, + "id": 10, "name": "exp", "kind": "Exp", "attrs": {}, "inputs": [ { - "id": 27, + "id": 7, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -351,10 +353,10 @@ ], "outputs": [ { - "id": 29, + "id": 9, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -373,18 +375,18 @@ ] }, { - "id": 34, + "id": 13, "name": "bmm_dv", "kind": "MatMul", "attrs": { - "accumulation_mode": { - "type": "string", - "value": "strict" - }, "transpose_b": { "type": "bool", "value": 0 }, + "accumulation_mode": { + "type": "string", + "value": "strict" + }, "transpose_a": { "type": "bool", "value": 1 @@ -392,10 +394,10 @@ }, "inputs": [ { - "id": 29, + "id": 9, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -412,10 +414,10 @@ "property_type": "undef" }, { - "id": 32, + "id": 105, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -434,10 +436,49 @@ ], "outputs": [ { - "id": 33, + "id": 12, + "dtype": "f32", + "shape": [ + 2, + 2, + 8, + 128, + 64 + ], + "stride": [ + 131072, + 65536, + 8192, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 15, + "name": "reduce_dv", + "kind": "ReduceSum", + "attrs": { + "keep_dims": { + "type": "bool", + "value": 1 + }, + "axes": { + "type": "s64[]", + "value": [ + 2 + ] + } + }, + "inputs": [ + { + "id": 12, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -453,21 +494,43 @@ "layout_type": "strided", "property_type": "undef" } + ], + "outputs": [ + { + "id": 14, + "dtype": "f32", + "shape": [ + 2, + 2, + 1, + 128, + 64 + ], + "stride": [ + 16384, + 8192, + 8192, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } ] }, { - "id": 39, + "id": 17, "name": "bmm_dprobs", "kind": "MatMul", "attrs": { - "accumulation_mode": { - "type": "string", - "value": "strict" - }, "transpose_a": { "type": "bool", "value": 0 }, + "accumulation_mode": { + "type": "string", + "value": "strict" + }, "transpose_b": { "type": "bool", "value": 1 @@ -475,10 +538,10 @@ }, "inputs": [ { - "id": 32, + "id": 105, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -495,10 +558,10 @@ "property_type": "undef" }, { - "id": 37, + "id": 104, "dtype": "f32", "shape": [ - 1, + 2, 2, 1, 128, @@ -517,10 +580,10 @@ ], "outputs": [ { - "id": 38, + "id": 16, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -539,41 +602,41 @@ ] }, { - "id": 41, - "name": "softmax_bwd", - "kind": "SoftMaxBackward", + "id": 19, + "name": "mul_o_do", + "kind": "Multiply", "attrs": { - "axis": { - "type": "s64", - "value": -1 + "auto_broadcast": { + "type": "string", + "value": "numpy" } }, "inputs": [ { - "id": 38, + "id": 10, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, - 128 + 64 ], "stride": [ - 262144, 131072, - 16384, - 128, + 65536, + 8192, + 64, 1 ], "layout_type": "strided", "property_type": "undef" }, { - "id": 29, + "id": 105, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -592,10 +655,10 @@ ], "outputs": [ { - "id": 40, + "id": 18, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -614,9 +677,70 @@ ] }, { - "id": 43, - "name": "scale_mul", - "kind": "Multiply", + "id": 21, + "name": "reducesum_correction", + "kind": "ReduceSum", + "attrs": { + "keep_dims": { + "type": "bool", + "value": 1 + }, + "axes": { + "type": "s64[]", + "value": [ + 4 + ] + } + }, + "inputs": [ + { + "id": 18, + "dtype": "f32", + "shape": [ + 2, + 2, + 8, + 128, + 64 + ], + "stride": [ + 131072, + 65536, + 8192, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 20, + "dtype": "f32", + "shape": [ + 2, + 2, + 8, + 128, + 1 + ], + "stride": [ + 2048, + 1024, + 128, + 1, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 23, + "name": "sub_dp_corrected", + "kind": "Subtract", "attrs": { "auto_broadcast": { "type": "string", @@ -625,10 +749,10 @@ }, "inputs": [ { - "id": 40, + "id": 16, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -648,9 +772,17 @@ "id": 20, "dtype": "f32", "shape": [ + 2, + 2, + 8, + 128, 1 ], "stride": [ + 2048, + 1024, + 128, + 1, 1 ], "layout_type": "strided", @@ -659,10 +791,10 @@ ], "outputs": [ { - "id": 42, + "id": 22, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -681,29 +813,21 @@ ] }, { - "id": 46, - "name": "bmm_dq", - "kind": "MatMul", + "id": 25, + "name": "mul_softmax_bwd", + "kind": "Multiply", "attrs": { - "accumulation_mode": { + "auto_broadcast": { "type": "string", - "value": "strict" - }, - "transpose_a": { - "type": "bool", - "value": 0 - }, - "transpose_b": { - "type": "bool", - "value": 0 + "value": "numpy" } }, "inputs": [ { - "id": 42, + "id": 9, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -720,20 +844,20 @@ "property_type": "undef" }, { - "id": 17, + "id": 22, "dtype": "f32", "shape": [ - 1, 2, - 1, + 2, + 8, 128, - 64 + 128 ], "stride": [ + 262144, + 131072, 16384, - 8192, - 8192, - 64, + 128, 1 ], "layout_type": "strided", @@ -742,20 +866,20 @@ ], "outputs": [ { - "id": 45, + "id": 24, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, - 64 + 128 ], "stride": [ + 262144, 131072, - 65536, - 8192, - 64, + 16384, + 128, 1 ], "layout_type": "strided", @@ -764,18 +888,85 @@ ] }, { - "id": 48, - "name": "bmm_dk", - "kind": "MatMul", + "id": 27, + "name": "scale_mul", + "kind": "Multiply", "attrs": { - "accumulation_mode": { + "auto_broadcast": { "type": "string", - "value": "strict" + "value": "numpy" + } + }, + "inputs": [ + { + "id": 24, + "dtype": "f32", + "shape": [ + 2, + 2, + 8, + 128, + 128 + ], + "stride": [ + 262144, + 131072, + 16384, + 128, + 1 + ], + "layout_type": "strided", + "property_type": "undef" }, + { + "id": 102, + "dtype": "f32", + "shape": [ + 1 + ], + "stride": [ + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 26, + "dtype": "f32", + "shape": [ + 2, + 2, + 8, + 128, + 128 + ], + "stride": [ + 262144, + 131072, + 16384, + 128, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 32, + "name": "bmm_dk", + "kind": "MatMul", + "attrs": { "transpose_b": { "type": "bool", "value": 0 }, + "accumulation_mode": { + "type": "string", + "value": "strict" + }, "transpose_a": { "type": "bool", "value": 1 @@ -783,10 +974,10 @@ }, "inputs": [ { - "id": 42, + "id": 26, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -803,10 +994,10 @@ "property_type": "undef" }, { - "id": 16, + "id": 100, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -825,10 +1016,10 @@ ], "outputs": [ { - "id": 47, + "id": 31, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -847,35 +1038,57 @@ ] }, { - "id": 36, - "name": "reduce_dv", - "kind": "ReduceSum", + "id": 30, + "name": "bmm_dq", + "kind": "MatMul", "attrs": { - "keep_dims": { + "transpose_a": { "type": "bool", - "value": 1 + "value": 0 }, - "axes": { - "type": "s64[]", - "value": [ - 2 - ] + "transpose_b": { + "type": "bool", + "value": 0 + }, + "accumulation_mode": { + "type": "string", + "value": "strict" } }, "inputs": [ { - "id": 33, + "id": 26, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, - 64 + 128 ], "stride": [ + 262144, 131072, - 65536, + 16384, + 128, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 101, + "dtype": "f32", + "shape": [ + 2, + 2, + 1, + 128, + 64 + ], + "stride": [ + 16384, + 8192, 8192, 64, 1 @@ -886,18 +1099,18 @@ ], "outputs": [ { - "id": 35, + "id": 29, "dtype": "f32", "shape": [ - 1, 2, - 1, + 2, + 8, 128, 64 ], "stride": [ - 16384, - 8192, + 131072, + 65536, 8192, 64, 1 @@ -908,7 +1121,7 @@ ] }, { - "id": 50, + "id": 34, "name": "reduce_dk", "kind": "ReduceSum", "attrs": { @@ -925,10 +1138,10 @@ }, "inputs": [ { - "id": 47, + "id": 31, "dtype": "f32", "shape": [ - 1, + 2, 2, 8, 128, @@ -947,10 +1160,10 @@ ], "outputs": [ { - "id": 49, + "id": 33, "dtype": "f32", "shape": [ - 1, + 2, 2, 1, 128, diff --git a/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json b/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json index 5a34789873f..f2c5862946e 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json +++ b/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json @@ -1,30 +1,32 @@ { - "version": "3.9.0", + "version": "3.12.0", "engine_kind": "gpu", "fpmath_mode": "strict", "fpmath_mode_apply_to_int": "false", "input_ports": [ - 16, - 17, - 20, - 23, - 26, - 33, - 33, - 36, - 20, - 17, - 16 + 100, + 101, + 102, + 103, + 8, + 105, + 105, + 104, + 10, + 105, + 102, + 100, + 101 ], "output_ports": [ - 34, - 45, - 47 + 13, + 29, + 31 ], "graph": [ { - "id": 19, - "name": "bmm_qk", + "id": 2, + "name": "bmm1", "kind": "MatMul", "attrs": { "transpose_a": { @@ -34,11 +36,15 @@ "transpose_b": { "type": "bool", "value": 1 + }, + "accumulation_mode": { + "type": "string", + "value": "strict" } }, "inputs": [ { - "id": 16, + "id": 100, "dtype": "bf16", "shape": [ 1, @@ -56,7 +62,7 @@ "property_type": "undef" }, { - "id": 17, + "id": 101, "dtype": "bf16", "shape": [ 1, @@ -76,7 +82,7 @@ ], "outputs": [ { - "id": 18, + "id": 1, "dtype": "f32", "shape": [ 1, @@ -96,9 +102,9 @@ ] }, { - "id": 22, - "name": "scale_div", - "kind": "Divide", + "id": 4, + "name": "scale_mul", + "kind": "Multiply", "attrs": { "auto_broadcast": { "type": "string", @@ -107,7 +113,7 @@ }, "inputs": [ { - "id": 18, + "id": 1, "dtype": "f32", "shape": [ 1, @@ -125,8 +131,8 @@ "property_type": "undef" }, { - "id": 20, - "dtype": "f32", + "id": 102, + "dtype": "bf16", "shape": [ 1 ], @@ -139,7 +145,7 @@ ], "outputs": [ { - "id": 21, + "id": 3, "dtype": "f32", "shape": [ 1, @@ -159,7 +165,7 @@ ] }, { - "id": 25, + "id": 6, "name": "mask_add", "kind": "Add", "attrs": { @@ -170,7 +176,7 @@ }, "inputs": [ { - "id": 21, + "id": 3, "dtype": "f32", "shape": [ 1, @@ -188,8 +194,8 @@ "property_type": "undef" }, { - "id": 23, - "dtype": "f32", + "id": 103, + "dtype": "bf16", "shape": [ 1, 16, @@ -208,7 +214,7 @@ ], "outputs": [ { - "id": 24, + "id": 5, "dtype": "f32", "shape": [ 1, @@ -228,8 +234,8 @@ ] }, { - "id": 28, - "name": "softmax_sub", + "id": 8, + "name": "subtract", "kind": "Subtract", "attrs": { "auto_broadcast": { @@ -239,7 +245,7 @@ }, "inputs": [ { - "id": 24, + "id": 5, "dtype": "f32", "shape": [ 1, @@ -257,7 +263,7 @@ "property_type": "undef" }, { - "id": 26, + "id": 8, "dtype": "f32", "shape": [ 1, @@ -277,7 +283,7 @@ ], "outputs": [ { - "id": 27, + "id": 7, "dtype": "f32", "shape": [ 1, @@ -297,13 +303,13 @@ ] }, { - "id": 30, - "name": "softmax_exp", + "id": 10, + "name": "exp", "kind": "Exp", "attrs": {}, "inputs": [ { - "id": 27, + "id": 7, "dtype": "f32", "shape": [ 1, @@ -323,7 +329,7 @@ ], "outputs": [ { - "id": 29, + "id": 9, "dtype": "f32", "shape": [ 1, @@ -343,7 +349,53 @@ ] }, { - "id": 35, + "id": 11, + "name": "typecast", + "kind": "TypeCast", + "attrs": {}, + "inputs": [ + { + "id": 9, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 12, + "dtype": "bf16", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 14, "name": "bmm_dv", "kind": "MatMul", "attrs": { @@ -351,6 +403,10 @@ "type": "bool", "value": 0 }, + "accumulation_mode": { + "type": "string", + "value": "strict" + }, "transpose_a": { "type": "bool", "value": 1 @@ -358,7 +414,7 @@ }, "inputs": [ { - "id": 32, + "id": 12, "dtype": "bf16", "shape": [ 1, @@ -376,7 +432,7 @@ "property_type": "undef" }, { - "id": 33, + "id": 105, "dtype": "bf16", "shape": [ 1, @@ -396,7 +452,7 @@ ], "outputs": [ { - "id": 34, + "id": 13, "dtype": "bf16", "shape": [ 1, @@ -416,14 +472,18 @@ ] }, { - "id": 38, - "name": "bmm_do_v", + "id": 16, + "name": "bmm_dprobs", "kind": "MatMul", "attrs": { "transpose_a": { "type": "bool", "value": 0 }, + "accumulation_mode": { + "type": "string", + "value": "strict" + }, "transpose_b": { "type": "bool", "value": 1 @@ -431,7 +491,7 @@ }, "inputs": [ { - "id": 33, + "id": 105, "dtype": "bf16", "shape": [ 1, @@ -449,7 +509,7 @@ "property_type": "undef" }, { - "id": 36, + "id": 104, "dtype": "bf16", "shape": [ 1, @@ -469,7 +529,7 @@ ], "outputs": [ { - "id": 37, + "id": 15, "dtype": "f32", "shape": [ 1, @@ -489,47 +549,104 @@ ] }, { - "id": 40, - "name": "softmax_bwd", - "kind": "SoftMaxBackward", + "id": 18, + "name": "mul_o_do", + "kind": "Multiply", "attrs": { - "axis": { - "type": "s64", - "value": -1 + "auto_broadcast": { + "type": "string", + "value": "numpy" } }, "inputs": [ { - "id": 37, - "dtype": "f32", + "id": 10, + "dtype": "bf16", "shape": [ 1, 16, 384, - 384 + 64 ], "stride": [ - 2359296, - 147456, - 384, + 393216, + 24576, + 64, 1 ], "layout_type": "strided", "property_type": "undef" }, { - "id": 29, + "id": 105, + "dtype": "bf16", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 17, "dtype": "f32", "shape": [ 1, 16, 384, - 384 + 64 ], "stride": [ - 2359296, - 147456, + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 20, + "name": "reducesum_correction", + "kind": "ReduceSum", + "attrs": { + "keep_dims": { + "type": "bool", + "value": 1 + }, + "axes": { + "type": "s64[]", + "value": [ + 3 + ] + } + }, + "inputs": [ + { + "id": 17, + "dtype": "f32", + "shape": [ + 1, + 16, 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, 1 ], "layout_type": "strided", @@ -538,18 +655,18 @@ ], "outputs": [ { - "id": 39, + "id": 19, "dtype": "f32", "shape": [ 1, 16, 384, - 384 + 1 ], "stride": [ - 2359296, - 147456, + 6144, 384, + 1, 1 ], "layout_type": "strided", @@ -558,9 +675,9 @@ ] }, { - "id": 42, - "name": "scale_div", - "kind": "Divide", + "id": 22, + "name": "sub_dp_corrected", + "kind": "Subtract", "attrs": { "auto_broadcast": { "type": "string", @@ -569,7 +686,7 @@ }, "inputs": [ { - "id": 39, + "id": 15, "dtype": "f32", "shape": [ 1, @@ -587,12 +704,18 @@ "property_type": "undef" }, { - "id": 20, + "id": 19, "dtype": "f32", "shape": [ + 1, + 16, + 384, 1 ], "stride": [ + 6144, + 384, + 1, 1 ], "layout_type": "strided", @@ -601,7 +724,7 @@ ], "outputs": [ { - "id": 41, + "id": 21, "dtype": "f32", "shape": [ 1, @@ -621,23 +744,19 @@ ] }, { - "id": 46, - "name": "bmm_dq", - "kind": "MatMul", + "id": 24, + "name": "mul_softmax_bwd", + "kind": "Multiply", "attrs": { - "transpose_a": { - "type": "bool", - "value": 0 - }, - "transpose_b": { - "type": "bool", - "value": 0 + "auto_broadcast": { + "type": "string", + "value": "numpy" } }, "inputs": [ { - "id": 44, - "dtype": "bf16", + "id": 9, + "dtype": "f32", "shape": [ 1, 16, @@ -654,18 +773,18 @@ "property_type": "undef" }, { - "id": 17, - "dtype": "bf16", + "id": 21, + "dtype": "f32", "shape": [ 1, 16, 384, - 64 + 384 ], "stride": [ - 393216, - 24576, - 64, + 2359296, + 147456, + 384, 1 ], "layout_type": "strided", @@ -674,18 +793,18 @@ ], "outputs": [ { - "id": 45, - "dtype": "bf16", + "id": 23, + "dtype": "f32", "shape": [ 1, 16, 384, - 64 + 384 ], "stride": [ - 393216, - 24576, - 64, + 2359296, + 147456, + 384, 1 ], "layout_type": "strided", @@ -694,23 +813,19 @@ ] }, { - "id": 48, - "name": "bmm_dk", - "kind": "MatMul", + "id": 26, + "name": "scale_mul", + "kind": "Multiply", "attrs": { - "transpose_b": { - "type": "bool", - "value": 0 - }, - "transpose_a": { - "type": "bool", - "value": 1 + "auto_broadcast": { + "type": "string", + "value": "numpy" } }, "inputs": [ { - "id": 44, - "dtype": "bf16", + "id": 23, + "dtype": "f32", "shape": [ 1, 16, @@ -727,18 +842,12 @@ "property_type": "undef" }, { - "id": 16, + "id": 102, "dtype": "bf16", "shape": [ - 1, - 16, - 384, - 64 + 1 ], "stride": [ - 393216, - 24576, - 64, 1 ], "layout_type": "strided", @@ -747,18 +856,18 @@ ], "outputs": [ { - "id": 47, - "dtype": "bf16", + "id": 25, + "dtype": "f32", "shape": [ 1, 16, 384, - 64 + 384 ], "stride": [ - 393216, - 24576, - 64, + 2359296, + 147456, + 384, 1 ], "layout_type": "strided", @@ -767,13 +876,13 @@ ] }, { - "id": 31, + "id": 27, "name": "typecast", "kind": "TypeCast", "attrs": {}, "inputs": [ { - "id": 29, + "id": 25, "dtype": "f32", "shape": [ 1, @@ -793,7 +902,7 @@ ], "outputs": [ { - "id": 32, + "id": 28, "dtype": "bf16", "shape": [ 1, @@ -813,14 +922,27 @@ ] }, { - "id": 43, - "name": "typecast", - "kind": "TypeCast", - "attrs": {}, + "id": 32, + "name": "bmm_dk", + "kind": "MatMul", + "attrs": { + "transpose_b": { + "type": "bool", + "value": 0 + }, + "accumulation_mode": { + "type": "string", + "value": "strict" + }, + "transpose_a": { + "type": "bool", + "value": 1 + } + }, "inputs": [ { - "id": 41, - "dtype": "f32", + "id": 28, + "dtype": "bf16", "shape": [ 1, 16, @@ -835,11 +957,68 @@ ], "layout_type": "strided", "property_type": "undef" + }, + { + "id": 100, + "dtype": "bf16", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" } ], "outputs": [ { - "id": 44, + "id": 31, + "dtype": "bf16", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 30, + "name": "bmm_dq", + "kind": "MatMul", + "attrs": { + "transpose_a": { + "type": "bool", + "value": 0 + }, + "accumulation_mode": { + "type": "string", + "value": "strict" + }, + "transpose_b": { + "type": "bool", + "value": 0 + } + }, + "inputs": [ + { + "id": 28, "dtype": "bf16", "shape": [ 1, @@ -855,6 +1034,44 @@ ], "layout_type": "strided", "property_type": "undef" + }, + { + "id": 101, + "dtype": "bf16", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 29, + "dtype": "bf16", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" } ] } diff --git a/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-training-backward-f32.json b/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-training-backward-f32.json index fcdb018a7c3..68f5905a08c 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-training-backward-f32.json +++ b/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-training-backward-f32.json @@ -1,29 +1,31 @@ { - "version": "3.9.0", + "version": "3.12.0", "engine_kind": "gpu", "fpmath_mode": "strict", "fpmath_mode_apply_to_int": "false", "input_ports": [ - 16, - 17, - 20, - 23, - 26, - 32, - 32, - 35, - 20, - 17, - 16 + 100, + 101, + 102, + 103, + 8, + 105, + 105, + 104, + 10, + 105, + 102, + 100, + 101 ], "output_ports": [ - 33, - 43, - 45 + 12, + 27, + 29 ], "graph": [ { - "id": 19, + "id": 2, "name": "bmm1", "kind": "MatMul", "attrs": { @@ -34,11 +36,15 @@ "transpose_b": { "type": "bool", "value": 1 + }, + "accumulation_mode": { + "type": "string", + "value": "strict" } }, "inputs": [ { - "id": 16, + "id": 100, "dtype": "f32", "shape": [ 1, @@ -56,7 +62,7 @@ "property_type": "undef" }, { - "id": 17, + "id": 101, "dtype": "f32", "shape": [ 1, @@ -76,7 +82,7 @@ ], "outputs": [ { - "id": 18, + "id": 1, "dtype": "f32", "shape": [ 1, @@ -96,9 +102,9 @@ ] }, { - "id": 22, - "name": "scale_div", - "kind": "Divide", + "id": 4, + "name": "scale_mul", + "kind": "Multiply", "attrs": { "auto_broadcast": { "type": "string", @@ -107,7 +113,7 @@ }, "inputs": [ { - "id": 18, + "id": 1, "dtype": "f32", "shape": [ 1, @@ -125,7 +131,7 @@ "property_type": "undef" }, { - "id": 20, + "id": 102, "dtype": "f32", "shape": [ 1 @@ -139,7 +145,7 @@ ], "outputs": [ { - "id": 21, + "id": 3, "dtype": "f32", "shape": [ 1, @@ -159,7 +165,7 @@ ] }, { - "id": 25, + "id": 6, "name": "mask_add", "kind": "Add", "attrs": { @@ -170,7 +176,7 @@ }, "inputs": [ { - "id": 21, + "id": 3, "dtype": "f32", "shape": [ 1, @@ -188,7 +194,7 @@ "property_type": "undef" }, { - "id": 23, + "id": 103, "dtype": "f32", "shape": [ 1, @@ -208,7 +214,7 @@ ], "outputs": [ { - "id": 24, + "id": 5, "dtype": "f32", "shape": [ 1, @@ -228,8 +234,8 @@ ] }, { - "id": 28, - "name": "softmax_sub", + "id": 8, + "name": "subtract", "kind": "Subtract", "attrs": { "auto_broadcast": { @@ -239,7 +245,7 @@ }, "inputs": [ { - "id": 24, + "id": 5, "dtype": "f32", "shape": [ 1, @@ -257,7 +263,7 @@ "property_type": "undef" }, { - "id": 26, + "id": 8, "dtype": "f32", "shape": [ 1, @@ -277,7 +283,7 @@ ], "outputs": [ { - "id": 27, + "id": 7, "dtype": "f32", "shape": [ 1, @@ -297,13 +303,13 @@ ] }, { - "id": 30, - "name": "softmax_exp", + "id": 10, + "name": "exp", "kind": "Exp", "attrs": {}, "inputs": [ { - "id": 27, + "id": 7, "dtype": "f32", "shape": [ 1, @@ -323,7 +329,7 @@ ], "outputs": [ { - "id": 29, + "id": 9, "dtype": "f32", "shape": [ 1, @@ -343,7 +349,7 @@ ] }, { - "id": 34, + "id": 13, "name": "bmm_dv", "kind": "MatMul", "attrs": { @@ -351,6 +357,10 @@ "type": "bool", "value": 0 }, + "accumulation_mode": { + "type": "string", + "value": "strict" + }, "transpose_a": { "type": "bool", "value": 1 @@ -358,7 +368,7 @@ }, "inputs": [ { - "id": 29, + "id": 9, "dtype": "f32", "shape": [ 1, @@ -376,7 +386,7 @@ "property_type": "undef" }, { - "id": 32, + "id": 105, "dtype": "f32", "shape": [ 1, @@ -396,7 +406,7 @@ ], "outputs": [ { - "id": 33, + "id": 12, "dtype": "f32", "shape": [ 1, @@ -416,14 +426,18 @@ ] }, { - "id": 37, - "name": "bmm_do_v", + "id": 15, + "name": "bmm_dprobs", "kind": "MatMul", "attrs": { "transpose_a": { "type": "bool", "value": 0 }, + "accumulation_mode": { + "type": "string", + "value": "strict" + }, "transpose_b": { "type": "bool", "value": 1 @@ -431,7 +445,114 @@ }, "inputs": [ { - "id": 32, + "id": 105, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 104, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 14, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 17, + "name": "mul_o_do", + "kind": "Multiply", + "attrs": { + "auto_broadcast": { + "type": "string", + "value": "numpy" + } + }, + "inputs": [ + { + "id": 10, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 105, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 64 + ], + "stride": [ + 393216, + 24576, + 64, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 16, "dtype": "f32", "shape": [ 1, @@ -447,9 +568,28 @@ ], "layout_type": "strided", "property_type": "undef" + } + ] + }, + { + "id": 19, + "name": "reducesum_correction", + "kind": "ReduceSum", + "attrs": { + "keep_dims": { + "type": "bool", + "value": 1 }, + "axes": { + "type": "s64[]", + "value": [ + 3 + ] + } + }, + "inputs": [ { - "id": 35, + "id": 16, "dtype": "f32", "shape": [ 1, @@ -469,7 +609,76 @@ ], "outputs": [ { - "id": 36, + "id": 18, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 1 + ], + "stride": [ + 6144, + 384, + 1, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ] + }, + { + "id": 21, + "name": "sub_dp_corrected", + "kind": "Subtract", + "attrs": { + "auto_broadcast": { + "type": "string", + "value": "numpy" + } + }, + "inputs": [ + { + "id": 14, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 384 + ], + "stride": [ + 2359296, + 147456, + 384, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + }, + { + "id": 18, + "dtype": "f32", + "shape": [ + 1, + 16, + 384, + 1 + ], + "stride": [ + 6144, + 384, + 1, + 1 + ], + "layout_type": "strided", + "property_type": "undef" + } + ], + "outputs": [ + { + "id": 20, "dtype": "f32", "shape": [ 1, @@ -489,18 +698,18 @@ ] }, { - "id": 39, - "name": "softmax_bwd", - "kind": "SoftMaxBackward", + "id": 23, + "name": "mul_softmax_bwd", + "kind": "Multiply", "attrs": { - "axis": { - "type": "s64", - "value": -1 + "auto_broadcast": { + "type": "string", + "value": "numpy" } }, "inputs": [ { - "id": 36, + "id": 9, "dtype": "f32", "shape": [ 1, @@ -518,7 +727,7 @@ "property_type": "undef" }, { - "id": 29, + "id": 20, "dtype": "f32", "shape": [ 1, @@ -538,7 +747,7 @@ ], "outputs": [ { - "id": 38, + "id": 22, "dtype": "f32", "shape": [ 1, @@ -558,9 +767,9 @@ ] }, { - "id": 41, - "name": "scale_div", - "kind": "Divide", + "id": 25, + "name": "scale_mul", + "kind": "Multiply", "attrs": { "auto_broadcast": { "type": "string", @@ -569,7 +778,7 @@ }, "inputs": [ { - "id": 38, + "id": 22, "dtype": "f32", "shape": [ 1, @@ -587,7 +796,7 @@ "property_type": "undef" }, { - "id": 20, + "id": 102, "dtype": "f32", "shape": [ 1 @@ -601,7 +810,7 @@ ], "outputs": [ { - "id": 40, + "id": 24, "dtype": "f32", "shape": [ 1, @@ -621,22 +830,26 @@ ] }, { - "id": 44, - "name": "bmm_dq", + "id": 30, + "name": "bmm_dk", "kind": "MatMul", "attrs": { - "transpose_a": { + "transpose_b": { "type": "bool", "value": 0 }, - "transpose_b": { + "accumulation_mode": { + "type": "string", + "value": "strict" + }, + "transpose_a": { "type": "bool", - "value": 0 + "value": 1 } }, "inputs": [ { - "id": 40, + "id": 24, "dtype": "f32", "shape": [ 1, @@ -654,7 +867,7 @@ "property_type": "undef" }, { - "id": 17, + "id": 100, "dtype": "f32", "shape": [ 1, @@ -674,7 +887,7 @@ ], "outputs": [ { - "id": 43, + "id": 29, "dtype": "f32", "shape": [ 1, @@ -694,22 +907,26 @@ ] }, { - "id": 46, - "name": "bmm_dk", + "id": 28, + "name": "bmm_dq", "kind": "MatMul", "attrs": { - "transpose_b": { + "transpose_a": { "type": "bool", "value": 0 }, - "transpose_a": { + "accumulation_mode": { + "type": "string", + "value": "strict" + }, + "transpose_b": { "type": "bool", - "value": 1 + "value": 0 } }, "inputs": [ { - "id": 40, + "id": 24, "dtype": "f32", "shape": [ 1, @@ -727,7 +944,7 @@ "property_type": "undef" }, { - "id": 16, + "id": 101, "dtype": "f32", "shape": [ 1, @@ -747,7 +964,7 @@ ], "outputs": [ { - "id": 45, + "id": 27, "dtype": "f32", "shape": [ 1, From 8948d9e8ae961aade35db1c55e0812f49c444db4 Mon Sep 17 00:00:00 2001 From: "Bao, Yixin" Date: Thu, 12 Mar 2026 03:53:26 +0000 Subject: [PATCH 22/23] tests: benchdnn: graph: remove legacy sdpa bwd cases --- .../graph/complex_fusion/harness_mha_all | 4 - .../graph/complex_fusion/harness_mha_ci | 5 - ...in-training-backward-w-dmask-bf16-f32.json | 1102 --------------- ...plain-training-bwd-w-dropout-bf16-f32.json | 1226 ----------------- ...in-training-backward-w-dmask-bf16-f32.json | 910 ------------ 5 files changed, 3247 deletions(-) delete mode 100644 tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-backward-w-dmask-bf16-f32.json delete mode 100644 tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-bwd-w-dropout-bf16-f32.json delete mode 100644 tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-training-backward-w-dmask-bf16-f32.json diff --git a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all index fed646237b0..957cc308430 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all +++ b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_all @@ -33,8 +33,6 @@ --reset --dt=0:f16+1:f16+7:f16+9:f16+10:f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f32.json --reset --dt=0:f16+1:f16+7:f16+8:f16+9:f16 --case=complex_fusion/mha/sdpa-plain-wo-scale-f32.json --reset --in-shapes=0:2x1x1024x128*131072x128x128x1+1:2x1x1024x128*131072x128x128x1+2:1*1+5:2x1x1024x128*131072x128x128x1 --case=complex_fusion/mha/sdpa-training-fwd-no-mask-f16-f32.json ---reset --dt=16:f16+17:f16+32:f16+33:f16+34:f16+36:f16+44:f16+45:f16+47:f16 --case=complex_fusion/mha/sdpa-plain-training-backward-w-dmask-bf16-f32.json ---reset --dt=16:f16+17:f16+20:f16+32:f16+33:f16+34:f16+36:f16+38:f16+46:f16+47:f16+49:f16+51:f16 --case=complex_fusion/mha/gqa-plain-training-backward-w-dmask-bf16-f32.json # bf16 inputs + f32 intermediates + bf16 outputs --reset --op-kind=1:Multiply,1:Divide --dt=1:bf16+2:bf16+3:bf16+4:bf16+5:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json @@ -53,8 +51,6 @@ --reset --dt=1:bf16+2:bf16+3:bf16+4:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json --reset --dt=0:bf16+1:bf16+7:bf16+9:bf16+10:bf16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f32.json --reset --dt=0:bf16+1:bf16+7:bf16+8:bf16+9:bf16 --case=complex_fusion/mha/sdpa-plain-wo-scale-f32.json ---reset --case=complex_fusion/mha/sdpa-plain-training-backward-w-dmask-bf16-f32.json ---reset --case=complex_fusion/mha/gqa-plain-training-backward-w-dmask-bf16-f32.json # int8 graphs --reset --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json diff --git a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci index 50ed39e941e..77550e28ed0 100644 --- a/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci +++ b/tests/benchdnn/inputs/graph/complex_fusion/harness_mha_ci @@ -34,9 +34,7 @@ --reset --dt=0:f16+1:f16+7:f16+9:f16+10:f16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f32.json --reset --dt=0:f16+1:f16+7:f16+8:f16+9:f16 --case=complex_fusion/mha/sdpa-plain-wo-scale-f32.json --reset --in-shapes=0:2x1x1024x128*131072x128x128x1+1:2x1x1024x128*131072x128x128x1+2:1*1+5:2x1x1024x128*131072x128x128x1 --case=complex_fusion/mha/sdpa-training-fwd-no-mask-f16-f32.json ---reset --dt=16:f16+17:f16+20:f16+32:f16+33:f16+34:f16+36:f16+38:f16+46:f16+47:f16+49:f16+51:f16 --case=complex_fusion/mha/gqa-plain-training-backward-w-dmask-bf16-f32.json --reset --dt=200:f16+201:f16+202:f16+203:f16+204:f16+7:f16+10:f16+12:f16 --case=complex_fusion/mha/gqa-plain-training-fwd-w-dropout-bf16-f32.json ---reset --dt=200:f16+201:f16+202:f16+203:f16+204:f16+14:f16+15:f16+17:f16+105:f16+28:f16+29:f16+31:f16+33:f16 --case=complex_fusion/mha/gqa-plain-training-bwd-w-dropout-bf16-f32.json # bf16 inputs + f32 intermediates + bf16 outputs --reset --op-kind=1:Multiply,1:Divide --dt=1:bf16+2:bf16+3:bf16+4:bf16+5:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-simplified-f16-f32.json @@ -48,7 +46,6 @@ --reset --dt=0:bf16+1:bf16+3:bf16+7:bf16+2:bf16+8:bf16 --case=complex_fusion/mha/MHA-stable_diffusion-inf-fp32-bs1.json --reset --case=complex_fusion/mha/sdpa-plain-training-forward-bf16-f32.json --reset --case=complex_fusion/mha/sdpa-plain-training-backward-bf16-f32.json ---reset --case=complex_fusion/mha/sdpa-plain-training-backward-w-dmask-bf16-f32.json --reset --case=complex_fusion/mha/gqa-plain-training-forward-bf16-f32.json --reset --case=complex_fusion/mha/gqa-plain-training-backward-bf16-f32.json --reset --case=complex_fusion/mha/codegemma-bf16-f32.json @@ -56,9 +53,7 @@ --reset --dt=1:bf16+2:bf16+3:bf16+4:bf16+6:bf16+104:bf16 --case=complex_fusion/mha/sdpa-plain-implicit-causal-mask-fp32-bs1.json --reset --dt=0:bf16+1:bf16+7:bf16+9:bf16+10:bf16 --case=complex_fusion/mha/sdpa-plain-wo-mask-f32.json --reset --dt=0:bf16+1:bf16+7:bf16+8:bf16+9:bf16 --case=complex_fusion/mha/sdpa-plain-wo-scale-f32.json ---reset --case=complex_fusion/mha/gqa-plain-training-backward-w-dmask-bf16-f32.json --reset --case=complex_fusion/mha/gqa-plain-training-fwd-w-dropout-bf16-f32.json ---reset --case=complex_fusion/mha/gqa-plain-training-bwd-w-dropout-bf16-f32.json # int8 graphs --reset --case=complex_fusion/mha/MHA-GPT-inf-int8-bs1.json diff --git a/tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-backward-w-dmask-bf16-f32.json b/tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-backward-w-dmask-bf16-f32.json deleted file mode 100644 index 17b37139d15..00000000000 --- a/tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-backward-w-dmask-bf16-f32.json +++ /dev/null @@ -1,1102 +0,0 @@ -{ - "version": "3.11.0", - "engine_kind": "cpu", - "fpmath_mode": "strict", - "fpmath_mode_apply_to_int": "false", - "input_ports": [ - 16, - 17, - 20, - 23, - 26, - 33, - 33, - 38, - 20, - 17, - 16 - ], - "output_ports": [ - 36, - 41, - 47, - 51 - ], - "graph": [ - { - "id": 19, - "name": "bmm1", - "kind": "MatMul", - "attrs": { - "transpose_b": { - "type": "bool", - "value": 1 - }, - "transpose_a": { - "type": "bool", - "value": 0 - }, - "accumulation_mode": { - "type": "string", - "value": "strict" - } - }, - "inputs": [ - { - "id": 16, - "dtype": "bf16", - "shape": [ - 1, - 2, - 8, - 384, - 64 - ], - "stride": [ - 393216, - 196608, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 17, - "dtype": "bf16", - "shape": [ - 1, - 2, - 1, - 384, - 64 - ], - "stride": [ - 49152, - 24576, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 18, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 22, - "name": "scale_div", - "kind": "Divide", - "attrs": { - "auto_broadcast": { - "type": "string", - "value": "numpy" - } - }, - "inputs": [ - { - "id": 18, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 20, - "dtype": "f32", - "shape": [ - 1 - ], - "stride": [ - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 21, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 25, - "name": "mask_add", - "kind": "Add", - "attrs": { - "auto_broadcast": { - "type": "string", - "value": "numpy" - } - }, - "inputs": [ - { - "id": 21, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 23, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 24, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 28, - "name": "subtract", - "kind": "Subtract", - "attrs": { - "auto_broadcast": { - "type": "string", - "value": "numpy" - } - }, - "inputs": [ - { - "id": 24, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 26, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 1 - ], - "stride": [ - 6144, - 3072, - 384, - 1, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 27, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 30, - "name": "exp", - "kind": "Exp", - "attrs": {}, - "inputs": [ - { - "id": 27, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 29, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 31, - "name": "typecast", - "kind": "TypeCast", - "attrs": {}, - "inputs": [ - { - "id": 29, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 32, - "dtype": "bf16", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 35, - "name": "bmm_dv", - "kind": "MatMul", - "attrs": { - "transpose_a": { - "type": "bool", - "value": 1 - }, - "transpose_b": { - "type": "bool", - "value": 0 - }, - "accumulation_mode": { - "type": "string", - "value": "strict" - } - }, - "inputs": [ - { - "id": 32, - "dtype": "bf16", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 33, - "dtype": "bf16", - "shape": [ - 1, - 2, - 8, - 384, - 64 - ], - "stride": [ - 393216, - 196608, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 34, - "dtype": "bf16", - "shape": [ - 1, - 2, - 8, - 384, - 64 - ], - "stride": [ - 393216, - 196608, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 37, - "name": "reduce_dv", - "kind": "ReduceSum", - "attrs": { - "keep_dims": { - "type": "bool", - "value": 1 - }, - "axes": { - "type": "s64[]", - "value": [ - 2 - ] - } - }, - "inputs": [ - { - "id": 34, - "dtype": "bf16", - "shape": [ - 1, - 2, - 8, - 384, - 64 - ], - "stride": [ - 393216, - 196608, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 36, - "dtype": "bf16", - "shape": [ - 1, - 2, - 1, - 384, - 64 - ], - "stride": [ - 49152, - 24576, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 40, - "name": "bmm_dprobs", - "kind": "MatMul", - "attrs": { - "transpose_b": { - "type": "bool", - "value": 1 - }, - "transpose_a": { - "type": "bool", - "value": 0 - }, - "accumulation_mode": { - "type": "string", - "value": "strict" - } - }, - "inputs": [ - { - "id": 33, - "dtype": "bf16", - "shape": [ - 1, - 2, - 8, - 384, - 64 - ], - "stride": [ - 393216, - 196608, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 38, - "dtype": "bf16", - "shape": [ - 1, - 2, - 1, - 384, - 64 - ], - "stride": [ - 49152, - 24576, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 39, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 42, - "name": "softmax_bwd", - "kind": "SoftMaxBackward", - "attrs": { - "axis": { - "type": "s64", - "value": -1 - } - }, - "inputs": [ - { - "id": 39, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 29, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 41, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 47, - "name": "end", - "kind": "End", - "attrs": {}, - "inputs": [ - { - "id": 41, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [] - }, - { - "id": 44, - "name": "scale_div", - "kind": "Divide", - "attrs": { - "auto_broadcast": { - "type": "string", - "value": "numpy" - } - }, - "inputs": [ - { - "id": 41, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 20, - "dtype": "f32", - "shape": [ - 1 - ], - "stride": [ - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 43, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 45, - "name": "typecast", - "kind": "TypeCast", - "attrs": {}, - "inputs": [ - { - "id": 43, - "dtype": "f32", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 46, - "dtype": "bf16", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 48, - "name": "bmm_dq", - "kind": "MatMul", - "attrs": { - "transpose_b": { - "type": "bool", - "value": 0 - }, - "transpose_a": { - "type": "bool", - "value": 0 - }, - "accumulation_mode": { - "type": "string", - "value": "strict" - } - }, - "inputs": [ - { - "id": 46, - "dtype": "bf16", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 17, - "dtype": "bf16", - "shape": [ - 1, - 2, - 1, - 384, - 64 - ], - "stride": [ - 49152, - 24576, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 47, - "dtype": "bf16", - "shape": [ - 1, - 2, - 8, - 384, - 64 - ], - "stride": [ - 393216, - 196608, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 50, - "name": "bmm_dk", - "kind": "MatMul", - "attrs": { - "transpose_a": { - "type": "bool", - "value": 1 - }, - "transpose_b": { - "type": "bool", - "value": 0 - }, - "accumulation_mode": { - "type": "string", - "value": "strict" - } - }, - "inputs": [ - { - "id": 46, - "dtype": "bf16", - "shape": [ - 1, - 2, - 8, - 384, - 384 - ], - "stride": [ - 2359296, - 1179648, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 16, - "dtype": "bf16", - "shape": [ - 1, - 2, - 8, - 384, - 64 - ], - "stride": [ - 393216, - 196608, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 49, - "dtype": "bf16", - "shape": [ - 1, - 2, - 8, - 384, - 64 - ], - "stride": [ - 393216, - 196608, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 52, - "name": "reduce_dk", - "kind": "ReduceSum", - "attrs": { - "keep_dims": { - "type": "bool", - "value": 1 - }, - "axes": { - "type": "s64[]", - "value": [ - 2 - ] - } - }, - "inputs": [ - { - "id": 49, - "dtype": "bf16", - "shape": [ - 1, - 2, - 8, - 384, - 64 - ], - "stride": [ - 393216, - 196608, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 51, - "dtype": "bf16", - "shape": [ - 1, - 2, - 1, - 384, - 64 - ], - "stride": [ - 49152, - 24576, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - } - ] -} diff --git a/tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-bwd-w-dropout-bf16-f32.json b/tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-bwd-w-dropout-bf16-f32.json deleted file mode 100644 index 8df966283c1..00000000000 --- a/tests/benchdnn/inputs/graph/complex_fusion/mha/gqa-plain-training-bwd-w-dropout-bf16-f32.json +++ /dev/null @@ -1,1226 +0,0 @@ -{ - "version": "3.11.0", - "engine_kind": "cpu", - "fpmath_mode": "strict", - "fpmath_mode_apply_to_int": "false", - "input_ports": [ - 200, - 201, - 202, - 203, - 8, - 205, - 206, - 207, - 105, - 105, - 204, - 205, - 206, - 207, - 202, - 201, - 200 - ], - "output_ports": [ - 17, - 29, - 33 - ], - "graph": [ - { - "id": 2, - "name": "bmm1", - "kind": "MatMul", - "attrs": { - "accumulation_mode": { - "type": "string", - "value": "strict" - }, - "transpose_b": { - "type": "bool", - "value": 1 - }, - "transpose_a": { - "type": "bool", - "value": 0 - } - }, - "inputs": [ - { - "id": 200, - "dtype": "bf16", - "shape": [ - 2, - 2, - 8, - 128, - 64 - ], - "stride": [ - 131072, - 65536, - 8192, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 201, - "dtype": "bf16", - "shape": [ - 2, - 2, - 1, - 128, - 64 - ], - "stride": [ - 16384, - 8192, - 8192, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 1, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 4, - "name": "scale_div", - "kind": "Divide", - "attrs": { - "auto_broadcast": { - "type": "string", - "value": "numpy" - } - }, - "inputs": [ - { - "id": 1, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 202, - "dtype": "bf16", - "shape": [ - 1 - ], - "stride": [ - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 3, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 6, - "name": "mask_add", - "kind": "Add", - "attrs": { - "auto_broadcast": { - "type": "string", - "value": "numpy" - } - }, - "inputs": [ - { - "id": 3, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 203, - "dtype": "bf16", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 5, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 8, - "name": "subtract", - "kind": "Subtract", - "attrs": { - "auto_broadcast": { - "type": "string", - "value": "numpy" - } - }, - "inputs": [ - { - "id": 5, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 8, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 1 - ], - "stride": [ - 2048, - 1024, - 128, - 1, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 7, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 10, - "name": "exp", - "kind": "Exp", - "attrs": {}, - "inputs": [ - { - "id": 7, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 9, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 12, - "name": "dropout", - "kind": "Dropout", - "attrs": {}, - "inputs": [ - { - "id": 9, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 205, - "dtype": "s64", - "shape": [], - "stride": [], - "layout_type": "strided", - "property_type": "host_scalar" - }, - { - "id": 206, - "dtype": "s64", - "shape": [], - "stride": [], - "layout_type": "strided", - "property_type": "host_scalar" - }, - { - "id": 207, - "dtype": "f32", - "shape": [], - "stride": [], - "layout_type": "strided", - "property_type": "host_scalar" - } - ], - "outputs": [ - { - "id": 11, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 13, - "name": "typecast", - "kind": "TypeCast", - "attrs": {}, - "inputs": [ - { - "id": 11, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 14, - "dtype": "bf16", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 16, - "name": "bmm_dv", - "kind": "MatMul", - "attrs": { - "accumulation_mode": { - "type": "string", - "value": "strict" - }, - "transpose_b": { - "type": "bool", - "value": 0 - }, - "transpose_a": { - "type": "bool", - "value": 1 - } - }, - "inputs": [ - { - "id": 14, - "dtype": "bf16", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 105, - "dtype": "bf16", - "shape": [ - 2, - 2, - 8, - 128, - 64 - ], - "stride": [ - 131072, - 65536, - 8192, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 15, - "dtype": "bf16", - "shape": [ - 2, - 2, - 8, - 128, - 64 - ], - "stride": [ - 131072, - 65536, - 8192, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 18, - "name": "reduce_dv", - "kind": "ReduceSum", - "attrs": { - "keep_dims": { - "type": "bool", - "value": 1 - }, - "axes": { - "type": "s64[]", - "value": [ - 2 - ] - } - }, - "inputs": [ - { - "id": 15, - "dtype": "bf16", - "shape": [ - 2, - 2, - 8, - 128, - 64 - ], - "stride": [ - 131072, - 65536, - 8192, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 17, - "dtype": "bf16", - "shape": [ - 2, - 2, - 1, - 128, - 64 - ], - "stride": [ - 16384, - 8192, - 8192, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 20, - "name": "bmm_dprobs", - "kind": "MatMul", - "attrs": { - "accumulation_mode": { - "type": "string", - "value": "strict" - }, - "transpose_b": { - "type": "bool", - "value": 1 - }, - "transpose_a": { - "type": "bool", - "value": 0 - } - }, - "inputs": [ - { - "id": 105, - "dtype": "bf16", - "shape": [ - 2, - 2, - 8, - 128, - 64 - ], - "stride": [ - 131072, - 65536, - 8192, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 204, - "dtype": "bf16", - "shape": [ - 2, - 2, - 1, - 128, - 64 - ], - "stride": [ - 16384, - 8192, - 8192, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 19, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 22, - "name": "dropout", - "kind": "Dropout", - "attrs": {}, - "inputs": [ - { - "id": 19, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 205, - "dtype": "s64", - "shape": [], - "stride": [], - "layout_type": "strided", - "property_type": "host_scalar" - }, - { - "id": 206, - "dtype": "s64", - "shape": [], - "stride": [], - "layout_type": "strided", - "property_type": "host_scalar" - }, - { - "id": 207, - "dtype": "f32", - "shape": [], - "stride": [], - "layout_type": "strided", - "property_type": "host_scalar" - } - ], - "outputs": [ - { - "id": 21, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 24, - "name": "softmax_bwd", - "kind": "SoftMaxBackward", - "attrs": { - "axis": { - "type": "s64", - "value": -1 - } - }, - "inputs": [ - { - "id": 21, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 9, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 23, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 26, - "name": "scale_div", - "kind": "Divide", - "attrs": { - "auto_broadcast": { - "type": "string", - "value": "numpy" - } - }, - "inputs": [ - { - "id": 23, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 202, - "dtype": "bf16", - "shape": [ - 1 - ], - "stride": [ - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 25, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 27, - "name": "typecast", - "kind": "TypeCast", - "attrs": {}, - "inputs": [ - { - "id": 25, - "dtype": "f32", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 28, - "dtype": "bf16", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 30, - "name": "bmm_dq", - "kind": "MatMul", - "attrs": { - "transpose_b": { - "type": "bool", - "value": 0 - }, - "transpose_a": { - "type": "bool", - "value": 0 - }, - "accumulation_mode": { - "type": "string", - "value": "strict" - } - }, - "inputs": [ - { - "id": 28, - "dtype": "bf16", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 201, - "dtype": "bf16", - "shape": [ - 2, - 2, - 1, - 128, - 64 - ], - "stride": [ - 16384, - 8192, - 8192, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 29, - "dtype": "bf16", - "shape": [ - 2, - 2, - 8, - 128, - 64 - ], - "stride": [ - 131072, - 65536, - 8192, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 32, - "name": "bmm_dk", - "kind": "MatMul", - "attrs": { - "accumulation_mode": { - "type": "string", - "value": "strict" - }, - "transpose_b": { - "type": "bool", - "value": 0 - }, - "transpose_a": { - "type": "bool", - "value": 1 - } - }, - "inputs": [ - { - "id": 28, - "dtype": "bf16", - "shape": [ - 2, - 2, - 8, - 128, - 128 - ], - "stride": [ - 262144, - 131072, - 16384, - 128, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 200, - "dtype": "bf16", - "shape": [ - 2, - 2, - 8, - 128, - 64 - ], - "stride": [ - 131072, - 65536, - 8192, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 31, - "dtype": "bf16", - "shape": [ - 2, - 2, - 8, - 128, - 64 - ], - "stride": [ - 131072, - 65536, - 8192, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 34, - "name": "reduce_dk", - "kind": "ReduceSum", - "attrs": { - "keep_dims": { - "type": "bool", - "value": 1 - }, - "axes": { - "type": "s64[]", - "value": [ - 2 - ] - } - }, - "inputs": [ - { - "id": 31, - "dtype": "bf16", - "shape": [ - 2, - 2, - 8, - 128, - 64 - ], - "stride": [ - 131072, - 65536, - 8192, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 33, - "dtype": "bf16", - "shape": [ - 2, - 2, - 1, - 128, - 64 - ], - "stride": [ - 16384, - 8192, - 8192, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - } - ] -} diff --git a/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-training-backward-w-dmask-bf16-f32.json b/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-training-backward-w-dmask-bf16-f32.json deleted file mode 100644 index 5c0a07e72e5..00000000000 --- a/tests/benchdnn/inputs/graph/complex_fusion/mha/sdpa-plain-training-backward-w-dmask-bf16-f32.json +++ /dev/null @@ -1,910 +0,0 @@ -{ - "version": "3.11.0", - "engine_kind": "cpu", - "fpmath_mode": "strict", - "fpmath_mode_apply_to_int": "false", - "input_ports": [ - 16, - 17, - 20, - 23, - 26, - 33, - 33, - 36, - 20, - 17, - 16 - ], - "output_ports": [ - 34, - 39, - 45, - 47 - ], - "graph": [ - { - "id": 19, - "name": "bmm_qk", - "kind": "MatMul", - "attrs": { - "transpose_b": { - "type": "bool", - "value": 1 - }, - "transpose_a": { - "type": "bool", - "value": 0 - }, - "accumulation_mode": { - "type": "string", - "value": "strict" - } - }, - "inputs": [ - { - "id": 16, - "dtype": "bf16", - "shape": [ - 1, - 16, - 384, - 64 - ], - "stride": [ - 393216, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 17, - "dtype": "bf16", - "shape": [ - 1, - 16, - 384, - 64 - ], - "stride": [ - 393216, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 18, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 22, - "name": "scale_div", - "kind": "Divide", - "attrs": { - "auto_broadcast": { - "type": "string", - "value": "numpy" - } - }, - "inputs": [ - { - "id": 18, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 20, - "dtype": "f32", - "shape": [ - 1 - ], - "stride": [ - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 21, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 25, - "name": "mask_add", - "kind": "Add", - "attrs": { - "auto_broadcast": { - "type": "string", - "value": "numpy" - } - }, - "inputs": [ - { - "id": 21, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 23, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 24, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 28, - "name": "softmax_sub", - "kind": "Subtract", - "attrs": { - "auto_broadcast": { - "type": "string", - "value": "numpy" - } - }, - "inputs": [ - { - "id": 24, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 26, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 1 - ], - "stride": [ - 6144, - 384, - 1, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 27, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 30, - "name": "softmax_exp", - "kind": "Exp", - "attrs": {}, - "inputs": [ - { - "id": 27, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 29, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 31, - "name": "typecast", - "kind": "TypeCast", - "attrs": {}, - "inputs": [ - { - "id": 29, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 32, - "dtype": "bf16", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 35, - "name": "bmm_dv", - "kind": "MatMul", - "attrs": { - "transpose_a": { - "type": "bool", - "value": 1 - }, - "transpose_b": { - "type": "bool", - "value": 0 - }, - "accumulation_mode": { - "type": "string", - "value": "strict" - } - }, - "inputs": [ - { - "id": 32, - "dtype": "bf16", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 33, - "dtype": "bf16", - "shape": [ - 1, - 16, - 384, - 64 - ], - "stride": [ - 393216, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 34, - "dtype": "bf16", - "shape": [ - 1, - 16, - 384, - 64 - ], - "stride": [ - 393216, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 38, - "name": "bmm_do_v", - "kind": "MatMul", - "attrs": { - "transpose_b": { - "type": "bool", - "value": 1 - }, - "transpose_a": { - "type": "bool", - "value": 0 - }, - "accumulation_mode": { - "type": "string", - "value": "strict" - } - }, - "inputs": [ - { - "id": 33, - "dtype": "bf16", - "shape": [ - 1, - 16, - 384, - 64 - ], - "stride": [ - 393216, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 36, - "dtype": "bf16", - "shape": [ - 1, - 16, - 384, - 64 - ], - "stride": [ - 393216, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 37, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 40, - "name": "softmax_bwd", - "kind": "SoftMaxBackward", - "attrs": { - "axis": { - "type": "s64", - "value": -1 - } - }, - "inputs": [ - { - "id": 37, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 29, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 39, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 45, - "name": "end", - "kind": "End", - "attrs": {}, - "inputs": [ - { - "id": 39, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [] - }, - { - "id": 42, - "name": "scale_div", - "kind": "Divide", - "attrs": { - "auto_broadcast": { - "type": "string", - "value": "numpy" - } - }, - "inputs": [ - { - "id": 39, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 20, - "dtype": "f32", - "shape": [ - 1 - ], - "stride": [ - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 41, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 43, - "name": "typecast", - "kind": "TypeCast", - "attrs": {}, - "inputs": [ - { - "id": 41, - "dtype": "f32", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 44, - "dtype": "bf16", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 46, - "name": "bmm_dq", - "kind": "MatMul", - "attrs": { - "transpose_b": { - "type": "bool", - "value": 0 - }, - "transpose_a": { - "type": "bool", - "value": 0 - }, - "accumulation_mode": { - "type": "string", - "value": "strict" - } - }, - "inputs": [ - { - "id": 44, - "dtype": "bf16", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 17, - "dtype": "bf16", - "shape": [ - 1, - 16, - 384, - 64 - ], - "stride": [ - 393216, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 45, - "dtype": "bf16", - "shape": [ - 1, - 16, - 384, - 64 - ], - "stride": [ - 393216, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - }, - { - "id": 48, - "name": "bmm_dk", - "kind": "MatMul", - "attrs": { - "transpose_a": { - "type": "bool", - "value": 1 - }, - "transpose_b": { - "type": "bool", - "value": 0 - }, - "accumulation_mode": { - "type": "string", - "value": "strict" - } - }, - "inputs": [ - { - "id": 44, - "dtype": "bf16", - "shape": [ - 1, - 16, - 384, - 384 - ], - "stride": [ - 2359296, - 147456, - 384, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - }, - { - "id": 16, - "dtype": "bf16", - "shape": [ - 1, - 16, - 384, - 64 - ], - "stride": [ - 393216, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ], - "outputs": [ - { - "id": 47, - "dtype": "bf16", - "shape": [ - 1, - 16, - 384, - 64 - ], - "stride": [ - 393216, - 24576, - 64, - 1 - ], - "layout_type": "strided", - "property_type": "undef" - } - ] - } - ] -} From 5638f5f3326428f4b7c48941f380f640f6e47631 Mon Sep 17 00:00:00 2001 From: "Bao, Yixin" Date: Thu, 12 Mar 2026 08:57:17 +0000 Subject: [PATCH 23/23] graph: backend: dnnl: fix code format --- src/graph/backend/dnnl/executables/sdpa.cpp | 7 +- .../dnnl/kernels/sdp_bwd_primitive.cpp | 9 +- .../dnnl/kernels/sdp_primitive_config.cpp | 4 +- src/graph/backend/dnnl/layout_propagator.cpp | 24 ++---- src/graph/backend/dnnl/passes/compile_ops.cpp | 2 +- src/graph/backend/dnnl/passes/insert_ops.cpp | 43 +++++----- src/graph/backend/dnnl/passes/transform.cpp | 84 ++++++++----------- src/graph/backend/dnnl/patterns/sdp.cpp | 4 +- 8 files changed, 74 insertions(+), 103 deletions(-) diff --git a/src/graph/backend/dnnl/executables/sdpa.cpp b/src/graph/backend/dnnl/executables/sdpa.cpp index 7a6a5e1188f..c8834a1494f 100644 --- a/src/graph/backend/dnnl/executables/sdpa.cpp +++ b/src/graph/backend/dnnl/executables/sdpa.cpp @@ -404,7 +404,8 @@ sdpa_bwd_executable_t::sdpa_bwd_executable_t(std::shared_ptr &op, status_t s = create_sdpa_pd(hint_fwd_pd, p_engine.get(), md_q.get(), md_k.get(), md_v.get(), md_dst.get(), md_attn_mask.get(), md_scale.get(), is_invert_scale_, kv_head_number, mask_type_, - softmax_alg, impl::prop_kind::forward_training, attr.get(), qk_attr.get(), vs_attr.get()); + softmax_alg, impl::prop_kind::forward_training, attr.get(), + qk_attr.get(), vs_attr.get()); if (s != dnnl::impl::status::success) { is_initialized_ = false; return; @@ -423,7 +424,6 @@ sdpa_bwd_executable_t::sdpa_bwd_executable_t(std::shared_ptr &op, s = sdpa_bwd_pd_->create_primitive(sdpa_bwd_prim_, p_engine.get()); is_initialized_ = s == status::success; } - } void sdpa_bwd_executable_t::execute(const stream &stream, @@ -466,8 +466,7 @@ void sdpa_bwd_executable_t::execute(const stream &stream, // Set up scratchpad grantor required by the primitive's execute const memory_storage_t *mem_storage = nullptr; memory_t *scratchpad_memory = ctx.output(DNNL_ARG_SCRATCHPAD); - if (scratchpad_memory) - mem_storage = scratchpad_memory->memory_storage(); + if (scratchpad_memory) mem_storage = scratchpad_memory->memory_storage(); const void *host_ptr = ctx.host_ptr(mem_storage, /* require_host_ptr = */ true); auto *scratchpad_grantor diff --git a/src/graph/backend/dnnl/kernels/sdp_bwd_primitive.cpp b/src/graph/backend/dnnl/kernels/sdp_bwd_primitive.cpp index 822c2b2285a..a9aead77a55 100644 --- a/src/graph/backend/dnnl/kernels/sdp_bwd_primitive.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_bwd_primitive.cpp @@ -51,8 +51,7 @@ status_t sdp_bwd_primitive_kernel_t::initial_check( const std::vector &inputs, const std::vector &outputs) { const bool is_f32 = inputs[0].data_type == data_type::f32; - VCHECK_SDP_BWD_PRIMITIVE(!is_f32, - status::unimplemented, + VCHECK_SDP_BWD_PRIMITIVE(!is_f32, status::unimplemented, "SDPA bwd primitive doesn't support f32 because of performance"); bool has_dropout = false; @@ -63,8 +62,7 @@ status_t sdp_bwd_primitive_kernel_t::initial_check( break; } } - VCHECK_SDP_BWD_PRIMITIVE(!has_dropout, - status::unimplemented, + VCHECK_SDP_BWD_PRIMITIVE(!has_dropout, status::unimplemented, "SDPA bwd primitive doesn't support Dropout for now"); bool has_host_scalar = false; @@ -74,8 +72,7 @@ status_t sdp_bwd_primitive_kernel_t::initial_check( break; } } - VCHECK_SDP_BWD_PRIMITIVE(!has_host_scalar, - status::unimplemented, + VCHECK_SDP_BWD_PRIMITIVE(!has_host_scalar, status::unimplemented, "SDPA bwd primitive doesn't support host scalar inputs for now"); return status::success; diff --git a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp index ebd9e988ca3..d0356fc2978 100644 --- a/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp +++ b/src/graph/backend/dnnl/kernels/sdp_primitive_config.cpp @@ -56,8 +56,8 @@ status_t sdp_primitive_config_t::initial_check( && opk != graph::op_kind::Quantize, status::unimplemented, "Not support quantized SDPA"); // SDPA with Dropout is currently unsupported in the ukernel. - VCHECK_SDP_PRIMITIVE(opk != graph::op_kind::Dropout, status::unimplemented, - "Not support SDPA with Dropout"); + VCHECK_SDP_PRIMITIVE(opk != graph::op_kind::Dropout, + status::unimplemented, "Not support SDPA with Dropout"); if (opk == graph::op_kind::GenIndex) { has_genindex = true; } } diff --git a/src/graph/backend/dnnl/layout_propagator.cpp b/src/graph/backend/dnnl/layout_propagator.cpp index 62a2fcfdef0..47516b3504a 100644 --- a/src/graph/backend/dnnl/layout_propagator.cpp +++ b/src/graph/backend/dnnl/layout_propagator.cpp @@ -1836,8 +1836,7 @@ status_t layout_propagator_for_sdpa_bwd(std::shared_ptr &op, // forward input logical tensor. If the input layout is already fixed, reuse // it; otherwise fall back to the canonical acbd format used by sdpa. auto get_md_for_diff = [](const logical_tensor_t <) { - if (!ltw(lt).is_any()) - return make_dnnl_memory_desc(lt); + if (!ltw(lt).is_any()) return make_dnnl_memory_desc(lt); return dnnl::memory::desc {ltw(lt).vdims(), static_cast(ltw(lt).data_type()), dnnl::memory::format_tag::acbd}; @@ -1872,12 +1871,10 @@ status_t layout_propagator_for_sdpa_bwd(std::shared_ptr &op, const bool with_scale = op->get_attr(op_attr::with_scale); const auto mask_type = static_cast( op->get_attr(op_attr::mask_type)); - const bool is_invert_scale - = op->has_attr(op_attr::is_invert_scale) + const bool is_invert_scale = op->has_attr(op_attr::is_invert_scale) ? op->get_attr(op_attr::is_invert_scale) : false; - const bool with_explicit_mask - = mask_type == attn_mask_type::buffer; + const bool with_explicit_mask = mask_type == attn_mask_type::buffer; auto md_q = make_dnnl_memory_desc(op->get_input_logical_tensor(0)); auto md_k = make_dnnl_memory_desc(op->get_input_logical_tensor(1)); @@ -1898,8 +1895,7 @@ status_t layout_propagator_for_sdpa_bwd(std::shared_ptr &op, md_attn_mask = make_dnnl_memory_desc( op->get_input_logical_tensor(idx++)); if (op->num_outputs() > 4) - md_dS = make_dnnl_memory_desc( - op->get_output_logical_tensor(4)); + md_dS = make_dnnl_memory_desc(op->get_output_logical_tensor(4)); } const auto &sdpa_fusion_info = op->has_attr(op_attr::fusion_info) @@ -1917,18 +1913,17 @@ status_t layout_propagator_for_sdpa_bwd(std::shared_ptr &op, vs_attr.set_accumulation_mode(str2accumulation_mode( op->get_attr(op_attr::vs_acc_mode))); attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); - attr.set_fpmath_mode( - static_cast(fpmath.mode_)); + attr.set_fpmath_mode(static_cast(fpmath.mode_)); dim_t kv_head_number = op->get_input_logical_tensor(1).dims[1]; - const alg_kind_t softmax_alg - = alg_kind::softmax_accurate_inf_as_zero; + const alg_kind_t softmax_alg = alg_kind::softmax_accurate_inf_as_zero; std::shared_ptr hint_fwd_pd; status = create_sdpa_pd(hint_fwd_pd, p_engine.get(), md_q.get(), md_k.get(), md_v.get(), md_dst.get(), md_attn_mask.get(), md_scale.get(), is_invert_scale, kv_head_number, mask_type, - softmax_alg, impl::prop_kind::forward_training, attr.get(), qk_attr.get(), vs_attr.get()); + softmax_alg, impl::prop_kind::forward_training, attr.get(), + qk_attr.get(), vs_attr.get()); VCHECK_LAYOUT_PROPAGATOR(status == status::success, status, "failed to create hint fwd pd for sdpa_bwd scratchpad"); @@ -1938,8 +1933,7 @@ status_t layout_propagator_for_sdpa_bwd(std::shared_ptr &op, md_diff_k.get(), md_diff_v.get(), md_diff_dst.get(), md_dS.get(), md_attn_mask.get(), md_scale.get(), is_invert_scale, kv_head_number, mask_type, softmax_alg, - attr.get(), hint_fwd_pd.get(), qk_attr.get(), - vs_attr.get()); + attr.get(), hint_fwd_pd.get(), qk_attr.get(), vs_attr.get()); VCHECK_LAYOUT_PROPAGATOR(status == status::success, status, "failed to create pd for sdpa_bwd scratchpad"); diff --git a/src/graph/backend/dnnl/passes/compile_ops.cpp b/src/graph/backend/dnnl/passes/compile_ops.cpp index f5aa642fe58..82142a5c956 100644 --- a/src/graph/backend/dnnl/passes/compile_ops.cpp +++ b/src/graph/backend/dnnl/passes/compile_ops.cpp @@ -69,7 +69,7 @@ status_t compile_ops(std::shared_ptr &sg) { "failed to create executable for op %s", op->get_name().c_str()); } - + sg->execs_.emplace_back(exec); sg->is_constant_.push_back(op->has_attr(op_attr::is_constant) diff --git a/src/graph/backend/dnnl/passes/insert_ops.cpp b/src/graph/backend/dnnl/passes/insert_ops.cpp index 641dc3c7ec4..029df32b103 100644 --- a/src/graph/backend/dnnl/passes/insert_ops.cpp +++ b/src/graph/backend/dnnl/passes/insert_ops.cpp @@ -666,9 +666,8 @@ status_t insert_reshape_for_sdpa(std::shared_ptr &sg) { // Insert reshape for optional stats output (output 2) if (cur_op->get_attr(op_attr::is_training)) { auto stats_dims = ltw(cur_op->get_output_logical_tensor(2)).vdims(); - dims expected_stats_dims = stats_dims; - op_ptr reshape_stats - = std::make_shared(op_kind::_reshape); + const dims &expected_stats_dims = stats_dims; + op_ptr reshape_stats = std::make_shared(op_kind::_reshape); reshape_stats->set_attr(op_attr::special_zero, false); reshape_stats->set_attr>( op_attr::shape, expected_stats_dims); @@ -728,8 +727,7 @@ status_t insert_reshape_for_sdpa_bwd(std::shared_ptr &sg) { size_t index = 6; // Insert reshape for scale (optional) if (cur_op->get_attr(op_attr::with_scale)) { - int32_t scale_ndims - = cur_op->get_input_logical_tensor(index).ndims; + int32_t scale_ndims = cur_op->get_input_logical_tensor(index).ndims; if (scale_ndims == 5) { auto scale_dims = ltw(cur_op->get_input_logical_tensor(index)).vdims(); @@ -741,9 +739,8 @@ status_t insert_reshape_for_sdpa_bwd(std::shared_ptr &sg) { // Insert reshape for mask (optional) if (cur_op->get_attr(op_attr::mask_type) == static_cast(attn_mask_type::buffer)) { - int32_t mask_ndims - = cur_op->get_input_logical_tensor(index).ndims; - + int32_t mask_ndims = cur_op->get_input_logical_tensor(index).ndims; + if (mask_ndims == 5) { auto mask_dims = ltw(cur_op->get_input_logical_tensor(index)).vdims(); @@ -753,10 +750,10 @@ status_t insert_reshape_for_sdpa_bwd(std::shared_ptr &sg) { } // Insert reshape for diff_query output (output 0) -> 4D to 5D - auto diff_query_dims = ltw(cur_op->get_output_logical_tensor(0)).vdims(); + auto diff_query_dims + = ltw(cur_op->get_output_logical_tensor(0)).vdims(); const dims &expected_diff_query_dims = diff_query_dims; - op_ptr reshape_diff_query - = std::make_shared(op_kind::_reshape); + op_ptr reshape_diff_query = std::make_shared(op_kind::_reshape); reshape_diff_query->set_attr(op_attr::special_zero, false); reshape_diff_query->set_attr>( op_attr::shape, expected_diff_query_dims); @@ -772,10 +769,10 @@ status_t insert_reshape_for_sdpa_bwd(std::shared_ptr &sg) { rewriter.insert_op_after(reshape_diff_key, cur_op, 1); // Insert reshape for diff_value output (output 2) -> 4D to 5D - auto diff_value_dims = ltw(cur_op->get_output_logical_tensor(2)).vdims(); + auto diff_value_dims + = ltw(cur_op->get_output_logical_tensor(2)).vdims(); const dims &expected_diff_value_dims = diff_value_dims; - op_ptr reshape_diff_value - = std::make_shared(op_kind::_reshape); + op_ptr reshape_diff_value = std::make_shared(op_kind::_reshape); reshape_diff_value->set_attr(op_attr::special_zero, false); reshape_diff_value->set_attr>( op_attr::shape, expected_diff_value_dims); @@ -783,15 +780,15 @@ status_t insert_reshape_for_sdpa_bwd(std::shared_ptr &sg) { // Insert reshape for diff_mask output (output 4) -> 4D to 5D if (cur_op->num_outputs() > 4) { - auto diff_mask_dims - = ltw(cur_op->get_output_logical_tensor(4)).vdims(); - const dims &expected_diff_mask_dims = diff_mask_dims; - op_ptr reshape_diff_mask - = std::make_shared(op_kind::_reshape); - reshape_diff_mask->set_attr(op_attr::special_zero, false); - reshape_diff_mask->set_attr>( - op_attr::shape, expected_diff_mask_dims); - rewriter.insert_op_after(reshape_diff_mask, cur_op, 4); + auto diff_mask_dims + = ltw(cur_op->get_output_logical_tensor(4)).vdims(); + const dims &expected_diff_mask_dims = diff_mask_dims; + op_ptr reshape_diff_mask + = std::make_shared(op_kind::_reshape); + reshape_diff_mask->set_attr(op_attr::special_zero, false); + reshape_diff_mask->set_attr>( + op_attr::shape, expected_diff_mask_dims); + rewriter.insert_op_after(reshape_diff_mask, cur_op, 4); } } diff --git a/src/graph/backend/dnnl/passes/transform.cpp b/src/graph/backend/dnnl/passes/transform.cpp index 109bc0b4ac1..02d3162d61b 100644 --- a/src/graph/backend/dnnl/passes/transform.cpp +++ b/src/graph/backend/dnnl/passes/transform.cpp @@ -3280,8 +3280,7 @@ status_t decompose_softmax(std::shared_ptr &sg) { auto reduce_dst_op_out_val = f32_dst; if (need_reduction) { // create reduce_src op - auto reduce_src_op - = std::make_shared(op_kind::_reduction); + auto reduce_src_op = std::make_shared(op_kind::_reduction); reduce_src_op->set_attr>( op_attr::axes, {cur_op->get_attr(op_attr::axis)}); reduce_src_op->set_attr(op_attr::keep_dims, true); @@ -3299,8 +3298,7 @@ status_t decompose_softmax(std::shared_ptr &sg) { insert_empty_scratchpad(reduce_src_op); // create reduce_dst op - auto reduce_dst_op - = std::make_shared(op_kind::_reduction); + auto reduce_dst_op = std::make_shared(op_kind::_reduction); reduce_dst_op->set_attr>( op_attr::axes, {cur_op->get_attr(op_attr::axis)}); reduce_dst_op->set_attr(op_attr::keep_dims, true); @@ -4957,7 +4955,7 @@ status_t fuse_sdpa_bwd(std::shared_ptr &sg) { if (cur_op->get_kind() != op_kind::_matmul) continue; // Step 1 – walk matmul_qk → [scale_pre] → [mask] → sub → exp - op_ptr matmul_qk = cur_op; + const op_ptr &matmul_qk = cur_op; op_ptr scale_pre = nullptr, mask_op = nullptr; op_ptr sub_op = nullptr, exp_op = nullptr; @@ -4981,8 +4979,7 @@ status_t fuse_sdpa_bwd(std::shared_ptr &sg) { w = sole_consumer(w); } } - if (!exp_op || !sub_op) - continue; + if (!exp_op || !sub_op) continue; // stats tensor feeds Subtract at input 1 value_ptr stats_val = sub_op->get_input_value(1); @@ -5023,13 +5020,11 @@ status_t fuse_sdpa_bwd(std::shared_ptr &sg) { permute_p = sole_consumer(permute_p); } } - if (!permute_p || permute_p->get_kind() != op_kind::_permute) - continue; + if (!permute_p || permute_p->get_kind() != op_kind::_permute) continue; matmul_dv = sole_consumer(permute_p); - if (!matmul_dv || !softmax_bwd) - continue; + if (!matmul_dv || !softmax_bwd) continue; // Optional reduce after matmul_dv (e.g., GQA dV accumulation) op_ptr reduce_dv = nullptr; @@ -5045,17 +5040,14 @@ status_t fuse_sdpa_bwd(std::shared_ptr &sg) { // correction = ReduceSum(o_do) // o_do = Mul(O, dO) value_ptr dp_corr_val = softmax_bwd->get_input_value(1); - if (!dp_corr_val->has_producer()) - continue; + if (!dp_corr_val->has_producer()) continue; op_ptr dp_corrected_op = dp_corr_val->get_producer().shared_from_this(); - if (!is_binary(dp_corrected_op, dnnl::algorithm::binary_sub)) - continue; + if (!is_binary(dp_corrected_op, dnnl::algorithm::binary_sub)) continue; // dP side (input 0): matmul_v_do, optionally via Dropout value_ptr dP_val = dp_corrected_op->get_input_value(0); - if (!dP_val->has_producer()) - continue; + if (!dP_val->has_producer()) continue; op_ptr dP_prod = dP_val->get_producer().shared_from_this(); @@ -5065,12 +5057,10 @@ status_t fuse_sdpa_bwd(std::shared_ptr &sg) { } else if (dP_prod->get_kind() == op_kind::_dropout) { dropout_bwd = dP_prod; value_ptr mm_out = dropout_bwd->get_input_value(0); - if (!mm_out->has_producer()) - continue; + if (!mm_out->has_producer()) continue; auto mm_prod = mm_out->get_producer().shared_from_this(); - if (mm_prod->get_kind() != op_kind::_matmul) - continue; + if (mm_prod->get_kind() != op_kind::_matmul) continue; matmul_vt_do = mm_prod; } else { @@ -5080,26 +5070,23 @@ status_t fuse_sdpa_bwd(std::shared_ptr &sg) { // get permute before matmul_vt_do op_ptr permute_v = nullptr; if (matmul_vt_do->get_input_value(1)->has_producer()) { - permute_v = matmul_vt_do->get_input_value(1)->get_producer().shared_from_this(); - if (permute_v->get_kind() != op_kind::_permute) - continue; + permute_v = matmul_vt_do->get_input_value(1) + ->get_producer() + .shared_from_this(); + if (permute_v->get_kind() != op_kind::_permute) continue; } // correction side (input 1): ReduceSum → o_do = Mul(O, dO) value_ptr corr_val = dp_corrected_op->get_input_value(1); - if (!corr_val->has_producer()) - continue; + if (!corr_val->has_producer()) continue; op_ptr correction_op = corr_val->get_producer().shared_from_this(); - if (correction_op->get_kind() != op_kind::_reduction) - continue; + if (correction_op->get_kind() != op_kind::_reduction) continue; value_ptr o_do_out = correction_op->get_input_value(0); - if (!o_do_out->has_producer()) - continue; + if (!o_do_out->has_producer()) continue; op_ptr o_do_op = o_do_out->get_producer().shared_from_this(); - if (!is_binary(o_do_op, dnnl::algorithm::binary_mul)) - continue; + if (!is_binary(o_do_op, dnnl::algorithm::binary_mul)) continue; value_ptr O_val = o_do_op->get_input_value(0); // forward output O value_ptr dO_val = o_do_op->get_input_value(1); // diff_dst dO @@ -5139,12 +5126,10 @@ status_t fuse_sdpa_bwd(std::shared_ptr &sg) { if (permute_ds && !matmul_dk) { auto next = sole_consumer(permute_ds); - if (next && next->get_kind() == op_kind::_matmul) - matmul_dk = next; + if (next && next->get_kind() == op_kind::_matmul) matmul_dk = next; } - - if (!matmul_dq || !matmul_dk) - continue; + + if (!matmul_dq || !matmul_dk) continue; // Detect and handle the permute of K that feeds matmul_dq input 1 // (matmul_dq computes dS * permute(K), where permute transposes K) @@ -5153,8 +5138,7 @@ status_t fuse_sdpa_bwd(std::shared_ptr &sg) { auto dq_in1 = matmul_dq->get_input_value(1); if (dq_in1->has_producer()) { auto prod = dq_in1->get_producer().shared_from_this(); - if (prod->get_kind() == op_kind::_permute) - permute_k = prod; + if (prod->get_kind() == op_kind::_permute) permute_k = prod; } } @@ -5169,7 +5153,8 @@ status_t fuse_sdpa_bwd(std::shared_ptr &sg) { // Optional reduce after matmul_dk (e.g., GQA dK accumulation) op_ptr reduce_dk = nullptr; { - auto next = transpose_dk? sole_consumer(transpose_dk): sole_consumer(matmul_dk); + auto next = transpose_dk ? sole_consumer(transpose_dk) + : sole_consumer(matmul_dk); if (next && next->get_kind() == op_kind::_reduction) reduce_dk = next; } @@ -5217,12 +5202,9 @@ status_t fuse_sdpa_bwd(std::shared_ptr &sg) { // 1: K (original transposed key from matmul_qk) auto Kv = matmul_qk->get_input_value(1); Kv->remove_consumer(*matmul_qk, 1); - if (permute_k) { - Kv->remove_consumer(*permute_k, 0); - } + if (permute_k) { Kv->remove_consumer(*permute_k, 0); } bwd_op->connect_input(1, Kv); - // 2: V V_val->remove_consumer(*permute_v, 0); bwd_op->connect_input(2, V_val); @@ -5263,7 +5245,8 @@ status_t fuse_sdpa_bwd(std::shared_ptr &sg) { // 1: dK (possibly through optional reduce) auto dK_val = reduce_dk ? reduce_dk->get_output_value(0) - : transpose_dk ? transpose_dk->get_output_value(0) : matmul_dk->get_output_value(0); + : transpose_dk ? transpose_dk->get_output_value(0) + : matmul_dk->get_output_value(0); dK_val->set_producer(*bwd_op); bwd_op->connect_output(1, dK_val); @@ -5287,12 +5270,13 @@ status_t fuse_sdpa_bwd(std::shared_ptr &sg) { } // Remove all pattern ops - std::vector to_remove = {matmul_qk, sub_op, exp_op, matmul_dv, - matmul_vt_do, o_do_op, correction_op, dp_corrected_op, - softmax_bwd, matmul_dq, matmul_dk, permute_p, permute_ds, permute_v}; + std::vector to_remove + = {matmul_qk, sub_op, exp_op, matmul_dv, matmul_vt_do, o_do_op, + correction_op, dp_corrected_op, softmax_bwd, matmul_dq, + matmul_dk, permute_p, permute_ds, permute_v}; for (auto *opt : {&scale_pre, &mask_op, &dropout_fwd, &tc_fwd, - &dropout_bwd, &scale_post, &end_op, &tc_bwd, - &reduce_dv, &reduce_dk, &permute_k, &transpose_dk}) + &dropout_bwd, &scale_post, &end_op, &tc_bwd, &reduce_dv, + &reduce_dk, &permute_k, &transpose_dk}) if (*opt) to_remove.push_back(*opt); for (auto &op : to_remove) diff --git a/src/graph/backend/dnnl/patterns/sdp.cpp b/src/graph/backend/dnnl/patterns/sdp.cpp index 531dd0f1604..8138c6b3dc0 100644 --- a/src/graph/backend/dnnl/patterns/sdp.cpp +++ b/src/graph/backend/dnnl/patterns/sdp.cpp @@ -267,7 +267,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_sdp_backward_fusion) auto matmul_dk = pgraph->append_op( graph::op_kind::MatMul, {in_edge(0, tc2, 0)}); auto matmul_dq = pgraph->append_op( - graph::op_kind::MatMul, {in_edge(0, tc2, 0)}); + graph::op_kind::MatMul, {in_edge(0, tc2, 0)}); // Q is a shared input for matmul_qk and matmul_dk pgraph->create_input_port(0, matmul_qk, 0); pgraph->create_input_port(0, matmul_dk, 1); @@ -630,7 +630,7 @@ DNNL_BACKEND_REGISTER_PATTERN_MATCHER_PASS(dnnl, float_gqa_backward_fusion) graph::op_kind::MatMul, {in_edge(0, tc2, 0)}); auto matmul_dq = pgraph->append_op( graph::op_kind::MatMul, {in_edge(0, tc2, 0)}); - + // reduction_dk pgraph->append_op(graph::op_kind::ReduceSum, {in_edge(0, matmul_dk, 0)});