-
-
Notifications
You must be signed in to change notification settings - Fork 10.4k
Support encoder_only attention for FlexAttention #22273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
51ca8e1
Add encoder-only support to FlexAttention
maxdebayser d68e60b
Fix several small problems
maxdebayser 88e388d
Merge branch 'upstream_main' into flex_encoder_attn
maxdebayser 16f18c6
use vllm_runner
maxdebayser ed55ccc
Merge branch 'upstream_main' into flex_encoder_attn
maxdebayser File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -148,6 +148,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, | |
|
||
@dataclass | ||
class FlexAttentionMetadata: | ||
causal: bool | ||
num_actual_tokens: int # Number of tokens excluding padding. | ||
max_query_len: int | ||
query_start_loc: torch.Tensor | ||
|
@@ -177,10 +178,9 @@ class FlexAttentionMetadata: | |
num_blocks = 0 | ||
block_mask: Optional[BlockMask] = None | ||
score_mod: Optional[_score_mod_signature] = None | ||
mask_mod: Optional[_mask_mod_signature] = None | ||
logical_mask_mod: _mask_mod_signature = causal_mask_mod | ||
|
||
def get_mask_mod(self) -> _mask_mod_signature: | ||
def get_causal_mask_mod(self) -> _mask_mod_signature: | ||
"""Creates the mask_mod function for FlexAttention. | ||
|
||
This function creates the combined mask mod function that handles: | ||
|
@@ -233,14 +233,39 @@ def final_mask_mod( | |
|
||
return final_mask_mod | ||
|
||
def get_bidirectional_mask_mod(self) -> _mask_mod_signature: | ||
"""Creates the encoder mask_mod function for FlexAttention. | ||
|
||
Since the encoder bidirectional attention doesn't run with | ||
KV cache, this function creates a mask based on the | ||
packed query sequences. | ||
""" | ||
# Create a lookup mapping from query indices -> request number | ||
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc) | ||
|
||
def final_mask_mod( | ||
b: torch.Tensor, | ||
h: torch.Tensor, | ||
q_idx: torch.Tensor, | ||
kv_idx: torch.Tensor, | ||
) -> torch.Tensor: | ||
return request_lookup[q_idx] == request_lookup[kv_idx] | ||
|
||
return final_mask_mod | ||
|
||
def build_block_mask(self) -> BlockMask: | ||
assert self.mask_mod is not None | ||
if self.causal: | ||
mask_mod = self.get_causal_mask_mod() | ||
kv_len = self.total_cache_tokens | ||
else: | ||
mask_mod = self.get_bidirectional_mask_mod() | ||
kv_len = self.num_actual_tokens | ||
return create_block_mask_compiled( | ||
self.mask_mod, | ||
mask_mod, | ||
None, | ||
None, | ||
self.num_actual_tokens, | ||
self.total_cache_tokens, | ||
kv_len, | ||
device=self.block_table.device, | ||
) | ||
|
||
|
@@ -251,7 +276,6 @@ def __post_init__(self): | |
assert self.prefix_kv_lens is None, "Not implemented yet." | ||
assert self.suffix_kv_lens is None, "Not implemented yet." | ||
self.num_blocks = self.total_cache_tokens // self.block_size | ||
self.mask_mod = self.get_mask_mod() | ||
self.block_mask = self.build_block_mask() | ||
|
||
|
||
|
@@ -306,6 +330,7 @@ def build(self, | |
self.device, non_blocking=True) | ||
|
||
out = FlexAttentionMetadata( | ||
causal=common_attn_metadata.causal, | ||
num_actual_tokens=num_actual_tokens, | ||
max_query_len=max_query_len, | ||
query_start_loc=query_start_loc, | ||
|
@@ -350,6 +375,12 @@ def __init__( | |
self.head_size = head_size | ||
self.scale = float(scale) | ||
self.num_kv_heads = num_kv_heads | ||
self.attn_type = attn_type | ||
|
||
if attn_type not in (AttentionType.ENCODER_ONLY, | ||
AttentionType.DECODER): | ||
raise NotImplementedError( | ||
f"FlexAttention does not support {attn_type} attention") | ||
|
||
if alibi_slopes is not None: | ||
raise NotImplementedError( | ||
|
@@ -425,26 +456,38 @@ def forward( | |
|
||
num_actual_tokens = attn_metadata.num_actual_tokens | ||
|
||
key_cache, value_cache = kv_cache.unbind(0) | ||
|
||
torch.ops._C_cache_ops.reshape_and_cache_flash( | ||
key, | ||
value, | ||
key_cache, | ||
value_cache, | ||
attn_metadata.slot_mapping, | ||
self.kv_cache_dtype, | ||
layer._k_scale, | ||
layer._v_scale, | ||
) | ||
if not attn_metadata.causal: | ||
assert self.attn_type == AttentionType.ENCODER_ONLY | ||
|
||
query, key_tensor, value_tensor = map( | ||
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), | ||
(query, key, value), | ||
) | ||
|
||
else: | ||
assert self.attn_type == AttentionType.DECODER | ||
key_cache, value_cache = kv_cache.unbind(0) | ||
|
||
torch.ops._C_cache_ops.reshape_and_cache_flash( | ||
key, | ||
value, | ||
key_cache, | ||
value_cache, | ||
attn_metadata.slot_mapping, | ||
self.kv_cache_dtype, | ||
layer._k_scale, | ||
layer._v_scale, | ||
) | ||
|
||
# View out the block_size dim | ||
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size) | ||
value_cache = value_cache.view(-1, self.num_kv_heads, | ||
self.head_size) | ||
query, key_tensor, value_tensor = map( | ||
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), | ||
(query, key_cache, value_cache), | ||
) | ||
maxdebayser marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# View out the block_size dim | ||
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size) | ||
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size) | ||
query, key_cache, value_cache = map( | ||
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3), | ||
(query, key_cache, value_cache), | ||
) | ||
query = query[:, :, :num_actual_tokens, :] | ||
# Doesn't work for now -> constraint violation | ||
# torch._dynamo.try_mark_dynamic(query, 2) | ||
|
@@ -465,8 +508,8 @@ def forward( | |
|
||
out = flex_attention_compiled( | ||
query, | ||
key_cache, | ||
value_cache, | ||
key_tensor, | ||
value_tensor, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've renamed the variables here because in one case they are key and value and in the other case key_cache and key_value. |
||
attn_metadata.score_mod, | ||
attn_metadata.block_mask, | ||
self.scale, | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've reduced the max tokens a bit because as the sequence length growth the chance of divergence increases. On the A100 where I'm testing this, I get the following output on the main branch: