diff --git a/src/common/memory_tracking.hpp b/src/common/memory_tracking.hpp index 10f07e9292d..4a9e7eea35a 100644 --- a/src/common/memory_tracking.hpp +++ b/src/common/memory_tracking.hpp @@ -325,6 +325,10 @@ enum { key_rnn_ptrs_wei_iter, key_rnn_ptrs_wei_projection, key_shuffle_precompute_transpose, + key_sdpa_Di, + key_sdpa_dQ_reduction, + key_sdpa_dK_reduction, + key_sdpa_dV_reduction, key_softmax_dst_scales, key_softmax_reduction, key_softmax_interim_store, diff --git a/src/common/primitive_hashing.cpp b/src/common/primitive_hashing.cpp index db03637f6cd..48dfcde38b9 100644 --- a/src/common/primitive_hashing.cpp +++ b/src/common/primitive_hashing.cpp @@ -734,6 +734,7 @@ size_t get_desc_hash(const sdpa_desc_t &desc) { size_t seed = 0; // Kinds seed = hash_combine(seed, static_cast(desc.primitive_kind)); + seed = hash_combine(seed, static_cast(desc.prop_kind)); // Memory descriptors seed = hash_combine(seed, get_md_hash(desc.q_desc)); seed = hash_combine(seed, get_md_hash(desc.k_desc)); @@ -742,7 +743,12 @@ size_t get_desc_hash(const sdpa_desc_t &desc) { seed = hash_combine(seed, desc.kq_zero_points.get_hash()); seed = hash_combine(seed, desc.vs_scales.get_hash()); seed = hash_combine(seed, desc.vs_zero_points.get_hash()); + seed = hash_combine(seed, get_md_hash(desc.dS_desc)); seed = hash_combine(seed, get_md_hash(desc.dst_desc)); + seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc)); + seed = hash_combine(seed, get_md_hash(desc.diff_q_desc)); + seed = hash_combine(seed, get_md_hash(desc.diff_k_desc)); + seed = hash_combine(seed, get_md_hash(desc.diff_v_desc)); seed = hash_combine(seed, get_md_hash(desc.attn_mask_desc)); seed = hash_combine(seed, get_md_hash(desc.scale_desc)); // Scale type diff --git a/src/common/primitive_serialization.cpp b/src/common/primitive_serialization.cpp index c2d3f85293f..fa8ad97af4d 100644 --- a/src/common/primitive_serialization.cpp +++ b/src/common/primitive_serialization.cpp @@ -581,6 +581,7 @@ void serialize(serialization_stream_t &sstream, const sum_desc_t &desc) { void serialize(serialization_stream_t &sstream, const sdpa_desc_t &desc) { // Kind sstream.append(desc.primitive_kind); + sstream.append(desc.prop_kind); serialize(sstream, desc.q_desc); serialize(sstream, desc.k_desc); serialize(sstream, desc.v_desc); @@ -588,7 +589,12 @@ void serialize(serialization_stream_t &sstream, const sdpa_desc_t &desc) { desc.kq_zero_points.serialize(sstream); desc.vs_scales.serialize(sstream); desc.vs_zero_points.serialize(sstream); + serialize(sstream, desc.dS_desc); serialize(sstream, desc.dst_desc); + serialize(sstream, desc.diff_dst_desc); + serialize(sstream, desc.diff_q_desc); + serialize(sstream, desc.diff_k_desc); + serialize(sstream, desc.diff_v_desc); serialize(sstream, desc.attn_mask_desc); serialize(sstream, desc.scale_desc); sstream.append(desc.kq_acc_dt); diff --git a/src/common/sdpa_pd.hpp b/src/common/sdpa_pd.hpp index dbccef487c1..c5bc1a06b48 100644 --- a/src/common/sdpa_pd.hpp +++ b/src/common/sdpa_pd.hpp @@ -37,82 +37,36 @@ namespace impl { this->info(engine), ##__VA_ARGS__) // NOLINTBEGIN(google-default-arguments) + +struct sdpa_fwd_pd_t; + struct sdpa_pd_t : public primitive_desc_t { static constexpr auto base_pkind = primitive_kind::sdpa; - using base_class = sdpa_pd_t; - using hint_class = sdpa_pd_t; + static constexpr int mask_q_index = 2; + static constexpr int mask_k_index = 3; + static constexpr int ndims = 4; const sdpa_desc_t *desc() const { return &desc_; } const op_desc_t *op_desc() const override { return reinterpret_cast(this->desc()); } - arg_usage_t arg_usage(int arg) const override { - // TODO: this is broken for cases when the user passes quantization - // memories unconditionally but the primitive desc is not set up for - // quantization. - if (utils::one_of(arg, DNNL_ARG_QUERIES, DNNL_ARG_KEYS, DNNL_ARG_VALUES, - DNNL_ARG_ATTN_MASK, DNNL_ARG_SCALE, - DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS, - DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES, - DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS, - DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES)) - return arg_usage_t::input; - - if (arg == DNNL_ARG_DST) return arg_usage_t::output; - - return primitive_desc_t::arg_usage(arg); + bool is_fwd() const { + return utils::one_of(desc_.prop_kind, prop_kind::forward_training, + prop_kind::forward_inference); } - const memory_desc_t *arg_md( - int arg, bool user_input = false) const override { - switch (arg) { - case DNNL_ARG_QUERIES: return src_md(0); - case DNNL_ARG_KEYS: return src_md(1); - case DNNL_ARG_VALUES: return src_md(2); - case DNNL_ARG_ATTN_MASK: return src_md(3); - case DNNL_ARG_DST: return dst_md(0, user_input); - default: return primitive_desc_t::arg_md(arg); - } - } - - const memory_desc_t *src_md( - int index = 0, bool user_input = false) const override { - switch (index) { - case 0: return &desc_.q_desc; - case 1: return &desc_.k_desc; - case 2: return &desc_.v_desc; - case 3: return &desc_.attn_mask_desc; - default: return &glob_zero_md; - } - } - const memory_desc_t *dst_md( - int index = 0, bool user_input = false) const override { - return index == 0 ? &desc_.dst_desc : &glob_zero_md; - } - - const memory_desc_t *qry_md() const { return &desc_.q_desc; } - const memory_desc_t *key_md() const { return &desc_.k_desc; } - const memory_desc_t *val_md() const { return &desc_.v_desc; } - const memory_desc_t *attn_mask_md() const { return &desc_.attn_mask_desc; } - const memory_desc_t *scale_md() const { return &desc_.scale_desc; } - - int n_inputs() const override { - return 3 + int(with_attn_mask()) + int(with_attn_scale()); - } - int n_outputs() const override { return 1; } - bool with_attn_scale() const { - return (scale_md()->data_type != data_type::undef); + return (desc()->scale_md()->data_type != data_type::undef); } bool with_host_scale() const { - return (scale_md()->format_kind == format_kind::host_scalar); + return (desc()->scale_md()->format_kind == format_kind::host_scalar); } bool with_attn_mask() const { - return (attn_mask_md()->data_type != data_type::undef); + return (desc()->attn_mask_md()->data_type != data_type::undef); } /// Returns the accumulation data type of the KQ matmul @@ -171,9 +125,9 @@ struct sdpa_pd_t : public primitive_desc_t { int key_group_size() const { int out = 0; if (with_key_scales()) - out = group_size(desc()->kq_scales, *key_md()); + out = group_size(desc()->kq_scales, *desc()->key_md()); else if (with_key_zp()) { - out = group_size(desc()->kq_zero_points, *key_md()); + out = group_size(desc()->kq_zero_points, *desc()->key_md()); } return out; } @@ -182,20 +136,33 @@ struct sdpa_pd_t : public primitive_desc_t { int value_group_size() const { int out = 0; if (with_value_scales()) - out = group_size(desc()->vs_scales, *val_md()); + out = group_size(desc()->vs_scales, *desc()->val_md()); else if (with_value_zp()) { - out = group_size(desc()->vs_zero_points, *val_md()); + out = group_size(desc()->vs_zero_points, *desc()->val_md()); } return out; } protected: sdpa_desc_t desc_; + const sdpa_fwd_pd_t *hint_fwd_pd_; + + memory_desc_t ws_md_; sdpa_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, - const hint_class *hint_fwd_pd) + const sdpa_fwd_pd_t *hint_fwd_pd) : primitive_desc_t(attr, base_pkind) - , desc_(*op_desc_t::to_desc(adesc)) {} + , desc_(*op_desc_t::to_desc(adesc)) + , hint_fwd_pd_(hint_fwd_pd) {} + + status_t init_default_ws() { + dims_t d; + d[0] = desc()->batch() * desc()->num_q_heads() + * desc()->queries(); // (logsumexp) per query + + return memory_desc_init_by_tag( + ws_md_, 1, d, data_type::f32, format_tag::a); + } bool set_default_format(memory_desc_t *md) { memory_desc_wrapper mdw(md); @@ -204,20 +171,6 @@ struct sdpa_pd_t : public primitive_desc_t { return true; } - bool set_default_formats() { - bool ok = true; - - for (auto md : {&desc_.q_desc, &desc_.k_desc, &desc_.v_desc, - &desc_.dst_desc}) { - ok = ok && set_default_format(md); - } - - auto status = attr_.post_ops_.set_default_formats(&desc_.dst_desc); - ok = ok && (status == status::success); - - return ok; - } - private: static int group_size( const quant_entry_t &scales, const memory_desc_t &desc) { @@ -239,6 +192,193 @@ struct sdpa_pd_t : public primitive_desc_t { return static_cast(out); } }; + +struct sdpa_fwd_pd_t : public sdpa_pd_t { + using base_class = sdpa_fwd_pd_t; + using hint_class = sdpa_fwd_pd_t; + + arg_usage_t arg_usage(int arg) const override { + // TODO: this is broken for cases when the user passes quantization + // memories unconditionally but the primitive desc is not set up for + // quantization. + if (utils::one_of(arg, DNNL_ARG_QUERIES, DNNL_ARG_KEYS, DNNL_ARG_VALUES, + DNNL_ARG_ATTN_MASK, DNNL_ARG_SCALE, + DNNL_ARG_ATTR_SCALES | DNNL_ARG_KEYS, + DNNL_ARG_ATTR_SCALES | DNNL_ARG_VALUES, + DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_KEYS, + DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_VALUES)) + return arg_usage_t::input; + + if (arg == DNNL_ARG_DST) return arg_usage_t::output; + + if (arg == DNNL_ARG_WORKSPACE) + return !types::is_zero_md(workspace_md()) ? arg_usage_t::output + : arg_usage_t::unused; + + return primitive_desc_t::arg_usage(arg); + } + + const memory_desc_t *arg_md( + int arg, bool user_input = false) const override { + switch (arg) { + case DNNL_ARG_QUERIES: return src_md(0); + case DNNL_ARG_KEYS: return src_md(1); + case DNNL_ARG_VALUES: return src_md(2); + case DNNL_ARG_ATTN_MASK: return src_md(3); + case DNNL_ARG_DST: return dst_md(0, user_input); + default: return primitive_desc_t::arg_md(arg); + } + } + + const memory_desc_t *src_md( + int index = 0, bool user_input = false) const override { + switch (index) { + case 0: return &desc_.q_desc; + case 1: return &desc_.k_desc; + case 2: return &desc_.v_desc; + case 3: return &desc_.attn_mask_desc; + default: return &glob_zero_md; + } + } + const memory_desc_t *dst_md( + int index = 0, bool user_input = false) const override { + return index == 0 ? &desc_.dst_desc : &glob_zero_md; + } + const memory_desc_t *workspace_md(int index = 0) const override { + return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ + : &glob_zero_md; + } + + int n_inputs() const override { + return 3 + int(with_attn_mask()) + int(with_attn_scale()); + } + int n_outputs() const override { + return 1 + (!types::is_zero_md(workspace_md())); + } + +protected: + sdpa_fwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, + const hint_class *hint_fwd_pd) + : sdpa_pd_t(adesc, attr, hint_fwd_pd) {} + + bool set_default_formats() { + bool ok = true; + + for (auto md : {&desc_.q_desc, &desc_.k_desc, &desc_.v_desc, + &desc_.dst_desc}) { + ok = ok && set_default_format(md); + } + + auto status = attr_.post_ops_.set_default_formats(&desc_.dst_desc); + ok = ok && (status == status::success); + + return ok; + } +}; + +struct sdpa_bwd_pd_t : public sdpa_pd_t { + using base_class = sdpa_bwd_pd_t; + using hint_class = sdpa_fwd_pd_t; + + arg_usage_t arg_usage(int arg) const override { + if (utils::one_of(arg, DNNL_ARG_QUERIES, DNNL_ARG_KEYS, DNNL_ARG_VALUES, + DNNL_ARG_DST, DNNL_ARG_DIFF_DST, DNNL_ARG_ATTN_MASK, + DNNL_ARG_SCALE)) + return arg_usage_t::input; + + if (utils::one_of(arg, DNNL_ARG_DIFF_QUERIES, DNNL_ARG_DIFF_KEYS, + DNNL_ARG_DIFF_VALUES)) + return arg_usage_t::output; + + if (arg == DNNL_ARG_DS) + return with_dS() ? arg_usage_t::output : arg_usage_t::unused; + + if (arg == DNNL_ARG_WORKSPACE) return arg_usage_t::input; + + return primitive_desc_t::arg_usage(arg); + } + + const memory_desc_t *arg_md( + int arg, bool user_input = false) const override { + switch (arg) { + case DNNL_ARG_QUERIES: return src_md(0); + case DNNL_ARG_KEYS: return src_md(1); + case DNNL_ARG_VALUES: return src_md(2); + case DNNL_ARG_ATTN_MASK: return src_md(3); + case DNNL_ARG_DST: return src_md(4); + case DNNL_ARG_DIFF_DST: return src_md(5); + case DNNL_ARG_DIFF_QUERIES: return dst_md(0, user_input); + case DNNL_ARG_DIFF_KEYS: return dst_md(1, user_input); + case DNNL_ARG_DIFF_VALUES: return dst_md(2, user_input); + case DNNL_ARG_DS: return dst_md(3, user_input); + default: return primitive_desc_t::arg_md(arg); + } + } + + const memory_desc_t *src_md( + int index = 0, bool user_input = false) const override { + switch (index) { + case 0: return &desc_.q_desc; + case 1: return &desc_.k_desc; + case 2: return &desc_.v_desc; + case 3: return &desc_.attn_mask_desc; + case 4: return &desc_.dst_desc; + case 5: return &desc_.diff_dst_desc; + default: return &glob_zero_md; + } + } + const memory_desc_t *dst_md( + int index = 0, bool user_input = false) const override { + switch (index) { + case 0: return &desc_.diff_q_desc; + case 1: return &desc_.diff_k_desc; + case 2: return &desc_.diff_v_desc; + case 3: return &desc_.dS_desc; + default: return &glob_zero_md; + } + } + const memory_desc_t *workspace_md(int index = 0) const override { + return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ + : &glob_zero_md; + } + + bool with_dS() const { + return (desc_.dS_desc.data_type != data_type::undef); + } + + int n_inputs() const override { + // Q, K, V, O, dO + return 5 + int(with_attn_mask()) + int(with_attn_scale()) + + int(!types::is_zero_md(workspace_md())); + } + int n_outputs() const override { return 3 + int(with_dS()); } + + const memory_desc_t *diff_dst_md( + int index = 0, bool user_input = false) const override { + return index == 0 ? &desc_.diff_dst_desc : &glob_zero_md; + } + +protected: + sdpa_bwd_pd_t(const op_desc_t *adesc, const primitive_attr_t *attr, + const hint_class *hint_fwd_pd) + : sdpa_pd_t(adesc, attr, hint_fwd_pd) {} + + bool set_default_formats() { + bool ok = true; + + for (auto md : {&desc_.q_desc, &desc_.k_desc, &desc_.v_desc, + &desc_.dst_desc, &desc_.diff_dst_desc, &desc_.diff_q_desc, + &desc_.diff_k_desc, &desc_.diff_v_desc}) { + ok = ok && set_default_format(md); + } + + auto status = attr_.post_ops_.set_default_formats(&desc_.dst_desc); + ok = ok && (status == status::success); + + return ok; + } +}; + // NOLINTEND(google-default-arguments) } // namespace impl diff --git a/src/common/sdpa_test_iface.cpp b/src/common/sdpa_test_iface.cpp index e04b0dcd5ff..6b6b33e1359 100644 --- a/src/common/sdpa_test_iface.cpp +++ b/src/common/sdpa_test_iface.cpp @@ -30,8 +30,8 @@ dnnl_status_t DNNL_API sdpa_primitive_desc_create( const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc, const_dnnl_memory_desc_t mask_desc, const_dnnl_memory_desc_t scale_desc, bool invert_scale, dnnl_dim_t kv_head_number, int attn_mask_type, - dnnl_alg_kind_t softmax_alg, const_dnnl_primitive_attr_t attr, - const_dnnl_primitive_attr_t kq_attr, + dnnl_alg_kind_t softmax_alg, prop_kind_t prop, + const_dnnl_primitive_attr_t attr, const_dnnl_primitive_attr_t kq_attr, const_dnnl_primitive_attr_t vs_attr) { CHECK(sdpa_desc_check(query_desc, key_desc, value_desc, dst_desc, mask_desc, engine, attr, kq_attr, vs_attr)); @@ -41,7 +41,34 @@ dnnl_status_t DNNL_API sdpa_primitive_desc_create( dnnl::impl::sdpa_desc_t sdpa_desc = dnnl::impl::create_sdpa_desc(query_desc, key_desc, value_desc, dst_desc, mask_desc, scale_desc, invert_scale, kv_head_number, static_cast(attn_mask_type), - softmax_alg, kq_attr, vs_attr); + softmax_alg, prop, kq_attr, vs_attr); return dnnl::impl::primitive_desc_create(primitive_desc_iface, engine, (const dnnl::impl::op_desc_t *)&sdpa_desc, nullptr, attr); } + +dnnl_status_t DNNL_API sdpa_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine, + const_dnnl_memory_desc_t query_desc, const_dnnl_memory_desc_t key_desc, + const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc, + const_dnnl_memory_desc_t mask_desc, const_dnnl_memory_desc_t scale_desc, + const_dnnl_memory_desc_t diff_query_desc, + const_dnnl_memory_desc_t diff_key_desc, + const_dnnl_memory_desc_t diff_value_desc, + const_dnnl_memory_desc_t diff_dst_desc, + const_dnnl_memory_desc_t dS_desc, bool invert_scale, + dnnl_dim_t kv_head_number, int attn_mask_type, + dnnl_alg_kind_t softmax_alg, const_dnnl_primitive_attr_t attr, + const_dnnl_primitive_desc_t hint_fwd_pd = nullptr) { + CHECK(sdpa_desc_check(query_desc, key_desc, value_desc, dst_desc, mask_desc, + diff_query_desc, diff_key_desc, diff_value_desc, diff_dst_desc, + engine, attr)); + CHECK(sdpa_attr_check(engine, attr)); + + dnnl::impl::sdpa_desc_t sdpa_desc = dnnl::impl::create_sdpa_desc(query_desc, + key_desc, value_desc, dst_desc, mask_desc, scale_desc, + diff_query_desc, diff_key_desc, diff_value_desc, diff_dst_desc, + dS_desc, invert_scale, kv_head_number, + static_cast(attn_mask_type), softmax_alg); + return dnnl::impl::primitive_desc_create(primitive_desc_iface, engine, + (const dnnl::impl::op_desc_t *)&sdpa_desc, hint_fwd_pd, attr); +} diff --git a/src/common/sdpa_types.hpp b/src/common/sdpa_types.hpp index 7e836812f33..4c70c8c6903 100644 --- a/src/common/sdpa_types.hpp +++ b/src/common/sdpa_types.hpp @@ -33,6 +33,11 @@ namespace impl { #define DNNL_ARG_VALUES DNNL_ARG_SRC_2 #define DNNL_ARG_ATTN_MASK DNNL_ARG_SHIFT +#define DNNL_ARG_DIFF_QUERIES DNNL_ARG_DIFF_SRC_0 +#define DNNL_ARG_DIFF_KEYS DNNL_ARG_DIFF_SRC_1 +#define DNNL_ARG_DIFF_VALUES DNNL_ARG_DIFF_SRC_2 +#define DNNL_ARG_DS DNNL_ARG_DIFF_SRC_3 + // NOLINTBEGIN(modernize-use-using) /// Types of attention mask typedef enum { @@ -66,6 +71,8 @@ struct sdpa_desc_t : public op_desc_t { return utils::make_unique(*this); } + prop_kind_t prop_kind {}; + memory_desc_t q_desc; /* queries */ memory_desc_t k_desc; /* keys */ memory_desc_t v_desc; /* values */ @@ -77,7 +84,13 @@ struct sdpa_desc_t : public op_desc_t { quant_entry_t vs_scales; quant_entry_t vs_zero_points; + memory_desc_t dS_desc; + memory_desc_t dst_desc; + memory_desc_t diff_dst_desc; + memory_desc_t diff_q_desc; + memory_desc_t diff_k_desc; + memory_desc_t diff_v_desc; memory_desc_t attn_mask_desc; memory_desc_t scale_desc; data_type_t kq_acc_dt {}; @@ -98,13 +111,20 @@ struct sdpa_desc_t : public op_desc_t { dnnl_dim_t keys() const { return k_desc.dims[k_desc.ndims - 1]; } // Number of values. dnnl_dim_t values() const { return v_desc.dims[v_desc.ndims - 1]; } - // Total batch size. - dnnl_dim_t batch_size() const { - dnnl_dim_t batch = 1; - for (int i = 0; i < dst_desc.ndims - 2; i++) - batch *= dst_desc.dims[i]; - return batch; - } + dim_t num_q_heads() const { return q_desc.dims[1]; } + dim_t num_kv_heads() const { return kv_head_number; } + // Batch size (outer batch dimension, excluding heads). + dnnl_dim_t batch() const { return dst_desc.dims[0]; } + + // Memory descriptors + const memory_desc_t *qry_md() const { return &q_desc; } + const memory_desc_t *key_md() const { return &k_desc; } + const memory_desc_t *val_md() const { return &v_desc; } + const memory_desc_t *attn_mask_md() const { return &attn_mask_desc; } + const memory_desc_t *scale_md() const { return &scale_desc; } + const memory_desc_t *diff_qry_md() const { return &diff_q_desc; } + const memory_desc_t *diff_key_md() const { return &diff_k_desc; } + const memory_desc_t *diff_val_md() const { return &diff_v_desc; } }; } // namespace impl diff --git a/src/common/sdpa_utils.hpp b/src/common/sdpa_utils.hpp index 3354c09b943..089d0a2a7ae 100644 --- a/src/common/sdpa_utils.hpp +++ b/src/common/sdpa_utils.hpp @@ -81,6 +81,61 @@ static inline status_t sdpa_desc_check(const memory_desc_t *q_desc, return status::success; } +static inline status_t sdpa_desc_check(const memory_desc_t *q_desc, + const memory_desc_t *k_desc, const memory_desc_t *v_desc, + const memory_desc_t *dst_desc, const memory_desc_t *attn_mask_md, + const memory_desc_t *diff_q_desc, const memory_desc_t *diff_k_desc, + const memory_desc_t *diff_v_desc, const memory_desc_t *diff_dst_desc, + const engine_t *engine, const primitive_attr_t *attr) { + int ndims = dst_desc->ndims; + int r = ndims - 2, c = ndims - 1; + VCHECK_SDPA_COND( + utils::everyone_is(ndims, q_desc->ndims, k_desc->ndims, + v_desc->ndims, diff_q_desc->ndims, diff_k_desc->ndims, + diff_v_desc->ndims, diff_dst_desc->ndims), + "number of dimensions have to match. expected: %d q: %d k: %d v: " + "%d dQ: %d dK: %d dV: %d dO: %d", + ndims, q_desc->ndims, k_desc->ndims, v_desc->ndims, + diff_q_desc->ndims, diff_k_desc->ndims, diff_v_desc->ndims, + diff_dst_desc->ndims); + + VCHECK_SDPA_COND(q_desc->dims[c] == k_desc->dims[r], + "q_desc->dims[%d](%s) must match k_desc->dims[%d](%s)", c, + md2dim_str(q_desc).c_str(), r, md2dim_str(k_desc).c_str()); + VCHECK_SDPA_COND(k_desc->dims[c] == v_desc->dims[r], + "k_desc->dims[%d](%s) must match v_desc->dims[%d](%s)", c, + md2dim_str(k_desc).c_str(), r, md2dim_str(v_desc).c_str()); + VCHECK_SDPA_COND(dst_desc->dims[r] == q_desc->dims[r], + "dst_desc->dims[%d](%s) == q_desc->dims[%d](%s)", r, + md2dim_str(dst_desc).c_str(), r, md2dim_str(q_desc).c_str()); + VCHECK_SDPA_COND(dst_desc->dims[c] == v_desc->dims[c], + "dst_desc->dims[%d](%s) == v_desc->dims[%d](%s)", c, + md2dim_str(dst_desc).c_str(), c, md2dim_str(v_desc).c_str()); + + for (int i = 0; i < ndims; i++) { + VCHECK_SDPA_COND(diff_q_desc->dims[i] == q_desc->dims[i], + "diff_q_desc->dims[%d](%s) must match q_desc->dims[%d](%s)", i, + md2dim_str(diff_q_desc).c_str(), i, md2dim_str(q_desc).c_str()); + VCHECK_SDPA_COND(diff_k_desc->dims[i] == k_desc->dims[i], + "diff_k_desc->dims[%d](%s) must match k_desc->dims[%d](%s)", i, + md2dim_str(diff_k_desc).c_str(), i, md2dim_str(k_desc).c_str()); + VCHECK_SDPA_COND(diff_v_desc->dims[i] == v_desc->dims[i], + "diff_v_desc->dims[%d](%s) must match v_desc->dims[%d](%s)", i, + md2dim_str(diff_v_desc).c_str(), i, md2dim_str(v_desc).c_str()); + VCHECK_SDPA_COND(diff_dst_desc->dims[i] == dst_desc->dims[i], + "diff_dst_desc->dims[%d](%s) must match dst_desc->dims[%d](%s)", + i, md2dim_str(diff_dst_desc).c_str(), i, + md2dim_str(dst_desc).c_str()); + } + + VCHECK_SDPA_COND(!any_memory_desc_host_scalar(q_desc, k_desc, v_desc, + dst_desc, attn_mask_md, diff_q_desc, diff_k_desc, + diff_v_desc, diff_dst_desc), + VERBOSE_UNSUPPORTED_FORMAT_KIND); + + return status::success; +} + static inline status_t sdpa_attr_check(const memory_desc_t *q_desc, const memory_desc_t *k_desc, const memory_desc_t *v_desc, const engine_t *engine, const primitive_attr_t *attr, @@ -148,12 +203,27 @@ static inline status_t sdpa_attr_check(const memory_desc_t *q_desc, return status::success; } +static inline status_t sdpa_attr_check( + const engine_t *engine, const primitive_attr_t *attr) { + using smask_t = primitive_attr_t::skip_mask_t; + + if (attr == nullptr) return status::success; + if (attr && attr->has_default_values()) { return status::success; } + + if (attr) { + smask_t attr_mask = smask_t::none; + VCHECK_SDPA_UNIMPL( + attr->has_default_values(attr_mask), VERBOSE_UNSUPPORTED_ATTR); + } + return status::success; +} static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md, const memory_desc_t *k_md, const memory_desc_t *v_md, const memory_desc_t *dst_md, const memory_desc_t *attn_mask_md, const memory_desc_t *scale_md, bool invert_scale, dim_t kv_head_number, attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg, - const primitive_attr_t *kq_attr, const primitive_attr_t *vs_attr) { + prop_kind_t prop, const primitive_attr_t *kq_attr, + const primitive_attr_t *vs_attr) { auto sdpa_desc = sdpa_desc_t(); sdpa_desc.primitive_kind = primitive_kind::sdpa; sdpa_desc.q_desc = *q_md; @@ -182,6 +252,40 @@ static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md, sdpa_desc.kv_head_number = kv_head_number; sdpa_desc.mask_type = attn_mask_type; sdpa_desc.softmax_alg = softmax_alg; + sdpa_desc.prop_kind = prop; + return sdpa_desc; +} + +static inline sdpa_desc_t create_sdpa_desc(const memory_desc_t *q_md, + const memory_desc_t *k_md, const memory_desc_t *v_md, + const memory_desc_t *dst_md, const memory_desc_t *attn_mask_md, + const memory_desc_t *scale_md, const memory_desc_t *diff_q_md, + const memory_desc_t *diff_k_md, const memory_desc_t *diff_v_md, + const memory_desc_t *diff_dst_md, const memory_desc_t *dS_md, + bool invert_scale, dim_t kv_head_number, + attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg) { + auto sdpa_desc = sdpa_desc_t(); + sdpa_desc.primitive_kind = primitive_kind::sdpa; + sdpa_desc.q_desc = *q_md; + sdpa_desc.k_desc = *k_md; + sdpa_desc.v_desc = *v_md; + + sdpa_desc.dst_desc = *dst_md; + sdpa_desc.diff_dst_desc = *diff_dst_md; + + if (dS_md) sdpa_desc.dS_desc = *dS_md; + + sdpa_desc.diff_q_desc = *diff_q_md; + sdpa_desc.diff_k_desc = *diff_k_md; + sdpa_desc.diff_v_desc = *diff_v_md; + + if (attn_mask_md) sdpa_desc.attn_mask_desc = *attn_mask_md; + sdpa_desc.scale_desc = *scale_md; + sdpa_desc.invert_scale = invert_scale; + sdpa_desc.kv_head_number = kv_head_number; + sdpa_desc.mask_type = attn_mask_type; + sdpa_desc.softmax_alg = softmax_alg; + sdpa_desc.prop_kind = prop_kind::backward; return sdpa_desc; } @@ -192,7 +296,8 @@ static inline status_t create_sdpa_pd( const memory_desc_t *attn_mask_md, const memory_desc_t *scale_md, bool invert_scale, dim_t kv_head_number, attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg, - const primitive_attr_t *attr, const primitive_attr_t *kq_attr = nullptr, + prop_kind_t prop, const primitive_attr_t *attr, + const primitive_attr_t *kq_attr = nullptr, const primitive_attr_t *vs_attr = nullptr) { CHECK(sdpa_attr_check(q_md, k_md, v_md, engine, attr, kq_attr, vs_attr)); CHECK(sdpa_desc_check(q_md, k_md, v_md, dst_md, attn_mask_md, engine, attr, @@ -200,7 +305,7 @@ static inline status_t create_sdpa_pd( auto sdpa_desc = create_sdpa_desc(q_md, k_md, v_md, dst_md, attn_mask_md, scale_md, invert_scale, kv_head_number, attn_mask_type, softmax_alg, - kq_attr, vs_attr); + prop, kq_attr, vs_attr); primitive_attr_t sdpa_attr = attr ? *attr : default_attr(); @@ -213,6 +318,37 @@ static inline status_t create_sdpa_pd( return status::success; } +static inline status_t create_sdpa_pd( + std::shared_ptr &sdpa_pd_, engine_t *engine, + const memory_desc_t *q_md, const memory_desc_t *k_md, + const memory_desc_t *v_md, const memory_desc_t *dst_md, + const memory_desc_t *diff_q_md, const memory_desc_t *diff_k_md, + const memory_desc_t *diff_v_md, const memory_desc_t *diff_dst_md, + const memory_desc_t *dS_md, const memory_desc_t *attn_mask_md, + const memory_desc_t *scale_md, bool invert_scale, dim_t kv_head_number, + attn_mask_type_t attn_mask_type, alg_kind_t softmax_alg, + const primitive_attr_t *attr, const primitive_desc_t *hint_fwd_pd, + const primitive_attr_t *kq_attr = nullptr, + const primitive_attr_t *vs_attr = nullptr) { + CHECK(sdpa_attr_check(q_md, k_md, v_md, engine, attr, kq_attr, vs_attr)); + CHECK(sdpa_desc_check(q_md, k_md, v_md, dst_md, attn_mask_md, engine, attr, + kq_attr, vs_attr)); + + auto sdpa_desc = create_sdpa_desc(q_md, k_md, v_md, dst_md, attn_mask_md, + scale_md, diff_q_md, diff_k_md, diff_v_md, diff_dst_md, dS_md, + invert_scale, kv_head_number, attn_mask_type, softmax_alg); + + primitive_attr_t sdpa_attr = attr ? *attr : default_attr(); + + primitive_desc_iterator_t it( + engine, (op_desc_t *)&sdpa_desc, &sdpa_attr, hint_fwd_pd); + + sdpa_pd_ = *(++it); + VCHECK_SDPA_COND(sdpa_pd_, "failed to create the backward SDPA primitive"); + + return status::success; +} + } // namespace impl } // namespace dnnl diff --git a/src/common/type_helpers.hpp b/src/common/type_helpers.hpp index e78ced9dd33..f281e03e6f8 100644 --- a/src/common/type_helpers.hpp +++ b/src/common/type_helpers.hpp @@ -1013,6 +1013,7 @@ inline bool operator==(const zero_pad_desc_t &lhs, const zero_pad_desc_t &rhs) { inline bool operator==(const sdpa_desc_t &lhs, const sdpa_desc_t &rhs) { bool ret = COMPARE_DESC_MEMBERS(primitive_kind) + && COMPARE_DESC_MEMBERS(prop_kind) && COMPARE_DESC_MEMBERS(q_desc) && COMPARE_DESC_MEMBERS(k_desc) && COMPARE_DESC_MEMBERS(v_desc) @@ -1020,7 +1021,12 @@ inline bool operator==(const sdpa_desc_t &lhs, const sdpa_desc_t &rhs) { && COMPARE_DESC_MEMBERS(kq_zero_points) && COMPARE_DESC_MEMBERS(vs_scales) && COMPARE_DESC_MEMBERS(vs_zero_points) + && COMPARE_DESC_MEMBERS(dS_desc) && COMPARE_DESC_MEMBERS(dst_desc) + && COMPARE_DESC_MEMBERS(diff_dst_desc) + && COMPARE_DESC_MEMBERS(diff_q_desc) + && COMPARE_DESC_MEMBERS(diff_k_desc) + && COMPARE_DESC_MEMBERS(diff_v_desc) && COMPARE_DESC_MEMBERS(attn_mask_desc) && COMPARE_DESC_MEMBERS(scale_desc) && COMPARE_DESC_MEMBERS(kq_acc_dt) diff --git a/src/common/verbose.cpp b/src/common/verbose.cpp index db93ef19741..a65bd0e7476 100644 --- a/src/common/verbose.cpp +++ b/src/common/verbose.cpp @@ -1577,19 +1577,21 @@ std::string init_info_sum(const engine_t *e, const pd_t *pd) { template std::string init_info_sdpa(const engine_t *e, const pd_t *pd) { stringstream_t ss; - ss << e << "," << pd->kind() << "," << pd->name() << "," << prop_kind::undef - << ","; + ss << e << "," << pd->kind() << "," << pd->name() << "," + << pd->desc()->prop_kind << ","; const sdpa_desc_t *desc = pd->desc(); ss << md2fmt_str( - "query", pd->qry_md(), pd->invariant_src_user_format_kind(0)) + "query", desc->qry_md(), pd->invariant_src_user_format_kind(0)) << " "; - ss << md2fmt_str("key", pd->key_md(), pd->invariant_src_user_format_kind(1)) + ss << md2fmt_str( + "key", desc->key_md(), pd->invariant_src_user_format_kind(1)) << " "; - ss << md2fmt_str("val", pd->val_md(), pd->invariant_src_user_format_kind(2)) + ss << md2fmt_str( + "val", desc->val_md(), pd->invariant_src_user_format_kind(2)) << " "; if (pd->with_attn_mask()) - ss << md2fmt_str("msk", pd->attn_mask_md(), + ss << md2fmt_str("msk", desc->attn_mask_md(), pd->invariant_src_user_format_kind(3)) << " "; ss << md2fmt_str("dst", pd->dst_md(), pd->invariant_dst_user_format_kind()) @@ -1627,7 +1629,7 @@ std::string init_info_sdpa(const engine_t *e, const pd_t *pd) { delimiter = " "; ss << ",alg:" << desc->softmax_alg; if (pd->with_attn_mask()) { - auto *md = pd->attn_mask_md(); + auto *md = desc->attn_mask_md(); ss << delimiter << "msk:" << (md->dims[2] == 1 ? 1 : 2) << 'd'; } else if (pd->with_causal_mask()) { ss << delimiter; @@ -1642,15 +1644,15 @@ std::string init_info_sdpa(const engine_t *e, const pd_t *pd) { ss << "div:"; else ss << "mul:"; - ss << dnnl_dt2str(pd->scale_md()->data_type) << ":"; + ss << dnnl_dt2str(desc->scale_md()->data_type) << ":"; if (pd->with_host_scale()) ss << "host"; else ss << "device"; } - ss << "," << md2dim_str(pd->qry_md()) << ":" << md2dim_str(pd->key_md()) - << ":" << md2dim_str(pd->val_md()); + ss << "," << md2dim_str(desc->qry_md()) << ":" << md2dim_str(desc->key_md()) + << ":" << md2dim_str(desc->val_md()); return ss.str(); } diff --git a/src/gpu/gpu_sdpa_list.cpp b/src/gpu/gpu_sdpa_list.cpp index 0b9d63bc43a..13a4b8f310d 100644 --- a/src/gpu/gpu_sdpa_list.cpp +++ b/src/gpu/gpu_sdpa_list.cpp @@ -28,19 +28,34 @@ namespace impl { namespace gpu { namespace { +using namespace dnnl::impl::prop_kind; // clang-format off -constexpr impl_list_item_t impl_list[] = REG_SDPA_P({ - GPU_INSTANCE_INTEL(intel::sdpa::micro_t) - GPU_INSTANCE_INTEL_DEVMODE(intel::sdpa::ref_t) +const std::map> + impl_list_map REG_SDPA_P({ + {{forward}, { + GPU_INSTANCE_INTEL(intel::sdpa::micro_fwd_t) + GPU_INSTANCE_INTEL_DEVMODE(intel::sdpa::ref_fwd_t) nullptr, + }}, + {{backward}, REG_BWD_PK({ + GPU_INSTANCE_INTEL(intel::sdpa::micro_bwd_t) + nullptr, + })}, }); // clang-format on } // namespace const impl_list_item_t *get_sdpa_impl_list(const sdpa_desc_t *desc) { - UNUSED(desc); - return impl_list; + static const impl_list_item_t empty_list[] = {nullptr}; + + const bool is_fwd = utils::one_of( + desc->prop_kind, forward_training, forward_inference); + prop_kind_t prop_kind = is_fwd ? forward : backward; + + const auto impl_list_it = impl_list_map.find({prop_kind}); + return impl_list_it != impl_list_map.cend() ? impl_list_it->second.data() + : empty_list; } } // namespace gpu diff --git a/src/gpu/intel/gemm/jit/generator/microkernel_provider.cpp b/src/gpu/intel/gemm/jit/generator/microkernel_provider.cpp index d496b77f8ed..e31f0dff19b 100644 --- a/src/gpu/intel/gemm/jit/generator/microkernel_provider.cpp +++ b/src/gpu/intel/gemm/jit/generator/microkernel_provider.cpp @@ -129,10 +129,6 @@ Package selectGEMMMicrokernel(GEMMProtocol protocol, HWInformation hwInfo, SizeP evalParams.beta = 0; evalParams.euCount = hwInfo.euCount; - /* Locate appropriate kernel catalog */ - if (localA && localB) - stub("Unsupported protocol"); - kcatalog::Catalog catalog = [&]() { if (localA) return kcatalog::Catalog(CatalogLMR); diff --git a/src/gpu/intel/include/generic_vector_ops.h b/src/gpu/intel/include/generic_vector_ops.h index 2ed6bcb4d41..0621d4d6718 100644 --- a/src/gpu/intel/include/generic_vector_ops.h +++ b/src/gpu/intel/include/generic_vector_ops.h @@ -73,6 +73,23 @@ float16 __attribute__((overloadable)) native_vexp2(float16 x) { return native_exp2(x); } +float1 __attribute__((overloadable)) native_vlog2(float1 x) { + x[0] = native_log2(x[0]); + return x; +} +float2 __attribute__((overloadable)) native_vlog2(float2 x) { + return native_log2(x); +} +float4 __attribute__((overloadable)) native_vlog2(float4 x) { + return native_log2(x); +} +float8 __attribute__((overloadable)) native_vlog2(float8 x) { + return native_log2(x); +} +float16 __attribute__((overloadable)) native_vlog2(float16 x) { + return native_log2(x); +} + float1 __attribute__((overloadable)) vselect(float1 x, float1 y, int1 c) { x[0] = select(x[0], y[0], c[0]); return x; diff --git a/src/gpu/intel/include/tile_ops.h b/src/gpu/intel/include/tile_ops.h index 758d80e0419..ab7207ab490 100644 --- a/src/gpu/intel/include/tile_ops.h +++ b/src/gpu/intel/include/tile_ops.h @@ -21,20 +21,18 @@ #include "gpu/intel/include/types.h" float __builtin_IB_atomic_max_local_f32(__local float *, float); +float __builtin_IB_atomic_add_local_f32(__local float *, float); +float __builtin_IB_atomic_add_global_f32(__global float *, float); +half __builtin_IB_atomic_add_global_f16(__global half *, half); __attribute__((overloadable)) float local_atomic_max(local float *p, float v) { return __builtin_IB_atomic_max_local_f32(p, v); } -__attribute__((overloadable)) half local_atomic_max( - local half *p, half v) { /* not implemented */ - return v; -} - +/* not implemented */ +__attribute__((overloadable)) half local_atomic_max(local half *p, half v); __attribute__((overloadable)) ushort local_atomic_max( - local ushort *p, ushort v) { /* not implemented */ - return v; -} + local ushort *p, ushort v); __attribute__((overloadable)) uint local_atomic_max(local uint *p, uint v) { return atomic_max(p, v); @@ -44,6 +42,44 @@ __attribute__((overloadable)) int local_atomic_max(local int *p, int v) { return atomic_max(p, v); } +__attribute__((overloadable)) float local_atomic_add(local float *p, float v) { + return __builtin_IB_atomic_add_local_f32(p, v); +} + +/* not implemented */ +__attribute__((overloadable)) half local_atomic_add(local half *p, half v); +__attribute__((overloadable)) ushort local_atomic_add( + local ushort *p, ushort v); + +__attribute__((overloadable)) uint local_atomic_add(local uint *p, uint v) { + return atomic_add(p, v); +} + +__attribute__((overloadable)) int local_atomic_add(local int *p, int v) { + return atomic_add(p, v); +} + +__attribute__((overloadable)) float global_atomic_add( + global float *p, float v) { + return __builtin_IB_atomic_add_global_f32(p, v); +} + +__attribute__((overloadable)) half global_atomic_add(global half *p, half v) { + return __builtin_IB_atomic_add_global_f16(p, v); +} + +/* not implemented */ +__attribute__((overloadable)) ushort global_atomic_add( + global ushort *p, ushort v); + +__attribute__((overloadable)) uint global_atomic_add(global uint *p, uint v) { + return atomic_add(p, v); +} + +__attribute__((overloadable)) int global_atomic_add(global int *p, int v) { + return atomic_add(p, v); +} + #define DEF_BLOCK_LOAD_STORE(type, itype, suffix, n) \ __attribute__((overloadable)) type##n block_load( \ const global type *p, int vlen) \ @@ -311,6 +347,21 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) } \ } while (0) +#define tile_predicated_assignment( \ + t, sg_offset_r, sg_offset_c, predicate, value, sg, br, bc, nbr, nbc) \ + do { \ + for (int j = 0; j < (bc * nbc); j++) { \ + for (int i0 = 0; i0 < (br * nbr); i0 += sg) { \ + int i = i0 + get_sub_group_local_id(); \ + int offset_r = sg_offset_r + i; \ + int offset_c = sg_offset_c + j; \ + if (predicate(offset_r, offset_c)) { \ + tile_access(t, i0, j, sg, br, bc, nbr) = value; \ + } \ + } \ + } \ + } while (0) + #define DECLARE_2D_TILE_OPS(tile_type, element_type, sg, br, bc, nbr, nbc) \ __attribute__((overloadable)) void tile_load_full(tile_type *t, \ const global element_type *ptr, int ld, int offset_r, \ @@ -334,6 +385,24 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) } \ } \ } \ + __attribute__((overloadable)) void tile_load(tile_type *t, \ + const local element_type *ptr, int m, int n, int ld, int offset_r, \ + int offset_c) { \ + if (m >= offset_r + br * nbr && n >= offset_c + bc * nbc) { \ + tile_load_full(t, ptr, ld, offset_r, offset_c); \ + return; \ + } \ + ptr += ld * offset_c + offset_r; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += ld) { \ + if (offset_c + j < n) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = i0 + get_sub_group_local_id(); \ + if (offset_r + i < m) \ + tile_access(*t, i0, j, sg, br, bc, nbr) = ptr[i]; \ + } \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_load(tile_type *t, \ const global element_type *ptr, int m, int n, int ld, \ int offset_r, int offset_c) { \ @@ -357,6 +426,38 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) int offset_c) { \ tile_load(t, ptr, m, n, m, offset_r, offset_c); \ } \ + __attribute__((overloadable)) void tile_load_t_full(tile_type *t, \ + const local element_type *ptr, int ld, int offset_r, \ + int offset_c) { \ + ptr += ld * offset_r + offset_c; \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; \ + i0 += sg, ptr += ld * sg) { \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + tile_access(*t, i0, j, sg, br, bc, nbr) \ + = ptr[get_sub_group_local_id() * ld + j]; \ + } \ + } \ + } \ + __attribute__((overloadable)) void tile_load_t(tile_type *t, \ + const local element_type *ptr, int m, int n, int ld, int offset_r, \ + int offset_c) { \ + if (m >= offset_r + br * nbr && n >= offset_c + bc * nbc) { \ + tile_load_t_full(t, ptr, ld, offset_r, offset_c); \ + return; \ + } \ + ptr += ld * offset_r + offset_c; \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; \ + i0 += sg, ptr += ld * sg) { \ + int i = i0 + get_sub_group_local_id(); \ + if (offset_r + i < m) \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + if (offset_c + j < n) { \ + tile_access(*t, i0, j, sg, br, bc, nbr) \ + = ptr[get_sub_group_local_id() * ld + j]; \ + } \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_load_t_full(tile_type *t, \ const global element_type *ptr, int ld, int offset_r, \ int offset_c) { \ @@ -394,6 +495,34 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) int offset_c) { \ tile_load_t(t, ptr, m, n, n, offset_r, offset_c); \ } \ + __attribute__((overloadable)) void tile_store_t_full(tile_type t, \ + local element_type *ptr, int ld, int offset_r, int offset_c) { \ + ptr += ld * offset_r + offset_c; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = ld * (i0 + get_sub_group_local_id()); \ + ptr[i] = tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } \ + __attribute__((overloadable)) void tile_store_t(tile_type t, \ + local element_type *ptr, int m, int n, int ld, int offset_r, \ + int offset_c) { \ + if (m >= offset_r + br * nbr && n >= offset_c + bc * nbc) { \ + tile_store_t_full(t, ptr, ld, offset_r, offset_c); \ + return; \ + } \ + ptr += ld * offset_r + offset_c; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr++) { \ + if (offset_c + j < n) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = ld * (i0 + get_sub_group_local_id()); \ + if ((offset_r + i0 + get_sub_group_local_id()) < m) \ + ptr[i] = tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_store_full(tile_type t, \ local element_type *ptr, int ld, int offset_r, int offset_c) { \ ptr += ld * offset_c + offset_r; \ @@ -404,6 +533,24 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) } \ } \ } \ + __attribute__((overloadable)) void tile_store(tile_type t, \ + local element_type *ptr, int m, int n, int ld, int offset_r, \ + int offset_c) { \ + if (m >= offset_r + br * nbr && n >= offset_c + bc * nbc) { \ + tile_store_full(t, ptr, ld, offset_r, offset_c); \ + return; \ + } \ + ptr += ld * offset_c + offset_r; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += ld) { \ + if (offset_c + j < n) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = i0 + get_sub_group_local_id(); \ + if (offset_r + i < m) \ + ptr[i] = tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_store_full(tile_type t, \ global element_type *ptr, int ld, int offset_r, int offset_c) { \ ptr += ld * offset_c + offset_r; \ @@ -437,6 +584,46 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) int offset_c) { \ tile_store(t, ptr, m, n, m, offset_r, offset_c); \ } \ + __attribute__((overloadable)) void tile_load_t_packed_src1(tile_type *t, \ + local element_type *ptr, int panel, int ld, int offset_r, \ + int offset_c) { \ + offset_c += get_sub_group_local_id(); \ + int offset_r0 = offset_r % panel; \ + int offset_r1 = offset_r - offset_r0; \ + ptr += offset_r0 + panel * offset_c + ld * offset_r1; \ + _Pragma("unroll") for (int j0 = 0; j0 < br * nbr; \ + j0 += sg, ptr += sg * panel) { \ + _Pragma("unroll") for (int i = 0; i < bc * nbc; i++) \ + tile_access(*(t), j0, i, sg, br, bc, nbr) \ + = ptr[i]; \ + } \ + } \ + __attribute__((overloadable)) void tile_load_packed_src1(tile_type *t, \ + local element_type *ptr, int panel, int ld, int offset_r, \ + int offset_c) { \ + ptr += offset_c * panel; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += panel) \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = i0 + get_sub_group_local_id(); \ + int offset_r0 = (offset_r + i) % panel; \ + int offset_r1 = (offset_r + i) - offset_r0; \ + tile_access(*(t), i0, j, sg, br, bc, nbr) \ + = ptr[offset_r0 + offset_r1 * ld]; \ + } \ + } \ + __attribute__((overloadable)) void tile_store_packed_src1(tile_type t, \ + local element_type *ptr, int panel, int ld, int offset_r, \ + int offset_c) { \ + ptr += offset_c * panel; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += panel) \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = i0 + get_sub_group_local_id(); \ + int offset_r0 = (offset_r + i) % panel; \ + int offset_r1 = (offset_r + i) - offset_r0; \ + ptr[offset_r0 + offset_r1 * ld] \ + = tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ __attribute__((overloadable)) void tile_store_t_packed_src1(tile_type t, \ local element_type *ptr, int panel, int ld, int offset_r, \ int offset_c) { \ @@ -451,6 +638,82 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) = tile_access(t, j0, i, sg, br, bc, nbr); \ } \ } \ + __attribute__((overloadable)) void tile_store_sys_src1(tile_type t, \ + local element_type *ptr, int tileR, int tileC, int wg_tile_m, \ + int wg_tile_n, int offset_r, int offset_c) { \ + const int crosspack = 2; \ + const int tile_panel_size = tileR * tileC; \ + const int num_row_panels = wg_tile_m / tileR; \ + const int num_col_panels = wg_tile_n / tileC; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + const int in_r = offset_r + i0 + get_sub_group_local_id(); \ + const int in_c = offset_c + j; \ + /* Compute 2D panel grid position: */ \ + const int row_panel = in_r \ + / tileR; /* Which vertical panel (every sg rows) */ \ + const int col_panel = in_c \ + / tileC; /* Which horizontal panel (every tile_n columns) */ \ + const int panel_base \ + = (col_panel * num_row_panels + row_panel) \ + * tile_panel_size; \ + /*const int panel_base = (row_panel * num_col_panels + col_panel) * tile_panel_size;*/ \ + /* Within-panel offsets using crosspack layout: */ \ + const int in_panel_row = in_r \ + & (tileR - 1); /* Row within panel (in_r % sg) */ \ + const int in_panel_col = in_c \ + & (tileC \ + - 1); /* Column within panel (in_c % tile_n) */ \ + const int col_group_offset = (in_panel_col >> 1) \ + * (crosspack * tileR); /* Column pair group */ \ + const int sg_lane_offset = in_panel_row \ + * crosspack; /* Subgroup lane position */ \ + const int crosspack_offset = (in_panel_col \ + & 1); /* Position within column pair */ \ + const int out_idx = panel_base + col_group_offset \ + + sg_lane_offset + crosspack_offset; \ + ptr[out_idx] = tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } \ + __attribute__((overloadable)) void tile_load_sys_src1(tile_type *t, \ + local element_type *ptr, int tileR, int tileC, int wg_tile_m, \ + int wg_tile_n, int offset_r, int offset_c) { \ + const int crosspack = 2; \ + const int tile_panel_size = tileR * tileC; \ + const int num_row_panels = wg_tile_m / tileR; \ + const int num_col_panels = wg_tile_n / tileC; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + const int in_r = offset_r + i0 + get_sub_group_local_id(); \ + const int in_c = offset_c + j; \ + /* Compute 2D panel grid position: */ \ + const int row_panel = in_r \ + / tileR; /* Which vertical panel (every sg rows) */ \ + const int col_panel = in_c \ + / tileC; /* Which horizontal panel (every tile_n columns) */ \ + const int panel_base \ + = (col_panel * num_row_panels + row_panel) \ + * tile_panel_size; \ + /*const int panel_base = (row_panel * num_col_panels + col_panel) * tile_panel_size;*/ \ + /* Within-panel offsets using crosspack layout: */ \ + const int in_panel_row = in_r \ + & (tileR - 1); /* Row within panel (in_r % sg) */ \ + const int in_panel_col = in_c \ + & (tileC \ + - 1); /* Column within panel (in_c % tile_n) */ \ + const int col_group_offset = (in_panel_col >> 1) \ + * (crosspack * tileR); /* Column pair group */ \ + const int sg_lane_offset = in_panel_row \ + * crosspack; /* Subgroup lane position */ \ + const int crosspack_offset = (in_panel_col \ + & 1); /* Position within column pair */ \ + const int out_idx = panel_base + col_group_offset \ + + sg_lane_offset + crosspack_offset; \ + tile_access(*t, i0, j, sg, br, bc, nbr) = ptr[out_idx]; \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_store_t_sys_src1(tile_type t, \ local element_type *ptr, int ld, int offset_r, int offset_c) { \ offset_c += get_sub_group_local_id(); \ @@ -463,6 +726,91 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) = tile_access(t, j0, i, sg, br, bc, nbr); \ } \ } \ + __attribute__((overloadable)) void tile_store_t_sys_src11(tile_type t, \ + local element_type *ptr, int tileR, int tileC, int wg_tile_m, \ + int wg_tile_n, int offset_r, int offset_c) { \ + const int crosspack = 2; \ + const int tile_panel_size = tileR * tileC; \ + const int num_row_panels \ + = wg_tile_m / tileR; /* is correct when _t? */ \ + const int num_col_panels = wg_tile_n / tileC; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + const int in_r = offset_r + i0 + get_sub_group_local_id(); \ + const int in_c = offset_c + j; \ + /* Compute 2D panel grid position: */ \ + const int row_panel = in_c / tileR; \ + const int col_panel = in_r / tileC; \ + const int panel_base \ + = (col_panel * num_row_panels + row_panel) \ + * tile_panel_size; \ + /* Within-panel offsets using crosspack layout: */ \ + const int in_panel_row = in_c \ + & (tileR - 1); /* Row within panel (in_c % tileR) */ \ + const int in_panel_col = in_r \ + & (tileC \ + - 1); /* Column within panel (in_r % tileC) */ \ + const int col_group_offset = (in_panel_col >> 1) \ + * (crosspack * tileR); /* Column pair group */ \ + const int sg_lane_offset = in_panel_row \ + * crosspack; /* Subgroup lane position */ \ + const int crosspack_offset = (in_panel_col \ + & 1); /* Position within column pair */ \ + const int out_idx = panel_base + col_group_offset \ + + sg_lane_offset + crosspack_offset; \ + ptr[out_idx] = tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } \ + __attribute__((overloadable)) void tile_store_sys_src22(tile_type t, \ + local element_type *ptr, int panel_n, int wg_tile_m, \ + int wg_tile_n, int offset_r, int offset_c) { \ + const int crosspack = 16; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + const int in_r = offset_r + i0 + get_sub_group_local_id(); \ + const int in_c = offset_c + j; \ + /* Panel-based addressing: panel_n cols per panel */ \ + const int col_panel = in_c / panel_n; \ + const int in_panel_c = in_c & (panel_n - 1); \ + /* Within-panel offsets using crosspack layout: */ \ + const int col_group_offset = (in_r / crosspack) \ + * (crosspack * panel_n); /* Column pair group */ \ + const int sg_lane_offset \ + = in_panel_c * crosspack; /* Subgroup lane position */ \ + const int crosspack_offset = (in_r & (crosspack - 1)); \ + const int out_idx = col_panel * (panel_n * wg_tile_m) \ + + col_group_offset + sg_lane_offset \ + + crosspack_offset; \ + ptr[out_idx] = tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } \ + __attribute__((overloadable)) void tile_store_t_sys_src22(tile_type t, \ + local element_type *ptr, int panel_n, int wg_tile_m, \ + int wg_tile_n, int offset_r, int offset_c) { \ + const int crosspack = 16; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + const int in_r = offset_r + i0 + get_sub_group_local_id(); \ + const int in_c = offset_c + j; \ + /* Panel-based addressing: panel_n cols per panel */ \ + const int col_panel = in_r / panel_n; \ + const int in_panel_c = in_r & (panel_n - 1); \ + /* Within-panel offsets using crosspack layout: */ \ + const int col_group_offset = (in_c / crosspack) \ + * (crosspack * panel_n); /* Column pair group */ \ + const int sg_lane_offset \ + = in_panel_c * crosspack; /* Subgroup lane position */ \ + const int crosspack_offset = (in_c \ + & (crosspack - 1)); /* Position within column pair */ \ + const int out_idx = col_panel * (panel_n * wg_tile_n) \ + + col_group_offset + sg_lane_offset \ + + crosspack_offset; \ + ptr[out_idx] = tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_store_t_sys_src2(tile_type t, \ local element_type *ptr, int tile_n, int ld, int offset_r, \ int offset_c) { \ @@ -483,6 +831,26 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) } \ } \ } \ + __attribute__((overloadable)) void tile_load_t_sys_src2(tile_type *t, \ + local element_type *ptr, int tile_n, int ld, int offset_r, \ + int offset_c) { \ + const int cp = 32 / sizeof(element_type); \ + offset_c += get_sub_group_local_id(); \ + int offset_r0 = offset_r & (cp - 1); \ + int offset_r1 = offset_r & ~(cp - 1); \ + ptr += offset_r0 + tile_n * offset_r1; \ + _Pragma("unroll") for (int j0 = 0; j0 < br * nbr; \ + j0 += sg, offset_c += sg) { \ + int offset_c0 = offset_c & (tile_n - 1); \ + int offset_c1 = offset_c & ~(tile_n - 1); \ + local element_type *ptr_j = ptr + cp * offset_c0 + ld * offset_c1; \ + _Pragma("unroll") for (int i = 0; i < bc * nbc; i++) { \ + tile_access(*t, j0, i, sg, br, bc, nbr) = *ptr_j; \ + ptr_j++; \ + if ((~i & (cp - 1)) == 0) ptr_j += cp * (tile_n - 1); \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_atomic_max_full(tile_type t, \ local element_type *ptr, int ld, int offset_r, int offset_c) { \ ptr += ld * offset_c + offset_r; \ @@ -493,6 +861,47 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) ptr + i, tile_access(t, i0, j, sg, br, bc, nbr)); \ } \ } \ + } \ + __attribute__((overloadable)) void tile_atomic_add_full(tile_type t, \ + local element_type *ptr, int ld, int offset_r, int offset_c) { \ + ptr += ld * offset_c + offset_r; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += ld) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = i0 + get_sub_group_local_id(); \ + (void)local_atomic_add( \ + ptr + i, tile_access(t, i0, j, sg, br, bc, nbr)); \ + } \ + } \ + } \ + __attribute__((overloadable)) void tile_atomic_add_full(tile_type t, \ + global element_type *ptr, int ld, int offset_r, int offset_c) { \ + ptr += ld * offset_c + offset_r; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += ld) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = i0 + get_sub_group_local_id(); \ + (void)global_atomic_add( \ + ptr + i, tile_access(t, i0, j, sg, br, bc, nbr)); \ + } \ + } \ + } \ + __attribute__((overloadable)) void tile_atomic_add(tile_type t, \ + global element_type *ptr, int m, int n, int ld, int offset_r, \ + int offset_c) { \ + if (m >= (offset_r + (br * nbr)) && n >= (offset_c + (bc * nbc))) { \ + tile_atomic_add_full(t, ptr, ld, offset_r, offset_c); \ + return; \ + } \ + ptr += ld * offset_c + offset_r; \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += ld) { \ + if (offset_c + j < n) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + int i = i0 + get_sub_group_local_id(); \ + if (offset_r + i < m) \ + (void)global_atomic_add(ptr + i, \ + tile_access(t, i0, j, sg, br, bc, nbr)); \ + } \ + } \ + } \ } #define DECLARE_2D_TILE_VREDUCE(tile_type, sg, br, bc, nbr, nbc, rtile_type, \ @@ -516,6 +925,15 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) } \ } \ } \ + __attribute__((overloadable)) void tile_vbroadcast_add( \ + tile_type *t, rtile_type tr) { \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + tile_access(*t, i0, j, sg, br, bc, nbr) \ + += tile_access(tr, i0, 0, rsg, rbr, rbc, rnbr); \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_vbroadcast_sub( \ tile_type *t, rtile_type tr) { \ _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ @@ -533,10 +951,29 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) *= tile_access(tr, i0, 0, rsg, rbr, rbc, rnbr); \ } \ } \ + } \ + __attribute__((overloadable)) void tile_vbroadcast_min( \ + tile_type *t, rtile_type tr) { \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + tile_access(*t, i0, j, sg, br, bc, nbr) \ + = min(tile_access(*t, i0, j, sg, br, bc, nbr), \ + tile_access(tr, i0, 0, rsg, rbr, rbc, rnbr)); \ + } \ + } \ } #define DECLARE_2D_TILE_HREDUCE(tile_type, sg, br, bc, nbr, nbc, rtile_type, \ rsg, rbr, rbc, rnbr, rnbc) \ + __attribute__((overloadable)) void tile_hreduce_add( \ + tile_type t, rtile_type *tr) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + tile_access(*tr, i0, j, rsg, rbr, rbc, rnbr) \ + += tile_access(t, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_hbroadcast_add( \ tile_type *t, rtile_type tr) { \ _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ @@ -546,6 +983,15 @@ DEF_BLOCK2D_LOAD_STORE(float, uint, 8, 16, u32_m8k16v1, 16, 8) } \ } \ } \ + __attribute__((overloadable)) void tile_hbroadcast_sub( \ + tile_type *t, rtile_type tr) { \ + _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \ + tile_access(*t, i0, j, sg, br, bc, nbr) \ + -= xlane_tile_access(tr, j, 0, rsg, rbr, rbc, rnbr); \ + } \ + } \ + } \ __attribute__((overloadable)) void tile_hbroadcast_mul( \ tile_type *t, rtile_type tr) { \ _Pragma("unroll") for (int j = 0; j < bc * nbc; j++) { \ @@ -840,4 +1286,33 @@ __attribute__((overloadable)) void cooperative_prefetch_2d_internal( } } +// inplace load-add-store to SLM, avoids allocating a full intermediate +// accumulator tile. +#define DECLARE_2D_TILE_SLM_ADD(tile_type, element_type, sg, br, bc, nbr, nbc) \ + __attribute__((overloadable)) inline void tile_slm_add(tile_type addend, \ + local element_type *ptr, int ld, int offset_r, int offset_c) { \ + ptr += ld * offset_c + offset_r; \ + _Pragma("unroll") for (int j = 0; j < (bc) * (nbc); j++, ptr += ld) { \ + _Pragma("unroll") for (int i0 = 0; i0 < (br) * (nbr); \ + i0 += (sg)) { \ + int i = i0 + get_sub_group_local_id(); \ + ptr[i] += tile_access(addend, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } + +#define DECLARE_2D_TILE_SLM_ADD_T( \ + tile_type, element_type, sg, br, bc, nbr, nbc) \ + __attribute__((overloadable)) inline void tile_slm_add_t(tile_type addend, \ + local element_type *ptr, int ld, int offset_r, int offset_c) { \ + ptr += ld * offset_r + offset_c; \ + _Pragma("unroll") for (int j = 0; j < (bc) * (nbc); j++, ptr++) { \ + _Pragma("unroll") for (int i0 = 0; i0 < (br) * (nbr); \ + i0 += (sg)) { \ + int i = ld * (i0 + get_sub_group_local_id()); \ + ptr[i] += tile_access(addend, i0, j, sg, br, bc, nbr); \ + } \ + } \ + } + #endif diff --git a/src/gpu/intel/microkernels/fuser.cpp b/src/gpu/intel/microkernels/fuser.cpp index c03be4a9ca1..1262e4c8a6f 100644 --- a/src/gpu/intel/microkernels/fuser.cpp +++ b/src/gpu/intel/microkernels/fuser.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include "elf.hpp" @@ -53,9 +54,8 @@ void fuseMicrokernel(std::vector &binary, "IGC did not generate a valid zebin program binary"); bool foundZeInfo = false; - SectionHeader *text = nullptr; const char *snames = nullptr; - int textSectionID = -1, relSectionID = -1; + std::vector> textSections; auto *sheaders = reinterpret_cast( base + fheaderPtr->sectionTableOff); @@ -73,104 +73,111 @@ void fuseMicrokernel(std::vector &binary, continue; if (sname.substr(0, 6) != ".text.") continue; } - if (text) - throw std::runtime_error("Multiple kernels in program"); - text = sheaders + s; - textSectionID = s; + textSections.emplace_back(sheaders + s, s); break; } default: break; } } - if (!foundZeInfo || !text || text->offset + text->size > bytes) + if (!foundZeInfo || textSections.empty()) throw std::runtime_error( "IGC did not generate a valid zebin program binary"); - std::string rname = ".rel"; - rname += (snames + text->name); - for (int s = 0; s < fheaderPtr->sectionCount; s++) { - if (sheaders[s].type != SectionHeader::Type::Relocation) continue; - if (rname != (snames + sheaders[s].name)) continue; - if (relSectionID >= 0) - throw std::runtime_error("Multiple relocation sections for kernel"); - relSectionID = s; - } - - auto *insn = reinterpret_cast(base + text->offset); - auto *iend = reinterpret_cast( - base + text->offset + text->size); - - const uint8_t *spliceStart = nullptr; - const uint8_t *spliceEnd = nullptr; - - for (; insn < iend; insn += 4) { - if (insn[0] & (1u << 29)) - insn -= 2; - else if (insn[3] == (sigilStart ^ id)) - spliceStart = reinterpret_cast(insn); - else if (insn[3] == (sigilEnd ^ id)) { - spliceEnd = reinterpret_cast(insn); - break; + for (auto &entry : textSections) { + auto *text = entry.first; + int textSectionID = entry.second; + if (text->offset + text->size > bytes) continue; + + auto *insn = reinterpret_cast(base + text->offset); + auto *iend = reinterpret_cast( + base + text->offset + text->size); + + const uint8_t *spliceStart = nullptr; + const uint8_t *spliceEnd = nullptr; + + for (; insn < iend; insn += 4) { + if (insn[0] & (1u << 29)) + insn -= 2; + else if (insn[3] == (sigilStart ^ id)) + spliceStart = reinterpret_cast(insn); + else if (insn[3] == (sigilEnd ^ id)) { + spliceEnd = reinterpret_cast(insn); + break; + } } - } - if (!spliceStart || !spliceEnd) return; + if (!spliceStart || !spliceEnd) continue; + + int relSectionID = -1; + std::string rname = ".rel"; + rname += (snames + text->name); + for (int s = 0; s < fheaderPtr->sectionCount; s++) { + if (sheaders[s].type != SectionHeader::Type::Relocation) continue; + if (rname != (snames + sheaders[s].name)) continue; + if (relSectionID >= 0) + throw std::runtime_error( + "Multiple relocation sections for kernel"); + relSectionID = s; + } - auto removeBytes = spliceEnd - spliceStart + 16; + auto removeBytes = spliceEnd - spliceStart + 16; - size_t before = spliceStart - base; - auto after = bytes - before - removeBytes; - ptrdiff_t sizeAdjust = microkernel.size() - removeBytes; + size_t before = spliceStart - base; + auto after = bytes - before - removeBytes; + ptrdiff_t sizeAdjust = microkernel.size() - removeBytes; - auto kbefore = before - text->offset; - auto kafter = text->size - kbefore - removeBytes; + auto kbefore = before - text->offset; + auto kafter = text->size - kbefore - removeBytes; - std::vector newBinary(bytes + sizeAdjust); - auto newBase = newBinary.data(); + std::vector newBinary(bytes + sizeAdjust); + auto newBase = newBinary.data(); - memmove(newBase, base, before); - memmove(newBase + before, microkernel.data(), microkernel.size()); - memmove(newBase + before + microkernel.size(), spliceStart + removeBytes, - after); + memmove(newBase, base, before); + memmove(newBase + before, microkernel.data(), microkernel.size()); + memmove(newBase + before + microkernel.size(), + spliceStart + removeBytes, after); - fixupJumpTargets(newBase + text->offset, kbefore, +sizeAdjust); - fixupJumpTargets( - newBase + before + microkernel.size(), kafter, -sizeAdjust); + fixupJumpTargets(newBase + text->offset, kbefore, +sizeAdjust); + fixupJumpTargets( + newBase + before + microkernel.size(), kafter, -sizeAdjust); - fheaderPtr = reinterpret_cast(newBase); + fheaderPtr = reinterpret_cast(newBase); - if (fheaderPtr->sectionTableOff > before) - fheaderPtr->sectionTableOff += sizeAdjust; + if (fheaderPtr->sectionTableOff > before) + fheaderPtr->sectionTableOff += sizeAdjust; - sheaders = reinterpret_cast( - newBase + fheaderPtr->sectionTableOff); - sheaders[textSectionID].size += sizeAdjust; - for (int s = 0; s < fheaderPtr->sectionCount; s++) - if (sheaders[s].offset > before) sheaders[s].offset += sizeAdjust; + sheaders = reinterpret_cast( + newBase + fheaderPtr->sectionTableOff); + sheaders[textSectionID].size += sizeAdjust; + for (int s = 0; s < fheaderPtr->sectionCount; s++) + if (sheaders[s].offset > before) sheaders[s].offset += sizeAdjust; - if (relSectionID >= 0) { - auto relSection = sheaders + relSectionID; - auto rel = reinterpret_cast(newBase + relSection->offset); - auto relEnd = reinterpret_cast( - newBase + relSection->offset + relSection->size); - for (; rel < relEnd; rel++) { - if (rel->offset >= kbefore) rel->offset += sizeAdjust; + if (relSectionID >= 0) { + auto relSection = sheaders + relSectionID; + auto rel = reinterpret_cast( + newBase + relSection->offset); + auto relEnd = reinterpret_cast( + newBase + relSection->offset + relSection->size); + for (; rel < relEnd; rel++) { + if (rel->offset >= kbefore) rel->offset += sizeAdjust; + } } - } #ifdef SPLICE_DEBUG - std::ofstream dump0("original.bin"); - dump0.write((const char *)binary.data(), binary.size()); + std::ofstream dump0("original." + std::to_string(id) + ".bin"); + dump0.write((const char *)binary.data(), binary.size()); - std::ofstream dump("patched.bin"); - dump.write((const char *)newBinary.data(), newBinary.size()); + std::ofstream dump("patched." + std::to_string(id) + ".bin"); + dump.write((const char *)newBinary.data(), newBinary.size()); #endif - std::swap(binary, newBinary); + std::swap(binary, newBinary); - // Tail-recurse to handle any further instances of this microkernel - fuseMicrokernel(binary, microkernel, id); + // Tail-recurse to handle any further instances of this microkernel + fuseMicrokernel(binary, microkernel, id); + return; + } } void fuseMicrokernels(std::vector &binary, const char *source) { diff --git a/src/gpu/intel/sdpa/config.hpp b/src/gpu/intel/sdpa/config.hpp index 9ddf3cf6944..fe9d3bb7a50 100644 --- a/src/gpu/intel/sdpa/config.hpp +++ b/src/gpu/intel/sdpa/config.hpp @@ -25,7 +25,8 @@ namespace gpu { namespace intel { namespace sdpa { -using pd_t = sdpa_pd_t; +using fwd_pd_t = sdpa_fwd_pd_t; +using bwd_pd_t = sdpa_bwd_pd_t; } // namespace sdpa } // namespace intel diff --git a/src/gpu/intel/sdpa/configs.cpp b/src/gpu/intel/sdpa/configs.cpp index 6dfbe984fda..a70d9b6a501 100644 --- a/src/gpu/intel/sdpa/configs.cpp +++ b/src/gpu/intel/sdpa/configs.cpp @@ -71,13 +71,21 @@ std::string to_string(const config_criteria_t &c) { << ((bool)(c.prop & property::f16_accumulate) ? " f16_acc" : ""); return s.str(); } -std::string to_string(const config_t &c) { +std::string to_string(const fwd_config_t &c) { std::stringstream s; s << c.unroll_m_kq << "," << c.unroll_n_kq << "," << c.unroll_m_vs << "," << c.unroll_n_vs << "," << c.wg_m_kq << "," << c.wg_n_kq << "," << c.wg_m_vs << "," << c.wg_n_vs; return s.str(); } +std::string to_string(const bwd_config_t &c) { + std::stringstream s; + s << c.unroll_m_BcBr << "," << c.unroll_n_BcBr << "," << c.unroll_m_DBc + << "," << c.unroll_n_DBc << "," << c.unroll_m_DBr << "," << c.unroll_n_DBr + << "," << c.wg_m_BcBr << "," << c.wg_n_BcBr << "," << c.wg_m_DBc << "," + << c.wg_n_DBc << "," << c.wg_m_DBr << "," << c.wg_n_DBr; + return s.str(); +} // A matching config is a combination of mandatory and optional requirements // it is assumed the query criteria are specific whereas key criteria may be approximate @@ -86,32 +94,35 @@ std::string to_string(const config_t &c) { // head size and sequence length must strictly match the inequality with a caveat for // the more general key criteria "any = -1" // properties must match exactly if they are specified in the key criteria -bool operator==(const config_record_t &key, const config_query_t &query) { - bool result = ((query.arch == key.criteria.arch) - && (query.head_size <= key.criteria.head_size) - && ((key.criteria.seq_len == -1) - || (key.criteria.seq_len != -1 - && query.seq_len <= key.criteria.seq_len)) +bool criteria_matches( + const config_criteria_t &key, const config_query_t &query) { + return ((query.arch == key.arch) && (query.head_size <= key.head_size) + && ((key.seq_len == -1) + || (key.seq_len != -1 && query.seq_len <= key.seq_len)) && ((((query.prop & property::second_token) - == (key.criteria.prop & property::second_token))) + == (key.prop & property::second_token))) && (((query.prop & property::quantized) - == (key.criteria.prop & property::quantized))) + == (key.prop & property::quantized))) && (((query.prop & property::fma) - == (key.criteria.prop & property::fma))) - && (((key.criteria.prop & property::f32) == property::none) + == (key.prop & property::fma))) + && (((key.prop & property::f32) == property::none) || ((query.prop & property::f32) - == (key.criteria.prop & property::f32))) - && (((key.criteria.prop & property::f16_accumulate) + == (key.prop & property::f32))) + && (((key.prop & property::f16_accumulate) == property::none) || ((query.prop & property::f16_accumulate) - == (key.criteria.prop - & property::f16_accumulate))) - && (((key.criteria.prop & property::integrated) - == property::none) + == (key.prop & property::f16_accumulate))) + && (((key.prop & property::integrated) == property::none) || ((query.prop & property::integrated) - == (key.criteria.prop - & property::integrated))))); - return result; + == (key.prop & property::integrated))))); +} + +bool operator==(const fwd_config_record_t &key, const config_query_t &query) { + return criteria_matches(key.criteria, query); +} + +bool operator==(const bwd_config_record_t &key, const config_query_t &query) { + return criteria_matches(key.criteria, query); } template @@ -181,7 +192,11 @@ bool operator<(const config_criteria_t &lhs, const config_criteria_t &rhs) { return false; } -bool operator<(const config_record_t &lhs, const config_record_t &rhs) { +bool operator<(const fwd_config_record_t &lhs, const fwd_config_record_t &rhs) { + return lhs.criteria < rhs.criteria; +} + +bool operator<(const bwd_config_record_t &lhs, const bwd_config_record_t &rhs) { return lhs.criteria < rhs.criteria; } @@ -193,13 +208,13 @@ static auto constexpr f32 = property::f32; static auto constexpr f16_accumulate = property::f16_accumulate; // Kernel configurations: [ arch, head_size, {sequence length}, {properties} ] -> config -static std::vector sorted_configs = []() { +static std::vector sorted_configs = []() { // clang-format off - std::vector configs = { + std::vector configs = { // xe_hpg {{compute::gpu_arch_t::xe_hpg, 16, fma | f32}, {8, 16, 8, 16, 2, 4, 2, 4}}, - {{compute::gpu_arch_t::xe_hpg, 32}, {32, 16, 16, 16, 2, 16, 2, 16}}, + {{compute::gpu_arch_t::xe_hpg, 32}, {16, 16, 16, 16, 2, 8, 2, 8}}, {{compute::gpu_arch_t::xe_hpg, 32, 256}, {16, 16, 16, 16, 2, 8, 2, 8}}, {{compute::gpu_arch_t::xe_hpg, 32, 64}, {16, 16, 16, 8, 4, 4, 2, 8}}, {{compute::gpu_arch_t::xe_hpg, 32, 32}, {8, 8, 8, 8, 4, 4, 4, 4}}, @@ -609,9 +624,9 @@ property set_properties(bool is_thin_q, bool is_quantized, bool is_integrated, return properties; } -config_t *choose_config(compute::gpu_arch_t arch, dim_t head_size, dim_t seq, - bool is_thin_q, bool is_quantized, bool is_integrated, bool is_fma, - bool is_f32, bool is_f16_accumulate) { +fwd_config_t *choose_config(compute::gpu_arch_t arch, dim_t head_size, + dim_t seq, bool is_thin_q, bool is_quantized, bool is_integrated, + bool is_fma, bool is_f32, bool is_f16_accumulate) { // quantized FMA for f16 on MTL not implemented in gemmstone if (arch == compute::gpu_arch_t::xe_hpg && is_fma && !is_f32 && is_quantized) @@ -669,12 +684,184 @@ dim_t nearest_conf_seq_interval(compute::gpu_arch_t arch, dim_t head_size, return utils::rnd_up_pow2(seq); } +// Backward pass kernel configurations: +// [ arch, head_size, {sequence length}, {properties} ] -> bwd_config +// bwd_config_t fields: {unroll_m_BcBr, unroll_n_BcBr, +// unroll_m_DBc, unroll_n_DBc, +// unroll_m_DBr, unroll_n_DBr, +// wg_m_BcBr, wg_n_BcBr, +// wg_m_DBc, wg_n_DBc, +// wg_m_DBr, wg_n_DBr} +static std::vector sorted_bwd_configs = []() { + // clang-format off + std::vector configs = { + // xe_hpc + {{compute::gpu_arch_t::xe_hpc, 32}, { 16, 16, 16, 16, 16, 16, 2, 16, 2, 2, 2, 16 }}, + {{compute::gpu_arch_t::xe_hpc, 32, 128}, { 16, 16, 16, 16, 16, 16, 2, 4, 2, 2, 2, 4 }}, + + {{compute::gpu_arch_t::xe_hpc, 64}, { 16, 32, 16, 16, 32, 32, 2, 16, 4, 2, 2, 16 }}, + {{compute::gpu_arch_t::xe_hpc, 64, 64}, { 16, 16, 16, 16, 16, 32, 2, 4, 4, 2, 4, 2 }}, + {{compute::gpu_arch_t::xe_hpc, 64, 77}, { 16, 16, 16, 16, 32, 32, 1, 8, 4, 1, 2, 4 }}, + {{compute::gpu_arch_t::xe_hpc, 64, 128}, { 16, 16, 16, 16, 16, 16, 4, 4, 4, 4, 4, 4 }}, + + {{compute::gpu_arch_t::xe_hpc, 128}, { 16, 16, 16, 16, 32, 32, 2, 8, 8, 2, 4, 4 }}, + //{{compute::gpu_arch_t::xe_hpc, 256}, { 16, 32, 16, 16, 32, 32, 4, 8, 8, 4, 4, 8 }}, + + {{compute::gpu_arch_t::xe_hpc, 32, second_token}, { 16, 16, 16, 16, 32, 16, 1, 2, 2, 1, 1, 2 }}, + {{compute::gpu_arch_t::xe_hpc, 64, second_token}, { 32, 16, 16, 32, 32, 32, 1, 4, 4, 1, 2, 2 }}, + {{compute::gpu_arch_t::xe_hpc, 128, second_token}, { 16, 16, 16, 16, 32, 32, 2, 8, 8, 2, 4, 4 }}, + //{{compute::gpu_arch_t::xe_hpc, 256, second_token}, { 16, 16, 16, 16, 32, 32, 2, 8, 8, 2, 4, 4 }}, + + {{compute::gpu_arch_t::xe_hpc, 32, f32 | fma}, { 32, 32, 16, 16, 32, 32, 1, 4, 2, 2, 1, 4 }}, + {{compute::gpu_arch_t::xe_hpc, 64, f32 | fma}, { 16, 32, 16, 16, 16, 32, 4, 4, 4, 4, 4, 4 }}, + {{compute::gpu_arch_t::xe_hpc, 128, f32 | fma}, { 16, 16, 16, 32, 32, 32, 2, 4, 8, 1, 4, 2 }}, + + {{compute::gpu_arch_t::xe2, 32, integrated}, { 16, 64, 16, 16, 32, 32, 2, 2, 2, 2, 1, 4 }}, + {{compute::gpu_arch_t::xe2, 64, integrated}, { 16, 32, 16, 16, 32, 32, 2, 4, 4, 2, 2, 4 }}, + {{compute::gpu_arch_t::xe2, 128, integrated}, { 16, 32, 16, 16, 32, 32, 4, 8, 8, 4, 4, 8 }}, + + {{compute::gpu_arch_t::xe2, 32}, { 16, 64, 16, 16, 32, 32, 2, 2, 2, 2, 1, 4 }}, + {{compute::gpu_arch_t::xe2, 64}, { 16, 32, 16, 16, 32, 32, 2, 4, 4, 2, 2, 4 }}, + {{compute::gpu_arch_t::xe2, 128}, { 16, 32, 16, 16, 32, 32, 4, 8, 8, 4, 4, 8 }}, + }; + // clang-format on + + // ensures configs appear in order of most to least defined/desirable + std::sort(std::begin(configs), std::end(configs)); + return configs; +}(); + +bwd_config_t *choose_bwd_config(compute::gpu_arch_t arch, dim_t head_size, + dim_t seq, bool is_thin_q, bool is_quantized, bool is_integrated, + bool is_fma, bool is_f32) { + const bool is_f16_accumulate = false; + compute::gpu_arch_t arch_query = (arch >= compute::gpu_arch_t::xe3) + ? compute::gpu_arch_t::xe2 + : arch; + property query_properties = set_properties(is_thin_q, is_quantized, + is_integrated, is_fma, is_f32, is_f16_accumulate); + + config_query_t query(arch_query, static_cast(head_size), + static_cast(seq), query_properties); + auto it = find(begin(sorted_bwd_configs), end(sorted_bwd_configs), query); + if (it != end(sorted_bwd_configs)) { + VDEBUGINFO(4, primitive, sdpa, + "bwd config search: {query %s} -> {%s config:%s},", + to_string(query).c_str(), to_string(it->criteria).c_str(), + to_string(it->config).c_str()); + return &it->config; + } + return nullptr; +} + void deserialize_config_to_gemmstone(gemmstone::HWInformation &hwInfo, gemmstone::GEMMProblem &problem_kq, gemmstone::GEMMProblem &problem_vs, micro::GEMMProtocol::Options &opts_kq, micro::GEMMProtocol::Options &opts_vs, gemmstone::SizeParams &sizes_kq, gemmstone::SizeParams &sizes_vs, - const micro_ukernel_params_t &ukernel_config) { + const micro_fwd_ukernel_params_t &ukernel_config) { + + // hardware info + hwInfo.gmdid = ukernel_config.hwinfo.gmdid; + hwInfo.euCount = ukernel_config.hwinfo.euCount; + hwInfo.systolicAvailable = ukernel_config.hwinfo.systolicAvailable; + + // options kq, vs + auto deserialize_options + = [](micro::GEMMProtocol::Options &gemmstone_opts, + const ukernel_serialized_opts_t &serialized_opts) { + gemmstone_opts.localA = serialized_opts.localA; + gemmstone_opts.localB = serialized_opts.localB; + gemmstone_opts.slmPtr = serialized_opts.slmPtr; + gemmstone_opts.scaleA = serialized_opts.scaleA; + gemmstone_opts.offsetA = serialized_opts.offsetA; + }; + + deserialize_options(opts_kq, ukernel_config.opts_kq); + deserialize_options(opts_vs, ukernel_config.opts_vs); + + // problems kq, vs + auto deserialize_problem + = [](gemmstone::GEMMProblem &problem, + const ukernel_serialized_problem_t &serialized_problem) { + problem.Ta_ext = { + static_cast(serialized_problem.Ta_ext)}; + problem.Tb_ext = { + static_cast(serialized_problem.Tb_ext)}; + problem.Ta + = {static_cast(serialized_problem.Ta)}; + problem.Tb + = {static_cast(serialized_problem.Tb)}; + problem.Tc_ext = { + static_cast(serialized_problem.Tc_ext)}; + problem.Tc + = {static_cast(serialized_problem.Tc)}; + problem.Ts + = {static_cast(serialized_problem.Ts)}; + problem.A.layout = static_cast( + serialized_problem.A_layout); + + problem.Ta_scale = {static_cast( + serialized_problem.Ta_scale)}; + problem.A_scale.setAlignment(serialized_problem.A_scale_alignment); + problem.A_scale.layout = static_cast( + serialized_problem.A_scale_layout); + problem.asPtrDims = serialized_problem.asPtrDims; + problem.Tao + = {static_cast(serialized_problem.Tao)}; + problem.AO.setAlignment(serialized_problem.AO_alignment); + problem.AO.layout = static_cast( + serialized_problem.AO_layout); + problem.aoPtrDims = serialized_problem.aoPtrDims; + problem.aOffset + = static_cast(serialized_problem.aOffset); + problem.aqGroupM = serialized_problem.aqGroupM; + problem.aqGroupK = serialized_problem.aqGroupK; + + problem.B.layout = static_cast( + serialized_problem.B_layout); + problem.C.layout = static_cast( + serialized_problem.C_layout); + problem.A.setAlignment(serialized_problem.A_alignment); + problem.A.crosspack = serialized_problem.A_crosspack; + problem.A.tileR = serialized_problem.A_tileR; + problem.A.tileC = serialized_problem.A_tileC; + + problem.B.setAlignment(serialized_problem.B_alignment); + problem.B.crosspack = serialized_problem.B_crosspack; + problem.B.tileR = serialized_problem.B_tileR; + problem.B.tileC = serialized_problem.B_tileC; + }; + deserialize_problem(problem_kq, ukernel_config.problem_kq); + deserialize_problem(problem_vs, ukernel_config.problem_vs); + + // sizes kq, vs + auto deserialize_sizes + = [](gemmstone::SizeParams &sizes, + const ukernel_serialized_sizes_t &serialized_sizes) { + sizes.m = serialized_sizes.m; + sizes.n = serialized_sizes.n; + sizes.k = serialized_sizes.k; + sizes.batch = serialized_sizes.batch; + }; + deserialize_sizes(sizes_kq, ukernel_config.sizes_kq); + deserialize_sizes(sizes_vs, ukernel_config.sizes_vs); +} + +void deserialize_config_to_gemmstone(gemmstone::HWInformation &hwInfo, + gemmstone::GEMMProblem &problem_kq, gemmstone::GEMMProblem &problem_vs, + gemmstone::GEMMProblem &problem_vtdA, + gemmstone::GEMMProblem &problem_ktq, + gemmstone::GEMMProblem &problem_qdSt, + micro::GEMMProtocol::Options &opts_kq, + micro::GEMMProtocol::Options &opts_vs, + micro::GEMMProtocol::Options &opts_vtdA, + micro::GEMMProtocol::Options &opts_ktq, + micro::GEMMProtocol::Options &opts_qdSt, + gemmstone::SizeParams &sizes_kq, gemmstone::SizeParams &sizes_vs, + gemmstone::SizeParams &sizes_vtdA, gemmstone::SizeParams &sizes_ktq, + gemmstone::SizeParams &sizes_qdSt, + const micro_bwd_ukernel_params_t &ukernel_config) { // hardware info hwInfo.gmdid = ukernel_config.hwinfo.gmdid; @@ -685,6 +872,7 @@ void deserialize_config_to_gemmstone(gemmstone::HWInformation &hwInfo, auto deserialize_options = [](micro::GEMMProtocol::Options &gemmstone_opts, const ukernel_serialized_opts_t &serialized_opts) { + gemmstone_opts.localA = serialized_opts.localA; gemmstone_opts.localB = serialized_opts.localB; gemmstone_opts.slmPtr = serialized_opts.slmPtr; gemmstone_opts.scaleA = serialized_opts.scaleA; @@ -692,6 +880,9 @@ void deserialize_config_to_gemmstone(gemmstone::HWInformation &hwInfo, }; deserialize_options(opts_kq, ukernel_config.opts_kq); deserialize_options(opts_vs, ukernel_config.opts_vs); + deserialize_options(opts_vtdA, ukernel_config.opts_vtdA); + deserialize_options(opts_ktq, ukernel_config.opts_ktq); + deserialize_options(opts_qdSt, ukernel_config.opts_qdSt); // problems kq, vs auto deserialize_problem @@ -736,6 +927,10 @@ void deserialize_config_to_gemmstone(gemmstone::HWInformation &hwInfo, problem.C.layout = static_cast( serialized_problem.C_layout); problem.A.setAlignment(serialized_problem.A_alignment); + problem.A.crosspack = serialized_problem.A_crosspack; + problem.A.tileR = serialized_problem.A_tileR; + problem.A.tileC = serialized_problem.A_tileC; + problem.B.setAlignment(serialized_problem.B_alignment); problem.B.crosspack = serialized_problem.B_crosspack; problem.B.tileR = serialized_problem.B_tileR; @@ -743,6 +938,9 @@ void deserialize_config_to_gemmstone(gemmstone::HWInformation &hwInfo, }; deserialize_problem(problem_kq, ukernel_config.problem_kq); deserialize_problem(problem_vs, ukernel_config.problem_vs); + deserialize_problem(problem_vtdA, ukernel_config.problem_vtdA); + deserialize_problem(problem_ktq, ukernel_config.problem_ktq); + deserialize_problem(problem_qdSt, ukernel_config.problem_qdSt); // sizes kq, vs auto deserialize_sizes @@ -755,6 +953,9 @@ void deserialize_config_to_gemmstone(gemmstone::HWInformation &hwInfo, }; deserialize_sizes(sizes_kq, ukernel_config.sizes_kq); deserialize_sizes(sizes_vs, ukernel_config.sizes_vs); + deserialize_sizes(sizes_vtdA, ukernel_config.sizes_vtdA); + deserialize_sizes(sizes_ktq, ukernel_config.sizes_ktq); + deserialize_sizes(sizes_qdSt, ukernel_config.sizes_qdSt); } } // namespace sdpa diff --git a/src/gpu/intel/sdpa/configs.hpp b/src/gpu/intel/sdpa/configs.hpp index f108c002357..6ac4c502ab4 100644 --- a/src/gpu/intel/sdpa/configs.hpp +++ b/src/gpu/intel/sdpa/configs.hpp @@ -30,13 +30,22 @@ namespace gpu { namespace intel { namespace sdpa { -struct config_t { +struct fwd_config_t { int unroll_m_kq, unroll_n_kq; // Subgroup tile sizes for K*Q GEMM int unroll_m_vs, unroll_n_vs; // Subgroup tile sizes for V*S GEMM int wg_m_kq, wg_n_kq; // Workgroup configuration for K*Q GEMM int wg_m_vs, wg_n_vs; // Workgroup configuration for V*S GEMM }; +struct bwd_config_t { + int unroll_m_BcBr, unroll_n_BcBr; // Subgroup tile sizes for Br*Bc GEMMs + int unroll_m_DBc, unroll_n_DBc; // Subgroup tile sizes for Bc*D GEMMs + int unroll_m_DBr, unroll_n_DBr; // Subgroup tile sizes for Br*D GEMMs + int wg_m_BcBr, wg_n_BcBr; // Workgroup configuration for Br*Bc GEMMs + int wg_m_DBc, wg_n_DBc; // Workgroup configuration for Bc*D GEMMs + int wg_m_DBr, wg_n_DBr; // Workgroup configuration for Br*D GEMMs +}; + enum class property : int { none = 0x0, second_token = 0x1, @@ -85,22 +94,36 @@ struct config_criteria_t { : arch(a), head_size(hs), seq_len(sq), prop(prop) {} }; -struct config_record_t { +struct fwd_config_record_t { config_criteria_t criteria; - config_t config; + fwd_config_t config; }; +struct bwd_config_record_t { + config_criteria_t criteria; + bwd_config_t config; +}; + +// Common criteria matching: returns true if query matches key criteria +bool criteria_matches( + const config_criteria_t &key, const config_query_t &query); + std::ostream &operator<<(std::ostream &s, const config_query_t &q); std::ostream &operator<<(std::ostream &s, const config_criteria_t &c); -std::ostream &operator<<(std::ostream &s, const config_t &c); +std::ostream &operator<<(std::ostream &s, const fwd_config_t &c); -bool operator==(const config_record_t &key, const config_query_t &query); +bool operator==(const fwd_config_record_t &key, const config_query_t &query); +bool operator==(const bwd_config_record_t &key, const config_query_t &query); bool operator<(const config_criteria_t &lhs, const config_criteria_t &rhs); -bool operator<(const config_record_t &lhs, const config_record_t &rhs); +bool operator<(const fwd_config_record_t &lhs, const fwd_config_record_t &rhs); +bool operator<(const bwd_config_record_t &lhs, const bwd_config_record_t &rhs); -config_t *choose_config(compute::gpu_arch_t arch, dim_t head_size, dim_t seq, - bool is_thin_q, bool is_quantized, bool is_integrated, bool is_fma, - bool is_f32, bool is_f16_accumulate); +fwd_config_t *choose_config(compute::gpu_arch_t arch, dim_t head_size, + dim_t seq, bool is_thin_q, bool is_quantized, bool is_integrated, + bool is_fma, bool is_f32, bool is_f16_accumulate); +bwd_config_t *choose_bwd_config(compute::gpu_arch_t arch, dim_t head_size, + dim_t seq, bool is_thin_q, bool is_quantized, bool is_integrated, + bool is_fma, bool is_f32); dim_t round_up_seq_interval(dim_t seq, compute::gpu_arch_t arch); dim_t nearest_conf_seq_interval(compute::gpu_arch_t arch, dim_t head_size, @@ -115,12 +138,13 @@ struct ukernel_serialized_opts_t ukernel_serialized_opts_t() = default; ukernel_serialized_opts_t(micro::GEMMProtocol::Options opts) - : localB(opts.localB) + : localA(opts.localA) + , localB(opts.localB) , slmPtr(opts.slmPtr) , scaleA(opts.scaleA) , offsetA(opts.offsetA) {} - bool localB, slmPtr, scaleA, offsetA; - uint8_t padding[4] = {0}; + bool localA, localB, slmPtr, scaleA, offsetA; + uint8_t padding[3] = {0}; }; DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(ukernel_serialized_opts_t); static_assert(sizeof(ukernel_serialized_opts_t) == 8, @@ -154,6 +178,8 @@ struct ukernel_serialized_sizes_t int64_t m = 0, n = 0, k = 0; }; DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(ukernel_serialized_sizes_t); +static_assert(sizeof(ukernel_serialized_sizes_t) == 32, + "Expected sizeof(ukernel_serialized_sizes_t) == 32"); struct ukernel_serialized_problem_t : trivially_serializable_t { @@ -170,8 +196,11 @@ struct ukernel_serialized_problem_t , A_layout(static_cast(problem.A.layout)) , B_layout(static_cast(problem.B.layout)) , C_layout(static_cast(problem.C.layout)) + , A_tileR(problem.A.tileR) , B_tileR(problem.B.tileR) + , A_tileC(problem.A.tileC) , B_tileC(problem.B.tileC) + , A_crosspack(problem.A.crosspack) , B_crosspack(problem.B.crosspack) , A_alignment(problem.A.alignment) , A_scale_alignment(problem.A_scale.alignment) @@ -194,18 +223,18 @@ struct ukernel_serialized_problem_t int B_layout; int C_layout; - uint16_t B_tileR; - uint16_t B_tileC; - uint8_t B_crosspack; + uint16_t A_tileR, B_tileR; + uint16_t A_tileC, B_tileC; + uint8_t A_crosspack, B_crosspack; uint8_t A_alignment; uint8_t A_scale_alignment; uint8_t AO_alignment; uint8_t B_alignment; // trivially serializable classes require alignment to 8-byte boundaries - // padding0 bumps class size from 49->56 bytes so uint8_t arguments + // padding0 bumps class size from 54->56 bytes so uint8_t arguments // related to alignment can be grouped together rather than placed at the end of the struct - uint8_t padding0[7] = {0}; + uint8_t padding0[2] = {0}; int asPtrDims; int aOffset; @@ -218,13 +247,14 @@ struct ukernel_serialized_problem_t int aoPtrDims; int aqGroupM; int aqGroupK; + uint8_t padding1[4] = {0}; }; DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(ukernel_serialized_problem_t); -static_assert(sizeof(ukernel_serialized_problem_t) == 92, - "Expected sizeof(ukernel_serialized_problem_t) == 92"); +static_assert(sizeof(ukernel_serialized_problem_t) == 96, + "Expected sizeof(ukernel_serialized_problem_t) == 96"); -struct micro_ukernel_params_t - : trivially_serializable_t { +struct micro_fwd_ukernel_params_t + : trivially_serializable_t { int unroll_m_kq, unroll_n_kq; int unroll_m_vs, unroll_n_vs; int wg_m_kq, wg_n_kq; @@ -237,14 +267,60 @@ struct micro_ukernel_params_t ukernel_serialized_opts_t opts_vs; ukernel_serialized_sizes_t sizes_kq, sizes_vs; }; -DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(micro_ukernel_params_t); +DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(micro_fwd_ukernel_params_t); + +struct micro_bwd_ukernel_params_t + : trivially_serializable_t { + int unroll_m_BcBr, unroll_n_BcBr; + int unroll_m_DBc, unroll_n_DBc; + int unroll_m_DBr, unroll_n_DBr; + + int wg_m_BcBr, wg_n_BcBr; + int wg_m_DBc, wg_n_DBc; + int wg_m_DBr, wg_n_DBr; + + ukernel_serialized_hwinfo_t hwinfo; + + ukernel_serialized_problem_t problem_kq; + ukernel_serialized_problem_t problem_vs; + ukernel_serialized_problem_t problem_vtdA; + ukernel_serialized_problem_t problem_ktq; + ukernel_serialized_problem_t problem_qdSt; + + ukernel_serialized_opts_t opts_kq; + ukernel_serialized_opts_t opts_vs; + ukernel_serialized_opts_t opts_vtdA; + ukernel_serialized_opts_t opts_ktq; + ukernel_serialized_opts_t opts_qdSt; + + ukernel_serialized_sizes_t sizes_kq, sizes_vs; + ukernel_serialized_sizes_t sizes_vtdA; + ukernel_serialized_sizes_t sizes_ktq; + ukernel_serialized_sizes_t sizes_qdSt; +}; +DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(micro_bwd_ukernel_params_t); void deserialize_config_to_gemmstone(gemmstone::HWInformation &hwInfo, gemmstone::GEMMProblem &problem_kq, gemmstone::GEMMProblem &problem_vs, micro::GEMMProtocol::Options &opts_kq, micro::GEMMProtocol::Options &opts_vs, gemmstone::SizeParams &sizes_kq, gemmstone::SizeParams &sizes_vs, - const micro_ukernel_params_t &ukernel_config); + const micro_fwd_ukernel_params_t &ukernel_config); + +void deserialize_config_to_gemmstone(gemmstone::HWInformation &hwInfo, + gemmstone::GEMMProblem &problem_kq, gemmstone::GEMMProblem &problem_vs, + gemmstone::GEMMProblem &problem_vtdA, + gemmstone::GEMMProblem &problem_ktq, + gemmstone::GEMMProblem &problem_qdSt, + micro::GEMMProtocol::Options &opts_kq, + micro::GEMMProtocol::Options &opts_vs, + micro::GEMMProtocol::Options &opts_vtdA, + micro::GEMMProtocol::Options &opts_ktq, + micro::GEMMProtocol::Options &opts_qdSt, + gemmstone::SizeParams &sizes_kq, gemmstone::SizeParams &sizes_vs, + gemmstone::SizeParams &sizes_vtdA, gemmstone::SizeParams &sizes_ktq, + gemmstone::SizeParams &sizes_qdSt, + const micro_bwd_ukernel_params_t &ukernel_config); } // namespace sdpa } // namespace intel diff --git a/src/gpu/intel/sdpa/micro.cl b/src/gpu/intel/sdpa/micro.cl index d53312cabed..09adc499722 100644 --- a/src/gpu/intel/sdpa/micro.cl +++ b/src/gpu/intel/sdpa/micro.cl @@ -370,7 +370,7 @@ inline void tile_store_t_slm_src1(q_tile_type *Q_tile, local QRY_DATA_T *Q_slm, __attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, - const global VAL_DATA_T *V, global DST_DATA_T *A, + const global VAL_DATA_T *V, global float *ws, global DST_DATA_T *A, #if WITH_HOST_SCALE float scalar_scale, float inv_scalar_scale, #else @@ -541,7 +541,6 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, #endif #endif #endif - scale *= 1.442695f; // log2(e) } #if PREFETCH_K0 @@ -684,6 +683,7 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, ldkq #endif ); + #if KQ_F16_ACC s_tile_type_float S_tile; tile_copy_reblock(S_tile_f16, &S_tile); @@ -778,11 +778,10 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, /* cache */ LSC_LDCC_L1C_L3C); #endif #endif -#ifndef ALT_MAX + /* Read back WG-wide maxima */ intel_work_group_barrier_wait(CLK_LOCAL_MEM_FENCE); tile_load_full(&S_max_tile, S_max_slm, ugemm_kq_wg_tile_n, sg_j0_kq, 0); -#endif #if SOFTMAX_INF_AS_ZERO #define set_zeros(v) vselect(-FLT_MAX, v, visfinite(v)) @@ -792,39 +791,23 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, tile_vbroadcast_sub(&S_tile, S_max_tile); /* Scale + exponentiate */ -#define scaled_exp(x) native_vexp2(x *scale) +#define scaled_exp(x) native_vexp2(x *scale * 1.442695f) tile_elementwise(S_tile, scaled_exp); - -#ifdef ALT_MAX - /* Read back WG-wide maxima and adjust S to match */ - intel_work_group_barrier_wait(CLK_LOCAL_MEM_FENCE); - s_sum_tile_type S_max_tile1; - tile_copy(S_max_tile, S_max_tile1); - tile_load_full(&S_max_tile, S_max_slm, ugemm_kq_wg_tile_n, sg_j0_kq, 0); - -#define binary_exp_neg(x, y) native_vexp2(scale *((x) - (y))) - tile_binary(S_max_tile1, S_max_tile, binary_exp_neg); - tile_vbroadcast_mul(&S_tile, S_max_tile1); -#endif +#undef scaled_exp /* Accumulate sums. S tile is transposed for easy summation. */ s_sum_tile_type S_sum_tile1; tile_fill(S_sum_tile1, 0.0f); tile_vreduce_add(S_tile, &S_sum_tile1); -#if USE_SYSTOLIC_UKERNEL - /* Convert to half or bf16, VNNI format */ - s_tile_type_packed S_tile_packed; - tile_copy_to_vec2(S_tile, S_tile_packed, VEC_TYPE2); - - /* Store to SLM, in packed format */ - tile_store_t_sys_src2(S_tile_packed, (local uint *)S_slm, - ugemm_vs_sg_tile_n, ugemm_kq_wg_tile_m / 2, sg_i0_kq / 2, - sg_j0_kq); -#else /* Reblock and store to SLM */ s_tile_type_reblock S_tile_reblock; tile_copy_reblock(S_tile, &S_tile_reblock); + +#if USE_SYSTOLIC_UKERNEL + tile_store_t_sys_src2(S_tile_reblock, (local FMA_TYPE *)S_slm, + ugemm_vs_sg_tile_n, ugemm_kq_wg_tile_m, sg_i0_kq, sg_j0_kq); +#else tile_store_block_packed(S_tile_reblock, S_slm, ugemm_vs_sg_tile_n, ugemm_kq_wg_tile_m, sg_j0_kq, sg_i0_kq); #endif @@ -833,7 +816,7 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, /* Rescale existing accumulator and sums to match new maxima */ if (!first) { -#define binary_exp_sub(x, y) native_vexp2(scale *((x) - (y))) +#define binary_exp_sub(x, y) native_vexp2(scale * 1.442695f * ((x) - (y))) #define binary_mul(x, y) ((x) * (y)) tile_binary(S_max_tile_old, S_max_tile, binary_exp_sub); tile_binary(S_sum_tile, S_max_tile_old, binary_mul); @@ -998,6 +981,37 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, if (need_sum_barrier) intel_work_group_barrier_wait(CLK_LOCAL_MEM_FENCE); +#if IS_TRAINING + s_sum_tile_type S_sum_total, S_sum_load; + tile_fill(S_sum_total, 0.f); +#pragma unroll + for (uint sg1 = 0; sg1 < ugemm_kq_sg_per_wg_m; sg1++) { + tile_load_full(&S_sum_load, S_sum_slm, ugemm_kq_wg_tile_n, + ugemm_kq_sg_tile_n * sg_j_kq, sg1); + tile_binary(S_sum_total, S_sum_load, binary_add); + } + +#define log2(x) (native_vlog2(x) * 0.6931471805f) + tile_elementwise(S_sum_total, log2); +#define scale_op(x) ((x) * scale) + tile_elementwise(S_max_tile_old, scale_op); + tile_binary(S_max_tile_old, S_sum_total, binary_add); + +#if SOFTMAX_INF_AS_ZERO +#define lse_set_zeros(v) vselect(0.f, v, visfinite(v)) + tile_elementwise(S_max_tile_old, lse_set_zeros); +#undef lse_set_zeros +#endif + + // save columns logsumexp to workspace for training pass + const uint preprocess_batch = b1 * (DST_D1 * q) + b0 * q; + + global float *ws_logsumexp = ws + preprocess_batch; + tile_store(S_max_tile_old, ws_logsumexp, q, 1, q, sg_j0_kq + wg_j0, + sg_i0_kq); + // sg_i0 specified to avoid OOB subgroups from aliasing +#endif + /* Load column sums from SLM + reduce in registers */ a_scale_tile_type A_scale_tile, A_scale_tile_load; tile_fill(A_scale_tile, 0.0f); diff --git a/src/gpu/intel/sdpa/micro.cpp b/src/gpu/intel/sdpa/micro.cpp index be2ce7e88c1..465c7357c0a 100644 --- a/src/gpu/intel/sdpa/micro.cpp +++ b/src/gpu/intel/sdpa/micro.cpp @@ -60,7 +60,8 @@ bool with_quantize_common(const quant_entry_t &entry) { } /* anonymous namespace */ -status_t update_config_from_devenv_values(config_t *config, bool quantized) { +status_t update_config_from_devenv_values( + fwd_config_t *config, bool quantized) { std::string q_config_str = gpu_utils::dev_getenv("QUANTIZED_SDPA_CONFIG", std::string("")); std::string config_str @@ -98,7 +99,48 @@ status_t update_config_from_devenv_values(config_t *config, bool quantized) { return status::success; } -status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { +status_t update_config_from_devenv_values(bwd_config_t *config) { + std::string bwd_config_str + = gpu_utils::dev_getenv("BWD_SDPA_CONFIG", std::string("")); + if (!bwd_config_str.empty()) { + std::array config_values; + int i; + int num_values = 0; + + stringstream_t ss(bwd_config_str); + while (ss >> i) { + config_values[num_values++] = i; + if (ss.peek() == ',') ss.ignore(); + } + VCHECK_SDPA_COND(num_values == 12, + "BWD_SDPA_CONFIG(%s) is invalid. Must be 12 integers " + "separate by a comma: " + ",," + ",," + ",," + ",," + ",," + ",", + bwd_config_str.c_str()); + if (num_values == 12) { + config->unroll_m_BcBr = config_values[0]; + config->unroll_n_BcBr = config_values[1]; + config->unroll_m_DBc = config_values[2]; + config->unroll_n_DBc = config_values[3]; + config->unroll_m_DBr = config_values[4]; + config->unroll_n_DBr = config_values[5]; + config->wg_m_BcBr = config_values[6]; + config->wg_n_BcBr = config_values[7]; + config->wg_m_DBc = config_values[8]; + config->wg_n_DBc = config_values[9]; + config->wg_m_DBr = config_values[10]; + config->wg_n_DBr = config_values[11]; + } + } + return status::success; +} + +status_t micro_fwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { using namespace jit; using gemm::jit::convert_dnnl_to_kernel_type; @@ -112,16 +154,16 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { "Microkernels not supported by the OpenCL driver."); /* Retrieve pre-tuned kernel configuration */ - config_t *config = nullptr; + fwd_config_t *config = nullptr; const dim_t thin_q_threshold = 16; auto queries = d->queries(); - if (queries == 1) { queries = (d->q_desc.dims[1] / d->kv_head_number); } + if (queries == 1) { queries = (d->q_desc.dims[1] / d->num_kv_heads()); } bool thin_q = (queries <= thin_q_threshold); bool quantized = with_key_scales() || with_key_zp() || with_value_scales() || with_value_zp(); bool is_integrated = intel_engine->device_info()->is_integrated(); - bool is_f32 = (qry_md()->data_type == data_type::f32); + bool is_f32 = (desc()->qry_md()->data_type == data_type::f32); use_systolic_ukernel_ = intel_engine->mayiuse(compute::device_ext_t:: intel_subgroup_matrix_multiply_accumulate) @@ -175,11 +217,12 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { // serializable minimal set of configuration params for ukernels // will be used to generate shim ukernels in reusable kernel_ctx - micro_ukernel_params_t ukernel_params; + micro_fwd_ukernel_params_t ukernel_params; ukernel_params.unroll_m_kq = config->unroll_m_kq; ukernel_params.unroll_n_kq = config->unroll_n_kq; ukernel_params.unroll_m_vs = config->unroll_m_vs; ukernel_params.unroll_n_vs = config->unroll_n_vs; + ukernel_params.wg_m_kq = config->wg_m_kq; ukernel_params.wg_n_kq = config->wg_n_kq; ukernel_params.wg_m_vs = config->wg_m_vs; @@ -191,7 +234,8 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { hw_info.gmdid = dev_info->ip_version(); hw_info.systolicAvailable = use_systolic_ukernel_; - if (hw_info.gmdid == 0) return status::unimplemented; + VDISPATCH_SDPA( + hw_info.gmdid != 0, "gmdid is 0, microkernels not supported."); ukernel_params.hwinfo = {hw_info}; @@ -207,27 +251,28 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { /* Set up GEMMProblem structure for first GEMM: K^T * Q */ GEMMProblem problem; - problem.Ta_ext = convert_dnnl_to_kernel_type(key_md()->data_type); - problem.Tb_ext = convert_dnnl_to_kernel_type(qry_md()->data_type); - if (qry_md()->data_type == data_type::f16) { + problem.Ta_ext = convert_dnnl_to_kernel_type(desc()->key_md()->data_type); + problem.Tb_ext = convert_dnnl_to_kernel_type(desc()->qry_md()->data_type); + if (desc()->qry_md()->data_type == data_type::f16) { problem.Ta = problem.Tb = Type::f16; - } else if (qry_md()->data_type == data_type::bf16) { + } else if (desc()->qry_md()->data_type == data_type::bf16) { problem.Ta = problem.Tb = Type::bf16; - } else if (qry_md()->data_type == data_type::f32) { + } else if (desc()->qry_md()->data_type == data_type::f32) { problem.Ta = problem.Tb = Type::f32; } else { - VCHECK_SDPA_COND(utils::one_of(qry_md()->data_type, data_type::f16, - data_type::bf16), + VCHECK_SDPA_COND(utils::one_of(desc()->qry_md()->data_type, + data_type::f16, data_type::bf16), "Q tensor's data type must be bf16 or f16"); } problem.Tc = problem.Tc_ext = Type::f32; problem.Ts = problem.Tc; auto problem_kq = problem; + problem_kq.Tc = problem_kq.Ts = (kq_acc_dt() == data_type::f16) ? Type::f16 : Type::f32; - problem_kq.A.layout = convert_dnnl_to_kernel_layout(key_md()); + problem_kq.A.layout = convert_dnnl_to_kernel_layout(desc()->key_md()); if (with_key_scales() && !kq_common_scales) { auto scale_dt = key_scales_dt(); @@ -256,10 +301,10 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { problem_kq.B.layout = MatrixLayout::Pr; problem_kq.C.layout = MatrixLayout::T; - const memory_desc_wrapper key_mdw(key_md()); + const memory_desc_wrapper key_mdw(desc()->key_md()); auto ldk = static_cast( - gemm_desc_t::get_ld(*key_md()) * key_mdw.data_type_size()); - problem_kq.A.setAlignment(alignmentForLD(ldk)); + gemm_desc_t::get_ld(*desc()->key_md()) * key_mdw.data_type_size()); + problem_kq.A.setAlignment(alignmentForLD(int(ldk))); problem_kq.B.setAlignment(64); // Q is packed in VNNI format in SLM if (use_systolic_ukernel()) { problem_kq.B.crosspack = 2; @@ -280,7 +325,7 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { /* Set up problem size information */ SizeParams heuristic_sizes; - // quanatizing sizes to large intervals allows kernel + // quantizing sizes to large intervals allows kernel // selection search while avoiding recompilation for every new size heuristic_sizes.m = nearest_conf_seq_interval(arch_, d->head_size(), d->keys(), thin_q, quantized, is_integrated, use_fma_config, is_f32, @@ -291,7 +336,7 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { : utils::rnd_up_pow2(queries); heuristic_sizes.k = d->head_size(); // baked into kernel regardless, no quantization - heuristic_sizes.batch = utils::rnd_up_pow2(d->batch_size()); + heuristic_sizes.batch = utils::rnd_up_pow2(d->batch() * d->num_q_heads()); ukernel_params.sizes_kq = {heuristic_sizes}; @@ -303,8 +348,9 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { bool vs_common_scales = with_quantize_common(d->vs_scales); bool vs_common_zp = with_quantize_common(d->vs_zero_points); - problem_vs.Ta_ext = convert_dnnl_to_kernel_type(val_md()->data_type); - problem_vs.A.layout = convert_dnnl_to_kernel_layout(val_md()); + problem_vs.Ta_ext + = convert_dnnl_to_kernel_type(desc()->val_md()->data_type); + problem_vs.A.layout = convert_dnnl_to_kernel_layout(desc()->val_md()); if (with_value_scales() && !vs_common_scales) { auto scale_dt = value_scales_dt(); problem_vs.Ta_scale = convert_dnnl_to_kernel_type(scale_dt); @@ -332,10 +378,10 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { problem_vs.B.layout = MatrixLayout::Pr; problem_vs.C.layout = MatrixLayout::N; - const memory_desc_wrapper val_mdw(val_md()); + const memory_desc_wrapper val_mdw(desc()->val_md()); auto ldv = static_cast( - gemm_desc_t::get_ld(*val_md()) * val_mdw.data_type_size()); - problem_vs.A.setAlignment(alignmentForLD(ldv)); + gemm_desc_t::get_ld(*desc()->val_md()) * val_mdw.data_type_size()); + problem_vs.A.setAlignment(alignmentForLD(int(ldv))); problem_vs.B.setAlignment(64); // S is packed in SLM if (use_systolic_ukernel()) { problem_vs.B.crosspack = 16; } @@ -365,51 +411,404 @@ status_t micro_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { return status::success; } -status_t micro_t::init(impl::engine_t *engine) { +status_t micro_bwd_t::pd_t::init_conf_microkernels(impl::engine_t *engine) { + using namespace jit; + using gemm::jit::convert_dnnl_to_kernel_type; + + assert(engine->kind() == engine_kind::gpu); + auto *intel_engine = utils::downcast(engine); + auto *dev_info = intel_engine->device_info(); + arch_ = dev_info->gpu_arch(); + auto *d = desc(); + + VCHECK_SDPA_COND(compute::mayiuse_microkernels(intel_engine), + "Microkernels not supported by the OpenCL driver."); + + /* Retrieve pre-tuned kernel configuration */ + bwd_config_t *config = nullptr; + const dim_t thin_q_threshold = 16; + auto queries = d->queries(); + // TODO: q=1 batch group optimizations + // if (queries == 1) { queries = (d->q_desc.dims[1] / d->num_kv_heads()); } + + bool thin_q = (queries <= thin_q_threshold); + bool quantized = false; + bool is_integrated = intel_engine->device_info()->is_integrated(); + bool is_f32 = (desc()->qry_md()->data_type == data_type::f32); + use_systolic_ukernel_ + = intel_engine->mayiuse(compute::device_ext_t:: + intel_subgroup_matrix_multiply_accumulate) + && !is_f32; // f32 -> non-systolic kernel only + + bool use_fma_config = !use_systolic_ukernel_; + config = choose_bwd_config(arch_, d->head_size(), d->keys(), thin_q, + quantized, is_integrated, use_fma_config, is_f32); + + VCHECK_SDPA_COND(config != nullptr, + "No suitable kernel configuration found for the given problem " + "size and attributes."); + + CHECK(update_config_from_devenv_values(config)); + + VDEBUGINFO(4, primitive, sdpa, + "D=%d,K=%d,%s%s%s" + "BcBr_tile(%d, %d): unroll_m=%d unroll_n=%d wg_m=%d wg_n=%d," + "DBc_tile(%d, %d): unroll_m=%d unroll_n=%d wg_m=%d wg_n=%d" + "DBr_tile(%d, %d): unroll_m=%d unroll_n=%d wg_m=%d wg_n=%d", + static_cast(d->head_size()), static_cast(d->keys()), + thin_q ? "thin_q," : "", quantized ? "quant," : "", + is_integrated ? "integrated" : "", + config->unroll_m_BcBr * config->wg_m_BcBr, + config->unroll_n_BcBr * config->wg_n_BcBr, config->unroll_m_BcBr, + config->unroll_n_BcBr, config->wg_m_BcBr, config->wg_n_BcBr, + config->unroll_m_DBc * config->wg_m_DBc, + config->unroll_n_DBc * config->wg_n_DBc, config->unroll_m_DBc, + config->unroll_n_DBc, config->wg_m_DBc, config->wg_n_DBc, + config->unroll_m_DBr * config->wg_m_DBr, + config->unroll_n_DBr * config->wg_n_DBr, config->unroll_m_DBr, + config->unroll_n_DBr, config->wg_m_DBr, config->wg_n_DBr); + + // Bc(Br) == (D)Bc + VCHECK_SDPA_COND( + ((config->unroll_m_BcBr * config->wg_m_BcBr + == config->unroll_n_DBc * config->wg_n_DBc) + && ((config->wg_m_DBc * config->wg_n_DBc) + <= (config->wg_m_BcBr * config->wg_n_BcBr))), + "[CONFIG] The config BcBr work_group tile M(%d) axis must equal " + "DBc work_group tile N(%d) axis and number of total subgroups " + "should be less than BcBr subgroups (%d ?<= %d)", + config->unroll_m_BcBr * config->wg_m_BcBr, + config->unroll_n_DBc * config->wg_n_DBc, + config->wg_m_DBc * config->wg_n_DBc, + config->wg_m_BcBr * config->wg_n_BcBr); + + // D(Bc) >= head size + VCHECK_SDPA_COND(config->unroll_m_DBc * config->wg_m_DBc >= d->head_size(), + "The DBc matmul config work_group tile N(%d*%d=%d) axis must be " + "greater than or equal to head size(%ld)", + config->unroll_m_DBc, config->wg_m_DBc, + config->unroll_m_DBc * config->wg_m_DBc, + static_cast(d->head_size())); + + // (Bc)Br == (D)Br, ngroups <= BcBr ngroups + VCHECK_SDPA_COND(((config->unroll_n_BcBr * config->wg_n_BcBr + == config->unroll_n_DBr * config->wg_n_DBr) + && (config->wg_m_DBr * config->wg_n_DBr + <= config->wg_m_BcBr * config->wg_n_BcBr)), + "[CONFIG] The config BcBr work_group tile N(%d) axis must equal " + "DBr work_group tile N(%d) axis and number of total subgroups " + "should be less than BcBr subgroups (%d ?<= %d)", + config->unroll_n_BcBr * config->wg_n_BcBr, + config->unroll_n_DBr * config->wg_n_DBr, + config->wg_m_DBr * config->wg_n_DBr, + config->wg_m_BcBr * config->wg_n_BcBr); + + // D(Br) >= head size + VCHECK_SDPA_COND(config->unroll_m_DBr * config->wg_m_DBr >= d->head_size(), + "The DBr matmul config work_group tile M(%d*%d=%d) axis must be " + "greater than or equal to head size(%ld)", + config->unroll_m_DBr, config->wg_m_DBr, + config->unroll_m_DBr * config->wg_m_DBr, + static_cast(d->head_size())); + + // serializable minimal set of configuration params for ukernels + // will be used to generate shim ukernels in reusable kernel_ctx + micro_bwd_ukernel_params_t ukernel_params; + + ukernel_params.unroll_m_BcBr = config->unroll_m_BcBr; + ukernel_params.unroll_n_BcBr = config->unroll_n_BcBr; + + ukernel_params.unroll_m_DBc = config->unroll_m_DBc; + ukernel_params.unroll_n_DBc = config->unroll_n_DBc; + + ukernel_params.unroll_m_DBr = config->unroll_m_DBr; + ukernel_params.unroll_n_DBr = config->unroll_n_DBr; + + ukernel_params.wg_m_BcBr = config->wg_m_BcBr; + ukernel_params.wg_n_BcBr = config->wg_n_BcBr; + + ukernel_params.wg_m_DBc = config->wg_m_DBc; + ukernel_params.wg_n_DBc = config->wg_n_DBc; + + ukernel_params.wg_m_DBr = config->wg_m_DBr; + ukernel_params.wg_n_DBr = config->wg_n_DBr; + + /* Get device information */ + HWInformation hw_info; + hw_info.euCount = dev_info->eu_count(); + hw_info.gmdid = dev_info->ip_version(); + hw_info.systolicAvailable = use_systolic_ukernel_; + + VDISPATCH_SDPA( + hw_info.gmdid != 0, "gmdid is 0, microkernels not supported."); + + ukernel_params.hwinfo = {hw_info}; + + sg_size_ = dev_info->min_subgroup_size(); + + auto convert_dnnl_to_kernel_layout = [](const memory_desc_t *md) { + return (gemm_desc_t::get_trans(*md) == dnnl_trans) ? MatrixLayout::T + : MatrixLayout::N; + }; + auto transpose_layout = [](const gemmstone::MatrixLayout l) { + switch (l) { + case MatrixLayout::N: return MatrixLayout::T; + case MatrixLayout::T: return MatrixLayout::N; + case MatrixLayout::Pr: return MatrixLayout::Pc; + case MatrixLayout::Pc: return MatrixLayout::Pr; + default: return l; + } + }; + + /* Set up GEMMProblem structure for first GEMM: K^T * Q */ + GEMMProblem problem; + problem.Ta_ext = convert_dnnl_to_kernel_type(desc()->key_md()->data_type); + problem.Tb_ext = convert_dnnl_to_kernel_type(desc()->qry_md()->data_type); + if (desc()->qry_md()->data_type == data_type::f16) { + problem.Ta = problem.Tb = Type::f16; + } else if (desc()->qry_md()->data_type == data_type::bf16) { + problem.Ta = problem.Tb = Type::bf16; + } else if (desc()->qry_md()->data_type == data_type::f32) { + problem.Ta = problem.Tb = Type::f32; + } else { + VCHECK_SDPA_COND( + utils::one_of(desc()->qry_md()->data_type, data_type::f16, + data_type::bf16, data_type::f32), + "Q tensor's data type must be bf16, f16, or f32"); + } + problem.Tc = problem.Tc_ext = Type::f32; + problem.Ts = problem.Tc; + + const int wg_tile_m_BcBr = config->wg_m_BcBr * config->unroll_m_BcBr; + const int wg_tile_n_BcBr = config->wg_n_BcBr * config->unroll_n_BcBr; + + auto problem_kq = problem; + + problem_kq.A.layout = MatrixLayout::Pc; + problem_kq.B.layout = MatrixLayout::N; + problem_kq.C.layout = MatrixLayout::N; + const memory_desc_wrapper key_mdw(desc()->key_md()); + const memory_desc_wrapper qry_mdw(desc()->qry_md()); + auto ldk = static_cast( + gemm_desc_t::get_ld(*desc()->key_md()) * key_mdw.data_type_size()); + auto ldq = static_cast( + gemm_desc_t::get_ld(*desc()->qry_md()) * qry_mdw.data_type_size()); + problem_kq.A.setAlignment(64); // Q is packed in VNNI format in SLM + if (use_systolic_ukernel()) { + problem_kq.A.crosspack = 2; + problem_kq.A.tileR = into(sg_size_); + problem_kq.A.tileC = into(d_max()); + } + problem_kq.B.setAlignment(alignmentForLD(int(ldq))); + + ukernel_params.problem_kq = {problem_kq}; + + /* Set up microkernel options */ + micro::GEMMProtocol::Options opts_kq; + opts_kq.localA = true; + opts_kq.slmPtr = true; + opts_kq.scaleA = false; + opts_kq.offsetA = false; + + ukernel_params.opts_kq = {opts_kq}; + + /* Set up problem size information */ + SizeParams heuristic_sizes; + heuristic_sizes.m = wg_tile_m_BcBr; + heuristic_sizes.n = wg_tile_n_BcBr; + heuristic_sizes.k = d->head_size(); + heuristic_sizes.batch = 1; + + ukernel_params.sizes_kq = {heuristic_sizes}; + + /* Set up GEMMProblem structure for second GEMM: V * S */ + auto problem_vs = std::move(problem); + problem_vs.Tc = problem_vs.Ts + = (vs_acc_dt() == data_type::f16) ? Type::f16 : Type::f32; + + problem_vs.Ta_ext + = convert_dnnl_to_kernel_type(desc()->val_md()->data_type); + problem_vs.A.layout = convert_dnnl_to_kernel_layout(diff_dst_md()); + problem_vs.B.layout = MatrixLayout::Pr; + problem_vs.C.layout = MatrixLayout::N; + const memory_desc_wrapper diff_dst_mdw(diff_dst_md()); + auto lda = static_cast(gemm_desc_t::get_ld(*diff_dst_md()) + * diff_dst_mdw.data_type_size()); + problem_vs.A.setAlignment(alignmentForLD(int(lda))); + problem_vs.B.setAlignment(64); // S is packed in SLM + if (use_systolic_ukernel()) { problem_vs.B.crosspack = 16; } + + ukernel_params.problem_vs = {problem_vs}; + + // directly tied to config, will recompile w/head size and config updates + // no need for interval quantization + heuristic_sizes.m = d->head_size(); + heuristic_sizes.n = wg_tile_m_BcBr; + heuristic_sizes.k = wg_tile_n_BcBr; + + ukernel_params.sizes_vs = {heuristic_sizes}; + + /* Set up microkernel options */ + micro::GEMMProtocol::Options opts_vs; + opts_vs.localA = false; + opts_vs.localB = true; + opts_vs.slmPtr = true; + + ukernel_params.opts_vs = {opts_vs}; + + //////// Vt * dA + auto problem_vtdA = problem; + problem_vtdA.Ta_ext + = convert_dnnl_to_kernel_type(desc()->val_md()->data_type); + + problem_vtdA.A.layout + = transpose_layout(convert_dnnl_to_kernel_layout(desc()->val_md())); + problem_vtdA.B.layout = convert_dnnl_to_kernel_layout(diff_dst_md()); + problem_vtdA.C.layout = MatrixLayout::N; + const memory_desc_wrapper val_mdw(desc()->val_md()); + auto ldv + = gemm_desc_t::get_ld(*desc()->val_md()) * val_mdw.data_type_size(); + problem_vtdA.A.setAlignment(alignmentForLD(int(ldv))); + problem_vtdA.B.setAlignment(alignmentForLD(int(lda))); + + ukernel_params.problem_vtdA = {problem_vtdA}; + + heuristic_sizes.m = wg_tile_m_BcBr; + heuristic_sizes.n = wg_tile_n_BcBr; + heuristic_sizes.k = d->head_size(); + + ukernel_params.sizes_vtdA = {heuristic_sizes}; + + /* Set up microkernel options */ + micro::GEMMProtocol::Options opts_vtdA; + opts_vtdA.localA = false; + opts_vtdA.localB = false; + opts_vtdA.slmPtr = true; + ukernel_params.opts_vtdA = {opts_vtdA}; + + //////// Q * dS^t + auto problem_qdSt = problem; + problem_qdSt.Ta_ext + = convert_dnnl_to_kernel_type(desc()->qry_md()->data_type); + problem_qdSt.A.layout = MatrixLayout::Pc; + problem_qdSt.B.layout + = transpose_layout(convert_dnnl_to_kernel_layout(desc()->qry_md())); + problem_qdSt.C.layout = MatrixLayout::N; + + problem_qdSt.A.setAlignment(64); + problem_qdSt.B.setAlignment(alignmentForLD(int(ldq))); + if (use_systolic_ukernel()) { + problem_qdSt.A.crosspack = 2; + problem_qdSt.A.tileR = into( + sg_size_); // tile will be transposed (dS^t -> n x m) + problem_qdSt.A.tileC = into(wg_tile_n_BcBr); + } + + ukernel_params.problem_qdSt = {problem_qdSt}; + + heuristic_sizes.m = wg_tile_m_BcBr; + heuristic_sizes.n = d->values(); + heuristic_sizes.k = wg_tile_n_BcBr; + + ukernel_params.sizes_qdSt = {heuristic_sizes}; + + /* Set up microkernel options */ + micro::GEMMProtocol::Options opts_qdSt; + opts_qdSt.localA = true; + opts_qdSt.localB = false; + opts_qdSt.slmPtr = true; + ukernel_params.opts_qdSt = {opts_qdSt}; + + // dS * K + auto problem_ktq = problem; + problem_ktq.Ta_ext + = convert_dnnl_to_kernel_type(desc()->key_md()->data_type); + + problem_ktq.A.layout + = transpose_layout(convert_dnnl_to_kernel_layout(desc()->key_md())); + problem_ktq.B.layout = MatrixLayout::Pr; + problem_ktq.C.layout = MatrixLayout::N; + + problem_ktq.A.setAlignment(alignmentForLD(int(ldk))); + problem_ktq.B.setAlignment(64); // S is packed in SLM + if (use_systolic_ukernel()) { problem_ktq.B.crosspack = 16; } + + ukernel_params.problem_ktq = {problem_ktq}; + + heuristic_sizes.m = d->head_size(); + heuristic_sizes.n = wg_tile_n_BcBr; + heuristic_sizes.k = wg_tile_m_BcBr; + + ukernel_params.sizes_ktq = {heuristic_sizes}; + + /* Set up microkernel options */ + micro::GEMMProtocol::Options opts_ktq; + opts_ktq.localA = false; + opts_ktq.localB = true; + opts_ktq.slmPtr = true; + ukernel_params.opts_ktq = {opts_ktq}; + + conf.ukernel_config = ukernel_params; + + return status::success; +} + +status_t micro_fwd_t::init(impl::engine_t *engine) { CHECK(create_kernel( engine, kernel_, pd()->conf.get_kernel_names()[0], pd()->conf)); + if (!kernel_) return status::runtime_error; return status::success; } -status_t micro_t::pd_t::init_conf(impl::engine_t *engine) { - using namespace micro; +status_t micro_bwd_t::init(impl::engine_t *engine) { + std::vector kernel_names = pd()->conf.get_kernel_names(); - auto *pd = this; + std::vector kernels; + CHECK(create_kernels(engine, kernels, kernel_names, pd()->conf)); + + preprocess_ = kernels[0]; + kernel_ = kernels[1]; + postprocess_ = kernels[2]; + + if (!preprocess_) return status::runtime_error; + if (!kernel_) return status::runtime_error; + if (!postprocess_) return status::runtime_error; + return status::success; +} + +template +static void init_conf_common(conf_t &conf, pd_type *pd) { + using pd_t = sdpa_pd_t; auto *d = pd->desc(); data_type_t data_t = pd->dst_md()->data_type; conf.data_t = data_t; conf.ndims = pd_t::ndims; - const memory_desc_wrapper qry_mdw(pd->qry_md()); - const memory_desc_wrapper key_mdw(pd->key_md()); - const memory_desc_wrapper val_mdw(pd->val_md()); + const memory_desc_wrapper qry_mdw(pd->desc()->qry_md()); + const memory_desc_wrapper key_mdw(pd->desc()->key_md()); + const memory_desc_wrapper val_mdw(pd->desc()->val_md()); const memory_desc_wrapper dst_mdw(pd->dst_md()); - const memory_desc_wrapper msk_mdw(pd->attn_mask_md()); + const memory_desc_wrapper msk_mdw(pd->desc()->attn_mask_md()); conf.key_data_t = key_mdw.data_type(); conf.qry_data_t = qry_mdw.data_type(); conf.val_data_t = val_mdw.data_type(); conf.dst_data_t = dst_mdw.data_type(); - conf.require_stateless_addressing = has_large_buffers(); - conf.msk_data_t = data_type::undef; if (pd->with_attn_mask()) { conf.msk_data_t = msk_mdw.data_type(); } - conf.key_scales_data_t = pd->key_scales_dt(); - conf.value_scales_data_t = pd->value_scales_dt(); - - conf.key_zp_data_t = pd->key_zp_dt(); - conf.value_zp_data_t = pd->value_zp_dt(); - auto Q_num_heads_dim = qry_mdw.dims()[1]; - conf.kv_group_size = static_cast(Q_num_heads_dim / d->kv_head_number); - - auto ldq = gemm_desc_t::get_ld(*pd->qry_md()) * qry_mdw.data_type_size(); - auto ldk = gemm_desc_t::get_ld(*pd->key_md()) * key_mdw.data_type_size(); - auto ldv = gemm_desc_t::get_ld(*pd->val_md()) * val_mdw.data_type_size(); + conf.kv_group_size = static_cast(Q_num_heads_dim / d->num_kv_heads()); + + auto ldq = gemm_desc_t::get_ld(*pd->desc()->qry_md()) + * qry_mdw.data_type_size(); + auto ldk = gemm_desc_t::get_ld(*pd->desc()->key_md()) + * key_mdw.data_type_size(); + auto ldv = gemm_desc_t::get_ld(*pd->desc()->val_md()) + * val_mdw.data_type_size(); auto lda = gemm_desc_t::get_ld(*pd->dst_md()) * dst_mdw.data_type_size(); conf.q_align = alignmentForLD(int(ldq)); @@ -417,22 +816,68 @@ status_t micro_t::pd_t::init_conf(impl::engine_t *engine) { conf.v_align = alignmentForLD(int(ldv)); conf.a_align = alignmentForLD(int(lda)); - conf.transpose_k = gemm_desc_t::get_trans(*pd->key_md()) == dnnl_trans; + conf.transpose_k + = gemm_desc_t::get_trans(*pd->desc()->key_md()) == dnnl_trans; + + conf.scale_data_t = pd->desc()->scale_md()->data_type; - int kq_scale_mask = (static_cast(pd->with_key_scales()) << 1) - | static_cast(with_quantize_common(d->kq_scales)); + conf.attn_mask_undef = attn_mask_type::undef; + conf.attn_mask_buffer = attn_mask_type::buffer; + conf.attn_mask_top_left = attn_mask_type::top_left; + conf.attn_mask_bottom_right = attn_mask_type::bottom_right; + + conf.invert_scale = d->invert_scale; + conf.with_attn_scale = pd->with_attn_scale(); + conf.with_host_scale = pd->with_host_scale(); + conf.with_attn_mask = (pd->with_attn_mask() && !pd->with_causal_mask()); + conf.broadcast_mask_q = (msk_mdw.dims()[pd_t::mask_q_index] == 1); + conf.with_causal_mask = pd->with_causal_mask(); + + conf.subgroup_size = pd->sg_size(); + conf.d_max = pd->d_max(); + + bool d_full = (d->head_size() == pd->d_max()); + conf.d_full = d_full; + conf.arch_gte_hpc = (pd->arch() >= compute::gpu_arch_t::xe_hpc); + + conf.use_systolic_ukernel = pd->use_systolic_ukernel(); +} + +status_t micro_fwd_t::pd_t::init_conf(impl::engine_t *engine) { + using namespace micro; + init_conf_common(conf, this); + + conf.require_stateless_addressing = has_large_buffers(); + + const memory_desc_wrapper qry_mdw(desc()->qry_md()); + const memory_desc_wrapper key_mdw(desc()->key_md()); + const memory_desc_wrapper val_mdw(desc()->val_md()); + const memory_desc_wrapper dst_mdw(dst_md()); + + conf.key_scales_data_t = key_scales_dt(); + conf.value_scales_data_t = value_scales_dt(); + + conf.key_zp_data_t = key_zp_dt(); + conf.value_zp_data_t = value_zp_dt(); + + auto ldq + = gemm_desc_t::get_ld(*desc()->qry_md()) * qry_mdw.data_type_size(); + auto lda = gemm_desc_t::get_ld(*dst_md()) * dst_mdw.data_type_size(); + + int kq_scale_mask = (static_cast(with_key_scales()) << 1) + | static_cast(with_quantize_common(desc()->kq_scales)); conf.kq_scale_mask = kq_scale_mask; - int vs_scale_mask = (static_cast(pd->with_value_scales()) << 1) - | static_cast(with_quantize_common(d->vs_scales)); + int vs_scale_mask = (static_cast(with_value_scales()) << 1) + | static_cast(with_quantize_common(desc()->vs_scales)); conf.vs_scale_mask = vs_scale_mask; - int kq_zp_mask = (static_cast(pd->with_key_zp()) << 1) - | static_cast(with_quantize_common(d->kq_zero_points)); + int kq_zp_mask = (static_cast(with_key_zp()) << 1) + | static_cast(with_quantize_common(desc()->kq_zero_points)); conf.kq_zp_mask = kq_zp_mask; - int vs_zp_mask = (static_cast(pd->with_value_zp()) << 1) - | static_cast(with_quantize_common(d->vs_zero_points)); + int vs_zp_mask = (static_cast(with_value_zp()) << 1) + | static_cast(with_quantize_common(desc()->vs_zero_points)); conf.vs_zp_mask = vs_zp_mask; using namespace data_type; @@ -445,36 +890,19 @@ status_t micro_t::pd_t::init_conf(impl::engine_t *engine) { }; conf.key_elements_per_byte = elems_per_byte(key_mdw.data_type()); - conf.key_zp_elements_per_byte = elems_per_byte(pd->key_zp_dt()); + conf.key_zp_elements_per_byte = elems_per_byte(key_zp_dt()); conf.val_elements_per_byte = elems_per_byte(val_mdw.data_type()); - conf.val_zp_elements_per_byte = elems_per_byte(pd->value_zp_dt()); + conf.val_zp_elements_per_byte = elems_per_byte(value_zp_dt()); conf.key_group_size = 1; conf.val_group_size = 1; - if (pd->with_key_scales() || pd->with_key_zp()) - conf.key_group_size = pd->key_group_size(); - if (pd->with_value_scales() || pd->with_value_zp()) - conf.val_group_size = pd->value_group_size(); - - conf.scale_data_t = pd->scale_md()->data_type; - - conf.attn_mask_undef = attn_mask_type::undef; - conf.attn_mask_buffer = attn_mask_type::buffer; - conf.attn_mask_top_left = attn_mask_type::top_left; - conf.attn_mask_bottom_right = attn_mask_type::bottom_right; - - conf.invert_scale = d->invert_scale; - conf.with_attn_scale = pd->with_attn_scale(); - conf.with_host_scale = pd->with_host_scale(); - conf.with_attn_mask = (pd->with_attn_mask() && !pd->with_causal_mask()); - conf.broadcast_mask_q = (msk_mdw.dims()[pd_t::mask_q_index] == 1); - conf.with_causal_mask = pd->with_causal_mask(); - - conf.subgroup_size = pd->sg_size(); - conf.d_max = pd->d_max(); + if (with_key_scales() || with_key_zp()) + conf.key_group_size = key_group_size(); + if (with_value_scales() || with_value_zp()) + conf.val_group_size = value_group_size(); /* Set up microkernel strategy */ - const config_t config = {conf.ukernel_config.unroll_m_kq, + const fwd_config_t config = {conf.ukernel_config.unroll_m_kq, conf.ukernel_config.unroll_n_kq, conf.ukernel_config.unroll_m_vs, conf.ukernel_config.unroll_n_vs, conf.ukernel_config.wg_m_kq, conf.ukernel_config.wg_n_kq, conf.ukernel_config.wg_m_vs, @@ -486,52 +914,129 @@ status_t micro_t::pd_t::init_conf(impl::engine_t *engine) { int tile_k = kq_wg_tile_m; int tile_v = vs_wg_tile_m; - bool d_full = (d->head_size() == pd->d_max()); - bool v_full = (d->head_size() == tile_v); + bool d_full = conf.d_full; + bool v_full = (desc()->head_size() == tile_v); - auto Q = d->queries(); + auto Q = desc()->queries(); const dim_t Q_per_kv_group = (Q == 1 ? Q * conf.kv_group_size : Q); bool q_full = ((Q_per_kv_group % kq_wg_tile_n) != 0); conf.remainder_q = d_full && q_full; - conf.d_full = d_full; - conf.arch_gte_hpc = (pd->arch() >= compute::gpu_arch_t::xe_hpc); - conf.block_q = conf.block_a = conf.block_2d_a = false; if (d_full) { conf.block_q = (ldq % 4 == 0); - conf.block_a = (lda % 16 == 0 && v_full); - } else if (pd->arch() >= compute::gpu_arch_t::xe_hpc - && (config.unroll_m_vs * dst_mdw.data_type_size()) <= 64) { - auto vbytes = d->values() * val_mdw.data_type_size(); + conf.block_a = (lda % 4 == 0 && v_full); + } else if (arch() >= compute::gpu_arch_t::xe_hpc + && config.unroll_m_vs < 64) { + auto vbytes = desc()->values() * val_mdw.data_type_size(); if (lda % 16 == 0 && vbytes % 4 == 0) conf.block_2d_a = true; } - if (pd->arch() >= compute::gpu_arch_t::xe_hpc) { + if (arch() >= compute::gpu_arch_t::xe_hpc) { conf.prefetch_mask = true; conf.prefetch_k0 = true; conf.prefetch_k = true; conf.prefetch_v = true; - bool no_rem = d_full && v_full && (d->keys() % tile_k == 0); + conf.prefetch_d_max = nstl::min(d_max(), 64); + bool no_rem = d_full && v_full && (desc()->keys() % tile_k == 0); conf.prefetch_remainder = !no_rem; - conf.prefetch_d_max = nstl::min(pd->d_max(), 64); } else { conf.prefetch_mask = conf.prefetch_k0 = conf.prefetch_k = conf.prefetch_v = conf.prefetch_remainder = false; conf.prefetch_d_max = 0; } - const bool arch_gte_xe2 = pd->arch() >= compute::gpu_arch_t::xe2; - conf.q_arrive_await_barrier = (Q > 1) && !arch_gte_xe2; + conf.q_arrive_await_barrier = (Q > 1); conf.softmax_inf_as_zero - = (d->softmax_alg == alg_kind::softmax_accurate_inf_as_zero); - conf.use_systolic_ukernel = pd->use_systolic_ukernel(); + = (desc()->softmax_alg == alg_kind::softmax_accurate_inf_as_zero); conf.kq_f16_accumulate = (kq_acc_dt() == data_type::f16); conf.vs_f16_accumulate = (vs_acc_dt() == data_type::f16); + + bool is_training = desc()->prop_kind == prop_kind::forward_training; + conf.is_training = is_training; + if (is_training) { init_default_ws(); } + + return status::success; +} + +status_t micro_bwd_t::pd_t::init_conf(impl::engine_t *engine) { + init_conf_common(conf, this); + + conf.require_stateless_addressing = has_large_buffers(); + conf.with_dS = with_dS(); + + const memory_desc_wrapper key_mdw(desc()->key_md()); + const memory_desc_wrapper val_mdw(desc()->val_md()); + + auto ldk + = gemm_desc_t::get_ld(*desc()->key_md()) * key_mdw.data_type_size(); + auto ldv + = gemm_desc_t::get_ld(*desc()->val_md()) * val_mdw.data_type_size(); + + /* Set up microkernel strategy */ + const bwd_config_t config = {conf.ukernel_config.unroll_m_BcBr, + conf.ukernel_config.unroll_n_BcBr, conf.ukernel_config.unroll_m_DBc, + conf.ukernel_config.unroll_n_DBc, conf.ukernel_config.unroll_m_DBr, + conf.ukernel_config.unroll_n_DBr, conf.ukernel_config.wg_m_BcBr, + conf.ukernel_config.wg_n_BcBr, conf.ukernel_config.wg_m_DBc, + conf.ukernel_config.wg_n_DBc, conf.ukernel_config.wg_m_DBr, + conf.ukernel_config.wg_n_DBr}; + + const int kq_wg_tile_m = config.wg_m_BcBr * config.unroll_m_BcBr; + const int tile_k = kq_wg_tile_m; + + const int tile_dv = config.wg_n_DBc * config.unroll_n_DBc; + + bool d_full = conf.d_full; + bool dv_full = (desc()->head_size() == tile_dv); + + conf.block_k = conf.block_dK = conf.block_dV = false; + if (d_full) { + bool can_block_load_k + = (ldk % 4 == 0) && (desc()->keys() % tile_k == 0); + conf.block_k = can_block_load_k; + conf.block_dK = can_block_load_k && !conf.transpose_k; + conf.block_dV = (ldv % 4 == 0) && (dv_full); + } + + return status::success; +} + +status_t micro_bwd_t::pd_t::init_scratchpad(impl::engine_t *engine) { + auto scratchpad = scratchpad_registry().registrar(); + auto gpu_align + = utils::downcast(engine)->get_buffer_alignment(); + size_t wspace_size = memory_desc_wrapper(desc()->diff_qry_md()).nelems(); + // f32 can directly atomic add to output + // others need intermediate scratchpad before conversion + if (conf.data_t != data_type::f32) { + scratchpad.book(memory_tracking::names::key_sdpa_dQ_reduction, + wspace_size, sizeof(float), gpu_align); + } + + // for GQA cases multiple Q heads atomic add into shared dK/dV + const bool needs_intermediate_dKV + = (conf.kv_group_size > 1 && conf.data_t != data_type::f32); + if (needs_intermediate_dKV) { + size_t dK_size = memory_desc_wrapper(desc()->diff_key_md()).nelems(); + scratchpad.book(memory_tracking::names::key_sdpa_dK_reduction, dK_size, + sizeof(float), gpu_align); + + size_t dV_size = memory_desc_wrapper(desc()->diff_val_md()).nelems(); + scratchpad.book(memory_tracking::names::key_sdpa_dV_reduction, dV_size, + sizeof(float), gpu_align); + } + + // space for D_i preprocess result + size_t Di_size + = desc()->batch() * desc()->num_q_heads() * desc()->queries(); + scratchpad.book(memory_tracking::names::key_sdpa_Di, Di_size, sizeof(float), + gpu_align); + return status::success; } -status_t micro_params_t::get_kernel_ctx( +status_t micro_fwd_params_t::get_kernel_ctx( compute::kernel_ctx_t &kernel_ctx) const { using namespace micro; kernel_ctx.require_stateless_addressing(require_stateless_addressing); @@ -543,7 +1048,6 @@ status_t micro_params_t::get_kernel_ctx( def_data_type(kernel_ctx, qry_data_t, "QRY"); def_data_type(kernel_ctx, val_data_t, "VAL"); def_data_type(kernel_ctx, dst_data_t, "DST"); - def_data_type(kernel_ctx, scale_data_t, "SCALE", !with_host_scale); if (with_attn_mask) { def_data_type(kernel_ctx, msk_data_t, "MSK"); } @@ -575,6 +1079,7 @@ status_t micro_params_t::get_kernel_ctx( kernel_ctx.define_int("KEY_GROUP_SIZE", key_group_size); kernel_ctx.define_int("VAL_GROUP_SIZE", val_group_size); + def_data_type(kernel_ctx, scale_data_t, "SCALE", !with_host_scale); kernel_ctx.define_int("INVERT_SCALE", invert_scale); kernel_ctx.define_int("WITH_ATTN_SCALE", with_attn_scale); kernel_ctx.define_int("WITH_HOST_SCALE", with_host_scale); @@ -607,6 +1112,7 @@ status_t micro_params_t::get_kernel_ctx( kernel_ctx.define_int("USE_SYSTOLIC_UKERNEL", use_systolic_ukernel); kernel_ctx.define_int("KQ_F16_ACC", kq_f16_accumulate); kernel_ctx.define_int("VS_F16_ACC", vs_f16_accumulate); + kernel_ctx.define_int("IS_TRAINING", is_training); gemmstone::HWInformation hw_info; gemmstone::GEMMProblem problem_kq, problem_vs; @@ -619,7 +1125,7 @@ status_t micro_params_t::get_kernel_ctx( micro::Package gemm_kq, gemm_vs; /* Set up microkernel strategy */ - const config_t config + const fwd_config_t config = {ukernel_config.unroll_m_kq, ukernel_config.unroll_n_kq, ukernel_config.unroll_m_vs, ukernel_config.unroll_n_vs, ukernel_config.wg_m_kq, ukernel_config.wg_n_kq, @@ -665,6 +1171,7 @@ status_t micro_params_t::get_kernel_ctx( "gemm_vs microkernel generation failure with message: %s", ex.what()); } + VDEBUGINFO(4, primitive, sdpa, "kq_gemm: %s, vs_gemm: %s,", problem_kq.toString().c_str(), problem_vs.toString().c_str()); @@ -689,13 +1196,222 @@ status_t micro_params_t::get_kernel_ctx( return status::success; } -status_t micro_t::execute(const exec_ctx_t &ctx) const { +status_t micro_bwd_params_t::get_kernel_ctx( + compute::kernel_ctx_t &kernel_ctx) const { + using namespace micro; + kernel_ctx.require_stateless_addressing(require_stateless_addressing); + + kernel_ctx.define_int("NDIMS", ndims); + kernel_ctx.set_data_type(data_t); + + def_data_type(kernel_ctx, key_data_t, "KEY"); + def_data_type(kernel_ctx, qry_data_t, "QRY"); + def_data_type(kernel_ctx, val_data_t, "VAL"); + def_data_type(kernel_ctx, dst_data_t, "DST"); + + if (with_attn_mask) { def_data_type(kernel_ctx, msk_data_t, "MSK"); } + + kernel_ctx.define_int("KV_GROUP_SIZE", kv_group_size); + + kernel_ctx.define_int("Q_ALIGN", q_align); + kernel_ctx.define_int("K_ALIGN", k_align); + kernel_ctx.define_int("V_ALIGN", v_align); + kernel_ctx.define_int("A_ALIGN", a_align); + + kernel_ctx.define_int("TRANSPOSE_K", transpose_k); + + def_data_type(kernel_ctx, scale_data_t, "SCALE", !with_host_scale); + kernel_ctx.define_int("INVERT_SCALE", invert_scale); + kernel_ctx.define_int("WITH_ATTN_SCALE", with_attn_scale); + kernel_ctx.define_int("WITH_HOST_SCALE", with_host_scale); + kernel_ctx.define_int("ATTN_MASK_UNDEF", attn_mask_undef); + kernel_ctx.define_int("ATTN_MASK_BUFFER", attn_mask_buffer); + kernel_ctx.define_int("ATTN_MASK_TOP_LEFT", attn_mask_top_left); + kernel_ctx.define_int("ATTN_MASK_BOTTOM_RIGHT", attn_mask_bottom_right); + + kernel_ctx.define_int("WITH_ATTN_MASK", with_attn_mask); + kernel_ctx.define_int("BROADCAST_MASK_Q", broadcast_mask_q); + kernel_ctx.define_int("WITH_CAUSAL_MASK", with_causal_mask); + kernel_ctx.define_int("WITH_DS", with_dS); + + kernel_ctx.define_int("SUBGROUP_SIZE", subgroup_size); + kernel_ctx.define_int("D_MAX", d_max); + + kernel_ctx.define_int("BLOCK_K", block_k); + kernel_ctx.define_int("BLOCK_DK", block_dK); + kernel_ctx.define_int("BLOCK_DV", block_dV); + + kernel_ctx.define_int("USE_SYSTOLIC_UKERNEL", use_systolic_ukernel); + + HWInformation hw_info; + gemmstone::GEMMProblem problem_kq, problem_vs; + micro::GEMMProtocol::Options opts_kq, opts_vs; + gemmstone::SizeParams sizes_kq, sizes_vs; + + gemmstone::GEMMProblem problem_vtdA, problem_ktq, problem_qdSt; + micro::GEMMProtocol::Options opts_vtdA, opts_ktq, opts_qdSt; + gemmstone::SizeParams sizes_vtdA, sizes_ktq, sizes_qdSt; + + deserialize_config_to_gemmstone(hw_info, problem_kq, problem_vs, + problem_vtdA, problem_ktq, problem_qdSt, opts_kq, opts_vs, + opts_vtdA, opts_ktq, opts_qdSt, sizes_kq, sizes_vs, sizes_vtdA, + sizes_ktq, sizes_qdSt, ukernel_config); + + micro::Package gemm_kq, gemm_vs, gemm_vtdA, gemm_ktq, gemm_qdSt; + + /* Set up microkernel strategy */ + const bwd_config_t config + = {ukernel_config.unroll_m_BcBr, ukernel_config.unroll_n_BcBr, + ukernel_config.unroll_m_DBc, ukernel_config.unroll_n_DBc, + ukernel_config.unroll_m_DBr, ukernel_config.unroll_n_DBr, + ukernel_config.wg_m_BcBr, ukernel_config.wg_n_BcBr, + ukernel_config.wg_m_DBc, ukernel_config.wg_n_DBc, + ukernel_config.wg_m_DBr, ukernel_config.wg_n_DBr}; + + std::vector reqs_kq; + reqs_kq.push_back(StrategyRequirement::UnrollM == config.unroll_m_BcBr); + reqs_kq.push_back(StrategyRequirement::UnrollN == config.unroll_n_BcBr); + reqs_kq.push_back(StrategyRequirement::WGM == config.wg_m_BcBr); + reqs_kq.push_back(StrategyRequirement::WGN == config.wg_n_BcBr); + + std::vector reqs_vs; + reqs_vs.push_back(StrategyRequirement::UnrollM == config.unroll_m_DBc); + reqs_vs.push_back(StrategyRequirement::UnrollN == config.unroll_n_DBc); + reqs_vs.push_back(StrategyRequirement::WGM == config.wg_m_DBc); + reqs_vs.push_back(StrategyRequirement::WGN == config.wg_n_DBc); + + std::vector reqs_vtdA; + reqs_vtdA.push_back(StrategyRequirement::UnrollM == config.unroll_m_BcBr); + reqs_vtdA.push_back(StrategyRequirement::UnrollN == config.unroll_n_BcBr); + reqs_vtdA.push_back(StrategyRequirement::WGM == config.wg_m_BcBr); + reqs_vtdA.push_back(StrategyRequirement::WGN == config.wg_n_BcBr); + + std::vector reqs_ktq; + reqs_ktq.push_back(StrategyRequirement::UnrollM == config.unroll_m_DBr); + reqs_ktq.push_back(StrategyRequirement::UnrollN == config.unroll_n_DBr); + reqs_ktq.push_back(StrategyRequirement::WGM == config.wg_m_DBr); + reqs_ktq.push_back(StrategyRequirement::WGN == config.wg_n_DBr); + + std::vector reqs_qdSt; + reqs_qdSt.push_back(StrategyRequirement::UnrollM == config.unroll_n_DBc); + reqs_qdSt.push_back(StrategyRequirement::UnrollN == config.unroll_m_DBc); + reqs_qdSt.push_back(StrategyRequirement::WGM == config.wg_n_DBc); + reqs_qdSt.push_back(StrategyRequirement::WGN == config.wg_m_DBc); + + /* Ask microkernel provider for microkernel */ + try { + gemm_kq = selectGEMMMicrokernel( + opts_kq, hw_info, sizes_kq, problem_kq, reqs_kq); + } catch (const std::runtime_error &ex) { + VCHECK_SDPA_COND(false, + "gemm_kq microkernel generation failure with message: %s", + ex.what()); + } + + try { + if (use_systolic_ukernel) { + auto adjust_vs = [](GEMMStrategy &strategy) { + /* Enable dpasw */ + strategy.dpasw |= strategy.fused; + }; + gemm_vs = selectGEMMMicrokernel( + opts_vs, hw_info, sizes_vs, problem_vs, reqs_vs, adjust_vs); + } else { + gemm_vs = selectGEMMMicrokernel( + opts_vs, hw_info, sizes_vs, problem_vs, reqs_vs); + } + } catch (const std::runtime_error &ex) { + VCHECK_SDPA_COND(false, + "gemm_vs microkernel generation failure with message: %s", + ex.what()); + } + + VDEBUGINFO(4, primitive, sdpa, + "kq_gemm: %s, vs_gemm: %s, vtdA_gemm: %s, ktq_gemm: %s, qdSt: %s\n", + problem_kq.toString().c_str(), problem_vs.toString().c_str(), + problem_vtdA.toString().c_str(), problem_ktq.toString().c_str(), + problem_qdSt.toString().c_str()); + + /* Generate microkernel shims */ + micro::ShimOptions shimOptions; + shimOptions.subgroupSize = subgroup_size; + shimOptions.useTileOps = true; + shimOptions.decorator = "kq"; + + std::string gemm_kq_header + = micro::generateShim(gemm_kq, HostLanguage::OpenCL_C, shimOptions); + kernel_ctx.add_custom_header("gemm_kq.h", std::move(gemm_kq_header)); + + shimOptions.microkernelID++; + shimOptions.decorator = "vs"; + + std::string gemm_vs_header + = micro::generateShim(gemm_vs, HostLanguage::OpenCL_C, shimOptions); + kernel_ctx.add_custom_header("gemm_vs.h", std::move(gemm_vs_header)); + + try { + gemm_vtdA = selectGEMMMicrokernel( + opts_vtdA, hw_info, sizes_vtdA, problem_vtdA, reqs_vtdA); + } catch (const std::runtime_error &ex) { + VCHECK_SDPA_COND(false, + "gemm_vtdA microkernel generation failure with message: %s", + ex.what()); + } + + shimOptions.microkernelID++; + shimOptions.decorator = "vtdA"; + + std::string gemm_vtdA_header = micro::generateShim( + gemm_vtdA, HostLanguage::OpenCL_C, shimOptions); + kernel_ctx.add_custom_header("gemm_vtdA.h", std::move(gemm_vtdA_header)); + + try { + gemm_ktq = selectGEMMMicrokernel( + opts_ktq, hw_info, sizes_ktq, problem_ktq, reqs_ktq); + } catch (const std::runtime_error &ex) { + VCHECK_SDPA_COND(false, + "gemm_ktq microkernel generation failure with message: %s", + ex.what()); + } + + shimOptions.microkernelID++; + shimOptions.decorator = "ktq"; + + std::string gemm_ktq_header = micro::generateShim( + gemm_ktq, HostLanguage::OpenCL_C, shimOptions); + kernel_ctx.add_custom_header("gemm_ktq.h", std::move(gemm_ktq_header)); + + try { + gemm_qdSt = selectGEMMMicrokernel( + opts_qdSt, hw_info, sizes_qdSt, problem_qdSt, reqs_qdSt); + } catch (const std::runtime_error &ex) { + VCHECK_SDPA_COND(false, + "gemm_qdSt microkernel generation failure with message: %s", + ex.what()); + } + + shimOptions.microkernelID++; + shimOptions.decorator = "qdSt"; + + std::string gemm_qdSt_header = micro::generateShim( + gemm_qdSt, HostLanguage::OpenCL_C, shimOptions); + kernel_ctx.add_custom_header("gemm_qdSt.h", std::move(gemm_qdSt_header)); + + if (gemm_kq.grfMin > 128 || gemm_vs.grfMin > 128 || gemm_vtdA.grfMin > 128 + || gemm_ktq.grfMin > 128 || gemm_qdSt.grfMin > 128) + kernel_ctx.add_option("-cl-intel-256-GRF-per-thread"); + + return status::success; +} + +status_t micro_fwd_t::execute_forward(const exec_ctx_t &ctx) const { const auto &conf = pd()->conf; const auto &qry = CTX_IN_STORAGE(DNNL_ARG_QUERIES); const auto &key = CTX_IN_STORAGE(DNNL_ARG_KEYS); const auto &val = CTX_IN_STORAGE(DNNL_ARG_VALUES); auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST); + auto &ws = CTX_OUT_STORAGE(DNNL_ARG_WORKSPACE); const auto &scale = CTX_IN_STORAGE(DNNL_ARG_SCALE); const auto &attn_mask = CTX_IN_STORAGE(DNNL_ARG_ATTN_MASK); @@ -714,7 +1430,7 @@ status_t micro_t::execute(const exec_ctx_t &ctx) const { const dim_t D = pd()->desc()->head_size(); const dim_t Q_per_kv_group = (Q == 1 ? Q * kv_group_size : Q); - const config_t config = {conf.ukernel_config.unroll_m_kq, + const fwd_config_t config = {conf.ukernel_config.unroll_m_kq, conf.ukernel_config.unroll_n_kq, conf.ukernel_config.unroll_m_vs, conf.ukernel_config.unroll_n_vs, conf.ukernel_config.wg_m_kq, conf.ukernel_config.wg_n_kq, conf.ukernel_config.wg_m_vs, @@ -724,11 +1440,11 @@ status_t micro_t::execute(const exec_ctx_t &ctx) const { auto wg_tile_q = kq_wg_tile_n; auto sg_per_wg = config.wg_m_kq * config.wg_n_kq; - const memory_desc_wrapper qry_mdw(pd()->qry_md()); - const memory_desc_wrapper key_mdw(pd()->key_md()); - const memory_desc_wrapper val_mdw(pd()->val_md()); + const memory_desc_wrapper qry_mdw(pd()->desc()->qry_md()); + const memory_desc_wrapper key_mdw(pd()->desc()->key_md()); + const memory_desc_wrapper val_mdw(pd()->desc()->val_md()); const memory_desc_wrapper dst_mdw(pd()->dst_md()); - const memory_desc_wrapper msk_mdw(pd()->attn_mask_md()); + const memory_desc_wrapper msk_mdw(pd()->desc()->attn_mask_md()); using offset_t = decltype(offsets_t().src_off); offset_t key_off, qry_off, val_off, dst_off, msk_off; @@ -742,20 +1458,45 @@ status_t micro_t::execute(const exec_ctx_t &ctx) const { //TODO: change arg_list type based on large_idx //bool use_int32_offset = conf.use_int32_offset; - auto append_offs + // pass only the individual stride/dim values + // actually consumed by the kernel to minimize register pressure + auto append_key_offs = [](compute::kernel_arg_list_t &arg_list, const offset_t &offs) { - compute::int64x4_t dims4 - = {offs[3][0], offs[3][1], offs[3][2], offs[3][3]}; - compute::int64x4_t strides4 - = {offs[1][0], offs[1][1], offs[1][2], offs[1][3]}; - arg_list.append(dims4); - arg_list.append(strides4); + arg_list.append((int64_t)offs[1][0]); // KEY_S0 + arg_list.append((int64_t)offs[1][1]); // KEY_S1 + arg_list.append((int64_t)offs[1][2]); // KEY_S2 + arg_list.append((int64_t)offs[1][3]); // KEY_S3 + arg_list.append((int64_t)offs[3][3]); // KEY_D3 + }; + auto append_qry_offs + = [](compute::kernel_arg_list_t &arg_list, const offset_t &offs) { + arg_list.append((int64_t)offs[1][0]); // QRY_S0 + arg_list.append((int64_t)offs[1][1]); // QRY_S1 + arg_list.append((int64_t)offs[1][2]); // QRY_S2 + }; + auto append_val_offs + = [](compute::kernel_arg_list_t &arg_list, const offset_t &offs) { + arg_list.append((int64_t)offs[1][0]); // VAL_S0 + arg_list.append((int64_t)offs[1][1]); // VAL_S1 + arg_list.append((int64_t)offs[1][2]); // VAL_S2 + }; + auto append_dst_offs + = [](compute::kernel_arg_list_t &arg_list, const offset_t &offs) { + arg_list.append((int64_t)offs[1][0]); // DST_S0 + arg_list.append((int64_t)offs[1][1]); // DST_S1 + arg_list.append((int64_t)offs[1][2]); // DST_S2 + arg_list.append((int64_t)offs[3][1]); // DST_D1 + }; + auto append_msk_offs + = [](compute::kernel_arg_list_t &arg_list, const offset_t &offs) { + arg_list.append((int64_t)offs[1][0]); // MSK_S0 + arg_list.append((int64_t)offs[1][1]); // MSK_S1 + arg_list.append((int64_t)offs[1][2]); // MSK_S2 + arg_list.append((int64_t)offs[3][0]); // MSK_D0 + arg_list.append((int64_t)offs[3][1]); // MSK_D1 }; - int mask_type = static_cast(pd()->desc()->mask_type); - compute::kernel_arg_list_t arg_list; - - const memory_desc_wrapper scale_mdw(pd()->scale_md()); + const memory_desc_wrapper scale_mdw(pd()->desc()->scale_md()); float scalar_scale = 1.f; float inv_scalar_scale = 1.f; if (pd()->with_host_scale()) { @@ -766,13 +1507,16 @@ status_t micro_t::execute(const exec_ctx_t &ctx) const { assert(status == status::success); if (status != status::success) return status; scalar_scale = dnnl::impl::cpu::io::load_float_value( - pd()->scale_md()->data_type, &scalar_scale, 0); + pd()->desc()->scale_md()->data_type, &scalar_scale, 0); inv_scalar_scale = 1. / scalar_scale; } + int mask_type = static_cast(pd()->desc()->mask_type); + compute::kernel_arg_list_t arg_list; arg_list.append(key); arg_list.append(qry); arg_list.append(val); + arg_list.append(ws); arg_list.append(dst); if (pd()->with_host_scale()) { arg_list.append(scalar_scale); @@ -790,12 +1534,12 @@ status_t micro_t::execute(const exec_ctx_t &ctx) const { arg_list.append(mask_type); if (pd()->with_attn_mask()) arg_list.append(attn_mask); - append_offs(arg_list, key_off); - append_offs(arg_list, qry_off); - append_offs(arg_list, val_off); - append_offs(arg_list, dst_off); + append_key_offs(arg_list, key_off); + append_qry_offs(arg_list, qry_off); + append_val_offs(arg_list, val_off); + append_dst_offs(arg_list, dst_off); - if (pd()->with_attn_mask()) { append_offs(arg_list, msk_off); } + if (pd()->with_attn_mask()) { append_msk_offs(arg_list, msk_off); } const int remainder_k = (K % kq_wg_tile_m) != 0; arg_list.append(remainder_k); @@ -813,12 +1557,312 @@ status_t micro_t::execute(const exec_ctx_t &ctx) const { gws[0] *= utils::div_up(Q, wg_tile_q); gws[1] *= pd()->dst_md()->dims[1]; } - gws[2] *= pd()->dst_md()->dims[0]; + gws[2] *= pd()->desc()->batch(); auto nd_range = compute::nd_range_t(gws, lws); return parallel_for(ctx, nd_range, kernel_, arg_list); } +status_t micro_bwd_t::execute_backward(const exec_ctx_t &ctx) const { + const auto &qry = CTX_IN_STORAGE(DNNL_ARG_QUERIES); + const auto &key = CTX_IN_STORAGE(DNNL_ARG_KEYS); + const auto &val = CTX_IN_STORAGE(DNNL_ARG_VALUES); + const auto &ws = CTX_IN_STORAGE(DNNL_ARG_WORKSPACE); + const auto &dst = CTX_IN_STORAGE(DNNL_ARG_DST); + const auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); + auto &diff_q = CTX_OUT_STORAGE(DNNL_ARG_DIFF_QUERIES); + auto &diff_k = CTX_OUT_STORAGE(DNNL_ARG_DIFF_KEYS); + auto &diff_v = CTX_OUT_STORAGE(DNNL_ARG_DIFF_VALUES); + const auto &scale = CTX_IN_STORAGE(DNNL_ARG_SCALE); + const auto &attn_mask = CTX_IN_STORAGE(DNNL_ARG_ATTN_MASK); + auto Di_scratch = ctx.get_scratchpad_grantor().get_memory_storage( + memory_tracking::names::key_sdpa_Di); + auto diff_q_scratch = ctx.get_scratchpad_grantor().get_memory_storage( + memory_tracking::names::key_sdpa_dQ_reduction); + auto diff_k_scratch = ctx.get_scratchpad_grantor().get_memory_storage( + memory_tracking::names::key_sdpa_dK_reduction); + auto diff_v_scratch = ctx.get_scratchpad_grantor().get_memory_storage( + memory_tracking::names::key_sdpa_dV_reduction); + + const bool with_dS = pd()->with_dS(); + + const int kv_group_size = pd()->conf.kv_group_size; + const dim_t Q = pd()->desc()->queries(); + const dim_t K = pd()->desc()->keys(); + const dim_t D = pd()->desc()->head_size(); + + const data_type_t data_t = pd()->dst_md()->data_type; + const bool needs_intermediate_dQ = (data_t != data_type::f32); + const bool needs_intermediate_dKV + = (kv_group_size > 1 && data_t != data_type::f32); + const bool needs_zero_dKV = (kv_group_size > 1); + + const auto &conf = pd()->conf; + + const bwd_config_t config = {conf.ukernel_config.unroll_m_BcBr, + conf.ukernel_config.unroll_n_BcBr, conf.ukernel_config.unroll_m_DBc, + conf.ukernel_config.unroll_n_DBc, conf.ukernel_config.unroll_m_DBr, + conf.ukernel_config.unroll_n_DBr, conf.ukernel_config.wg_m_BcBr, + conf.ukernel_config.wg_n_BcBr, conf.ukernel_config.wg_m_DBc, + conf.ukernel_config.wg_n_DBc, conf.ukernel_config.wg_m_DBr, + conf.ukernel_config.wg_n_DBr}; + + auto wg_tile_k = config.unroll_m_BcBr * config.wg_m_BcBr; + auto wg_tile_q = config.unroll_n_BcBr * config.wg_n_BcBr; + + auto sg_per_wg_BcBr = config.wg_m_BcBr * config.wg_n_BcBr; + auto sg_per_wg_DBc = config.wg_m_DBc * config.wg_n_DBc; + auto sg_per_wg_DBr = config.wg_m_DBr * config.wg_n_DBr; + + auto sg_per_wg + = std::max(std::max(sg_per_wg_BcBr, sg_per_wg_DBc), sg_per_wg_DBr); + + const memory_desc_wrapper qry_mdw(pd()->desc()->qry_md()); + const memory_desc_wrapper key_mdw(pd()->desc()->key_md()); + const memory_desc_wrapper val_mdw(pd()->desc()->val_md()); + const memory_desc_wrapper dst_mdw(pd()->dst_md()); + const memory_desc_wrapper msk_mdw(pd()->desc()->attn_mask_md()); + const memory_desc_wrapper diff_dst_mdw(pd()->diff_dst_md()); + const memory_desc_wrapper diff_qry_mdw(pd()->desc()->diff_qry_md()); + const memory_desc_wrapper diff_key_mdw(pd()->desc()->diff_key_md()); + const memory_desc_wrapper diff_val_mdw(pd()->desc()->diff_val_md()); + using offset_t = decltype(offsets_t().src_off); + + offset_t qry_off, key_off, val_off, dst_off, msk_off; + + set_offsets(qry_mdw, qry_off); + set_offsets(key_mdw, key_off); + set_offsets(val_mdw, val_off); + set_offsets(dst_mdw, dst_off); + set_offsets(msk_mdw, msk_off); + + // pass only the individual stride/dim values + // actually consumed by the kernel to minimize register pressure + auto append_key_offs + = [](compute::kernel_arg_list_t &arg_list, const offset_t &offs) { + arg_list.append((int64_t)offs[1][0]); // KEY_S0 + arg_list.append((int64_t)offs[1][1]); // KEY_S1 + arg_list.append((int64_t)offs[1][2]); // KEY_S2 + arg_list.append((int64_t)offs[1][3]); // KEY_S3 + arg_list.append((int64_t)offs[3][3]); // KEY_D3 + }; + auto append_qry_offs + = [](compute::kernel_arg_list_t &arg_list, const offset_t &offs) { + arg_list.append((int64_t)offs[1][0]); // QRY_S0 + arg_list.append((int64_t)offs[1][1]); // QRY_S1 + arg_list.append((int64_t)offs[1][2]); // QRY_S2 + }; + auto append_val_offs + = [](compute::kernel_arg_list_t &arg_list, const offset_t &offs) { + arg_list.append((int64_t)offs[1][0]); // VAL_S0 + arg_list.append((int64_t)offs[1][1]); // VAL_S1 + arg_list.append((int64_t)offs[1][2]); // VAL_S2 + }; + auto append_dst_offs + = [](compute::kernel_arg_list_t &arg_list, const offset_t &offs) { + arg_list.append((int64_t)offs[1][0]); // DST_S0 + arg_list.append((int64_t)offs[1][1]); // DST_S1 + arg_list.append((int64_t)offs[1][2]); // DST_S2 + arg_list.append((int64_t)offs[3][1]); // DST_D1 + }; + auto append_msk_offs + = [](compute::kernel_arg_list_t &arg_list, const offset_t &offs) { + arg_list.append((int64_t)offs[1][0]); // MSK_S0 + arg_list.append((int64_t)offs[1][1]); // MSK_S1 + arg_list.append((int64_t)offs[1][2]); // MSK_S2 + arg_list.append((int64_t)offs[3][0]); // MSK_D0 + arg_list.append((int64_t)offs[3][1]); // MSK_D1 + }; + + int mask_type = static_cast(pd()->desc()->mask_type); + + const memory_desc_wrapper scale_mdw(pd()->desc()->scale_md()); + float scalar_scale = 1.f; + float inv_scalar_scale = 1.f; + if (pd()->with_host_scale()) { + auto scalar_storage = utils::downcast< + const dnnl::impl::host_scalar_memory_storage_t *>(&scale); + auto status = scalar_storage->get_scalar_value( + &scalar_scale, scale_mdw.data_type_size()); + assert(status == status::success); + if (status != status::success) return status; + scalar_scale = dnnl::impl::cpu::io::load_float_value( + pd()->desc()->scale_md()->data_type, &scalar_scale, 0); + inv_scalar_scale = 1. / scalar_scale; + } + + /// preprocess kernel + // will zero dQ, calculate Di + compute::range_t lws = {(size_t)pd()->sg_size(), (size_t)sg_per_wg, 1}; + compute::range_t gws_preprocess = lws; + + gws_preprocess[0] *= utils::div_up(Q, wg_tile_q); + gws_preprocess[1] *= pd()->dst_md()->dims[1]; + gws_preprocess[2] *= pd()->desc()->batch(); + + auto nd_range_preprocess = compute::nd_range_t(gws_preprocess, lws); + + compute::kernel_arg_list_t preprocess_arg_list; + preprocess_arg_list.append(*Di_scratch); + preprocess_arg_list.append(dst); + preprocess_arg_list.append(diff_dst); + preprocess_arg_list.append((int)D); + preprocess_arg_list.append((int)K); + preprocess_arg_list.append((int)Q); + + append_qry_offs(preprocess_arg_list, qry_off); + append_dst_offs(preprocess_arg_list, dst_off); + + CHECK(parallel_for( + ctx, nd_range_preprocess, preprocess_, preprocess_arg_list)); + + auto *d = pd()->desc(); + // zero f32 intermediates before atomic adds in the main kernel + // dQ always needs atomics, dK/dV only for GQA cases + { + auto compute_stream = utils::downcast(ctx.stream()); + auto &fill_deps = compute_stream->ctx().get_deps(); + + const dim_t batch = pd()->dst_md()->dims[0]; + const dim_t num_kv_heads = d->num_kv_heads(); + const dim_t num_q_heads = d->num_q_heads(); + + auto zero_fill + = [&](const memory_storage_t &buf, size_t bytes) -> status_t { + return compute_stream->fill(buf, 0, bytes, fill_deps, fill_deps); + }; + + // always zero dQ + auto &dQ_buf = needs_intermediate_dQ ? *diff_q_scratch : diff_q; + const size_t dQ_bytes = needs_intermediate_dQ + ? size_t(batch * num_q_heads * Q * D) * sizeof(float) + : diff_qry_mdw.size(); + CHECK(zero_fill(dQ_buf, dQ_bytes)); + + // zero dK/dV for GQA cases + if (needs_zero_dKV) { + auto &dK_buf = needs_intermediate_dKV ? *diff_k_scratch : diff_k; + auto &dV_buf = needs_intermediate_dKV ? *diff_v_scratch : diff_v; + const size_t scratch_kv_bytes + = size_t(batch * num_kv_heads * K * D) * sizeof(float); + const size_t dK_bytes = needs_intermediate_dKV + ? scratch_kv_bytes + : diff_key_mdw.size(); + const size_t dV_bytes = needs_intermediate_dKV + ? scratch_kv_bytes + : diff_val_mdw.size(); + CHECK(zero_fill(dK_buf, dK_bytes)); + CHECK(zero_fill(dV_buf, dV_bytes)); + } + } + + /// backwards pass kernel, calculates dK, dV, dQ(float) + compute::kernel_arg_list_t arg_list; + arg_list.append(key); + arg_list.append(qry); + arg_list.append(val); + arg_list.append(ws); + arg_list.append(*Di_scratch); + arg_list.append(dst); + arg_list.append(diff_dst); + if (with_dS) arg_list.append(CTX_OUT_STORAGE(DNNL_ARG_DS)); + arg_list.append(needs_intermediate_dKV ? *diff_k_scratch : diff_k); + arg_list.append(needs_intermediate_dQ ? *diff_q_scratch : diff_q); + arg_list.append(needs_intermediate_dKV ? *diff_v_scratch : diff_v); + if (pd()->with_host_scale()) { + arg_list.append(scalar_scale); + arg_list.append(inv_scalar_scale); + } else { + arg_list.append(scale); + } + arg_list.append((int)D); + arg_list.append((int)K); + arg_list.append((int)Q); + arg_list.append(mask_type); + if (pd()->with_attn_mask()) arg_list.append(attn_mask); + + append_key_offs(arg_list, key_off); + append_qry_offs(arg_list, qry_off); + append_val_offs(arg_list, val_off); + append_dst_offs(arg_list, dst_off); + + if (pd()->with_attn_mask()) { append_msk_offs(arg_list, msk_off); } + const int remainder_k = (K % wg_tile_k) != 0; + + const bool d_full = (d->head_size() == pd()->d_max()); + const int remainder_q = d_full && ((Q % wg_tile_q) != 0); + + arg_list.append(remainder_k); + arg_list.append(remainder_q); + + compute::range_t gws = lws; + + gws[0] *= utils::div_up(K, wg_tile_k); + gws[1] *= pd()->dst_md()->dims[1]; + gws[2] *= pd()->desc()->batch(); + auto nd_range = compute::nd_range_t(gws, lws); + + CHECK(parallel_for(ctx, nd_range, kernel_, arg_list)); + + /// postprocessing kernels + // will cast dQ/dK/dV to lower precision outputs if needed + if (needs_intermediate_dQ) { + static constexpr size_t lws_pp = 256; + compute::range_t lws_p = {(size_t)lws_pp, 1, 1}; + compute::range_t gws_p = lws_p; + gws_p[0] *= utils::div_up(Q * D, lws_pp); + gws_p[1] *= pd()->dst_md()->dims[1]; // Q heads + gws_p[2] *= pd()->desc()->batch(); + + compute::kernel_arg_list_t pp; + pp.append(diff_q); + pp.append(*diff_q_scratch); + pp.append((int)(Q * D)); + append_qry_offs(pp, qry_off); + CHECK(parallel_for( + ctx, compute::nd_range_t(gws_p, lws_p), postprocess_, pp)); + } + + if (needs_intermediate_dKV) { + const dim_t num_kv_heads = d->num_kv_heads(); + static constexpr size_t lws_pp = 256; + compute::range_t lws_p = {(size_t)lws_pp, 1, 1}; + + // dK + { + compute::range_t gws_p = lws_p; + gws_p[0] *= utils::div_up(K * D, lws_pp); + gws_p[1] *= num_kv_heads; // KV heads + gws_p[2] *= pd()->desc()->batch(); + + compute::kernel_arg_list_t pp; + pp.append(diff_k); + pp.append(*diff_k_scratch); + pp.append((int)(K * D)); + append_qry_offs(pp, key_off); + CHECK(parallel_for( + ctx, compute::nd_range_t(gws_p, lws_p), postprocess_, pp)); + } + // dV + { + compute::range_t gws_p = lws_p; + gws_p[0] *= utils::div_up(K * D, lws_pp); + gws_p[1] *= num_kv_heads; + gws_p[2] *= pd()->desc()->batch(); + + compute::kernel_arg_list_t pp; + pp.append(diff_v); + pp.append(*diff_v_scratch); + pp.append((int)(K * D)); + append_qry_offs(pp, val_off); + CHECK(parallel_for( + ctx, compute::nd_range_t(gws_p, lws_p), postprocess_, pp)); + } + } + + return status::success; +} + } // namespace sdpa } // namespace intel } // namespace gpu diff --git a/src/gpu/intel/sdpa/micro.hpp b/src/gpu/intel/sdpa/micro.hpp index 33a14f27819..f109db67c32 100644 --- a/src/gpu/intel/sdpa/micro.hpp +++ b/src/gpu/intel/sdpa/micro.hpp @@ -36,11 +36,12 @@ namespace gpu { namespace intel { namespace sdpa { -struct micro_params_t : trivially_serializable_t { +struct micro_fwd_params_t : trivially_serializable_t { const std::vector &get_kernel_names() const { - static const std::vector kernel_names = {"micro_sdpa"}; - return kernel_names; + static const std::vector kernel_names_fwd + = {"micro_sdpa"}; + return kernel_names_fwd; } status_t create_generator(const intel::engine_t &engine, @@ -91,34 +92,82 @@ struct micro_params_t : trivially_serializable_t { bool use_systolic_ukernel; bool kq_f16_accumulate, vs_f16_accumulate; bool require_stateless_addressing; - uint8_t padding3[6] = {0}; + bool is_training; + uint8_t padding3[5] = {0}; - micro_ukernel_params_t ukernel_config; + micro_fwd_ukernel_params_t ukernel_config; }; -DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(micro_params_t); +DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(micro_fwd_params_t); -struct micro_t : public primitive_t { +struct micro_bwd_params_t : trivially_serializable_t { + + const std::vector &get_kernel_names() const { + static const std::vector kernel_names_bwd + = {"preprocess_Di", "micro_sdpa_bwd", "postprocess_dQ"}; + return kernel_names_bwd; + } + + status_t create_generator(const intel::engine_t &engine, + compute::kernel_bundle_t &bundle) const { + compute::kernel_ctx_t kernel_ctx; + CHECK(get_kernel_ctx(kernel_ctx)); + auto status = engine.create_kernel_bundle( + bundle, get_kernel_names(), kernel_ctx); + return status; + } + + status_t get_kernel_ctx(compute::kernel_ctx_t &) const; + + int ndims; + int kv_group_size; + data_type_t data_t; + data_type_t dst_data_t, key_data_t, qry_data_t, val_data_t, msk_data_t; + + int q_align, k_align, v_align, a_align; + bool transpose_k; + uint8_t padding0[3] = {0}; + + int key_group_size, val_group_size; + data_type_t scale_data_t; + + int attn_mask_undef, attn_mask_buffer, attn_mask_top_left, + attn_mask_bottom_right; + bool invert_scale, with_attn_scale, with_host_scale, with_attn_mask, + broadcast_mask_q, with_causal_mask; + uint8_t padding1[2] = {0}; + int subgroup_size, d_max; + + bool d_full, arch_gte_hpc; + bool block_k, block_dK, block_dV; + bool remainder_q; + bool use_systolic_ukernel; + bool with_dS; + bool require_stateless_addressing; + uint8_t padding2[7] = {0}; + + micro_bwd_ukernel_params_t ukernel_config; +}; +DNNL_ASSERT_TRIVIALLY_SERIALIZABLE(micro_bwd_params_t); + +struct micro_fwd_t : public primitive_t { using primitive_t::primitive_t; - struct pd_t : public sdpa::pd_t { - using sdpa::pd_t::pd_t; - static constexpr int mask_mb_index = 0; - static constexpr int mask_q_index = 2; - static constexpr int mask_k_index = 3; - static constexpr int ndims = 4; + struct pd_t : public sdpa_fwd_pd_t { + using sdpa_fwd_pd_t::sdpa_fwd_pd_t; - DECLARE_COMMON_PD_T("ocl:micro:reusable", micro_t); + DECLARE_COMMON_PD_T("ocl:micro:reusable", micro_fwd_t); status_t init(impl::engine_t *engine) { using namespace data_type; - VCHECK_SDPA_COND( - utils::everyone_is(4, qry_md()->ndims, key_md()->ndims, - val_md()->ndims, dst_md()->ndims), + VCHECK_SDPA_COND(is_fwd(), VERBOSE_BAD_PROPKIND); + VCHECK_SDPA_COND(utils::everyone_is(4, desc()->qry_md()->ndims, + desc()->key_md()->ndims, + desc()->val_md()->ndims, dst_md()->ndims), VERBOSE_UNSUPPORTED_TAG); - memory_desc_wrapper qry_mdw(qry_md()); - memory_desc_wrapper key_mdw(key_md()); - memory_desc_wrapper val_mdw(val_md()); + memory_desc_wrapper qry_mdw(desc()->qry_md()); + memory_desc_wrapper key_mdw(desc()->key_md()); + memory_desc_wrapper val_mdw(desc()->val_md()); memory_desc_wrapper dst_mdw(dst_md()); VCHECK_SDPA_COND(utils::everyone_is(true, qry_mdw.is_plain(), key_mdw.is_plain(), val_mdw.is_plain(), @@ -126,72 +175,77 @@ struct micro_t : public primitive_t { VERBOSE_UNSUPPORTED_TAG); if (with_attn_mask()) { + VCHECK_SDPA_COND(desc()->attn_mask_md()->ndims == 4, + VERBOSE_UNSUPPORTED_TAG); VCHECK_SDPA_COND( - attn_mask_md()->ndims == 4, VERBOSE_UNSUPPORTED_TAG); - VCHECK_SDPA_COND( - utils::one_of(attn_mask_md()->dims[mask_q_index], + utils::one_of( + desc()->attn_mask_md()->dims[mask_q_index], desc()->queries(), 1), VERBOSE_INVALID_BROADCAST, "attn_mask", mask_q_index); - VCHECK_SDPA_COND( - attn_mask_md()->dims[mask_k_index] == desc()->keys(), + VCHECK_SDPA_COND(desc()->attn_mask_md()->dims[mask_k_index] + == desc()->keys(), VERBOSE_INVALID_BROADCAST, "attn_mask", mask_k_index); - if (qry_md()->data_type == data_type::f32) { - VCHECK_SDPA_COND( - attn_mask_md()->data_type == qry_md()->data_type, + if (desc()->qry_md()->data_type == data_type::f32) { + VCHECK_SDPA_COND(desc()->attn_mask_md()->data_type + == desc()->qry_md()->data_type, "Mask data type(%s) should match Qry/Dst data " "type(%s).", - dnnl_dt2str(attn_mask_md()->data_type), - dnnl_dt2str(qry_md()->data_type)); + dnnl_dt2str(desc()->attn_mask_md()->data_type), + dnnl_dt2str(desc()->qry_md()->data_type)); } else { - VCHECK_SDPA_COND( - (attn_mask_md()->data_type == qry_md()->data_type) - || (attn_mask_md()->data_type + VCHECK_SDPA_COND((desc()->attn_mask_md()->data_type + == desc()->qry_md()->data_type) + || (desc()->attn_mask_md()->data_type == data_type::f32), "Mask data type(%s) should be xf16 or f32 when " "Qry/Dst(%s) is xf16.", - dnnl_dt2str(attn_mask_md()->data_type), - dnnl_dt2str(qry_md()->data_type)); + dnnl_dt2str(desc()->attn_mask_md()->data_type), + dnnl_dt2str(desc()->qry_md()->data_type)); } } VCHECK_SDPA_COND( - (utils::everyone_is(data_type::f16, qry_md()->data_type, - dst_md()->data_type) + (utils::everyone_is(data_type::f16, + desc()->qry_md()->data_type, dst_md()->data_type) || utils::everyone_is(data_type::bf16, - qry_md()->data_type, dst_md()->data_type) + desc()->qry_md()->data_type, + dst_md()->data_type) || utils::everyone_is(data_type::f32, - qry_md()->data_type, dst_md()->data_type)), + desc()->qry_md()->data_type, + dst_md()->data_type)), VERBOSE_UNSUPPORTED_DT); - VCHECK_SDPA_COND(utils::one_of(key_md()->data_type, f32, bf16, f16, - u8, s8, u4, s4), + VCHECK_SDPA_COND(utils::one_of(desc()->key_md()->data_type, f32, + bf16, f16, u8, s8, u4, s4), VERBOSE_UNSUPPORTED_DT); - VCHECK_SDPA_COND(utils::one_of(val_md()->data_type, f32, bf16, f16, - u8, s8, u4, s4), + VCHECK_SDPA_COND(utils::one_of(desc()->val_md()->data_type, f32, + bf16, f16, u8, s8, u4, s4), VERBOSE_UNSUPPORTED_DT); VCHECK_SDPA_COND(set_default_formats() == status::success, VERBOSE_UNSUPPORTED_TAG); VCHECK_SDPA_COND(desc()->values() == desc()->head_size(), "values does not match head size"); - if (utils::one_of(key_md()->data_type, u4, s4)) { + if (utils::one_of(desc()->key_md()->data_type, u4, s4)) { VCHECK_SDPA_COND(desc()->keys() % 2 == 0, "The number of keys must be an even size with the data " "type is u4 or s4."); } - if (utils::one_of(val_md()->data_type, u4, s4)) { + if (utils::one_of(desc()->val_md()->data_type, u4, s4)) { VCHECK_SDPA_COND(desc()->values() % 2 == 0, "The number of values must be an even size with the " "data type is u4 or s4."); } - VCHECK_SDPA_COND(qry_md()->dims[1] >= key_md()->dims[1] - && qry_md()->dims[1] >= val_md()->dims[1], + VCHECK_SDPA_COND( + desc()->qry_md()->dims[1] >= desc()->key_md()->dims[1] + && desc()->qry_md()->dims[1] + >= desc()->val_md()->dims[1], "number of heads in query tensor(%ld) must be greater " "than the number of heads in the key(%ld) and value(%ld) " "tensors", - static_cast(qry_md()->dims[1]), - static_cast(key_md()->dims[1]), - static_cast(val_md()->dims[1])); + static_cast(desc()->qry_md()->dims[1]), + static_cast(desc()->key_md()->dims[1]), + static_cast(desc()->val_md()->dims[1])); VCHECK_SDPA_COND(utils::one_of(kq_acc_dt(), f16, f32), "KQ accumulation data type should be f16 or f32"); @@ -249,32 +303,184 @@ struct micro_t : public primitive_t { int vgs = value_group_size(); VCHECK_SDPA_COND(utils::one_of(vs_scales_mask, 0, 1, 3) || (math::is_pow2(vgs) - || vgs == val_md()->dims[3]), + || vgs == desc()->val_md()->dims[3]), "the value group size(%d) must be a power of 2 or " "equal to the number of values(%ld).", - vgs, static_cast(val_md()->dims[3])); + vgs, static_cast(desc()->val_md()->dims[3])); } CHECK(init_conf_microkernels(engine)); CHECK(init_conf(engine)); - - VCHECK_SDPA_COND( - IMPLICATION((arch() == compute::gpu_arch_t::xe_hpc) - && (qry_md()->data_type == data_type::f32), - with_causal_mask()), + VCHECK_SDPA_COND(IMPLICATION((arch() == compute::gpu_arch_t::xe_hpc) + && (desc()->qry_md()->data_type + == data_type::f32), + with_causal_mask()), "fused f32 SDPA only optimized for causal mask"); //TODO: update when performance improved return status::success; } + status_t set_default_formats() { + CHECK(set_default_format(desc_.q_desc, false)); + CHECK(set_default_format(desc_.k_desc, true)); + CHECK(set_default_format(desc_.v_desc, false)); + CHECK(set_default_format(desc_.dst_desc, false)); + return status::success; + } + + int sg_size() const { return sg_size_; } + bool use_systolic_ukernel() const { return use_systolic_ukernel_; } + + // Block size for head_size, which must be hard-coded into the kernel. + int d_max() const { + int head_size = into(desc()->head_size()); + for (int i = 32; i <= 1024; i *= 2) + if (head_size <= i) return i; + return head_size; + } + + compute::gpu_arch_t arch() const { return arch_; } + micro_fwd_params_t conf; + + private: + int sg_size_ = 0; + bool use_systolic_ukernel_ = true; + compute::gpu_arch_t arch_ = compute::gpu_arch_t::unknown; + + status_t init_conf_microkernels(impl::engine_t *engine); + status_t init_conf(impl::engine_t *engine); + status_t set_default_format(memory_desc_t &md, bool allow_transpose) { using namespace format_tag; memory_desc_wrapper mdw(md); - if (mdw.format_any()) return status::unimplemented; - if (!is_md_gemm_compatible_plain_format(&md)) - return status::unimplemented; - if (gemm_desc_t::get_trans(md) == dnnl_trans && !allow_transpose) - return status::unimplemented; + VCHECK_SDPA_UNIMPL(!mdw.format_any(), VERBOSE_UNSUPPORTED_TAG); + VCHECK_SDPA_UNIMPL(is_md_gemm_compatible_plain_format(&md), + VERBOSE_UNSUPPORTED_TAG); + VCHECK_SDPA_UNIMPL( + IMPLICATION(gemm_desc_t::get_trans(md) == dnnl_trans, + allow_transpose), + VERBOSE_UNSUPPORTED_TAG); + return status::success; + } + }; + + status_t init(impl::engine_t *engine) override; + + status_t execute(const exec_ctx_t &ctx) const override { + return execute_forward(ctx); + } + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } + status_t execute_forward(const exec_ctx_t &ctx) const; + + compute::kernel_t kernel_; +}; + +struct micro_bwd_t : public primitive_t { + + using primitive_t::primitive_t; + struct pd_t : public sdpa_bwd_pd_t { + using sdpa_bwd_pd_t::sdpa_bwd_pd_t; + + DECLARE_COMMON_PD_T("ocl:micro:reusable", micro_bwd_t); + + status_t init(impl::engine_t *engine) { + using namespace data_type; + + VCHECK_SDPA_COND(!is_fwd(), VERBOSE_BAD_PROPKIND); + + VCHECK_SDPA_COND(utils::everyone_is(4, desc()->qry_md()->ndims, + desc()->key_md()->ndims, + desc()->val_md()->ndims, dst_md()->ndims), + VERBOSE_UNSUPPORTED_TAG); + VCHECK_SDPA_COND( + utils::everyone_is(4, desc()->diff_qry_md()->ndims, + desc()->diff_key_md()->ndims, + desc()->diff_val_md()->ndims, diff_dst_md()->ndims), + VERBOSE_UNSUPPORTED_TAG); + if (with_attn_mask()) { + VCHECK_SDPA_COND(desc()->attn_mask_md()->ndims == 4, + VERBOSE_UNSUPPORTED_TAG); + VCHECK_SDPA_COND( + utils::one_of( + desc()->attn_mask_md()->dims[mask_q_index], + desc()->queries(), 1), + VERBOSE_INVALID_BROADCAST, "attn_mask", mask_q_index); + VCHECK_SDPA_COND(desc()->attn_mask_md()->dims[mask_k_index] + == desc()->keys(), + VERBOSE_INVALID_BROADCAST, "attn_mask", mask_k_index); + VCHECK_SDPA_COND(desc()->attn_mask_md()->data_type + == desc()->qry_md()->data_type, + "Mask data type should match Qry/Dst data type."); + } + VCHECK_SDPA_COND( + (utils::everyone_is(data_type::f16, + desc()->qry_md()->data_type, dst_md()->data_type) + || utils::everyone_is(data_type::bf16, + desc()->qry_md()->data_type, + dst_md()->data_type) + || utils::everyone_is(data_type::f32, + desc()->qry_md()->data_type, + dst_md()->data_type)), + VERBOSE_UNSUPPORTED_DT); + VCHECK_SDPA_COND( + utils::one_of(desc()->key_md()->data_type, f32, bf16, f16), + VERBOSE_UNSUPPORTED_DT); + VCHECK_SDPA_COND( + utils::one_of(desc()->val_md()->data_type, f32, bf16, f16), + VERBOSE_UNSUPPORTED_DT); + VCHECK_SDPA_COND(set_default_formats() == status::success, + VERBOSE_UNSUPPORTED_TAG); + VCHECK_SDPA_COND(desc()->values() == desc()->head_size(), + "values does not match head size"); + + VCHECK_SDPA_COND( + desc()->qry_md()->dims[1] >= desc()->key_md()->dims[1] + && desc()->qry_md()->dims[1] + >= desc()->val_md()->dims[1], + "number of heads in query tensor(%ld) must be greater " + "than the number of heads in the key(%ld) and value(%ld) " + "tensors", + static_cast(desc()->qry_md()->dims[1]), + static_cast(desc()->key_md()->dims[1]), + static_cast(desc()->val_md()->dims[1])); + { + memory_desc_wrapper diff_qry_mdw(desc()->diff_qry_md()); + memory_desc_wrapper diff_key_mdw(desc()->diff_key_md()); + memory_desc_wrapper diff_val_mdw(desc()->diff_val_md()); + memory_desc_wrapper diff_dst_mdw(diff_dst_md()); + VCHECK_SDPA_COND( + utils::everyone_is(true, diff_qry_mdw.is_plain(), + diff_key_mdw.is_plain(), + diff_val_mdw.is_plain(), + diff_dst_mdw.is_plain()), + VERBOSE_UNSUPPORTED_TAG); + } + + VCHECK_SDPA_COND(utils::everyone_is(desc()->qry_md()->data_type, + desc()->diff_qry_md()->data_type, + desc()->diff_key_md()->data_type, + desc()->diff_val_md()->data_type, + diff_dst_md()->data_type), + "diff tensor data types must match qry data type(%s) " + " ?= dQ(%s), dK(%s), dV(%s), dO(%s)", + dnnl_dt2str(desc()->qry_md()->data_type), + dnnl_dt2str(desc()->diff_qry_md()->data_type), + dnnl_dt2str(desc()->diff_key_md()->data_type), + dnnl_dt2str(desc()->diff_val_md()->data_type), + dnnl_dt2str(diff_dst_md()->data_type)); + + CHECK(init_default_ws()); + VCHECK_SDPA_COND(compare_ws(hint_fwd_pd_), VERBOSE_WS_MISMATCH); + + VCHECK_SDPA_COND(arch() != compute::gpu_arch_t::xe_hpg, + "fused SDPA BWD not supported for xe_hpg "); + + CHECK(init_conf_microkernels(engine)); + CHECK(init_conf(engine)); + CHECK(init_scratchpad(engine)); + return status::success; } @@ -283,6 +489,10 @@ struct micro_t : public primitive_t { CHECK(set_default_format(desc_.k_desc, true)); CHECK(set_default_format(desc_.v_desc, false)); CHECK(set_default_format(desc_.dst_desc, false)); + CHECK(set_default_format(desc_.diff_dst_desc, false)); + CHECK(set_default_format(desc_.diff_q_desc, false)); + CHECK(set_default_format(desc_.diff_k_desc, true)); + CHECK(set_default_format(desc_.diff_v_desc, false)); return status::success; } @@ -298,24 +508,42 @@ struct micro_t : public primitive_t { } compute::gpu_arch_t arch() const { return arch_; } - micro_params_t conf; + micro_bwd_params_t conf; private: int sg_size_ = 0; bool use_systolic_ukernel_ = true; compute::gpu_arch_t arch_ = compute::gpu_arch_t::unknown; + status_t init_scratchpad(impl::engine_t *engine); status_t init_conf_microkernels(impl::engine_t *engine); status_t init_conf(impl::engine_t *engine); + + status_t set_default_format(memory_desc_t &md, bool allow_transpose) { + using namespace format_tag; + memory_desc_wrapper mdw(md); + VCHECK_SDPA_UNIMPL(!mdw.format_any(), VERBOSE_UNSUPPORTED_TAG); + VCHECK_SDPA_UNIMPL(is_md_gemm_compatible_plain_format(&md), + VERBOSE_UNSUPPORTED_TAG); + VCHECK_SDPA_UNIMPL( + IMPLICATION(gemm_desc_t::get_trans(md) == dnnl_trans, + allow_transpose), + VERBOSE_UNSUPPORTED_TAG); + return status::success; + } }; status_t init(impl::engine_t *engine) override; + status_t execute(const exec_ctx_t &ctx) const override { + return execute_backward(ctx); + } + private: const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } - status_t execute(const exec_ctx_t &ctx) const override; + status_t execute_backward(const exec_ctx_t &ctx) const; - compute::kernel_t kernel_; + compute::kernel_t kernel_, preprocess_, postprocess_; }; } // namespace sdpa diff --git a/src/gpu/intel/sdpa/micro_bwd.cl b/src/gpu/intel/sdpa/micro_bwd.cl new file mode 100644 index 00000000000..be00b005edd --- /dev/null +++ b/src/gpu/intel/sdpa/micro_bwd.cl @@ -0,0 +1,989 @@ +/******************************************************************************* +* Copyright 2026 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "gpu/intel/include/tile_ops.h" +#include "gpu/intel/include/types_interop.h" +#include "gpu/intel/sdpa/utils.h" + +/* Microkernel headers -- generated at runtime */ +#include "gemm_kq.h" +#include "gemm_ktq.h" +#include "gemm_qdSt.h" +#include "gemm_vs.h" +#include "gemm_vtdA.h" + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define DIV_UP(x, y) (((x) + (y) - 1) / (y)) + +#define sg_per_wg_BcBr \ + (ugemm_kq_sg_per_wg_m * ugemm_kq_sg_per_wg_n) // same for kq, vtdA +#define sg_per_wg_BcD \ + (ugemm_vs_sg_per_wg_m * ugemm_vs_sg_per_wg_n) // same for qdSt and vs +#define sg_per_wg_BrD (ugemm_ktq_sg_per_wg_m * ugemm_ktq_sg_per_wg_n) +#define sg_per_wg MAX(sg_per_wg_BcBr, MAX(sg_per_wg_BcD, sg_per_wg_BrD)) + +#define q_tile_sg_n DIV_UP(ugemm_kq_wg_tile_n, sg_per_wg) + +/* Instantiate tile types and operations */ +typedef ugemm_kq_c_type s_tile_type; // Bc*Br tile +typedef ugemm_qdSt_c_type a_tile_type; // Bc*D tile +typedef ugemm_vtdA_c_type p_tile_type; // Br*Bc tile (.T) +typedef ugemm_vs_c_type dv_tile_type; // D*Bc tile +typedef ugemm_ktq_c_type ktq_tile_type; // D*Br tile + +#ifdef QRY_DT_F32 +#define FMA_TYPE float +#elif QRY_DT_F16 +#define VEC_TYPE2 half2 +#define FMA_TYPE half +#elif defined(QRY_DT_BF16) +#define VEC_TYPE2 ushort2 +#define FMA_TYPE ushort +#else +#error "Data type not supported for VEC_TYPE2" +#endif + +#ifdef SCALE_DT_BF16 +#define SCALES_TO_FLOAT cvt_bf16_to_f32 +#else +#define SCALES_TO_FLOAT convert_float +#endif + +DECLARE_2D_TILE(q_tile_type, FMA_TYPE, SUBGROUP_SIZE, D_MAX, 1, 1, q_tile_sg_n) + +DECLARE_2D_TILE(dq_tile_type, float, SUBGROUP_SIZE, D_MAX, 1, 1, q_tile_sg_n) +DECLARE_2D_TILE_BLOCK_OPS( + dq_tile_type, float, SUBGROUP_SIZE, D_MAX, 1, 1, q_tile_sg_n) +DECLARE_2D_TILE_COPY_REBLOCK(q_tile_type, SUBGROUP_SIZE, D_MAX, 1, 1, + q_tile_sg_n, dq_tile_type, SUBGROUP_SIZE, D_MAX, 1, 1, q_tile_sg_n, + CONVERT_FLOAT_T) + +#if TRANSPOSE_K + +#define k_tile_t_sg_n DIV_UP(ugemm_kq_wg_tile_m, sg_per_wg) +DECLARE_2D_TILE( + k_tile_type, FMA_TYPE, SUBGROUP_SIZE, D_MAX, 1, 1, k_tile_t_sg_n) +#if BLOCK_K +DECLARE_2D_TILE_BLOCK_OPS( + k_tile_type, FMA_TYPE, SUBGROUP_SIZE, D_MAX, 1, 1, k_tile_t_sg_n) +#endif + +#else + +#define dmax_tile_sg_n DIV_UP(D_MAX, sg_per_wg) +DECLARE_2D_TILE(k_tile_type, FMA_TYPE, SUBGROUP_SIZE, ugemm_kq_wg_tile_m, 1, 1, + dmax_tile_sg_n) +#if BLOCK_K +DECLARE_2D_TILE_BLOCK_OPS(k_tile_type, FMA_TYPE, SUBGROUP_SIZE, + ugemm_kq_wg_tile_m, 1, 1, dmax_tile_sg_n) +#endif +#endif + +DECLARE_2D_TILE(s_tile_type_packed, uint, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1 / 2, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1) +DECLARE_2D_TILE(s_tile_type_packed_t, uint, SUBGROUP_SIZE, + ugemm_kq_c_type_block1, ugemm_kq_c_type_block0 / 2, + ugemm_kq_c_type_nblock1, ugemm_kq_c_type_nblock0) + +DECLARE_2D_TILE(s_tile_type_reblock, FMA_TYPE, SUBGROUP_SIZE, + ugemm_kq_sg_tile_m, 1, 1, ugemm_kq_sg_tile_n) +DECLARE_2D_TILE_BLOCK_OPS(s_tile_type_reblock, FMA_TYPE, SUBGROUP_SIZE, + ugemm_kq_sg_tile_m, 1, 1, ugemm_kq_sg_tile_n) + +DECLARE_2D_TILE(p_tile_type_reblock, FMA_TYPE, SUBGROUP_SIZE, + ugemm_vtdA_c_type_block0, 1, ugemm_vtdA_c_type_nblock0, + ugemm_vtdA_c_type_block1 *ugemm_vtdA_c_type_nblock1) +DECLARE_2D_TILE_BLOCK_OPS(p_tile_type_reblock, FMA_TYPE, SUBGROUP_SIZE, + ugemm_vtdA_c_type_block0, 1, ugemm_vtdA_c_type_nblock0, + ugemm_vtdA_c_type_block1 *ugemm_vtdA_c_type_nblock1) + +DECLARE_2D_TILE( + s_sum_tile_type, float, SUBGROUP_SIZE, ugemm_kq_sg_tile_n, 1, 1, 1) +DECLARE_2D_TILE( + p_sum_tile_type, float, SUBGROUP_SIZE, ugemm_vtdA_sg_tile_n, 1, 1, 1) + +#if BROADCAST_MASK_Q +#define mask_br ugemm_kq_sg_tile_m +#define mask_bc 1 +#define mask_nbr 1 +#define mask_nbc 1 +#else +#define mask_br ugemm_kq_c_type_block0 +#define mask_bc ugemm_kq_c_type_block1 +#define mask_nbr ugemm_kq_c_type_nblock0 +#define mask_nbc ugemm_kq_c_type_nblock1 +#endif + +DECLARE_2D_TILE(qmask_tile_type_float, float, SUBGROUP_SIZE, ugemm_kq_sg_tile_n, + 1, 1, 1) +DECLARE_2D_TILE(kmask_tile_type_float, float, SUBGROUP_SIZE, ugemm_kq_sg_tile_m, + 1, 1, 1) + +#if WITH_ATTN_MASK +DECLARE_2D_TILE(mask_tile_type, MSK_DATA_T, SUBGROUP_SIZE, mask_br, mask_bc, + mask_nbr, mask_nbc) + +#if BROADCAST_MASK_Q +DECLARE_2D_TILE_BLOCK_OPS(mask_tile_type, MSK_DATA_T, SUBGROUP_SIZE, mask_br, + mask_bc, mask_nbr, mask_nbc) +#endif +DECLARE_2D_TILE(mask_tile_type_float, float, SUBGROUP_SIZE, mask_br, mask_bc, + mask_nbr, mask_nbc) +DECLARE_2D_TILE_COPY_REBLOCK(mask_tile_type, SUBGROUP_SIZE, mask_br, mask_bc, + mask_nbr, mask_nbc, mask_tile_type_float, SUBGROUP_SIZE, mask_br, + mask_bc, mask_nbr, mask_nbc, CONVERT_FLOAT_T) +#endif + +DECLARE_2D_TILE(a_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, + ugemm_qdSt_sg_tile_m, 1, 1, ugemm_qdSt_sg_tile_n) +DECLARE_2D_TILE_BLOCK_OPS(a_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, + ugemm_qdSt_sg_tile_m, 1, 1, ugemm_qdSt_sg_tile_n) +DECLARE_2D_TILE_COPY_REBLOCK(a_tile_type, SUBGROUP_SIZE, + ugemm_qdSt_c_type_block0, ugemm_qdSt_c_type_block1, + ugemm_qdSt_c_type_nblock0, ugemm_qdSt_c_type_nblock1, a_tile_type_dst, + SUBGROUP_SIZE, ugemm_qdSt_sg_tile_m, 1, 1, ugemm_qdSt_sg_tile_n, + CONVERT_DATA_T) + +DECLARE_2D_TILE(dv_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, ugemm_vs_sg_tile_m, + 1, 1, ugemm_vs_sg_tile_n) +DECLARE_2D_TILE_BLOCK_OPS(dv_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, + ugemm_vs_sg_tile_m, 1, 1, ugemm_vs_sg_tile_n) +DECLARE_2D_TILE_COPY_REBLOCK(dv_tile_type, SUBGROUP_SIZE, + ugemm_vs_c_type_block0, ugemm_vs_c_type_block1, ugemm_vs_c_type_nblock0, + ugemm_vs_c_type_nblock1, dv_tile_type_dst, SUBGROUP_SIZE, + ugemm_vs_sg_tile_m, 1, 1, ugemm_vs_sg_tile_n, CONVERT_DATA_T) + +DECLARE_2D_TILE(dq_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, + ugemm_ktq_sg_tile_m, 1, 1, ugemm_ktq_sg_tile_n) +DECLARE_2D_TILE_BLOCK_OPS(dq_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, + ugemm_ktq_sg_tile_m, 1, 1, ugemm_ktq_sg_tile_n) +DECLARE_2D_TILE_COPY_REBLOCK(ktq_tile_type, SUBGROUP_SIZE, + ugemm_ktq_c_type_block0, ugemm_ktq_c_type_block1, + ugemm_ktq_c_type_nblock0, ugemm_ktq_c_type_nblock1, dq_tile_type_dst, + SUBGROUP_SIZE, ugemm_ktq_sg_tile_m, 1, 1, ugemm_ktq_sg_tile_n, + CONVERT_DATA_T) + +DECLARE_2D_TILE_COPY_REBLOCK(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1, s_tile_type_reblock, SUBGROUP_SIZE, + ugemm_kq_sg_tile_m, 1, 1, ugemm_kq_sg_tile_n, CONVERT_DATA_T) +DECLARE_2D_TILE_COPY_REBLOCK(p_tile_type, SUBGROUP_SIZE, + ugemm_vtdA_c_type_block0, ugemm_vtdA_c_type_block1, + ugemm_vtdA_c_type_nblock0, ugemm_vtdA_c_type_nblock1, + p_tile_type_reblock, SUBGROUP_SIZE, ugemm_vtdA_c_type_block0, 1, + ugemm_vtdA_c_type_nblock0, + ugemm_vtdA_c_type_block1 *ugemm_vtdA_c_type_nblock1, CONVERT_DATA_T) +DECLARE_2D_TILE_COPY_REBLOCK(p_tile_type_reblock, SUBGROUP_SIZE, + ugemm_vtdA_c_type_block0, 1, ugemm_vtdA_c_type_nblock0, + ugemm_vtdA_c_type_block1 *ugemm_vtdA_c_type_nblock1, p_tile_type, + SUBGROUP_SIZE, ugemm_vtdA_c_type_block0, ugemm_vtdA_c_type_block1, + ugemm_vtdA_c_type_nblock0, ugemm_vtdA_c_type_nblock1, CONVERT_FLOAT_T) + +DECLARE_2D_TILE_VREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1, s_sum_tile_type, SUBGROUP_SIZE, + ugemm_kq_sg_tile_n, 1, 1, 1) +DECLARE_2D_TILE_VREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1, qmask_tile_type_float, SUBGROUP_SIZE, + ugemm_kq_sg_tile_n, 1, 1, 1) +DECLARE_2D_TILE_HREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1, qmask_tile_type_float, SUBGROUP_SIZE, + ugemm_kq_sg_tile_n, 1, 1, 1) +DECLARE_2D_TILE_HREDUCE(p_tile_type, SUBGROUP_SIZE, ugemm_vtdA_c_type_block0, + ugemm_vtdA_c_type_block1, ugemm_vtdA_c_type_nblock0, + ugemm_vtdA_c_type_nblock1, kmask_tile_type_float, SUBGROUP_SIZE, + ugemm_vtdA_sg_tile_m, 1, 1, 1) +DECLARE_2D_TILE_VREDUCE(p_tile_type, SUBGROUP_SIZE, ugemm_vtdA_c_type_block0, + ugemm_vtdA_c_type_block1, ugemm_vtdA_c_type_nblock0, + ugemm_vtdA_c_type_nblock1, kmask_tile_type_float, SUBGROUP_SIZE, + ugemm_vtdA_sg_tile_m, 1, 1, 1) + +DECLARE_2D_TILE_HREDUCE(p_tile_type, SUBGROUP_SIZE, ugemm_vtdA_c_type_block0, + ugemm_vtdA_c_type_block1, ugemm_vtdA_c_type_nblock0, + ugemm_vtdA_c_type_nblock1, p_sum_tile_type, SUBGROUP_SIZE, + ugemm_vtdA_sg_tile_n, 1, 1, 1) + +DECLARE_2D_TILE_HREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1, s_sum_tile_type, SUBGROUP_SIZE, + ugemm_kq_sg_tile_n, 1, 1, 1) +#if WITH_ATTN_MASK +DECLARE_2D_TILE_VREDUCE(s_tile_type, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1, mask_tile_type_float, SUBGROUP_SIZE, mask_br, + mask_bc, mask_nbr, mask_nbc) +#endif + +DECLARE_2D_TILE_SLM_ADD(dv_tile_type, float, SUBGROUP_SIZE, + ugemm_vs_c_type_block0, ugemm_vs_c_type_block1, ugemm_vs_c_type_nblock0, + ugemm_vs_c_type_nblock1) +#if (ugemm_qdSt_c_type_block0 != ugemm_vs_c_type_block0) \ + || (ugemm_qdSt_c_type_block1 != ugemm_vs_c_type_block1) \ + || (ugemm_qdSt_c_type_nblock0 != ugemm_vs_c_type_nblock0) \ + || (ugemm_qdSt_c_type_nblock1 != ugemm_vs_c_type_nblock1) +DECLARE_2D_TILE_SLM_ADD(a_tile_type, float, SUBGROUP_SIZE, + ugemm_qdSt_c_type_block0, ugemm_qdSt_c_type_block1, + ugemm_qdSt_c_type_nblock0, ugemm_qdSt_c_type_nblock1) +#endif +DECLARE_2D_TILE_SLM_ADD_T(a_tile_type, float, SUBGROUP_SIZE, + ugemm_qdSt_c_type_block0, ugemm_qdSt_c_type_block1, + ugemm_qdSt_c_type_nblock0, ugemm_qdSt_c_type_nblock1) + +#define tile_load_block_rem_q(t, ptr, n, ld, off_r, off_c, load_rem) \ + if (load_rem) { \ + tile_load_block(t, ptr, n, ld, off_r, off_c); \ + } else { \ + tile_load_block(t, ptr, ld, off_r, off_c); \ + } + +#define tile_store_block_rem_q(t, ptr, n, ld, off_r, off_c, store_rem) \ + if (store_rem) { \ + tile_store_block(t, ptr, n, ld, off_r, off_c); \ + } else { \ + tile_store_block(t, ptr, ld, off_r, off_c); \ + } + +#define binary_add(x, y) ((x) + (y)) + +inline void tile_load_k(k_tile_type *K_tile, const global KEY_DATA_T *K, + int seq_len, int ldk, int seq_off, int sg_ij, int load_rem) { + +#if TRANSPOSE_K + // Bc / n_sg -- each sg loads k_tile_t_sg_n k-columns + uint k0_copy = k_tile_t_sg_n * sg_ij; + // Coalesced load from d×k column-major memory (d contiguous, k strided) +#if BLOCK_K + tile_load_block(K_tile, K, ldk, 0, seq_off + k0_copy); +#else + tile_load(K_tile, K, D_MAX, seq_len, ldk, 0, seq_off + k0_copy); +#endif + +#else + // D_MAX / n_sg + uint k0_copy = dmax_tile_sg_n * sg_ij; +#if BLOCK_K + // can ignore load_rem due to d_full requirement + tile_load_block(K_tile, K, ldk, seq_off, k0_copy); +#else + tile_load(K_tile, K, seq_len, D_MAX, ldk, seq_off, k0_copy); +#endif + +#endif +} + +inline void tile_store_k_slm( + k_tile_type *K_tile, local KEY_DATA_T *K_slm, int sg_ij) { + +#if TRANSPOSE_K + // Bc / n_sg -- tile is D*Bc, write transposed to SLM (Bc*D) + uint k0_copy = k_tile_t_sg_n * sg_ij; +#if USE_SYSTOLIC_UKERNEL + tile_store_t_sys_src11(*K_tile, K_slm, SUBGROUP_SIZE, D_MAX, D_MAX, + ugemm_kq_wg_tile_m, 0, k0_copy); +#else + tile_store_t_packed_src1( + *K_tile, K_slm, ugemm_kq_sg_tile_m, D_MAX, k0_copy, 0); +#endif + +#else + + uint k0_copy = dmax_tile_sg_n * sg_ij; +#if USE_SYSTOLIC_UKERNEL + tile_store_sys_src1(*K_tile, K_slm, SUBGROUP_SIZE, D_MAX, + ugemm_kq_wg_tile_m, D_MAX, 0, k0_copy); +#else + tile_store_packed_src1( + *K_tile, K_slm, ugemm_kq_sg_tile_m, D_MAX, 0, k0_copy); +#endif + +#endif +} + +#if KV_GROUP_SIZE > 1 +#define DST_DATA_T_DKDV float +#else +#define DST_DATA_T_DKDV DST_DATA_T +#endif + +inline void tile_store_dV(dv_tile_type *dV_tile_slm, global DST_DATA_T_DKDV *dV, + int m, int n, int ld, int offset_r, int offset_c, int rem) { + +#if KV_GROUP_SIZE > 1 // GQA update + tile_atomic_add(*dV_tile_slm, dV, m, n, ld, offset_r, offset_c); +#else // MHA update + + dv_tile_type_dst dV_tile_dst; // convert to half + tile_copy_reblock(*dV_tile_slm, &dV_tile_dst); +#if BLOCK_DV + tile_store_block_rem_q(dV_tile_dst, dV, n, ld, offset_r, offset_c, rem) +#else + tile_store(dV_tile_dst, dV, m, n, ld, offset_r, offset_c); +#endif + +#endif +} + +#if TRANSPOSE_K +// uses transposed dv_tile_type (D*Bc) for dK update +inline void tile_store_dK_t(dv_tile_type *dK_tile, global DST_DATA_T_DKDV *dK, + int m, int n, int ld, int offset_r, int offset_c, int rem) { + +#if KV_GROUP_SIZE > 1 // GQA update + tile_atomic_add(*dK_tile, dK, m, n, ld, offset_r, offset_c); +#else // MHA update + dv_tile_type_dst dK_tile_dst; + tile_copy_reblock(*dK_tile, &dK_tile_dst); + tile_store(dK_tile_dst, dK, m, n, ld, offset_r, offset_c); +#endif +} + +#else + +// uses qdSt tile (Bc*D) for dK update +inline void tile_store_dK(a_tile_type *dK_tile, global DST_DATA_T_DKDV *dK, + int m, int n, int ld, int offset_r, int offset_c) { + +#if KV_GROUP_SIZE > 1 // GQA update + tile_atomic_add(*dK_tile, dK, m, n, ld, offset_r, offset_c); +#else // MHA update + + a_tile_type_dst dK_tile_dst; + tile_copy_reblock(*dK_tile, &dK_tile_dst); +#if BLOCK_DK + tile_store_block(dK_tile_dst, dK, ld, offset_r, offset_c); +#else + tile_store(dK_tile_dst, dK, m, n, ld, offset_r, offset_c); +#endif + +#endif +} + +#endif + +#define DO_MM 1 + +__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void +micro_sdpa_bwd(const global KEY_DATA_T *K, const global QRY_DATA_T *Q, + const global VAL_DATA_T *V, const global float *ws, + const global float *Di, const global DST_DATA_T *A, + const global DST_DATA_T *dA, +#if WITH_DS + global DST_DATA_T *dS, // expensive, optional intermediate +#endif + global DST_DATA_T_DKDV *dK, global float *dQ, + global DST_DATA_T_DKDV *dV, +#if WITH_HOST_SCALE + float scalar_scale, float inv_scalar_scale, +#else + const global SCALE_DATA_T *scale_ptr, +#endif + int d, int k, int q, const int attn_mask_type +#if WITH_ATTN_MASK + , + const global MSK_DATA_T *msk +#endif + , + KEY_OFFSETS, QRY_OFFSETS, VAL_OFFSETS, DST_OFFSETS +#if WITH_ATTN_MASK + , + MSK_OFFSETS +#endif + , + const int remainder_k, const int remainder_q) { + + uint wg_k = get_group_id(0); + + uint sg_ij = sub_group_broadcast(get_local_id(1), 0); + + uint b1 = get_group_id(2); + + // TODO: batch q=1 cases to KV_GROUP_SIZE + uint b0, b0_kv; + b0 = get_group_id(1); + b0_kv = b0 / KV_GROUP_SIZE; + + uint wg_i0 = wg_k * ugemm_kq_wg_tile_m; + + const uint preprocess_batch = b1 * (DST_D1 * q) + b0 * q; + const global float *ws_logsumexp = ws + preprocess_batch; + Di += preprocess_batch; + + /* Calculate the number of keys to process */ + int q0end = q; + int qdiag0 = 0; // potentially offset starting idx in causal mask cases +#if WITH_CAUSAL_MASK + if (attn_mask_type == ATTN_MASK_TOP_LEFT) { + qdiag0 = max(0, (int)(wg_i0)); + } else { + qdiag0 = max(0, (int)(wg_i0 + (q - k))); + } +#endif + + /* Leading dimension for matrices */ + uint ldk = TRANSPOSE_K ? KEY_S3 : KEY_S2; + uint ldq = QRY_S2; + uint ldv = VAL_S2; + uint lda = DST_S2; + + /* Subgroup IDs for each GEMM, although total number of + * sg per wg may be shared + * ordering may differ due to transposes */ + uint sg_i_kq = sg_ij % ugemm_kq_sg_per_wg_m; + uint sg_j_kq = sg_ij / ugemm_kq_sg_per_wg_m; + + uint sg_i_vtdA = sg_ij % ugemm_vtdA_sg_per_wg_m; + uint sg_j_vtdA = sg_ij / ugemm_vtdA_sg_per_wg_m; + + uint sg_i_vs = sg_ij % ugemm_vs_sg_per_wg_m; + uint sg_j_vs = sg_ij / ugemm_vs_sg_per_wg_m; + + uint sg_i_qdSt = sg_ij % ugemm_qdSt_sg_per_wg_m; + uint sg_j_qdSt = sg_ij / ugemm_qdSt_sg_per_wg_m; + + uint sg_i_ktq = sg_ij % ugemm_ktq_sg_per_wg_m; + uint sg_j_ktq = sg_ij / ugemm_ktq_sg_per_wg_m; + + /* SLM allocations -- place in one array to work around compiler bug */ +#define K_slm_size (ugemm_kq_wg_tile_m * D_MAX * sizeof(KEY_DATA_T)) +#define S_slm_size (ugemm_kq_wg_tile_m * ugemm_kq_wg_tile_n * sizeof(FMA_TYPE)) +#if USE_SYSTOLIC_UKERNEL +#define S2_f32_slm_size \ + (ugemm_kq_wg_tile_m * ugemm_kq_wg_tile_n * sizeof(float)) +#else +#define S2_f32_slm_size 0 +#endif + +#define dK_slm_size (ugemm_kq_wg_tile_m * D_MAX * sizeof(float)) +#define dV_slm_size (ugemm_kq_wg_tile_m * D_MAX * sizeof(float)) + +#define ugemm_slm_size \ + MAX(MAX(MAX(MAX(ugemm_kq_slm_size, ugemm_vs_slm_size), \ + ugemm_vtdA_slm_size), \ + ugemm_qdSt_slm_size), \ + ugemm_ktq_slm_size) + + local char slm[K_slm_size + S_slm_size + S2_f32_slm_size + ugemm_slm_size + + dK_slm_size + dV_slm_size]; + + local KEY_DATA_T *K_slm = (local KEY_DATA_T *)&slm[0]; + + // S_slm, softmax for ugemm_vs also reused for dS + local FMA_TYPE *S_slm = (local FMA_TYPE *)&slm[K_slm_size]; +#if USE_SYSTOLIC_UKERNEL + // f32 softmax cache, reused for dS^t (systolic only) + local float *S2_f32_slm = (local float *)&slm[K_slm_size + S_slm_size]; +#endif + + // ugemm scratch space + local uint *ugemm_slm + = (local uint *)&slm[K_slm_size + S_slm_size + S2_f32_slm_size]; + + // used for accumulation of dV, dK across q-loop + local float *dK_slm = (local float *)&slm[K_slm_size + S_slm_size + + S2_f32_slm_size + ugemm_slm_size]; + local float *dV_slm = (local float *)&slm[K_slm_size + S_slm_size + + S2_f32_slm_size + ugemm_slm_size + dK_slm_size]; + + const size_t k_offset = KEY_BATCH(b1, b0_kv); + const size_t v_offset = VAL_BATCH(b1, b0_kv); + const size_t q_offset = QRY_BATCH(b1, b0); + const size_t a_offset = DST_BATCH(b1, b0); + + /* Locate K/Q/V/A matrices within batch */ + K += k_offset; + Q += q_offset; + V += v_offset; + A += a_offset; + + dK += k_offset; + dQ += q_offset; + dV += v_offset; + dA += a_offset; + +#if WITH_DS + dS += b1 * (DST_D1 * q * k) + b0 * (q * k); +#endif + +#if WITH_ATTN_MASK + msk += MSK_BATCH(b1 % MSK_D0, b0 % MSK_D1); + int mask_aligned = (((size_t)msk) % 4) == 0; + bool block_msk = (b1 < MSK_D0 - ceil((float)ugemm_kq_wg_tile_m / MSK_S2)) + && mask_aligned; +#endif + + if (qdiag0 < q0end) { + /* Load K tile, destined for SLM */ + k_tile_type K_tile; + tile_fill(K_tile, TO_DATA_T(0.f)); + + tile_load_k(&K_tile, K, k, ldk, wg_i0, sg_ij, remainder_k); + + /* Store K tile to SLM */ + tile_store_k_slm(&K_tile, K_slm, sg_ij); + } + + /* Load scale */ + float scale = 1.f; + float iscale = 1.f; + if (qdiag0 < q0end) { +#if WITH_ATTN_SCALE +#if WITH_HOST_SCALE +#if INVERT_SCALE + iscale = scalar_scale; + scale = inv_scalar_scale; +#else + scale = scalar_scale; + iscale = inv_scalar_scale; +#endif +#else +#if INVERT_SCALE + iscale = SCALES_TO_FLOAT(*scale_ptr); + scale = native_recip(iscale); +#else + scale = SCALES_TO_FLOAT(*scale_ptr); + iscale = native_recip(scale); +#endif +#endif +#endif + } + + /* Initialize dV, dK to zero */ +#pragma unroll + for (int i = get_local_id(0); i < ugemm_kq_wg_tile_m * D_MAX; + i += get_local_size(0)) { + dK_slm[i] = 0.f; + dV_slm[i] = 0.f; + } + + uint sg_i0_kq = sg_i_kq * ugemm_kq_sg_tile_m; + uint sg_j0_kq = sg_j_kq * ugemm_kq_sg_tile_n; + + const int k0 = wg_i0; + + // make sure K_tile in SLM + barrier(CLK_LOCAL_MEM_FENCE); + + /* Main loop over k blocks */ + for (int q0 = qdiag0; q0 < q0end; q0 += ugemm_kq_wg_tile_n) { + const bool first = (q0 == qdiag0); + const int qnext = q0 + ugemm_kq_wg_tile_n; + const bool last = (qnext >= q0end); + + int k_chunk = min(k - k0, ugemm_kq_wg_tile_m); + int q_nchunk = min(q0end - q0, ugemm_kq_wg_tile_n); + /* Calculate S = (K^T) * Q */ +#if DO_MM + s_tile_type S_tile + = ugemm_kq(K_slm, D_MAX, Q + q0 * ldq, ldq, k_chunk, q_nchunk, + d, 0, 0, 0, sg_i_kq, sg_j_kq, (local char *)ugemm_slm); +#else + s_tile_type S_tile; +#endif + uint sg_i0_s2 = sg_i_kq * ugemm_kq_sg_tile_m + k0; + uint sg_j0_s2 = sg_j_kq * ugemm_kq_sg_tile_n + q0; + + /* Apply attention mask */ +#if WITH_ATTN_MASK + mask_tile_type mask_tile; +#if BROADCAST_MASK_Q + if (block_msk) { + tile_load_block(&mask_tile, msk, MSK_S2, 0, k0 + sg_i0_kq, 0); + } else { + tile_load(&mask_tile, msk, k, 1, MSK_S2, k0 + sg_i0_kq, 0); + } +#else + tile_load(&mask_tile, msk, k, q, MSK_S2, k0 + sg_i0_kq, q0 + sg_j0_kq); +#endif + +#define unscale(x) ((x) * iscale) + mask_tile_type_float mask_tile_float; + tile_copy_reblock(mask_tile, &mask_tile_float); +#if WITH_ATTN_SCALE + tile_elementwise(mask_tile_float, unscale); +#endif +#undef unscale +#if BROADCAST_MASK_Q + tile_vbroadcast_add(&S_tile, mask_tile_float); +#else + tile_binary(S_tile, mask_tile_float, binary_add); +#endif +#endif + + /* Apply q mask */ + if (remainder_q) { + qmask_tile_type_float q_mask; +#define gte_q(offset_k, offset_q) (offset_q >= q) + tile_predicated_assignment(S_tile, k0 + sg_i0_kq, q0 + sg_j0_kq, + gte_q, -INFINITY, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1); +#undef gte_q + } + +#if WITH_CAUSAL_MASK +#define less_than(offset_k, offset_q) (offset_q < offset_k) + + int col_offset = q0 + sg_j0_kq; + if (q == 1) col_offset = 1; + if (attn_mask_type == ATTN_MASK_BOTTOM_RIGHT) col_offset += k - q; + + /* Apply causal mask */ + const bool is_diag = (q0 + == qdiag0); // first iteration will be on diagonal, requiring partial masking + if (is_diag) { + tile_predicated_assignment(S_tile, k0 + sg_i0_kq, col_offset, + less_than, -INFINITY, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1); + } +#undef less_than +#endif + + s_sum_tile_type S_logsumexp_tile; + tile_fill(S_logsumexp_tile, 0.f); + tile_load(&S_logsumexp_tile, ws_logsumexp, q, 1, ugemm_kq_wg_tile_n, + sg_j0_kq + q0, 0); +#define mulscale(x) (x * scale) + tile_elementwise(S_tile, mulscale); +#undef mulscale + tile_hbroadcast_sub(&S_tile, S_logsumexp_tile); //layout.N + + /* Scale + exponentiate */ +#define scaled_exp(x) native_vexp2(x * 1.44269504089f) + tile_elementwise(S_tile, scaled_exp); +#undef scaled_exp + + barrier(CLK_LOCAL_MEM_FENCE); + { +#if USE_SYSTOLIC_UKERNEL + // store softmax in f32 for S2 reload (systolic only) + tile_store(S_tile, S2_f32_slm, ugemm_kq_wg_tile_m, + ugemm_kq_wg_tile_n, ugemm_kq_wg_tile_m, sg_i0_kq, sg_j0_kq); +#endif + + // Store softmax for ugemm_vs B-operand + s_tile_type_reblock S_tile_reblock; + tile_copy_reblock(S_tile, &S_tile_reblock); + +#if USE_SYSTOLIC_UKERNEL + tile_store_t_sys_src22(S_tile_reblock, (local FMA_TYPE *)S_slm, + ugemm_vs_sg_tile_n, ugemm_kq_wg_tile_m, ugemm_kq_wg_tile_n, + sg_i0_kq, sg_j0_kq); +#else + tile_store_packed_src1(S_tile_reblock, S_slm, ugemm_vs_sg_tile_n, + ugemm_kq_wg_tile_n, sg_i0_kq, sg_j0_kq); +#endif + } + barrier(CLK_LOCAL_MEM_FENCE); + + { +#if DO_MM + dv_tile_type dV_tile1; + dV_tile1 = ugemm_vs(dA + q0 * lda, lda, (local FMA_TYPE *)S_slm, + ugemm_kq_wg_tile_n, d, k_chunk, q_nchunk, 0, 0, 0, sg_i_vs, + sg_j_vs, (local char *)ugemm_slm); +#else + dv_tile_type dV_tile1; +#endif + uint sg_i0_vs = sg_i_vs * ugemm_vs_sg_tile_m; + uint sg_j0_vs = sg_j_vs * ugemm_vs_sg_tile_n; + + // accumulate dv tile to slm + if (sg_ij < sg_per_wg_BcD) { + tile_slm_add(dV_tile1, dV_slm, D_MAX, sg_i0_vs, sg_j0_vs); + } + } + +#if DO_MM + p_tile_type dP_tile = ugemm_vtdA(V + k0 * ldv, ldv, dA + q0 * lda, lda, + k_chunk, q_nchunk, d, 0, 0, 0, sg_i_kq, sg_j_kq, + (local char *)ugemm_slm); +#else + p_tile_type dP_tile; +#endif + + p_sum_tile_type D_i; + tile_fill(D_i, 0.0f); + tile_load(&D_i, Di, q0end, 1, q0end, q0 + sg_j0_kq, 0); + tile_hbroadcast_sub(&dP_tile, + D_i); // needs output to be transposed from vtdA layout.C = N + + // reload softmax since ugemm_vtdA() clobbers registers + { + p_tile_type S2_tile; +#if USE_SYSTOLIC_UKERNEL + tile_load(&S2_tile, S2_f32_slm, ugemm_kq_wg_tile_m, + ugemm_kq_wg_tile_n, ugemm_kq_wg_tile_m, sg_i0_kq, sg_j0_kq); +#else + // reload from packed S_slm + p_tile_type_reblock S2_tile_reblock; + tile_load_packed_src1(&S2_tile_reblock, S_slm, ugemm_vs_sg_tile_n, + ugemm_kq_wg_tile_n, sg_i0_kq, sg_j0_kq); + tile_copy_reblock(S2_tile_reblock, &S2_tile); +#endif + intel_work_group_barrier_arrive(CLK_LOCAL_MEM_FENCE); + +#define binary_mul_scale(x, y) ((x) * (y) * scale) + tile_binary(dP_tile, S2_tile, binary_mul_scale); + } + + if (remainder_k) { + kmask_tile_type_float k_mask; +#define gte_k(offset_k, offset_q) (offset_k >= k) + tile_predicated_assignment(S_tile, k0 + sg_i0_kq, q0 + sg_j0_kq, + gte_k, 0, SUBGROUP_SIZE, ugemm_kq_c_type_block0, + ugemm_kq_c_type_block1, ugemm_kq_c_type_nblock0, + ugemm_kq_c_type_nblock1); +#undef gte_k + } + +#if USE_SYSTOLIC_UKERNEL + local FMA_TYPE *dSt_slm = (local FMA_TYPE *)S2_f32_slm; +#endif + { + p_tile_type_reblock P_tile_reblock; + tile_copy_reblock(dP_tile, &P_tile_reblock); +#if WITH_DS + tile_store(P_tile_reblock, dS, k_chunk, q_nchunk, k, k0 + sg_i0_kq, + q0 + sg_j0_kq); +#endif + + intel_work_group_barrier_wait(CLK_LOCAL_MEM_FENCE); +#if USE_SYSTOLIC_UKERNEL + // softmax no longer needed, use slm to cache dS + tile_store_sys_src22(P_tile_reblock, dSt_slm, ugemm_ktq_sg_tile_n, + ugemm_kq_wg_tile_m, ugemm_kq_wg_tile_n, sg_i0_kq, sg_j0_kq); + tile_store_sys_src1(P_tile_reblock, S_slm, SUBGROUP_SIZE, + ugemm_kq_wg_tile_n, ugemm_kq_wg_tile_m, ugemm_kq_wg_tile_n, + sg_i0_kq, sg_j0_kq); +#else + // Store dS to S_slm for ugemm_qdSt + tile_store_packed_src1(P_tile_reblock, S_slm, ugemm_qdSt_sg_tile_m, + ugemm_kq_wg_tile_n, sg_i0_kq, sg_j0_kq); +#endif + } + barrier(CLK_LOCAL_MEM_FENCE); + + { +#if DO_MM + a_tile_type dK_tile1; + dK_tile1 = ugemm_qdSt(S_slm, ugemm_kq_wg_tile_n, Q + q0 * ldq, ldq, + k_chunk, d, q_nchunk, 0, 0, 0, sg_i_qdSt, sg_j_qdSt, + (local char *)ugemm_slm); // dS^t * Q -> Bc x d +#else + a_tile_type dK_tile1; +#endif + uint sg_i0_dk = sg_i_qdSt * ugemm_qdSt_sg_tile_m; + uint sg_j0_dk = sg_j_qdSt * ugemm_qdSt_sg_tile_n; + + // dk slm tile + if (sg_ij < sg_per_wg_BcD) { +#if TRANSPOSE_K + tile_slm_add_t(dK_tile1, dK_slm, D_MAX, sg_i0_dk, sg_j0_dk); +#else + tile_slm_add(dK_tile1, dK_slm, ugemm_kq_wg_tile_m, sg_i0_dk, + sg_j0_dk); +#endif + } + } + +#if !USE_SYSTOLIC_UKERNEL + // re-read dS from S_slm and re-store transposed for ugemm_ktq + { + p_tile_type_reblock dS_reblock; + tile_load_packed_src1(&dS_reblock, S_slm, ugemm_qdSt_sg_tile_m, + ugemm_kq_wg_tile_n, sg_i0_kq, sg_j0_kq); + barrier(CLK_LOCAL_MEM_FENCE); + tile_store_t_packed_src1(dS_reblock, S_slm, ugemm_ktq_sg_tile_n, + ugemm_kq_wg_tile_m, sg_j0_kq, sg_i0_kq); + } + barrier(CLK_LOCAL_MEM_FENCE); + local FMA_TYPE *dSt_slm = S_slm; +#endif + + { +#if DO_MM + ktq_tile_type dQ_tile; + + dQ_tile = ugemm_ktq( +#if TRANSPOSE_K + K + k0 * ldk, +#else + K + k0, +#endif + ldk, dSt_slm, ugemm_kq_wg_tile_m, d, q_nchunk, k_chunk, 0, + 0, 0, sg_i_ktq, sg_j_ktq, (local char *)ugemm_slm); +#else + ktq_tile_type dQ_tile; +#endif + uint sg_i0_dq = sg_i_ktq * ugemm_ktq_sg_tile_m; + uint sg_j0_dq = sg_j_ktq * ugemm_ktq_sg_tile_n + q0; + + if (sg_ij < sg_per_wg_BrD) + tile_atomic_add(dQ_tile, dQ, d, q, ldq, sg_i0_dq, sg_j0_dq); + } + } + + //////// update dV + uint sg_i0_vs = sg_i_vs * ugemm_vs_sg_tile_m; + uint sg_j0_vs = sg_j_vs * ugemm_vs_sg_tile_n; + + // ensure all loops done writing to SLM + barrier(CLK_LOCAL_MEM_FENCE); + + dv_tile_type dV_tile_slm; + + if (sg_ij < sg_per_wg_BcD) { + tile_load(&dV_tile_slm, dV_slm, D_MAX, ugemm_kq_wg_tile_m, D_MAX, + sg_i0_vs, sg_j0_vs); + tile_store_dV(&dV_tile_slm, dV, d, k, ldv, sg_i0_vs, wg_i0 + sg_j0_vs, + remainder_k); + } + // /update dV + + //////// update dK +#if TRANSPOSE_K + // transposed dK_slm (D*Bc) matches dV tile layout + dv_tile_type dK_tile_t; + + if (sg_ij < sg_per_wg_BcD) { + tile_load(&dK_tile_t, dK_slm, D_MAX, ugemm_kq_wg_tile_m, D_MAX, + sg_i0_vs, sg_j0_vs); + tile_store_dK_t(&dK_tile_t, dK, d, k, ldk, sg_i0_vs, wg_i0 + sg_j0_vs, + remainder_k); + } +#else + // non-transposed dK_slm uses qdSt layout (Bc*D) and indexing + uint sg_i0_dk = sg_i_qdSt * ugemm_qdSt_sg_tile_m; + uint sg_j0_dk = sg_j_qdSt * ugemm_qdSt_sg_tile_n; + + a_tile_type dK_tile_slm; + int wg_k_chunk = min(k - k0, ugemm_kq_wg_tile_m); + if (sg_ij < sg_per_wg_BcD) { + tile_load(&dK_tile_slm, dK_slm, ugemm_kq_wg_tile_m, D_MAX, + ugemm_kq_wg_tile_m, sg_i0_dk, sg_j0_dk); + tile_store_dK(&dK_tile_slm, dK + wg_i0, wg_k_chunk, d, ldk, sg_i0_dk, + sg_j0_dk); + } +#endif + // /update dK +} + +__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void +preprocess_Di(global float *Di, const global DST_DATA_T *A, + const global DST_DATA_T *dA, int d, int k, int q, QRY_OFFSETS, + DST_OFFSETS) { + + uint lda = DST_S2; + uint ldq = QRY_S2; + + uint sg_ij = sub_group_broadcast(get_local_id(1), 0); + uint sg_i_kq = sg_ij % ugemm_kq_sg_per_wg_m; + uint sg_j_kq = sg_ij / ugemm_kq_sg_per_wg_m; + + uint b0, b1; + b0 = get_group_id(1); + b1 = get_group_id(2); + + const uint preprocess_batch = b1 * (DST_D1 * q) + b0 * q; + + const size_t q_offset = QRY_BATCH(b1, b0); + const size_t a_offset = DST_BATCH(b1, b0); + + /* Locate A/dA matrices within batch */ + A += a_offset; + dA += a_offset; + + Di += preprocess_batch; + + uint wg_q = get_group_id(0); + uint wg_j0 = wg_q * ugemm_kq_wg_tile_n; + +#define Di_slm_size (ugemm_kq_wg_tile_n * sizeof(float)) + local char slm[Di_slm_size]; + + local float *Di_slm = (local float *)&slm[0]; + + uint sg_i0_kq = sg_i_kq * ugemm_kq_sg_tile_m; + uint sg_j0_kq = sg_j_kq * ugemm_kq_sg_tile_n; + + uint q0_copy = q_tile_sg_n * sg_ij; + + if (q > 0) { + // D_i calculation +#if QRY_DT_F32 + dq_tile_type dA_tile, A_tile; + tile_fill(A_tile, 0.f); + tile_fill(dA_tile, 0.f); + tile_load( + &dA_tile, (global FMA_TYPE *)dA, d, q, lda, 0, wg_j0 + q0_copy); + tile_load(&A_tile, (global FMA_TYPE *)A, d, q, lda, 0, wg_j0 + q0_copy); +#else + dq_tile_type dA_tile, A_tile; + q_tile_type dA_tile_reblock, A_tile_reblock; // load native type + tile_fill(A_tile_reblock, TO_DATA_T(0.f)); + tile_fill(dA_tile_reblock, TO_DATA_T(0.f)); + + tile_load(&dA_tile_reblock, (global FMA_TYPE *)dA, d, q, lda, 0, + wg_j0 + q0_copy); + tile_load(&A_tile_reblock, (global FMA_TYPE *)A, d, q, lda, 0, + wg_j0 + q0_copy); + + // convert to float for calculation + tile_copy_reblock(dA_tile_reblock, &dA_tile); + tile_copy_reblock(A_tile_reblock, &A_tile); +#endif + +#define binary_mul(x, y) ((x) * (y)) + tile_binary(A_tile, dA_tile, binary_mul); + + // reduce tile across D_MAX + for (int j = 0; j < q_tile_sg_n; j++) { + float r = 0.f; + for (int i0 = 0; i0 < D_MAX; i0 += SUBGROUP_SIZE) { + r += sub_group_reduce_add( + tile_access(A_tile, i0, j, SUBGROUP_SIZE, D_MAX, 1, 1)); + } + Di_slm[j + q0_copy] = r; + } + barrier(CLK_LOCAL_MEM_FENCE); + + for (int i = get_local_id(0); i < ugemm_kq_wg_tile_n; + i += get_local_size(0)) { + if (get_local_id(1) == 0 && (wg_j0 + i) < q) { + Di[wg_j0 + i] = Di_slm[i]; + } + } + } +} + +__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void +postprocess_dQ(global DST_DATA_T *dst, global const float *src, int nelems, + QRY_OFFSETS) { + uint b0 = get_group_id(1); + uint b1 = get_group_id(2); + + const size_t offset = QRY_BATCH(b1, b0); + + /* Locate dQ/dV/dK matrices within batch */ + src += offset; + dst += offset; + size_t idx = get_global_id(0); + if (idx < (nelems)) { dst[idx] = TO_DATA_T(src[idx]); } +} diff --git a/src/gpu/intel/sdpa/ref.cpp b/src/gpu/intel/sdpa/ref.cpp index d86a2dce5c0..e979cf8d1d8 100644 --- a/src/gpu/intel/sdpa/ref.cpp +++ b/src/gpu/intel/sdpa/ref.cpp @@ -25,7 +25,7 @@ namespace gpu { namespace intel { namespace sdpa { -status_t ref_t::execute_ref(const exec_ctx_t &ctx) const { +status_t ref_fwd_t::execute_ref(const exec_ctx_t &ctx) const { const auto &qry = CTX_IN_STORAGE(DNNL_ARG_QUERIES); const auto &key = CTX_IN_STORAGE(DNNL_ARG_KEYS); const auto &val = CTX_IN_STORAGE(DNNL_ARG_VALUES); diff --git a/src/gpu/intel/sdpa/ref.hpp b/src/gpu/intel/sdpa/ref.hpp index 07f1705e1ba..34ad410738b 100644 --- a/src/gpu/intel/sdpa/ref.hpp +++ b/src/gpu/intel/sdpa/ref.hpp @@ -27,12 +27,12 @@ namespace gpu { namespace intel { namespace sdpa { -struct ref_t : public primitive_t { +struct ref_fwd_t : public primitive_t { using primitive_t::primitive_t; - struct pd_t : public sdpa::pd_t { - using sdpa::pd_t::pd_t; + struct pd_t : public sdpa_fwd_pd_t { + using sdpa_fwd_pd_t::sdpa_fwd_pd_t; - DECLARE_COMMON_PD_T("ocl:ref:any", ref_t); + DECLARE_COMMON_PD_T("ocl:ref:any", ref_fwd_t); status_t init(impl::engine_t *engine) { using namespace data_type; @@ -44,13 +44,13 @@ struct ref_t : public primitive_t { VDISPATCH_SDPA(attr()->has_default_values(smask_t::scales), VERBOSE_UNSUPPORTED_ATTR); - VDISPATCH_SDPA( - utils::everyone_is(4, qry_md()->ndims, key_md()->ndims, - val_md()->ndims, dst_md()->ndims), + VDISPATCH_SDPA(utils::everyone_is(4, desc()->qry_md()->ndims, + desc()->key_md()->ndims, + desc()->val_md()->ndims, dst_md()->ndims), VERBOSE_UNSUPPORTED_TAG); if (with_attn_mask()) { - VDISPATCH_SDPA( - attn_mask_md()->ndims == 4, VERBOSE_UNSUPPORTED_TAG); + VDISPATCH_SDPA(desc()->attn_mask_md()->ndims == 4, + VERBOSE_UNSUPPORTED_TAG); } VDISPATCH_SDPA(set_default_formats(), VERBOSE_UNSUPPORTED_TAG); @@ -66,11 +66,11 @@ struct ref_t : public primitive_t { int ndims = 4; - const memory_desc_wrapper qry_mdw(pd()->qry_md()); - const memory_desc_wrapper key_mdw(pd()->key_md()); - const memory_desc_wrapper val_mdw(pd()->val_md()); + const memory_desc_wrapper qry_mdw(pd()->desc()->qry_md()); + const memory_desc_wrapper key_mdw(pd()->desc()->key_md()); + const memory_desc_wrapper val_mdw(pd()->desc()->val_md()); const memory_desc_wrapper dst_mdw(pd()->dst_md()); - const memory_desc_wrapper msk_mdw(pd()->attn_mask_md()); + const memory_desc_wrapper msk_mdw(pd()->desc()->attn_mask_md()); using offset_t = decltype(offsets_t().src_off); offset_t qry_off, key_off, val_off, dst_off, msk_off; set_offsets(qry_mdw, qry_off); @@ -90,12 +90,13 @@ struct ref_t : public primitive_t { kernel_ctx.define_int("WITH_ATTN_SCALE", pd()->with_attn_scale()); kernel_ctx.define_int("WITH_ATTN_MASK", pd()->with_attn_mask()); - def_data_type(kernel_ctx, pd()->qry_md()->data_type, "QRY"); - def_data_type(kernel_ctx, pd()->key_md()->data_type, "KEY"); - def_data_type(kernel_ctx, pd()->val_md()->data_type, "VAL"); + def_data_type(kernel_ctx, pd()->desc()->qry_md()->data_type, "QRY"); + def_data_type(kernel_ctx, pd()->desc()->key_md()->data_type, "KEY"); + def_data_type(kernel_ctx, pd()->desc()->val_md()->data_type, "VAL"); def_data_type(kernel_ctx, pd()->dst_md()->data_type, "DST"); - def_data_type(kernel_ctx, pd()->attn_mask_md()->data_type, "MSK"); - def_data_type(kernel_ctx, pd()->scale_md()->data_type, "SCALE"); + def_data_type( + kernel_ctx, pd()->desc()->attn_mask_md()->data_type, "MSK"); + def_data_type(kernel_ctx, pd()->desc()->scale_md()->data_type, "SCALE"); CHECK(create_kernel(engine, &kernel_, "ref_sdpa", kernel_ctx)); if (!kernel_) return status::runtime_error; return status::success; diff --git a/src/gpu/intel/sdpa/utils.h b/src/gpu/intel/sdpa/utils.h index d5b41ab295e..6f466760ec2 100644 --- a/src/gpu/intel/sdpa/utils.h +++ b/src/gpu/intel/sdpa/utils.h @@ -28,8 +28,7 @@ #define VAL_OFF(x0, x1, x2, x3) _4D_OFF(VAL, x0, x1, x2, x3) #define MSK_OFF(x0, x1, x2, x3) _4D_OFF(MSK, x0, x1, x2, x3) -#define _BATCH_OFF(tag, x0, x1) \ - ((x0) * tag##_S.array[0] + (x1) * tag##_S.array[1]) +#define _BATCH_OFF(tag, x0, x1) ((x0) * tag##_S0 + (x1) * tag##_S1) #define QRY_BATCH(x0, x1) _BATCH_OFF(QRY, x0, x1) #define KEY_BATCH(x0, x1) _BATCH_OFF(KEY, x0, x1) @@ -37,27 +36,15 @@ #define DST_BATCH(x0, x1) _BATCH_OFF(DST, x0, x1) #define MSK_BATCH(x0, x1) _BATCH_OFF(MSK, x0, x1) -#define JOIN_COMMA(x, y) x, y - -#define RT_DIM4(varname) const int64x4_t varname -#define RT_OFFSETS(basename) \ - JOIN_COMMA(RT_DIM4(basename##_D), RT_DIM4(basename##_S)) - -#define KEY_OFFSETS RT_OFFSETS(KEY) -#define QRY_OFFSETS RT_OFFSETS(QRY) -#define VAL_OFFSETS RT_OFFSETS(VAL) -#define DST_OFFSETS RT_OFFSETS(DST) -#define MSK_OFFSETS RT_OFFSETS(MSK) - -// helper shorthands for accessing individual dimensions -#define KEY_D3 KEY_D.array[3] -#define KEY_S3 KEY_S.array[3] -#define KEY_S2 KEY_S.array[2] -#define QRY_S2 QRY_S.array[2] -#define VAL_S2 VAL_S.array[2] -#define DST_S2 DST_S.array[2] -#define MSK_D0 MSK_D.array[0] -#define MSK_D1 MSK_D.array[1] -#define MSK_S2 MSK_S.array[2] +#define KEY_OFFSETS \ + const long KEY_S0, const long KEY_S1, const long KEY_S2, \ + const long KEY_S3, const long KEY_D3 +#define QRY_OFFSETS const long QRY_S0, const long QRY_S1, const long QRY_S2 +#define VAL_OFFSETS const long VAL_S0, const long VAL_S1, const long VAL_S2 +#define DST_OFFSETS \ + const long DST_S0, const long DST_S1, const long DST_S2, const long DST_D1 +#define MSK_OFFSETS \ + const long MSK_S0, const long MSK_S1, const long MSK_S2, \ + const long MSK_D0, const long MSK_D1 #endif diff --git a/src/graph/backend/dnnl/executables/sdpa.cpp b/src/graph/backend/dnnl/executables/sdpa.cpp index 3ec9ce405e9..7b84f7feca2 100644 --- a/src/graph/backend/dnnl/executables/sdpa.cpp +++ b/src/graph/backend/dnnl/executables/sdpa.cpp @@ -76,7 +76,8 @@ sdpa_executable_t::sdpa_executable_t(std::shared_ptr &op, status_t s = create_sdpa_pd(sdpa_pd_, p_engine.get(), md_q.get(), md_k.get(), md_v.get(), md_dst.get(), md_mask.get(), md_scale.get(), is_invert_scale_, kv_head_number, mask_type_, softmax_alg, - attr.get(), qk_attr.get(), vs_attr.get()); + impl::prop_kind::forward_inference, attr.get(), qk_attr.get(), + vs_attr.get()); if (s != dnnl::impl::status::success) { is_initialized_ = false; } else { diff --git a/tests/gtests/internals/CMakeLists.txt b/tests/gtests/internals/CMakeLists.txt index 7b980801434..194c575d665 100644 --- a/tests/gtests/internals/CMakeLists.txt +++ b/tests/gtests/internals/CMakeLists.txt @@ -50,4 +50,10 @@ register_exe(${TEST_EXE}_env_vars_onednn "test" "dnnl_gtest") list(REMOVE_ITEM TEST_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/test_env_vars_onednn.cpp) +# Register SDPA tests as a separate executable due to their runtime +register_exe(${TEST_EXE}_sdpa + "${MAIN_SRC_GTEST};${CMAKE_CURRENT_SOURCE_DIR}/test_sdpa.cpp;${CMAKE_CURRENT_SOURCE_DIR}/test_utils.cpp" + "test" "dnnl_gtest") +list(REMOVE_ITEM TEST_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/test_sdpa.cpp) + register_exe(${TEST_EXE} "${TEST_SOURCES}" "test" "dnnl_gtest") diff --git a/tests/gtests/internals/sdpa_internal.hpp b/tests/gtests/internals/sdpa_internal.hpp index 8ab35a07884..cbd388802a8 100644 --- a/tests/gtests/internals/sdpa_internal.hpp +++ b/tests/gtests/internals/sdpa_internal.hpp @@ -42,10 +42,24 @@ dnnl_status_t DNNL_API sdpa_primitive_desc_create( const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc, const_dnnl_memory_desc_t mask_desc, const_dnnl_memory_desc_t scale_desc, bool invert_scale, dnnl_dim_t kv_head_number, int attn_mask_type, - dnnl_alg_kind_t softmax_alg, const_dnnl_primitive_attr_t attr, - const_dnnl_primitive_attr_t kq_attr, + dnnl_alg_kind_t softmax_alg, dnnl_prop_kind_t prop, + const_dnnl_primitive_attr_t attr, const_dnnl_primitive_attr_t kq_attr, const_dnnl_primitive_attr_t vs_attr); +dnnl_status_t DNNL_API sdpa_primitive_desc_create( + dnnl_primitive_desc_t *primitive_desc_iface, dnnl_engine_t engine, + const_dnnl_memory_desc_t query_desc, const_dnnl_memory_desc_t key_desc, + const_dnnl_memory_desc_t value_desc, const_dnnl_memory_desc_t dst_desc, + const_dnnl_memory_desc_t diff_query_desc, + const_dnnl_memory_desc_t diff_key_desc, + const_dnnl_memory_desc_t diff_value_desc, + const_dnnl_memory_desc_t diff_dst_desc, + const_dnnl_memory_desc_t dS_desc, const_dnnl_memory_desc_t mask_desc, + const_dnnl_memory_desc_t scale_desc, bool invert_scale, + dnnl_dim_t kv_head_number, int attn_mask_type, + dnnl_alg_kind_t softmax_alg, const_dnnl_primitive_attr_t attr, + const dnnl_primitive_desc *hint_fwd_pd); + namespace dnnl { namespace impl { @@ -63,18 +77,19 @@ struct sdpa : public dnnl::primitive { const memory::desc &scale_desc, const memory::desc &output_desc, bool invert_scale, memory::dim kv_head_number, int attn_mask_type, int softmax_alg, + prop_kind_t prop_kind = prop_kind::forward_inference, const primitive_attr &attr = default_attr(), const primitive_attr &kq_attr = default_attr(), const primitive_attr &vs_attr = default_attr()) { dnnl_primitive_desc_t pd = nullptr; - dnnl_status_t status - = sdpa_primitive_desc_create(&pd, aengine.get(), - query_desc.get(), key_desc.get(), value_desc.get(), - output_desc.get(), optional_arg(attn_mask_desc), - scale_desc.get(), invert_scale, kv_head_number, - attn_mask_type, (dnnl_alg_kind_t)softmax_alg, - attr.get(), kq_attr.get(), vs_attr.get()); + dnnl_status_t status = sdpa_primitive_desc_create(&pd, + aengine.get(), query_desc.get(), key_desc.get(), + value_desc.get(), output_desc.get(), + optional_arg(attn_mask_desc), scale_desc.get(), + invert_scale, kv_head_number, attn_mask_type, + (dnnl_alg_kind_t)softmax_alg, (prop_kind_t)prop_kind, + attr.get(), kq_attr.get(), vs_attr.get()); dnnl::error::wrap_c_api(status, "could not create a primitive descriptor for a sdpa " @@ -90,6 +105,55 @@ struct sdpa : public dnnl::primitive { /// @param pd Primitive descriptor for a sdpa primitive. sdpa(const primitive_desc &pd) : primitive(pd) {} }; + +/// Scaled Dot Product Attention (sdpa) backward propagation internal primitive. +/// Implementing internally for more flexible validation +struct sdpa_backward : public dnnl::primitive { + /// Primitive descriptor for a sdpa_backward primitive. + struct primitive_desc : public dnnl::primitive_desc { + /// Default constructor. Produces an empty object. + primitive_desc() = default; + + primitive_desc(const engine &aengine, const memory::desc &query_desc, + const memory::desc &key_desc, const memory::desc &value_desc, + const memory::desc *attn_mask_desc, + const memory::desc &scale_desc, const memory::desc &output_desc, + const memory::desc &diff_query_desc, + const memory::desc &diff_key_desc, + const memory::desc &diff_value_desc, + const memory::desc &diff_output_desc, + const memory::desc *dS_desc, bool invert_scale, + memory::dim kv_head_number, int attn_mask_type, int softmax_alg, + const sdpa::primitive_desc &hint_fwd_pd, + const primitive_attr &attr = default_attr()) { + + dnnl_primitive_desc_t pd = nullptr; + dnnl_status_t status = sdpa_primitive_desc_create(&pd, + aengine.get(), query_desc.get(), key_desc.get(), + value_desc.get(), output_desc.get(), + optional_arg(attn_mask_desc), scale_desc.get(), + diff_query_desc.get(), diff_key_desc.get(), + diff_value_desc.get(), diff_output_desc.get(), + dS_desc ? dS_desc->get() : nullptr, invert_scale, + kv_head_number, attn_mask_type, + (dnnl_alg_kind_t)softmax_alg, attr.get(), + hint_fwd_pd.get()); + + dnnl::error::wrap_c_api(status, + "could not create a primitive descriptor for a sdpa " + "primitive"); + reset(pd); + } + }; + + /// Default constructor. Produces an empty object. + sdpa_backward() = default; + + /// Constructs a sdpa primitive. + /// @param pd Primitive descriptor for a sdpa primitive. + sdpa_backward(const primitive_desc &pd) : primitive(pd) {} +}; + } // namespace impl } // namespace dnnl diff --git a/tests/gtests/internals/test_sdpa.cpp b/tests/gtests/internals/test_sdpa.cpp index a0c576a0edb..34ac2533517 100644 --- a/tests/gtests/internals/test_sdpa.cpp +++ b/tests/gtests/internals/test_sdpa.cpp @@ -28,6 +28,8 @@ #include #include +#define DEBUG_PRINT_MEM 0 + using mdt = memory::data_type; using dnnl::accumulation_mode; @@ -267,6 +269,9 @@ struct sdpa_tensors_t { memory m_scale; // tested sdpa arg, can be host-side scalar memory m_scale_prim; // reference (prim) sdpa arg + memory m_diff_query, m_diff_key, m_diff_value, m_diff_output, m_dS; + memory m_diff_query_quantized, m_diff_key_quantized, m_diff_value_quantized; + memory m_key_scales, m_key_zp, m_value_scales, m_value_zp; dnnl::primitive_attr sdpa_attr_quantized, sdpa_kq_attr_quantized, sdpa_vs_attr_quantized; @@ -519,6 +524,7 @@ sdpa_tensors_t get_descriptors(dnnl::engine &eng, dnnl::stream &strm, = {p.mb, p.heads.kv, p.seq_len.kv * 2, p.head_group.head_size}; const memory::dims v_sz = {p.mb, p.heads.kv, p.seq_len.kv, p.head_group.head_size}; + const memory::dims dS_sz = {p.mb, p.heads.q, p.seq_len.q, p.seq_len.kv}; const memory::dims scale_sz = {1, 1, 1, 1}; bool with_host_scale = p.stype == scale_type::host_side; @@ -602,24 +608,45 @@ sdpa_tensors_t get_descriptors(dnnl::engine &eng, dnnl::stream &strm, auto mask_md = memory::desc(mask_sz, p.mask.dt != mdt::undef ? p.mask.dt : p.dt.dt, abcd); auto output_md = memory::desc(q_sz, p.dt.dt, abcd); auto output_quantized_md = memory::desc(q_sz, p.dt.dt, abcd); + + auto dS_md = memory::desc(dS_sz, p.dt.dt, abcd); // clang-format on // Create memory objects out.m_query = double_and_resize(query_md, eng, strm, doubled_memory); + out.m_diff_query = double_and_resize(query_md, eng, strm, doubled_memory); + out.m_diff_query_quantized + = double_and_resize(query_md, eng, strm, doubled_memory); + out.m_key_quantized = double_and_resize(key_quantized_md, eng, strm, doubled_memory); + out.m_diff_key + = double_and_resize(key_quantized_md, eng, strm, doubled_memory); + out.m_diff_key_quantized + = double_and_resize(key_quantized_md, eng, strm, doubled_memory); + out.m_key_scales = double_and_resize(key_scales_md, eng, strm, doubled_memory); out.m_key_zp = double_and_resize(key_zp_md, eng, strm, doubled_memory); + out.m_value_quantized = double_and_resize(val_quantized_md, eng, strm, doubled_memory); + out.m_diff_value + = double_and_resize(val_quantized_md, eng, strm, doubled_memory); + out.m_diff_value_quantized + = double_and_resize(val_quantized_md, eng, strm, doubled_memory); + out.m_value_scales = double_and_resize(val_scales_md, eng, strm, doubled_memory); out.m_value_zp = double_and_resize(val_zp_md, eng, strm, doubled_memory); out.m_mask = double_and_resize(mask_md, eng, strm, doubled_memory); + out.m_output = double_and_resize(output_md, eng, strm, doubled_memory); out.m_output_quantized = double_and_resize(output_quantized_md, eng, strm, doubled_memory); + out.m_diff_output = double_and_resize(output_md, eng, strm, doubled_memory); + + out.m_dS = double_and_resize(dS_md, eng, strm, doubled_memory); // Allocate user data. std::vector query_data(product(q_sz), 0.f); @@ -628,6 +655,7 @@ sdpa_tensors_t get_descriptors(dnnl::engine &eng, dnnl::stream &strm, std::vector val_quantized_data(product(v_sz), 0); std::vector key_scale_data(product(key_scales_sz), std::nanf("1")); std::vector val_scale_data(product(val_scales_sz), std::nanf("1")); + std::vector diff_output_data(product(q_sz), 0.f); std::vector key_zp_data_signed(product(key_scales_sz), INT_MAX); std::vector val_zp_data_signed(product(val_scales_sz), INT_MAX); @@ -638,6 +666,8 @@ sdpa_tensors_t get_descriptors(dnnl::engine &eng, dnnl::stream &strm, std::vector mask_data(product(mask_sz), NAN); std::vector output_data(product(q_sz), NAN); + std::vector dS_data(product(dS_sz), 0); + out.sdpa_attr_quantized.set_scratchpad_mode(dnnl::scratchpad_mode::library); out.kq_mask = 0; @@ -700,6 +730,7 @@ sdpa_tensors_t get_descriptors(dnnl::engine &eng, dnnl::stream &strm, } fill_random(query_data, query_md); + fill_random(diff_output_data, output_md); fill_random_quantized(key_quantized_data, key_quantized_md, (p.key.dt == mdt::u4 || p.key.dt == mdt::u8)); fill_random_quantized(val_quantized_data, val_quantized_md, @@ -912,6 +943,8 @@ sdpa_tensors_t get_descriptors(dnnl::engine &eng, dnnl::stream &strm, // Write data to tensor object's handle. write_to_dnnl_memory(query_data.data(), out.m_query, eng, strm); + write_to_dnnl_memory(diff_output_data.data(), out.m_diff_output, eng, strm); + write_to_dnnl_memory( key_quantized_data.data(), out.m_key_quantized, eng, strm); @@ -962,6 +995,7 @@ sdpa_tensors_t get_descriptors(dnnl::engine &eng, dnnl::stream &strm, setup_device_scale(&out.m_scale); setup_device_scale(&out.m_scale_prim); + write_to_dnnl_memory(dS_data.data(), out.m_dS, eng, strm); return out; } @@ -1163,7 +1197,7 @@ void prim_sdpa_quant(const sdpa_dims_t &p, const sdpa_tensors_t &t, = dnnl::reorder::primitive_desc(eng, score_f16_md, eng, score_md); auto f16_to_f32_prim = dnnl::reorder(f16_to_f32_pd); - // binary primitive for scaling (f32) + // binary primitive for scaling primitive_attr binary_attr; auto scale_algo = invert_scale ? algorithm::binary_div : algorithm::binary_mul; @@ -1309,6 +1343,7 @@ void prim_sdpa_quant(const sdpa_dims_t &p, const sdpa_tensors_t &t, loop(); strm.wait(); + void *output_ptr_ = (void *)output.map_data(); void *grouped_output_ptr_ = (void *)grouped_output.map_data(); memcpy(output_ptr_, grouped_output_ptr_, grouped_query_md.get_size()); @@ -1317,6 +1352,452 @@ void prim_sdpa_quant(const sdpa_dims_t &p, const sdpa_tensors_t &t, strm.wait(); } +std::vector timeit( + const std::function &func, dnnl::stream &str, int iterations) { + using namespace std::chrono; + func(); + func(); + std::vector times; + for (int j = 0; j < 5; j++) { + auto e = steady_clock::now(); + str.wait(); + auto s = steady_clock::now(); + for (int i = 0; i < iterations; i++) { + func(); + } + str.wait(); + e = steady_clock::now(); + times.push_back(std::chrono::duration_cast(e - s)); + } + return times; +} + +std::chrono::nanoseconds prim_sdpa_quant_bwd(const sdpa_dims_t &p, + const sdpa_tensors_t &t, dnnl::engine &eng, dnnl::stream &strm, + dnnl::memory &query, dnnl::memory &key, + dnnl::memory::data_type scale_dt, dnnl::memory &scale, + dnnl::memory &mask, dnnl::memory &value, dnnl::memory &output, + dnnl::memory &diff_output, bool invert_scale, + std::vector &doubled_memory, dnnl::memory &diff_query, + dnnl::memory &diff_key, dnnl::memory &diff_value, + bool with_timing = false) { + + using namespace dnnl; + + primitive_attr bmm1_attr; + bmm1_attr.set_scratchpad_mode(dnnl::scratchpad_mode::library); + + post_ops bmm1_po; + auto scale_f32 = as(strm, scale, mdt::f32); + auto mask_f32 = as(strm, mask, mdt::f32); + auto mask_sz = mask.get_desc().get_dims(); + + if (scale_dt != mdt::undef) { + scale_f32 = reshape(strm, scale_f32, + {{1, 1, 1, 1, 1}, mdt::f32, memory::format_tag::abcde}); + if (invert_scale) + bmm1_po.append_binary(algorithm::binary_div, scale_f32.get_desc()); + else + bmm1_po.append_binary(algorithm::binary_mul, scale_f32.get_desc()); + } + if (p.mask.type != mask_type::no_mask) { + mask_f32 = reshape(strm, mask_f32, + {{mask_sz[0], 1, 1, mask_sz[2], mask_sz[3]}, mdt::f32, + memory::format_tag::abcde}); + bmm1_po.append_binary(algorithm::binary_add, mask_f32.get_desc()); + } + bmm1_attr.set_post_ops(bmm1_po); + // Keep a copy of the 5D scale for the backward pass. + // bmm1_args will std::move scale_f32, so save it before that happens. + memory scale_f32_bwd = scale_f32; + + int head_kv_group_size = 0; + int head_q_group_size = 0; + int head_group_batches = 0; + if (p.heads.kv == p.heads.q) { + head_kv_group_size = p.heads.kv; + head_q_group_size = p.heads.q; + head_group_batches = 1; + } else { + head_kv_group_size = 1; + head_q_group_size = p.heads.q / p.heads.kv; + head_group_batches = p.heads.kv; + } + + auto original_k_sz = key.get_desc().get_dims(); + const memory::dims k_sz {p.mb, head_group_batches, head_kv_group_size, + original_k_sz[2], original_k_sz[3]}; + const memory::dims v_sz {p.mb, head_group_batches, head_kv_group_size, + p.seq_len.kv, p.head_group.head_size}; + const memory::dims q_sz {p.mb, head_group_batches, head_q_group_size, + p.seq_len.q, p.head_group.head_size}; + + memory::desc grouped_key_md(k_sz, p.dt.dt, memory::format_tag::abcde); + memory::desc grouped_value_md(v_sz, p.dt.dt, memory::format_tag::abcde); + memory::desc grouped_query_md(q_sz, p.dt.dt, memory::format_tag::abcde); + + memory key_dequantized; + auto keytmp = as(strm, key, p.dt.dt); + grouped_key_md = p.key_format_tag == memory::format_tag::abcd + ? memory::desc(k_sz, p.dt.dt, memory::format_tag::abcde) + : memory::desc(k_sz, p.dt.dt, memory::format_tag::abced); + + key_dequantized = reshape(strm, keytmp, grouped_key_md); + + memory value_dequantized; + auto value32 = as(strm, value, p.dt.dt); + value_dequantized = reshape(strm, value32, grouped_value_md); + + memory grouped_query = reshape(strm, query, grouped_query_md); + + const memory::dims score_sz = {p.mb, head_group_batches, head_q_group_size, + p.seq_len.q, p.seq_len.kv}; + memory::desc score_md {score_sz, mdt::f32, memory::format_tag::abcde}; + memory::desc score_f16_md {score_sz, mdt::f16, memory::format_tag::abcde}; + + auto score = memory(score_md, eng); + auto score_f16 = memory(score_f16_md, eng); + auto score2 = memory(score_md, eng); + auto score2_f16 = memory(score_f16_md, eng); + + // matmul primitive for QK + auto bmm1_pd = matmul::primitive_desc(eng, grouped_query_md, + key_dequantized.get_desc(), score_md, bmm1_attr); + auto bmm1_prim = matmul(bmm1_pd); + + // softmax forward + primitive_attr softmax_attr; + softmax_attr.set_scratchpad_mode(scratchpad_mode::library); + auto softmax_fwd_pd = softmax_forward::primitive_desc(eng, + prop_kind::forward_training, + (algorithm)dnnl::impl::alg_kind::softmax_accurate_inf_as_zero, + score.get_desc(), score.get_desc(), 4, softmax_attr); + auto softmax_prim = softmax_forward(softmax_fwd_pd); + + // attention_output = attention_probs x value + primitive_attr bmm2_attr; + + bmm2_attr.set_scratchpad_mode(scratchpad_mode::library); + auto grouped_output + = double_and_resize(grouped_query_md, eng, strm, doubled_memory); + auto bmm2_pd = matmul::primitive_desc( + eng, score_md, grouped_value_md, grouped_query_md, bmm2_attr); + auto bmm2_prim = matmul(bmm2_pd); + + // setup args + std::unordered_map bmm1_args = {{DNNL_ARG_SRC, grouped_query}, + {DNNL_ARG_WEIGHTS, key_dequantized}, {DNNL_ARG_DST, score}}; + + if (scale_dt != mdt::undef) { + bmm1_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1] + = std::move(scale_f32); + if (p.mask.type != mask_type::no_mask) { + bmm1_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(1) | DNNL_ARG_SRC_1] + = std::move(mask_f32); + } + } else { + if (p.mask.type != mask_type::no_mask) { + bmm1_args[DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1] + = std::move(mask_f32); + } + } + + std::unordered_map bmm2_args + = {{DNNL_ARG_SRC, score2}, {DNNL_ARG_WEIGHTS, value_dequantized}, + {DNNL_ARG_DST, grouped_output}}; + + const auto fwd_loop = [&]() { + // QK + bmm1_prim.execute(strm, bmm1_args); + + // softmax + softmax_prim.execute(strm, + { + {DNNL_ARG_SRC, score}, + {DNNL_ARG_DST, score2}, + }); + + // SV + bmm2_prim.execute(strm, bmm2_args); + }; + + // Warmup run. + // Execute primitives of sdpa. + fwd_loop(); + strm.wait(); +#if DEBUG_PRINT_MEM + print_mem(grouped_query, "FWD grouped_query"); + print_mem(key_dequantized, "FWD keq_deq"); + print_mem(scale, "FWD scale"); + if (p.mask.type != mask_type::no_mask) { print_mem(mask, "FWD mask"); } + print_mem(score, "FWD intermediate score"); + print_mem(score2, "FWD intermediate score2"); +#endif + + void *output_ptr_ = (void *)output.map_data(); + void *grouped_output_ptr_ = (void *)grouped_output.map_data(); + memcpy(output_ptr_, grouped_output_ptr_, grouped_query_md.get_size()); + grouped_output.unmap_data(grouped_output_ptr_); + output.unmap_data(output_ptr_); + + strm.wait(); + + // end forward portion + // backwards pass of SDPA + + // init incoming gradient, dO + memory::desc grouped_output_md = grouped_output.get_desc(); + memory dO_mem = reshape(strm, diff_output, grouped_output_md); + +#if DEBUG_PRINT_MEM + printf("\n\n================\n\n\n"); + print_mem(dO_mem, "BWD incoming dO"); +#endif + + ///////// final MM + // matmul forward, O = s2 * v + + // init v, score2 gradients + memory::desc diff_score2_md(score_sz, mdt::f32, memory::format_tag::abcde); + memory dV_mem(grouped_value_md, eng); + memory diff_score2_mem(diff_score2_md, eng); + + // backwards pass gradient of s2 (dS2 = dO * v^t) + memory::desc v_t_md + = memory::desc({v_sz[0], v_sz[1], v_sz[2], v_sz[4], v_sz[3]}, + /*dO.get_dt*/ p.dt.dt, memory::format_tag::abced); + matmul::primitive_desc mm_bwd_dS2_pd( + eng, grouped_output_md, v_t_md, diff_score2_md); + matmul mm_bwd_dS2(mm_bwd_dS2_pd); + mm_bwd_dS2.execute(strm, + {{DNNL_ARG_SRC, dO_mem}, {DNNL_ARG_WEIGHTS, value_dequantized}, + {DNNL_ARG_DST, diff_score2_mem}}); + strm.wait(); + +#if DEBUG_PRINT_MEM + print_mem(diff_score2_mem, "BWD dS2"); +#endif + + // backwards pass gradient of v (dV = s2^t * dO) + + // downcast score to p.dt.dt + memory::desc s2_t_md = memory::desc( + {score_sz[0], score_sz[1], score_sz[2], score_sz[4], score_sz[3]}, + p.dt.dt, memory::format_tag::abced); + + memory score2_downcast; + auto score32 = as(strm, score2, p.dt.dt); + score2_downcast = reshape(strm, score32, s2_t_md); + + const bool is_gqa = (head_q_group_size != head_kv_group_size); + memory::dims dV_full_dims = {v_sz[0], v_sz[1], + is_gqa ? head_q_group_size : head_kv_group_size, v_sz[3], v_sz[4]}; + memory::desc dV_full_md(dV_full_dims, p.dt.dt, memory::format_tag::abcde); + memory dV_full_mem = is_gqa ? memory(dV_full_md, eng) : dV_mem; + matmul::primitive_desc mm_bwd_dV_pd( + eng, s2_t_md, grouped_output_md, dV_full_mem.get_desc()); + matmul mm_bwd_dV(mm_bwd_dV_pd); + mm_bwd_dV.execute(strm, + {{DNNL_ARG_SRC, score2_downcast}, {DNNL_ARG_WEIGHTS, dO_mem}, + {DNNL_ARG_DST, dV_full_mem}}); + strm.wait(); + // reduce [mb, groups, hq_group, kv, D] -> [mb, groups, 1, kv, D] for GQA + dnnl::reduction dV_reduce; + if (is_gqa) { + dnnl::reduction::primitive_desc dV_reduce_pd(eng, + algorithm::reduction_sum, dV_full_md, grouped_value_md, 0.f, + 0.f); + dV_reduce = dnnl::reduction(dV_reduce_pd); + dV_reduce.execute( + strm, {{DNNL_ARG_SRC, dV_full_mem}, {DNNL_ARG_DST, dV_mem}}); + strm.wait(); + } + +#if DEBUG_PRINT_MEM + print_mem(dV_mem, "BWD dV"); +#endif + + ///////// intermediate softmax(QK) + + // compute softmax backward and scale in f32 to match kernel/graph + memory::desc diff_score_f32_md( + score_sz, mdt::f32, memory::format_tag::abcde); + memory diff_score_f32_mem(diff_score_f32_md, eng); + + // backwards pass gradient of softmax + softmax_backward::primitive_desc softmax_bwd_pd(eng, + algorithm::softmax_accurate, diff_score_f32_md, diff_score2_md, + score_md, 4, softmax_fwd_pd); + softmax_backward softmax_bwd(softmax_bwd_pd); + softmax_bwd.execute(strm, + {{DNNL_ARG_DST, score2}, {DNNL_ARG_DIFF_DST, diff_score2_mem}, + {DNNL_ARG_DIFF_SRC, diff_score_f32_mem}}); + + binary scale_bwd_prim; + if (scale_dt != mdt::undef) { + auto scale_algo_bwd + = invert_scale ? algorithm::binary_div : algorithm::binary_mul; + auto scale_bwd_pd = binary::primitive_desc(eng, scale_algo_bwd, + diff_score_f32_md, scale_f32_bwd.get_desc(), diff_score_f32_md); + scale_bwd_prim = binary(scale_bwd_pd); + scale_bwd_prim.execute(strm, + {{DNNL_ARG_SRC_0, diff_score_f32_mem}, + {DNNL_ARG_SRC_1, scale_f32_bwd}, + {DNNL_ARG_DST, diff_score_f32_mem}}); + } + + // downcast dS from f32 to p.dt.dt for dQ/dK matmuls + memory::desc diff_score_md(score_sz, p.dt.dt, memory::format_tag::abcde); + memory diff_score_mem(diff_score_md, eng); + dnnl::reorder diff_score_cast(diff_score_f32_mem, diff_score_mem); + diff_score_cast.execute(strm, diff_score_f32_mem, diff_score_mem); + + strm.wait(); + // print softmax gradient +#if DEBUG_PRINT_MEM + print_mem(diff_score_mem, "BWD dS"); +#endif + + ///////// first MM + // matmul forward, s = q * k + + // init q,k gradients + memory dQ_mem(grouped_query_md, eng); + memory dK_mem(key_dequantized.get_desc(), eng); + + // backwards pass gradient of q (dQ = dS * k^t) + // k^t requires transposed format and dims + auto k_t_fmt = p.key_format_tag == memory::format_tag::abcd + ? memory::format_tag::abced + : memory::format_tag::abcde; + memory::desc k_t_md = memory::desc( + {k_sz[0], k_sz[1], k_sz[2], k_sz[4], k_sz[3]}, p.dt.dt, k_t_fmt); + matmul::primitive_desc mm_bwd_dq_pd( + eng, diff_score_md, k_t_md, grouped_query_md); + matmul mm_bwd_dq(mm_bwd_dq_pd); + mm_bwd_dq.execute(strm, + {{DNNL_ARG_SRC, diff_score_mem}, + {DNNL_ARG_WEIGHTS, key_dequantized}, + {DNNL_ARG_DST, dQ_mem}}); + strm.wait(); + + // backwards pass gradient of k (dK = q^t * dS) + memory::desc q_t_md({q_sz[0], q_sz[1], q_sz[2], q_sz[4], q_sz[3]}, p.dt.dt, + memory::format_tag::abced); + + // for GQA cases, perform additional reduction of dK along headq_group + auto key_fmt = p.key_format_tag == memory::format_tag::abcd + ? memory::format_tag::abcde + : memory::format_tag::abced; + memory::dims dK_full_dims = {k_sz[0], k_sz[1], + is_gqa ? head_q_group_size : head_kv_group_size, k_sz[3], k_sz[4]}; + memory::desc dK_full_md(dK_full_dims, p.dt.dt, key_fmt); + memory dK_full_mem = is_gqa ? memory(dK_full_md, eng) : dK_mem; + matmul::primitive_desc mm_bwd_dk_pd( + eng, q_t_md, diff_score_md, dK_full_mem.get_desc()); + matmul mm_bwd_dk(mm_bwd_dk_pd); + mm_bwd_dk.execute(strm, + {{DNNL_ARG_SRC, grouped_query}, {DNNL_ARG_WEIGHTS, diff_score_mem}, + {DNNL_ARG_DST, dK_full_mem}}); + strm.wait(); + // reduce [mb, groups, hq_group, D, K] -> [mb, groups, 1, D, K] for GQA cases + dnnl::reduction dK_reduce; + if (is_gqa) { + dnnl::reduction::primitive_desc dK_reduce_pd(eng, + algorithm::reduction_sum, dK_full_md, + key_dequantized.get_desc(), 0.f, 0.f); + dK_reduce = dnnl::reduction(dK_reduce_pd); + dK_reduce.execute( + strm, {{DNNL_ARG_SRC, dK_full_mem}, {DNNL_ARG_DST, dK_mem}}); + strm.wait(); + } + + const auto bwd_loop = [&]() { + strm.wait(); + mm_bwd_dS2.execute(strm, + {{DNNL_ARG_SRC, dO_mem}, {DNNL_ARG_WEIGHTS, value_dequantized}, + {DNNL_ARG_DST, diff_score2_mem}}); + mm_bwd_dV.execute(strm, + {{DNNL_ARG_SRC, score2_downcast}, {DNNL_ARG_WEIGHTS, dO_mem}, + {DNNL_ARG_DST, dV_full_mem}}); + if (is_gqa) + dV_reduce.execute(strm, + {{DNNL_ARG_SRC, dV_full_mem}, {DNNL_ARG_DST, dV_mem}}); + + softmax_bwd.execute(strm, + {{DNNL_ARG_DST, score2}, {DNNL_ARG_DIFF_DST, diff_score2_mem}, + {DNNL_ARG_DIFF_SRC, diff_score_f32_mem}}); + + if (scale_dt != mdt::undef) { + scale_bwd_prim.execute(strm, + {{DNNL_ARG_SRC_0, diff_score_f32_mem}, + {DNNL_ARG_SRC_1, scale_f32_bwd}, + {DNNL_ARG_DST, diff_score_f32_mem}}); + } + + diff_score_cast.execute(strm, diff_score_f32_mem, diff_score_mem); + + mm_bwd_dq.execute(strm, + {{DNNL_ARG_SRC, diff_score_mem}, + {DNNL_ARG_WEIGHTS, key_dequantized}, + {DNNL_ARG_DST, dQ_mem}}); + mm_bwd_dk.execute(strm, + {{DNNL_ARG_SRC, grouped_query}, + {DNNL_ARG_WEIGHTS, diff_score_mem}, + {DNNL_ARG_DST, dK_full_mem}}); + if (is_gqa) + dK_reduce.execute(strm, + {{DNNL_ARG_SRC, dK_full_mem}, {DNNL_ARG_DST, dK_mem}}); + strm.wait(); + }; + bwd_loop(); + strm.wait(); + + using namespace std::chrono; + nanoseconds qtime_bwd {}; + + if (with_timing) { + auto loop_bwd_prim = [&] { bwd_loop(); }; + + int iterations = 20; + auto quantized_bwd_time = timeit(loop_bwd_prim, strm, iterations); + + auto min_time = [](const std::vector &a) { + return *std::min_element(a.begin(), a.end()); + }; + + qtime_bwd = min_time(quantized_bwd_time) / iterations; + printf("qtime_training_backwards_prim %f\n", (float)qtime_bwd.count()); + } + + // print q,k gradients +#if DEBUG_PRINT_MEM + print_mem(dQ_mem, "BWD dQ"); + print_mem(dK_mem, "BWD dK"); +#endif + + // reshape gradients from grouped layout back to 4d + void *diff_query_ptr = (void *)diff_query.map_data(); + void *dQ_ptr = (void *)dQ_mem.map_data(); + memcpy(diff_query_ptr, dQ_ptr, grouped_query_md.get_size()); + dQ_mem.unmap_data(dQ_ptr); + diff_query.unmap_data(diff_query_ptr); + + void *diff_key_ptr = (void *)diff_key.map_data(); + void *dK_ptr = (void *)dK_mem.map_data(); + memcpy(diff_key_ptr, dK_ptr, key_dequantized.get_desc().get_size()); + dK_mem.unmap_data(dK_ptr); + diff_key.unmap_data(diff_key_ptr); + + void *diff_value_ptr = (void *)diff_value.map_data(); + void *dV_ptr = (void *)dV_mem.map_data(); + memcpy(diff_value_ptr, dV_ptr, grouped_value_md.get_size()); + dV_mem.unmap_data(dV_ptr); + diff_value.unmap_data(diff_value_ptr); + + return qtime_bwd; +} + template void check_memory(dnnl::stream &strm, memory &gold, memory &test, float max_diff_threshold = 0.03f, float fthreshold = 0.001466) { @@ -1380,7 +1861,7 @@ void check_memory(dnnl::stream &strm, memory &gold, memory &test, gold.unmap_data(mapped_ptr_gold); test.unmap_data(mapped_ptr_test); - int threshold = total * 0.0006; + int threshold = total * 0.002; ASSERT_LE(mismatches, threshold) << mismatches << " out of: " << total; ASSERT_LE(max_diff, max_diff_threshold); @@ -1397,26 +1878,6 @@ int to_attn_mask_type(mask_type t) { return static_cast(attn_mask); } -std::vector timeit( - const std::function &func, dnnl::stream &str, int iterations) { - using namespace std::chrono; - func(); - func(); - std::vector times; - for (int j = 0; j < 5; j++) { - auto e = steady_clock::now(); - str.wait(); - auto s = steady_clock::now(); - for (int i = 0; i < iterations; i++) { - func(); - } - str.wait(); - e = steady_clock::now(); - times.push_back(std::chrono::duration_cast(e - s)); - } - return times; -} - template O magnitude_cast(I input) { using ratio = std::ratio_divide; @@ -1553,8 +2014,8 @@ class sdpa_test_t : public ::testing::TestWithParam { t.m_scale.get_desc(), t.m_output_quantized.get_desc(), invert_scale, p.heads.kv, to_attn_mask_type(p.mask.type), dnnl::impl::alg_kind::softmax_accurate_inf_as_zero, - t.sdpa_attr_quantized, t.sdpa_kq_attr_quantized, - t.sdpa_vs_attr_quantized); + prop_kind::forward_inference, t.sdpa_attr_quantized, + t.sdpa_kq_attr_quantized, t.sdpa_vs_attr_quantized); sdpa_quantized_p = sdpa(sdpa_quantized_pd); } catch (const dnnl::error &e) { if (e.status == dnnl_unimplemented) @@ -1642,6 +2103,184 @@ class sdpa_test_t : public ::testing::TestWithParam { #endif } + void compare_bwd() { + using namespace dnnl::impl; + auto mask = t.m_mask.get_desc(); + + memory::desc *mask_ptr = nullptr; + + switch (p.mask.type) { + case mask_type::no_mask: + case mask_type::causal_tl: + case mask_type::causal_br: mask_ptr = nullptr; break; + case mask_type::oneD: + case mask_type::twoD: mask_ptr = &mask; break; + } + + auto dS_desc = t.m_dS.get_desc(); + memory::desc *dS_ptr = nullptr; + //dS_ptr = &dS_desc; // uncomment for optional dS output + + // fwd sdpa primitive to populate dst, col_maxes + sdpa::primitive_desc sdpa_fwd_pd; + sdpa sdpa_fwd; + + sdpa_backward::primitive_desc sdpa_bwd_pd; + sdpa_backward sdpa_bwd; + try { + sdpa_fwd_pd = sdpa::primitive_desc(eng, t.m_query.get_desc(), + t.m_key_quantized.get_desc(), + t.m_value_quantized.get_desc(), mask_ptr, + t.m_scale.get_desc(), t.m_output_quantized.get_desc(), + invert_scale, p.heads.kv, to_attn_mask_type(p.mask.type), + dnnl::impl::alg_kind::softmax_accurate_inf_as_zero, + prop_kind::forward_training, t.sdpa_attr_quantized, + t.sdpa_kq_attr_quantized, t.sdpa_vs_attr_quantized); + sdpa_fwd = sdpa(sdpa_fwd_pd); + + sdpa_bwd_pd = sdpa_backward::primitive_desc(eng, + t.m_query.get_desc(), t.m_key_quantized.get_desc(), + t.m_value_quantized.get_desc(), mask_ptr, + t.m_scale.get_desc(), t.m_output_quantized.get_desc(), + t.m_diff_query_quantized.get_desc(), + t.m_diff_key_quantized.get_desc(), + t.m_diff_value_quantized.get_desc(), + t.m_diff_output.get_desc(), dS_ptr, invert_scale, + p.heads.kv, to_attn_mask_type(p.mask.type), + dnnl::impl::alg_kind::softmax_accurate_inf_as_zero, + sdpa_fwd_pd, t.sdpa_attr_quantized); + sdpa_bwd = sdpa_backward(sdpa_bwd_pd); + } catch (const dnnl::error &e) { + if (e.status == dnnl_unimplemented) + GTEST_SKIP() << "Unimplemented: " << e.what(); + else + throw; + } + + auto sdpa_fwd_workspace_memory + = memory(sdpa_fwd_pd.workspace_desc(), eng); + + std::unordered_map sdpa_fwd_args + = {{DNNL_ARG_QUERIES, t.m_query}, + {DNNL_ARG_KEYS, t.m_key_quantized}, + {DNNL_ARG_VALUES, t.m_value_quantized}, + {DNNL_ARG_DST, t.m_output_quantized}, + {DNNL_ARG_WORKSPACE, sdpa_fwd_workspace_memory}}; + + if (scale_dt != mdt::undef) { + sdpa_fwd_args[DNNL_ARG_SCALE] = t.m_scale; + } + if (mask_ptr) { sdpa_fwd_args[DNNL_ARG_ATTN_MASK] = t.m_mask; } + + strm.wait(); + sdpa_fwd.execute(strm, sdpa_fwd_args); + strm.wait(); + + std::unordered_map sdpa_bwd_args + = {{DNNL_ARG_QUERIES, t.m_query}, + {DNNL_ARG_KEYS, t.m_key_quantized}, + {DNNL_ARG_VALUES, t.m_value_quantized}, + {DNNL_ARG_DST, t.m_output_quantized}, + {DNNL_ARG_DIFF_DST, t.m_diff_output}, + {DNNL_ARG_DIFF_QUERIES, t.m_diff_query_quantized}, + {DNNL_ARG_DIFF_KEYS, t.m_diff_key_quantized}, + {DNNL_ARG_DIFF_VALUES, t.m_diff_value_quantized}, + {DNNL_ARG_DS, t.m_dS}, + {DNNL_ARG_WORKSPACE, sdpa_fwd_workspace_memory}}; + +#if DEBUG_PRINT_MEM + print_mem(t.m_query, "input t_m_query"); + print_mem(t.m_key_quantized, "input t_m_key"); + print_mem(t.m_value_quantized, "input t_m_value"); + print_mem(t.m_diff_output, "input dA"); +#endif + + if (scale_dt != mdt::undef) { + sdpa_bwd_args[DNNL_ARG_SCALE] = t.m_scale; + } + if (mask_ptr) { sdpa_bwd_args[DNNL_ARG_ATTN_MASK] = t.m_mask; } + + sdpa_bwd.execute(strm, sdpa_bwd_args); + strm.wait(); + +#if DEBUG_PRINT_MEM + print_mem(t.m_dS, "computed dS"); + print_mem(t.m_diff_value_quantized, "dV bwd out"); + printf("-------------- Primitives based implementation -------------- " + "\n"); +#endif + + // perform primitives based backwards sdpa pass to generate "gold" gradient outputs + prim_sdpa_quant_bwd(p, t, eng, strm, t.m_query, t.m_key_quantized, + scale_dt, t.m_scale, t.m_mask, t.m_value_quantized, t.m_output, + t.m_diff_output, invert_scale, doubled_memory, t.m_diff_query, + t.m_diff_key, t.m_diff_value); + + float max_diff_threshold = 0.3f; + // backward thresholds are higher than forward due to chained matmuls + // softmax backward potential catastrophic cancellation S*(dP - Di) + // and atomic adds across dQ accumulation + float fthreshold = 0.f; + if (p.dt.dt == mdt::bf16) { + fthreshold = 0.1f; + } else if (p.dt.dt == mdt::f16) { + fthreshold = 0.035f; + } else { + fthreshold = 0.002f; + } + + strm.wait(); + + if (t.m_output.get_desc().get_data_type() == mdt::f16) { + check_memory(strm, t.m_output, t.m_output_quantized, + max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_query, + t.m_diff_query_quantized, max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_key, t.m_diff_key_quantized, + max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_value, + t.m_diff_value_quantized, max_diff_threshold, fthreshold); + + } else if (t.m_output.get_desc().get_data_type() == mdt::bf16) { + check_memory(strm, t.m_output, t.m_output_quantized, + max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_query, + t.m_diff_query_quantized, max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_key, t.m_diff_key_quantized, + max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_value, + t.m_diff_value_quantized, max_diff_threshold, fthreshold); + + } else if (t.m_output.get_desc().get_data_type() == mdt::f32) { + check_memory(strm, t.m_output, t.m_output_quantized, + max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_query, + t.m_diff_query_quantized, max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_key, t.m_diff_key_quantized, + max_diff_threshold, fthreshold); + check_memory(strm, t.m_diff_value, + t.m_diff_value_quantized, max_diff_threshold, fthreshold); + } + +#if DEBUG_PRINT_MEM + print_mem(t.m_output, "gold m_output"); + print_mem(t.m_output_quantized, "test m_output"); + printf("----------\n\n"); + + print_mem(t.m_diff_query, "gold m_diff_query"); + print_mem(t.m_diff_query_quantized, "test m_diff_query"); + printf("----------\n\n"); + + print_mem(t.m_diff_key, "gold m_diff_key"); + print_mem(t.m_diff_key_quantized, "test m_diff_key"); + printf("----------\n\n"); + + print_mem(t.m_diff_value, "gold m_diff_value"); + print_mem(t.m_diff_value_quantized, "test m_diff_value"); + printf("----------\n\n"); +#endif + } + void perf() { using namespace dnnl::impl; auto mask = t.m_mask.get_desc(); @@ -1665,8 +2304,8 @@ class sdpa_test_t : public ::testing::TestWithParam { t.m_scale.get_desc(), t.m_output_quantized.get_desc(), invert_scale, p.heads.kv, to_attn_mask_type(p.mask.type), alg_kind::softmax_accurate_inf_as_zero, - t.sdpa_attr_quantized, t.sdpa_kq_attr_quantized, - t.sdpa_vs_attr_quantized); + prop_kind::forward_inference, t.sdpa_attr_quantized, + t.sdpa_kq_attr_quantized, t.sdpa_vs_attr_quantized); sdpa_quantized_p = sdpa(sdpa_quantized_pd); } catch (const dnnl::error &e) { if (e.status == dnnl_unimplemented) @@ -1788,6 +2427,194 @@ class sdpa_test_t : public ::testing::TestWithParam { << "|" << std::endl; } + void perf_bwd(bool time_reference = false) { + using namespace dnnl::impl; + auto mask = t.m_mask.get_desc(); + + memory::desc *mask_ptr = nullptr; + + switch (p.mask.type) { + case mask_type::no_mask: + case mask_type::causal_tl: + case mask_type::causal_br: mask_ptr = nullptr; break; + case mask_type::oneD: + case mask_type::twoD: mask_ptr = &mask; break; + } + + // Forward training pass (needed for workspace) + sdpa::primitive_desc sdpa_fwd_pd; + sdpa sdpa_fwd; + + sdpa_backward::primitive_desc sdpa_bwd_pd; + sdpa_backward sdpa_bwd; + try { + sdpa_fwd_pd = sdpa::primitive_desc(eng, t.m_query.get_desc(), + t.m_key_quantized.get_desc(), + t.m_value_quantized.get_desc(), mask_ptr, + t.m_scale.get_desc(), t.m_output_quantized.get_desc(), + invert_scale, p.heads.kv, to_attn_mask_type(p.mask.type), + dnnl::impl::alg_kind::softmax_accurate_inf_as_zero, + prop_kind::forward_training, t.sdpa_attr_quantized, + t.sdpa_kq_attr_quantized, t.sdpa_vs_attr_quantized); + sdpa_fwd = sdpa(sdpa_fwd_pd); + + sdpa_bwd_pd = sdpa_backward::primitive_desc(eng, + t.m_query.get_desc(), t.m_key_quantized.get_desc(), + t.m_value_quantized.get_desc(), mask_ptr, + t.m_scale.get_desc(), t.m_output_quantized.get_desc(), + t.m_diff_query_quantized.get_desc(), + t.m_diff_key_quantized.get_desc(), + t.m_diff_value_quantized.get_desc(), + t.m_diff_output.get_desc(), nullptr, invert_scale, + p.heads.kv, to_attn_mask_type(p.mask.type), + dnnl::impl::alg_kind::softmax_accurate_inf_as_zero, + sdpa_fwd_pd, t.sdpa_attr_quantized); + sdpa_bwd = sdpa_backward(sdpa_bwd_pd); + } catch (const dnnl::error &e) { + if (e.status == dnnl_unimplemented) + GTEST_SKIP() << "Unimplemented: " << e.what(); + else + throw; + } + + // Execute forward pass to populate workspace + auto sdpa_fwd_workspace_memory + = memory(sdpa_fwd_pd.workspace_desc(), eng); + + std::unordered_map sdpa_fwd_args + = {{DNNL_ARG_QUERIES, t.m_query}, + {DNNL_ARG_KEYS, t.m_key_quantized}, + {DNNL_ARG_VALUES, t.m_value_quantized}, + {DNNL_ARG_DST, t.m_output_quantized}, + {DNNL_ARG_WORKSPACE, sdpa_fwd_workspace_memory}}; + + if (scale_dt != mdt::undef) { + sdpa_fwd_args[DNNL_ARG_SCALE] = t.m_scale; + } + if (mask_ptr) { sdpa_fwd_args[DNNL_ARG_ATTN_MASK] = t.m_mask; } + + sdpa_fwd.execute(strm, sdpa_fwd_args); + strm.wait(); + + // Build backward args + std::unordered_map sdpa_bwd_args + = {{DNNL_ARG_QUERIES, t.m_query}, + {DNNL_ARG_KEYS, t.m_key_quantized}, + {DNNL_ARG_VALUES, t.m_value_quantized}, + {DNNL_ARG_DST, t.m_output_quantized}, + {DNNL_ARG_DIFF_DST, t.m_diff_output}, + {DNNL_ARG_DIFF_QUERIES, t.m_diff_query_quantized}, + {DNNL_ARG_DIFF_KEYS, t.m_diff_key_quantized}, + {DNNL_ARG_DIFF_VALUES, t.m_diff_value_quantized}, + {DNNL_ARG_WORKSPACE, sdpa_fwd_workspace_memory}}; + + if (scale_dt != mdt::undef) { + sdpa_bwd_args[DNNL_ARG_SCALE] = t.m_scale; + } + if (mask_ptr) { sdpa_bwd_args[DNNL_ARG_ATTN_MASK] = t.m_mask; } + + // Time backward pass + auto loop_bwd = [&] { sdpa_bwd.execute(strm, sdpa_bwd_args); }; + + int iterations = 20; + auto bwd_time = timeit(loop_bwd, strm, iterations); + + using namespace std::chrono; + auto min_time = [](const std::vector &a) { + return *std::min_element(a.begin(), a.end()); + }; + + auto qtime = min_time(bwd_time) / iterations; + printf("qtimebwd %f\n", (float)qtime.count() / 1e6); + + // Backward reads: Q, K, V, O, dO, workspace(logsumexp) + // Backward writes: dQ, dK, dV + byte_t<> total_bytes = t.m_query.get_desc().get_size() + + t.m_key_quantized.get_desc().get_size() + + t.m_value_quantized.get_desc().get_size() + + t.m_output_quantized.get_desc().get_size() + + t.m_diff_output.get_desc().get_size() + + t.m_diff_query_quantized.get_desc().get_size() + + t.m_diff_key_quantized.get_desc().get_size() + + t.m_diff_value_quantized.get_desc().get_size() + + (mask_ptr ? t.m_mask.get_desc().get_size() : 0); + + size_t kv_slice_tensor_elements + = (p.head_group.head_size * p.seq_len.kv); + auto mask_slice_elements = 0; + switch (p.mask.type) { + case mask_type::twoD: + mask_slice_elements = p.seq_len.kv * p.seq_len.q; + break; + case mask_type::oneD: mask_slice_elements = p.seq_len.kv; break; + default: mask_slice_elements = 0; break; + } + size_t batch_elements = p.mb * std::max(p.heads.q, p.heads.kv); + + // Effective bytes: expand broadcast tensors to full batch size + // Reads: Q, K, V, O, dO (all expanded); Writes: dQ, dK, dV + byte_t<> total_bytes_effective = (batch_elements + * (byte_t<>(p.key.dt) * kv_slice_tensor_elements + + byte_t<>(p.value.dt) * kv_slice_tensor_elements + + byte_t<>(p.dt.dt) + * (5 * p.head_group.head_size * p.seq_len.q) + + byte_t<>(p.dt.dt) * (2 * kv_slice_tensor_elements) + + (mask_ptr ? byte_t<>(p.mask.dt) * mask_slice_elements + : 0))); + + // Backward has 5 matmuls: S=K*Q, dS2=V*dO, dV=S*dO, dK=dS*Q, dQ=dS*K + // plus softmax_bwd, scale, Di reduction + float causal_divisor = (p.mask.type == mask_type::causal_tl + || p.mask.type == mask_type::causal_br) + ? 2.f + : 1.f; + + // 5 matmuls, flash-attention paper scales FWD flops by 5/2 + // softmax_bwd (~5*K*Q), scale, Di + num_ops_t<> total_flops = std::max(p.heads.kv, p.heads.q) * p.mb + * (2.f + * (2.5f * 2.f * p.head_group.head_size + * p.seq_len.kv * p.seq_len.q) + + (scale_dt != mdt::undef + ? (2 * p.seq_len.kv * p.seq_len.q) + : 0) + + (p.mask.type != mask_type::no_mask + ? (p.seq_len.kv * p.seq_len.q) + : 0) + + (5 * p.seq_len.kv * p.seq_len.q)) + / causal_divisor; + + // Just the 5 matmul flops + num_ops_t<> flash_flops + = (2.5f * 2.f * p.mb * p.heads.q * p.seq_len.kv * p.seq_len.q + * p.head_group.head_size) + / causal_divisor; + + std::call_once(header_flag, print_table_header); + std::cout << "BWD" << print_row(p) << "|" << qtime.count() << "|" + << bandwidth(magnitude_cast(total_bytes_effective), + qtime) + << "/" + << bandwidth(magnitude_cast(total_bytes), qtime) + << "|" << compute(magnitude_cast(flash_flops), qtime) + << "/" << compute(magnitude_cast(total_flops), qtime) + << "|" << std::endl; + + if (time_reference) { + auto ref_time = prim_sdpa_quant_bwd(p, t, eng, strm, t.m_query, + t.m_key_quantized, scale_dt, t.m_scale, t.m_mask, + t.m_value_quantized, t.m_output, t.m_diff_output, + invert_scale, doubled_memory, t.m_diff_query, t.m_diff_key, + t.m_diff_value, /*with_timing=*/true); + + float speedup = ref_time.count() > 0 + ? (float)ref_time.count() / (float)qtime.count() + : 0.f; + std::cout << "REF" << print_row(p) << "|" << ref_time.count() << "|" + << speedup << "x speedup|" << std::endl; + } + } + protected: dnnl::engine eng; dnnl::stream strm; @@ -1804,6 +2631,8 @@ memory::format_tag no_key_transposed = memory::format_tag::abcd; using sdpa_test = sdpa_test_t; using sdpa_test_datatypes = sdpa_test_t; +using sdpa_bwd_test = sdpa_test_t; +using sdpa_bwd_test_datatypes = sdpa_test_t; // clang-format off @@ -2071,6 +2900,21 @@ INSTANTIATE_TEST_SUITE_P(phi3_mini_4k_instruct, sdpa_dims_t{ 1, 32, 32, 2049, 1, 96, 96, 96, mdt::f16, mdt::s8, mdt::f16, mdt::s8, mdt::s8, mdt::f16, mdt::s8, mdt::f16, quantize_type::per_token_with_groups, with_key_transposed, mask_type::twoD } ), &print_to_string); +INSTANTIATE_TEST_SUITE_P(bwd_perf, + sdpa_bwd_test, + // mb,hd_num,kv_hd_num,seq_len,qry_num,hd_size, kg_sz, vgrp_sz, dt, kdt, ksdt, kzpdt, vdt, vsdt, vzpdt, mskdt, qtype + testing::Values( + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 32, 32, 32, mdt::f32, mdt::f32, mdt::undef, mdt::undef, mdt::f32, mdt::undef, mdt::undef, mdt::f32, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 64, 64, 64, mdt::f32, mdt::f32, mdt::undef, mdt::undef, mdt::f32, mdt::undef, mdt::undef, mdt::f32, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 128, 128, 128, mdt::f32, mdt::f32, mdt::undef, mdt::undef, mdt::f32, mdt::undef, mdt::undef, mdt::f32, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 32, 32, 32, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 64, 64, 64, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::no_mask }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 32, 32, 32, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::causal_tl }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 64, 64, 64, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::causal_tl }, + sdpa_dims_t{ 4, 4, 4, 4096, 4096, 128, 128, 128, mdt::f16, mdt::f16, mdt::undef, mdt::undef, mdt::f16, mdt::undef, mdt::undef, mdt::f16, quantize_type::no_quantization, no_key_transposed, mask_type::causal_tl } + ), &print_to_string); + // clang-format on GPU_TEST_P(sdpa_test, compare) { @@ -2081,12 +2925,161 @@ GPU_TEST_P(sdpa_test_datatypes, compare) { compare(); } +GPU_TEST_P(sdpa_bwd_test_datatypes, compare_bwd) { + compare_bwd(); +} + GPU_TEST_P(sdpa_test, perf) { perf(); } -/* -GPU_TEST_P(sdpa_test_datatypes, perf) { - perf(); +GPU_TEST_P(sdpa_bwd_test, perf_bwd) { + /// long running benchmark, + /// commented to avoid timeouts in CI + // const bool time_reference = true; + // perf_bwd(time_reference); + // } -*/ + +// clang-format off + +// backward pass: f16 +INSTANTIATE_TEST_SUITE_P(bwd_f16, sdpa_bwd_test_datatypes, + testing::Combine(testing::Values(1, 2), // mb + testing::Values(num_heads_t {1, 1}, num_heads_t {2, 2}), // heads + testing::Values(seq_len_size_t {64, 64}, + seq_len_size_t {1024, 1024}), // seq_len + testing::Values(head_group_size_t {32, 32, 32}, + head_group_size_t {64, 64, 64}, + head_group_size_t {128, 128, 128}), // head_size + testing::Values(tensor_type_t("Q", mdt::f16)), // dt + testing::Values(tensor_type_t("K", mdt::f16)), // kdt + testing::Values(tensor_type_t("V", mdt::f16)), // vdt + testing::Values(quantize_type::no_quantization), // qtype + testing::Values(dnnl::memory::format_tag::abcd, + dnnl::memory::format_tag::abdc), // key_format_tag + testing::Values(mask_config_t {mask_type::no_mask}, + mask_config_t {mask_type::causal_tl}, mask_config_t {mask_type::causal_br}, + mask_config_t {mask_type::twoD}, mask_config_t {mask_type::oneD} + ), // mask_type + testing::Values(default_scale_type), // scale_type + testing::Values( + accumulation_t {accumulation_mode::f32, + accumulation_mode::f32}) // accumulation_mode + ), + &print_to_string2); + +// backward pass: bf16 +INSTANTIATE_TEST_SUITE_P(bwd_bf16, sdpa_bwd_test_datatypes, + testing::Combine(testing::Values(1, 2), // mb + testing::Values(num_heads_t {1, 1}, num_heads_t {2, 2}), // heads + testing::Values(seq_len_size_t {32, 32}, seq_len_size_t {128, 128}, + seq_len_size_t {384, 384}), // seq_len + testing::Values(head_group_size_t {32, 32, 32}, + head_group_size_t {64, 64, 64}, + head_group_size_t {128, 128, 128}), // head_size + testing::Values(tensor_type_t("Q", mdt::bf16)), // dt + testing::Values(tensor_type_t("K", mdt::bf16)), // kdt + testing::Values(tensor_type_t("V", mdt::bf16)), // vdt + testing::Values(quantize_type::no_quantization), // qtype + testing::Values(dnnl::memory::format_tag::abcd, + dnnl::memory::format_tag::abdc), // key_format_tag + testing::Values(mask_config_t {mask_type::causal_tl}), // mask_type + testing::Values(default_scale_type), // scale_type + testing::Values( + accumulation_t {accumulation_mode::f32, + accumulation_mode::f32}) // accumulation_mode + ), + &print_to_string2); + +// backward pass: GQA configurations +INSTANTIATE_TEST_SUITE_P(bwd_gqa, sdpa_bwd_test_datatypes, + testing::Combine(testing::Values(1), // mb + testing::Values(num_heads_t {4, 2}, num_heads_t {8, 2}, + num_heads_t {32, 2}), // heads (q > kv) + testing::Values(seq_len_size_t {32, 32}, seq_len_size_t {128, 128}, + seq_len_size_t {384, 384}), // seq_len + testing::Values(head_group_size_t {64, 64, 64}, + head_group_size_t {128, 128, 128}), // head_size + testing::Values(tensor_type_t("Q", mdt::f16)), // dt + testing::Values(tensor_type_t("K", mdt::f16)), // kdt + testing::Values(tensor_type_t("V", mdt::f16)), // vdt + testing::Values(quantize_type::no_quantization), // qtype + testing::Values(dnnl::memory::format_tag::abcd, + dnnl::memory::format_tag::abdc), // key_format_tag + testing::Values(mask_config_t {mask_type::no_mask}, + mask_config_t {mask_type::causal_tl}), // mask_type + testing::Values(default_scale_type), // scale_type + testing::Values( + accumulation_t {accumulation_mode::f32, + accumulation_mode::f32}) // accumulation_mode + ), + &print_to_string2); + +// backward pass: non-uniform sequence lengths (q != kv) +INSTANTIATE_TEST_SUITE_P(bwd_nonuniform_seq, sdpa_bwd_test_datatypes, + testing::Combine(testing::Values(1), // mb + testing::Values(num_heads_t {2, 2}), // heads + testing::Values(seq_len_size_t {64, 513}, + seq_len_size_t {513, 64}), + testing::Values(head_group_size_t {32, 32, 32}, + head_group_size_t {64, 64, 64}), // head_size + testing::Values(tensor_type_t("Q", mdt::f16)), // dt + testing::Values(tensor_type_t("K", mdt::f16)), // kdt + testing::Values(tensor_type_t("V", mdt::f16)), // vdt + testing::Values(quantize_type::no_quantization), // qtype + testing::Values(dnnl::memory::format_tag::abcd), // key_format_tag + testing::Values(mask_config_t {mask_type::no_mask}), // mask_type + testing::Values(default_scale_type), // scale_type + testing::Values( + accumulation_t {accumulation_mode::f32, + accumulation_mode::f32}) // accumulation_mode + ), + &print_to_string2); + +// backward pass: f32 +INSTANTIATE_TEST_SUITE_P(bwd_f32, sdpa_bwd_test_datatypes, + testing::Combine(testing::Values(1, 2), // mb + testing::Values(num_heads_t {1, 1}, num_heads_t {2, 2}), // heads + testing::Values(seq_len_size_t {32, 32}, + seq_len_size_t {384, 384}, + seq_len_size_t {4096, 4096}), // seq_len + testing::Values( + head_group_size_t {32, 32, 32}, + head_group_size_t {64, 64, 64}, + head_group_size_t {128, 128, 128}), // head_size + testing::Values(tensor_type_t("Q", mdt::f32)), // dt + testing::Values(tensor_type_t("K", mdt::f32)), // kdt + testing::Values(tensor_type_t("V", mdt::f32)), // vdt + testing::Values(quantize_type::no_quantization), // qtype + testing::Values(dnnl::memory::format_tag::abcd), // key_format_tag + testing::Values(mask_config_t {mask_type::no_mask}, + mask_config_t {mask_type::causal_tl}), // mask_type + testing::Values(default_scale_type), // scale_type + testing::Values( + accumulation_t {accumulation_mode::f32, + accumulation_mode::f32}) // accumulation_mode + ), + &print_to_string2); + + +// backward pass: large batch and head counts +INSTANTIATE_TEST_SUITE_P(bwd_large_batch, sdpa_bwd_test_datatypes, + testing::Combine(testing::Values(4), // mb + testing::Values(num_heads_t {4, 4}), // heads + testing::Values(seq_len_size_t {4096, 4096}), // seq_len + testing::Values(head_group_size_t {32, 32, 32}), // head_size + testing::Values(tensor_type_t("Q", mdt::f16)), // dt + testing::Values(tensor_type_t("K", mdt::f16)), // kdt + testing::Values(tensor_type_t("V", mdt::f16)), // vdt + testing::Values(quantize_type::no_quantization), // qtype + testing::Values(dnnl::memory::format_tag::abcd), // key_format_tag + testing::Values(mask_config_t {mask_type::no_mask}), // mask_type + testing::Values(default_scale_type), // scale_type + testing::Values( + accumulation_t {accumulation_mode::f32, + accumulation_mode::f32}) // accumulation_mode + ), + &print_to_string2); + +// clang-format on