Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/common/memory_tracking.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ enum {
key_rnn_ptrs_wei_iter,
key_rnn_ptrs_wei_projection,
key_shuffle_precompute_transpose,
key_sdpa_Di,
key_sdpa_dQ_reduction,
key_sdpa_dK_reduction,
key_sdpa_dV_reduction,
key_softmax_dst_scales,
key_softmax_reduction,
key_softmax_interim_store,
Expand Down
6 changes: 6 additions & 0 deletions src/common/primitive_hashing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,7 @@ size_t get_desc_hash(const sdpa_desc_t &desc) {
size_t seed = 0;
// Kinds
seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
// Memory descriptors
seed = hash_combine(seed, get_md_hash(desc.q_desc));
seed = hash_combine(seed, get_md_hash(desc.k_desc));
Expand All @@ -742,7 +743,12 @@ size_t get_desc_hash(const sdpa_desc_t &desc) {
seed = hash_combine(seed, desc.kq_zero_points.get_hash());
seed = hash_combine(seed, desc.vs_scales.get_hash());
seed = hash_combine(seed, desc.vs_zero_points.get_hash());
seed = hash_combine(seed, get_md_hash(desc.dS_desc));
seed = hash_combine(seed, get_md_hash(desc.dst_desc));
seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
seed = hash_combine(seed, get_md_hash(desc.diff_q_desc));
seed = hash_combine(seed, get_md_hash(desc.diff_k_desc));
seed = hash_combine(seed, get_md_hash(desc.diff_v_desc));
seed = hash_combine(seed, get_md_hash(desc.attn_mask_desc));
seed = hash_combine(seed, get_md_hash(desc.scale_desc));
// Scale type
Expand Down
6 changes: 6 additions & 0 deletions src/common/primitive_serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -581,14 +581,20 @@ void serialize(serialization_stream_t &sstream, const sum_desc_t &desc) {
void serialize(serialization_stream_t &sstream, const sdpa_desc_t &desc) {
// Kind
sstream.append(desc.primitive_kind);
sstream.append(desc.prop_kind);
serialize(sstream, desc.q_desc);
serialize(sstream, desc.k_desc);
serialize(sstream, desc.v_desc);
desc.kq_scales.serialize(sstream);
desc.kq_zero_points.serialize(sstream);
desc.vs_scales.serialize(sstream);
desc.vs_zero_points.serialize(sstream);
serialize(sstream, desc.dS_desc);
serialize(sstream, desc.dst_desc);
serialize(sstream, desc.diff_dst_desc);
serialize(sstream, desc.diff_q_desc);
serialize(sstream, desc.diff_k_desc);
serialize(sstream, desc.diff_v_desc);
serialize(sstream, desc.attn_mask_desc);
serialize(sstream, desc.scale_desc);
sstream.append(desc.kq_acc_dt);
Expand Down
Loading
Loading