@@ -148,6 +148,7 @@ def causal_mask_mod(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor,
148
148
149
149
@dataclass
150
150
class FlexAttentionMetadata :
151
+ causal : bool
151
152
num_actual_tokens : int # Number of tokens excluding padding.
152
153
max_query_len : int
153
154
query_start_loc : torch .Tensor
@@ -177,10 +178,9 @@ class FlexAttentionMetadata:
177
178
num_blocks = 0
178
179
block_mask : Optional [BlockMask ] = None
179
180
score_mod : Optional [_score_mod_signature ] = None
180
- mask_mod : Optional [_mask_mod_signature ] = None
181
181
logical_mask_mod : _mask_mod_signature = causal_mask_mod
182
182
183
- def get_mask_mod (self ) -> _mask_mod_signature :
183
+ def get_causal_mask_mod (self ) -> _mask_mod_signature :
184
184
"""Creates the mask_mod function for FlexAttention.
185
185
186
186
This function creates the combined mask mod function that handles:
@@ -233,14 +233,39 @@ def final_mask_mod(
233
233
234
234
return final_mask_mod
235
235
236
+ def get_bidirectional_mask_mod (self ) -> _mask_mod_signature :
237
+ """Creates the encoder mask_mod function for FlexAttention.
238
+
239
+ Since the encoder bidirectional attention doesn't run with
240
+ KV cache, this function creates a mask based on the
241
+ packed query sequences.
242
+ """
243
+ # Create a lookup mapping from query indices -> request number
244
+ request_lookup = _offsets_to_doc_ids_tensor (self .query_start_loc )
245
+
246
+ def final_mask_mod (
247
+ b : torch .Tensor ,
248
+ h : torch .Tensor ,
249
+ q_idx : torch .Tensor ,
250
+ kv_idx : torch .Tensor ,
251
+ ) -> torch .Tensor :
252
+ return request_lookup [q_idx ] == request_lookup [kv_idx ]
253
+
254
+ return final_mask_mod
255
+
236
256
def build_block_mask (self ) -> BlockMask :
237
- assert self .mask_mod is not None
257
+ if self .causal :
258
+ mask_mod = self .get_causal_mask_mod ()
259
+ kv_len = self .total_cache_tokens
260
+ else :
261
+ mask_mod = self .get_bidirectional_mask_mod ()
262
+ kv_len = self .num_actual_tokens
238
263
return create_block_mask_compiled (
239
- self . mask_mod ,
264
+ mask_mod ,
240
265
None ,
241
266
None ,
242
267
self .num_actual_tokens ,
243
- self . total_cache_tokens ,
268
+ kv_len ,
244
269
device = self .block_table .device ,
245
270
)
246
271
@@ -251,7 +276,6 @@ def __post_init__(self):
251
276
assert self .prefix_kv_lens is None , "Not implemented yet."
252
277
assert self .suffix_kv_lens is None , "Not implemented yet."
253
278
self .num_blocks = self .total_cache_tokens // self .block_size
254
- self .mask_mod = self .get_mask_mod ()
255
279
self .block_mask = self .build_block_mask ()
256
280
257
281
@@ -306,6 +330,7 @@ def build(self,
306
330
self .device , non_blocking = True )
307
331
308
332
out = FlexAttentionMetadata (
333
+ causal = common_attn_metadata .causal ,
309
334
num_actual_tokens = num_actual_tokens ,
310
335
max_query_len = max_query_len ,
311
336
query_start_loc = query_start_loc ,
@@ -350,6 +375,12 @@ def __init__(
350
375
self .head_size = head_size
351
376
self .scale = float (scale )
352
377
self .num_kv_heads = num_kv_heads
378
+ self .attn_type = attn_type
379
+
380
+ if attn_type not in (AttentionType .ENCODER_ONLY ,
381
+ AttentionType .DECODER ):
382
+ raise NotImplementedError (
383
+ f"FlexAttention does not support { attn_type } attention" )
353
384
354
385
if alibi_slopes is not None :
355
386
raise NotImplementedError (
@@ -425,26 +456,38 @@ def forward(
425
456
426
457
num_actual_tokens = attn_metadata .num_actual_tokens
427
458
428
- key_cache , value_cache = kv_cache .unbind (0 )
429
-
430
- torch .ops ._C_cache_ops .reshape_and_cache_flash (
431
- key ,
432
- value ,
433
- key_cache ,
434
- value_cache ,
435
- attn_metadata .slot_mapping ,
436
- self .kv_cache_dtype ,
437
- layer ._k_scale ,
438
- layer ._v_scale ,
439
- )
459
+ if not attn_metadata .causal :
460
+ assert self .attn_type == AttentionType .ENCODER_ONLY
461
+
462
+ query , key_tensor , value_tensor = map (
463
+ lambda x : self .view_as_4d (x ).permute (0 , 2 , 1 , 3 ),
464
+ (query , key , value ),
465
+ )
466
+
467
+ else :
468
+ assert self .attn_type == AttentionType .DECODER
469
+ key_cache , value_cache = kv_cache .unbind (0 )
470
+
471
+ torch .ops ._C_cache_ops .reshape_and_cache_flash (
472
+ key ,
473
+ value ,
474
+ key_cache ,
475
+ value_cache ,
476
+ attn_metadata .slot_mapping ,
477
+ self .kv_cache_dtype ,
478
+ layer ._k_scale ,
479
+ layer ._v_scale ,
480
+ )
481
+
482
+ # View out the block_size dim
483
+ key_cache = key_cache .view (- 1 , self .num_kv_heads , self .head_size )
484
+ value_cache = value_cache .view (- 1 , self .num_kv_heads ,
485
+ self .head_size )
486
+ query , key_tensor , value_tensor = map (
487
+ lambda x : self .view_as_4d (x ).permute (0 , 2 , 1 , 3 ),
488
+ (query , key_cache , value_cache ),
489
+ )
440
490
441
- # View out the block_size dim
442
- key_cache = key_cache .view (- 1 , self .num_kv_heads , self .head_size )
443
- value_cache = value_cache .view (- 1 , self .num_kv_heads , self .head_size )
444
- query , key_cache , value_cache = map (
445
- lambda x : self .view_as_4d (x ).permute (0 , 2 , 1 , 3 ),
446
- (query , key_cache , value_cache ),
447
- )
448
491
query = query [:, :, :num_actual_tokens , :]
449
492
# Doesn't work for now -> constraint violation
450
493
# torch._dynamo.try_mark_dynamic(query, 2)
@@ -465,8 +508,8 @@ def forward(
465
508
466
509
out = flex_attention_compiled (
467
510
query ,
468
- key_cache ,
469
- value_cache ,
511
+ key_tensor ,
512
+ value_tensor ,
470
513
attn_metadata .score_mod ,
471
514
attn_metadata .block_mask ,
472
515
self .scale ,
0 commit comments