-
Notifications
You must be signed in to change notification settings - Fork 60
Adding support for BlockedKV attention in CasualLM models #618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Thanks @vaibverm |
5997515 to
4e817c2
Compare
|
Hi @vbaddi, |
| K_block_states = repeat_kv(K_block, module.num_key_value_groups) | ||
| V_block_states = repeat_kv(V_block, module.num_key_value_groups) | ||
| past_seen_tokens_start = start_index | ||
| past_seen_tokens_end = torch.where( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we are comparing int do we need torch.where? can we use min()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.min() requires both inputs to be tensors. I tried torch.min() before but using torch.min() here leads to ONNX export time error.
|
|
||
| # Compute attention scores for the block | ||
| attn_weights_block = torch.matmul(query, K_block_states.transpose(2, 3)) * scaling | ||
| if attention_mask is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we are not using the attention_mask do we need this condition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a causal model, so we use masking. The reason I kept attention_mask instead of the causal_mask_block in the condition was because earlier I was using a common method for both eager and blockedKV attention. I would suggest using the original condition testing if the overhead is not too high for compatibility reasons with regular eager attention.
| repl_module = type(module) | ||
| module.__class__ = repl_module | ||
| module.forward = MethodType(partial(repl_module.forward, num_kv_blocks=num_kv_blocks), module) | ||
| transformed = True # Set to True if at least one transformation occurs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a warning if the arcitecture doesnt support blocked KV
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to have a broader discussion on this. I implemented blockedKV attention for Qwen2.5_VL model as well and the norm there was to use environment variable to switch between different blocking techniques. Is that the norm we want to keep across QEff? If yes, then we will not really need this transform anymore although I think using the PyTorch transform was a cleaner way to switch between different blocking techniques and consistent with current transform usage in QEff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@quic-rishinr - Would you suggest we should go the route of using environment variables? Or would you prefer PyTorch transforms like above to implement the blocking?
Signed-off-by: Vaibhav Verma <[email protected]>
Signed-off-by: Vaibhav Verma <[email protected]>
Signed-off-by: Vaibhav Verma <[email protected]>
… number indices Signed-off-by: Vaibhav Verma <[email protected]>
Signed-off-by: Vaibhav Verma <[email protected]>
Signed-off-by: Vaibhav Verma <[email protected]>
Signed-off-by: Vaibhav Verma <[email protected]>
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("model_name", test_models_blockedKV) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add @pytest.mark.on_qaic on the test as both tests would be using the qaic cards
Signed-off-by: Vaibhav Verma <[email protected]>
…llama.py Signed-off-by: Vaibhav Verma <[email protected]>
…ling_llama.py Signed-off-by: Vaibhav Verma <[email protected]>
Objective:
This PR introduces the KV blocking technique for CausalLM models where the K/V cache is read and processed block by block in the attention computation. Number of desired KV blocks are defined at model initialization in the "from_pretrained" call to export the ONNX with required number of KV blocks. As a result, the following changes are introduced:
Changes:
Please review and feel free to suggest changes and tests.