Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
cbe271f
common: sdpa: adds backwards sdpa primitives
syurkevi Feb 24, 2026
fa30129
common: sdpa: prepares internal primitive for bwd testing
syurkevi Feb 24, 2026
c622268
gtests: internals: adds sdpa training tests
syurkevi Feb 24, 2026
8b5c39b
xe: sdpa: prepares config structs for bwd pass
syurkevi Feb 24, 2026
88ed3bc
gpu: gemm: jit: allow LLR gemms
syurkevi Mar 7, 2026
868dd36
gpu: gemm: jit: enables ukernel with multiple kernel sources
syurkevi Mar 7, 2026
9170719
xe: sdpa: splits fwd/bwd gpu training primitives
syurkevi Feb 25, 2026
52114cc
xe: sdpa: split gemm setup for backwards pass
syurkevi Feb 25, 2026
05eee81
xe: sdpa: calculates forward logusmexp to ws
syurkevi Feb 25, 2026
1fbb9a9
xe: sdpa: updates tile_ops packed load/stores
syurkevi Feb 25, 2026
9299438
xe: sdpa: adds bwd kernel implementation
syurkevi Feb 25, 2026
2fa06da
xe: sdpa: adds create_sdpa_pd for backwards pass
syurkevi Feb 26, 2026
124cf6c
xe: sdpa: rename fwd/bwd primitives
syurkevi Mar 6, 2026
7b2daeb
gtests: internals: separate sdpa from internals
syurkevi Mar 7, 2026
8d8ab57
common: sdpa: move prop_kind param before attrs
syurkevi Mar 9, 2026
55684e0
xe: sdpa: enable transpose_k for training
syurkevi Mar 11, 2026
59dd346
common: sdpa: refactors pd accessors, misc cleanup
syurkevi Mar 12, 2026
65748a1
graph: backend: dnnl: move softmax decomposition to a transform pass
ElaineBao Feb 26, 2026
bd12f2d
graph: backend: dnnl: enable sdpa microkernel for training fwd
ElaineBao Feb 26, 2026
218f0e9
graph: backend: dnnl: enable sdpa microkernel for training bwd
ElaineBao Mar 6, 2026
4798b55
tests, examples: update sdpa training bwd cases
ElaineBao Mar 13, 2026
8948d9e
tests: benchdnn: graph: remove legacy sdpa bwd cases
ElaineBao Mar 12, 2026
5638f5f
graph: backend: dnnl: fix code format
ElaineBao Mar 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions examples/graph/gqa_training.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,13 +317,33 @@ bool bench_gqa_backward(engine::kind ekind, logical_tensor::data_type dt,
bmm_do_v.add_inputs({doutput, value});
bmm_do_v.add_outputs({dprobs});

// compute dmasked_score = dsoftmax(dprobs)
// compute dmasked_score = P * (dprobs - ReduceSum(O * dO))
// decomposed softmax backward: dS = P * (dP - rowsum(O * dO))
auto o_do_out
= logical_tensor(id++, dt_inter, output_sz, layout_type::strided);
auto o_do_mul = op(id++, op::kind::Multiply, "mul_o_do");
o_do_mul.add_inputs({output, doutput});
o_do_mul.add_outputs({o_do_out});

auto correction_out
= logical_tensor(id++, dt_inter, stats_sz, layout_type::strided);
auto correction = op(id++, op::kind::ReduceSum, "reducesum_correction");
correction.set_attr<std::vector<int64_t>>(op::attr::axes, {4});
correction.set_attr<bool>(op::attr::keep_dims, true);
correction.add_inputs({o_do_out});
correction.add_outputs({correction_out});

auto dp_corrected_out
= logical_tensor(id++, dt_inter, score_sz, layout_type::strided);
auto dp_corrected_op = op(id++, op::kind::Subtract, "sub_dp_corrected");
dp_corrected_op.add_inputs({dprobs, correction_out});
dp_corrected_op.add_outputs({dp_corrected_out});

auto dmasked_score
= logical_tensor(id++, dt_inter, score_sz, layout_type::strided);
auto softmax_grad = op(id++, op::kind::SoftMaxBackward, "softmax_bwd");
softmax_grad.set_attr<int64_t>(op::attr::axis, -1);
softmax_grad.add_inputs({dprobs, probs});
softmax_grad.add_outputs({dmasked_score});
auto softmax_bwd_mul = op(id++, op::kind::Multiply, "mul_softmax_bwd");
softmax_bwd_mul.add_inputs({probs, dp_corrected_out});
softmax_bwd_mul.add_outputs({dmasked_score});

// compute dscored_score = dmasked_score / scale
auto dscaled_score
Expand Down Expand Up @@ -372,10 +392,13 @@ bool bench_gqa_backward(engine::kind ekind, logical_tensor::data_type dt,
gqa_bwd.add_op(exp);
gqa_bwd.add_op(bmm_p_do);
gqa_bwd.add_op(bmm_do_v);
gqa_bwd.add_op(softmax_grad);
gqa_bwd.add_op(o_do_mul);
gqa_bwd.add_op(correction);
gqa_bwd.add_op(dp_corrected_op);
gqa_bwd.add_op(softmax_bwd_mul);
gqa_bwd.add_op(scale_div2);
gqa_bwd.add_op(bmm_dscaled_score_k);
gqa_bwd.add_op(bmm_dscaled_score_q);
gqa_bwd.add_op(bmm_dscaled_score_k);
gqa_bwd.add_op(reduce_dv);
gqa_bwd.add_op(reduce_dk);
if (dt != dt_inter) {
Expand Down
4 changes: 4 additions & 0 deletions src/common/memory_tracking.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ enum {
key_rnn_ptrs_wei_iter,
key_rnn_ptrs_wei_projection,
key_shuffle_precompute_transpose,
key_sdpa_Di,
key_sdpa_dQ_reduction,
key_sdpa_dK_reduction,
key_sdpa_dV_reduction,
key_softmax_dst_scales,
key_softmax_reduction,
key_softmax_interim_store,
Expand Down
6 changes: 6 additions & 0 deletions src/common/primitive_hashing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,7 @@ size_t get_desc_hash(const sdpa_desc_t &desc) {
size_t seed = 0;
// Kinds
seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
// Memory descriptors
seed = hash_combine(seed, get_md_hash(desc.q_desc));
seed = hash_combine(seed, get_md_hash(desc.k_desc));
Expand All @@ -743,7 +744,12 @@ size_t get_desc_hash(const sdpa_desc_t &desc) {
seed = hash_combine(seed, desc.kq_zero_points.get_hash());
seed = hash_combine(seed, desc.vs_scales.get_hash());
seed = hash_combine(seed, desc.vs_zero_points.get_hash());
seed = hash_combine(seed, get_md_hash(desc.dS_desc));
seed = hash_combine(seed, get_md_hash(desc.dst_desc));
seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
seed = hash_combine(seed, get_md_hash(desc.diff_q_desc));
seed = hash_combine(seed, get_md_hash(desc.diff_k_desc));
seed = hash_combine(seed, get_md_hash(desc.diff_v_desc));
seed = hash_combine(seed, get_md_hash(desc.attn_mask_desc));
seed = hash_combine(seed, get_md_hash(desc.scale_desc));
// Scale type
Expand Down
6 changes: 6 additions & 0 deletions src/common/primitive_serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,14 +581,20 @@ void serialize(serialization_stream_t &sstream, const sum_desc_t &desc) {
void serialize(serialization_stream_t &sstream, const sdpa_desc_t &desc) {
// Kind
sstream.append(desc.primitive_kind);
sstream.append(desc.prop_kind);
serialize(sstream, desc.q_desc);
serialize(sstream, desc.k_desc);
serialize(sstream, desc.v_desc);
desc.kq_scales.serialize(sstream);
desc.kq_zero_points.serialize(sstream);
desc.vs_scales.serialize(sstream);
desc.vs_zero_points.serialize(sstream);
serialize(sstream, desc.dS_desc);
serialize(sstream, desc.dst_desc);
serialize(sstream, desc.diff_dst_desc);
serialize(sstream, desc.diff_q_desc);
serialize(sstream, desc.diff_k_desc);
serialize(sstream, desc.diff_v_desc);
serialize(sstream, desc.attn_mask_desc);
serialize(sstream, desc.scale_desc);
sstream.append(desc.kq_acc_dt);
Expand Down
Loading
Loading