@@ -516,6 +516,55 @@ def maybe_create_nnx(einsum, *args):
516
516
self .AqtEinsum_2 = jnp .einsum
517
517
self .AqtEinsum_3 = jnp .einsum
518
518
519
+ if self .attention_kernel == "cudnn_flash_te" :
520
+ # These imports are only meant to work in a GPU build.
521
+ # pylint: disable=import-outside-toplevel
522
+ from transformer_engine .jax .flax .transformer import DotProductAttention # pytype: disabatch_sizeatch_sizele=import-error
523
+
524
+ using_context_parallelism = self .mesh .shape ["context" ] > 1
525
+
526
+ if self .attention_type == AttentionType .LOCAL_SLIDING and using_context_parallelism :
527
+ raise AssertionError ("Sliding window attention is not supported when context parallelism is enabled" )
528
+
529
+ sliding_window_size = None
530
+
531
+ if self .attention_type == AttentionType .LOCAL_SLIDING or not self .config .enable_padding_causal_mask :
532
+ sliding_window_size = [self .sliding_window_size , 0 ]
533
+
534
+ if self .attention_type == AttentionType .LOCAL_SLIDING or using_context_parallelism :
535
+ mask_type = "causal" # SWA and Context Parallelism only work with causal masking
536
+ dummy_attn_mask = None
537
+ else :
538
+ # generate attn_mask
539
+ mask_type = "padding_causal" # only padding_causal mask type can take a created mask
540
+ dummy_attn_mask = jnp .zeros ((1 , 1 , 1 , self .max_target_length , self .max_target_length ), dtype = jnp .uint8 )
541
+
542
+ dpa_layer = DotProductAttention (
543
+ head_dim = config .head_dim ,
544
+ num_attention_heads = self .num_query_heads ,
545
+ num_gqa_groups = self .num_kv_heads ,
546
+ attn_mask_type = mask_type , # 'no_mask', 'padding', 'causal', or 'padding_causal'
547
+ attn_bias_type = "no_bias" , # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
548
+ attention_dropout = self .dropout_rate ,
549
+ dropout_rng_name = "aqt" ,
550
+ dtype = self .dtype ,
551
+ float32_logits = self .float32_logits ,
552
+ qkv_layout = "BSHD_BSHD_BSHD" , # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
553
+ scale_factor = 1.0 ,
554
+ transpose_batch_sequence = False ,
555
+ window_size = sliding_window_size ,
556
+ context_parallel_causal_load_balanced = config .context_parallel_load_balance ,
557
+ context_parallel_axis = "context" ,
558
+ )
559
+
560
+ dpa_layer = nnx_wrappers .ToNNX (dpa_layer , rngs = rngs )
561
+ dummy_query_prefill = jnp .zeros ((1 , self .max_target_length , self .num_query_heads , config .head_dim ), dtype = self .dtype )
562
+ dummy_key_prefill = jnp .zeros ((1 , self .max_target_length , self .num_kv_heads , config .head_dim ), dtype = self .dtype )
563
+ dummy_value_prefill = jnp .zeros ((1 , self .max_target_length , self .num_kv_heads , config .head_dim ), dtype = self .dtype )
564
+
565
+ dpa_layer .lazy_init (dummy_query_prefill , dummy_key_prefill , dummy_value_prefill , mask = dummy_attn_mask )
566
+ self .dpa_layer = dpa_layer
567
+
519
568
def check_attention_inputs (self , query : Array , key : Array | KVTensor , value : Array | KVTensor ) -> None :
520
569
"""Check attention inputs."""
521
570
@@ -1096,48 +1145,23 @@ def cudnn_flash_attention(
1096
1145
1. Stable API, supports GQA, SWA (only with causal masking)
1097
1146
2. Head_dim = 256 is also supported from TE-1.12 stable release with CUDNN 12.6
1098
1147
"""
1099
- # These imports are only meant to work in a GPU build.
1100
- # pylint: disable=import-outside-toplevel
1101
- from transformer_engine .jax .flax .transformer import DotProductAttention # pytype: disable=import-error
1102
-
1103
1148
_ , _ , _ , head_dim = query .shape # pylint: disable=unused-variable
1104
1149
1105
1150
using_context_parallelism = self .mesh .shape ["context" ] > 1
1106
1151
1107
1152
if self .attention_type == AttentionType .LOCAL_SLIDING and using_context_parallelism :
1108
1153
raise AssertionError ("Sliding window attention is not supported when context parallelism is enabled" )
1109
1154
1110
- sliding_window_size = None
1111
-
1112
- if self .attention_type == AttentionType .LOCAL_SLIDING or not self .config .enable_padding_causal_mask :
1113
- sliding_window_size = [self .sliding_window_size , 0 ]
1114
-
1115
1155
if self .attention_type == AttentionType .LOCAL_SLIDING or using_context_parallelism :
1116
- mask_type = "causal" # SWA and Context Parallelism only work with causal masking
1156
+ # SWA and Context Parallelism only work with causal masking
1117
1157
attn_mask = None
1118
1158
else :
1119
1159
# generate attn_mask
1120
- mask_type = "padding_causal" # only padding_causal mask type can take a created mask
1121
1160
attn_mask = self .generate_attention_mask (query , key , decoder_segment_ids , model_mode )
1122
1161
1123
- dpa_layer = DotProductAttention (
1124
- head_dim = head_dim ,
1125
- num_attention_heads = self .num_query_heads ,
1126
- num_gqa_groups = self .num_kv_heads ,
1127
- attn_mask_type = mask_type , # 'no_mask', 'padding', 'causal', or 'padding_causal'
1128
- attn_bias_type = "no_bias" , # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
1129
- attention_dropout = self .dropout_rate ,
1130
- dropout_rng_name = "aqt" ,
1131
- dtype = self .dtype ,
1132
- float32_logits = self .float32_logits ,
1133
- qkv_layout = "BSHD_BSHD_BSHD" , # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
1134
- scale_factor = 1.0 ,
1135
- transpose_batch_sequence = False ,
1136
- window_size = sliding_window_size ,
1137
- context_parallel_causal_load_balanced = self .config .context_parallel_load_balance ,
1138
- context_parallel_axis = "context" ,
1139
- )
1140
- return dpa_layer (query , key , value , mask = attn_mask )
1162
+ if attn_mask is not None :
1163
+ attn_mask = jnp .where ((attn_mask >= DEFAULT_MASK_VALUE * 0.5 ), 0 , 1 )
1164
+ return self .dpa_layer (query , key , value , mask = attn_mask )
1141
1165
1142
1166
def cudnn_jax_flash_attention (
1143
1167
self ,
@@ -1354,9 +1378,7 @@ def qk_product(
1354
1378
raise NotImplementedError (self .compute_axis_order )
1355
1379
return result
1356
1380
1357
- def wv_product (
1358
- self , attn_weights : Array , value : Array | KVTensor , model_mode : str , einsum : Callable [..., Array ]
1359
- ) -> Array :
1381
+ def wv_product (self , attn_weights : Array , value : Array | KVTensor , model_mode : str , einsum : Callable [..., Array ]) -> Array :
1360
1382
"""weighted value product.
1361
1383
1362
1384
Args:
0 commit comments