Skip to content

Conversation

@vaibverm
Copy link

Objective:

This PR introduces the KV blocking technique for CausalLM models where the K/V cache is read and processed block by block in the attention computation. Number of desired KV blocks are defined at model initialization in the "from_pretrained" call to export the ONNX with required number of KV blocks. As a result, the following changes are introduced:

Changes:

  1. SoftMax needs to be changed from regular SoftMax to online SoftMax where the running maximum and cumulative denominators are tracked and updated once each block is processed to retain mathematical accuracy compared to regular SoftMax.
  2. Changes to CTXGather and CTXGatherCB custom ops to read only 1 block worth of data in each cache gather/read.
  3. Changes to read_only function in QEffDynamicCache to allow reading of a cache block by block rather than full K/V cache.
  4. Generation of attention mask per block.
  5. Changes to eager_attention_forward implementation in the llama model to allow BlockedKV attention and online SoftMax implementation.
  6. Wrapping the num_kv_blocks variable inside qaic_config to keep consistent calling style.
  7. A new PyTorch transform to pass the num_kv_blocks variable to QEffLlamaAttention block.
  8. A new constant added for num_kv_blocks.
  9. Added tests to switch the BlockedKV feature on and off.

Please review and feel free to suggest changes and tests.

@vbaddi
Copy link
Contributor

vbaddi commented Nov 14, 2025

Thanks @vaibverm
Could you please address the conflicts and run the lint/format?

@vaibverm vaibverm force-pushed the main branch 3 times, most recently from 5997515 to 4e817c2 Compare November 14, 2025 08:05
@vaibverm
Copy link
Author

Hi @vbaddi,
I have addressed the conflicts but some workflows need approval. Would you be able to approve those?

K_block_states = repeat_kv(K_block, module.num_key_value_groups)
V_block_states = repeat_kv(V_block, module.num_key_value_groups)
past_seen_tokens_start = start_index
past_seen_tokens_end = torch.where(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are comparing int do we need torch.where? can we use min()?

Copy link
Author

@vaibverm vaibverm Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.min() requires both inputs to be tensors. I tried torch.min() before but using torch.min() here leads to ONNX export time error.


# Compute attention scores for the block
attn_weights_block = torch.matmul(query, K_block_states.transpose(2, 3)) * scaling
if attention_mask is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we are not using the attention_mask do we need this condition?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a causal model, so we use masking. The reason I kept attention_mask instead of the causal_mask_block in the condition was because earlier I was using a common method for both eager and blockedKV attention. I would suggest using the original condition testing if the overhead is not too high for compatibility reasons with regular eager attention.

repl_module = type(module)
module.__class__ = repl_module
module.forward = MethodType(partial(repl_module.forward, num_kv_blocks=num_kv_blocks), module)
transformed = True # Set to True if at least one transformation occurs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a warning if the arcitecture doesnt support blocked KV

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to have a broader discussion on this. I implemented blockedKV attention for Qwen2.5_VL model as well and the norm there was to use environment variable to switch between different blocking techniques. Is that the norm we want to keep across QEff? If yes, then we will not really need this transform anymore although I think using the PyTorch transform was a cleaner way to switch between different blocking techniques and consistent with current transform usage in QEff.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@quic-rishinr - Would you suggest we should go the route of using environment variables? Or would you prefer PyTorch transforms like above to implement the blocking?

)


@pytest.mark.parametrize("model_name", test_models_blockedKV)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add @pytest.mark.on_qaic on the test as both tests would be using the qaic cards

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants