From c06496eaf22b7865950fff0a4895d99dba3145cd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 14 Oct 2025 10:01:11 -0700 Subject: [PATCH 1/5] Allow present_key to be empty when past_key is provided in Attention Signed-off-by: Justin Chu --- onnxruntime/core/providers/cpu/llm/attention.cc | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc index 7ac19bf67fb8a..fc7acb20dc0ff 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.cc +++ b/onnxruntime/core/providers/cpu/llm/attention.cc @@ -199,13 +199,10 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, T* output_qk, // Q*K output ThreadPool* tp, AllocatorPtr allocator) const { - // The case past_key != nullptr and present_key == nullptr is not supported. - // We use the fact present_key is requested to avoid any extra allocation. - // However, if present_key is not requested, we should avoid allocated more memory than needed but that mean - // allocating one buffer per thread. That's why the implementation is not done. - // The user should define a model with a present_key even if not used if past_key is not null. - ORT_ENFORCE((past_key == nullptr) == (present_key == nullptr), - "The implementation only supports past_key and present_key both null or both not null."); + if (present_key != nullptr) { + ORT_ENFORCE(past_key != nullptr, "past_key must be provided when present_key is requested."); + } + const size_t past_chunk_length = static_cast(parameters.past_sequence_length) * parameters.head_size; // P x H const size_t q_input_chunk_length = static_cast(parameters.q_sequence_length) * parameters.head_size; // S x H const size_t k_input_chunk_length = static_cast(parameters.kv_sequence_length) * parameters.head_size; // L x H From f6b7bffd95f77cb0c2e67f8b5503a8a33ec70df7 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 14 Oct 2025 11:32:34 -0700 Subject: [PATCH 2/5] Remove check Signed-off-by: Justin Chu --- onnxruntime/core/providers/cpu/llm/attention.cc | 4 ---- 1 file changed, 4 deletions(-) diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc index fc7acb20dc0ff..aa2d30cfff8f2 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.cc +++ b/onnxruntime/core/providers/cpu/llm/attention.cc @@ -199,10 +199,6 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, T* output_qk, // Q*K output ThreadPool* tp, AllocatorPtr allocator) const { - if (present_key != nullptr) { - ORT_ENFORCE(past_key != nullptr, "past_key must be provided when present_key is requested."); - } - const size_t past_chunk_length = static_cast(parameters.past_sequence_length) * parameters.head_size; // P x H const size_t q_input_chunk_length = static_cast(parameters.q_sequence_length) * parameters.head_size; // S x H const size_t k_input_chunk_length = static_cast(parameters.kv_sequence_length) * parameters.head_size; // L x H From 348887afe44077cd25a4b91af3e81c593e9dd836 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 14 Oct 2025 11:38:59 -0700 Subject: [PATCH 3/5] Updated Signed-off-by: Justin Chu --- onnxruntime/core/providers/cpu/llm/attention.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc index aa2d30cfff8f2..e3ddc6461a301 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.cc +++ b/onnxruntime/core/providers/cpu/llm/attention.cc @@ -199,6 +199,13 @@ void AttentionBase::ComputeAttentionProbs(T* attention_probs, T* output_qk, // Q*K output ThreadPool* tp, AllocatorPtr allocator) const { + // The case past_key != nullptr and present_key == nullptr is not supported. + // We use the fact present_key is requested to avoid any extra allocation. + // However, if present_key is not requested, we should avoid allocated more memory than needed but that mean + // allocating one buffer per thread. That's why the implementation is not done. + // The user should define a model with a present_key even if not used if past_key is not null. + ORT_ENFORCE(!((past_key != nullptr) && (present_key == nullptr)), + "The implementation does not support past_key provided and present_key being null."); const size_t past_chunk_length = static_cast(parameters.past_sequence_length) * parameters.head_size; // P x H const size_t q_input_chunk_length = static_cast(parameters.q_sequence_length) * parameters.head_size; // S x H const size_t k_input_chunk_length = static_cast(parameters.kv_sequence_length) * parameters.head_size; // L x H From 9df401e1d9c91b1db4b94a5b3921ad2450929991 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 14 Oct 2025 13:21:43 -0700 Subject: [PATCH 4/5] try this Signed-off-by: Justin Chu --- onnxruntime/core/providers/cpu/llm/attention.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc index e3ddc6461a301..687dd5546865b 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.cc +++ b/onnxruntime/core/providers/cpu/llm/attention.cc @@ -529,8 +529,8 @@ void AttentionBase::ComputeVxAttentionScore(T* output, // bu T* present_value, // present value only (if not using present state) bool transpose_output, // whether to transpose the output (0, 2, 1, 3) ThreadPool* tp) const { - ORT_ENFORCE((past_value == nullptr) == (present_value == nullptr), - "The implementation only supports past_value and present_value both null or both not null."); + ORT_ENFORCE(!((past_key != nullptr) && (present_key == nullptr)), + "The implementation does not support past_key provided and present_key being null."); const ptrdiff_t past_chunk_length = SafeInt(past_sequence_length) * v_head_size; // P x H_v const ptrdiff_t v_input_chunk_length = SafeInt(kv_sequence_length) * v_head_size; // L x H_v const ptrdiff_t present_chunk_length = past_chunk_length + v_input_chunk_length; // T x H_v From f03d6d111b85c04721b98f170f0edf0439ca93fd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Tue, 14 Oct 2025 13:29:31 -0700 Subject: [PATCH 5/5] past_value Signed-off-by: Justin Chu --- onnxruntime/core/providers/cpu/llm/attention.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/llm/attention.cc b/onnxruntime/core/providers/cpu/llm/attention.cc index 687dd5546865b..4238624c9e48d 100644 --- a/onnxruntime/core/providers/cpu/llm/attention.cc +++ b/onnxruntime/core/providers/cpu/llm/attention.cc @@ -529,8 +529,8 @@ void AttentionBase::ComputeVxAttentionScore(T* output, // bu T* present_value, // present value only (if not using present state) bool transpose_output, // whether to transpose the output (0, 2, 1, 3) ThreadPool* tp) const { - ORT_ENFORCE(!((past_key != nullptr) && (present_key == nullptr)), - "The implementation does not support past_key provided and present_key being null."); + ORT_ENFORCE(!((past_value != nullptr) && (present_value == nullptr)), + "The implementation does not support past_value provided and present_value being null."); const ptrdiff_t past_chunk_length = SafeInt(past_sequence_length) * v_head_size; // P x H_v const ptrdiff_t v_input_chunk_length = SafeInt(kv_sequence_length) * v_head_size; // L x H_v const ptrdiff_t present_chunk_length = past_chunk_length + v_input_chunk_length; // T x H_v