From 16ad76e94a98fed897ad0487ff6e15ecb8bf4de5 Mon Sep 17 00:00:00 2001 From: hsuan-lun Date: Tue, 19 Aug 2025 05:23:26 +0000 Subject: [PATCH] Migrate DotProductAttention to NNX --- MaxText/layers/attention_op.py | 86 +++++++++++++++++++++------------- 1 file changed, 54 insertions(+), 32 deletions(-) diff --git a/MaxText/layers/attention_op.py b/MaxText/layers/attention_op.py index 1e223f5ee8..185b60ed9c 100644 --- a/MaxText/layers/attention_op.py +++ b/MaxText/layers/attention_op.py @@ -516,6 +516,55 @@ def maybe_create_nnx(einsum, *args): self.AqtEinsum_2 = jnp.einsum self.AqtEinsum_3 = jnp.einsum + if self.attention_kernel == "cudnn_flash_te": + # These imports are only meant to work in a GPU build. + # pylint: disable=import-outside-toplevel + from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disabatch_sizeatch_sizele=import-error + + using_context_parallelism = self.mesh.shape["context"] > 1 + + if self.attention_type == AttentionType.LOCAL_SLIDING and using_context_parallelism: + raise AssertionError("Sliding window attention is not supported when context parallelism is enabled") + + sliding_window_size = None + + if self.attention_type == AttentionType.LOCAL_SLIDING or not self.config.enable_padding_causal_mask: + sliding_window_size = [self.sliding_window_size, 0] + + if self.attention_type == AttentionType.LOCAL_SLIDING or using_context_parallelism: + mask_type = "causal" # SWA and Context Parallelism only work with causal masking + dummy_attn_mask = None + else: + # generate attn_mask + mask_type = "padding_causal" # only padding_causal mask type can take a created mask + dummy_attn_mask = jnp.zeros((1, 1, 1, self.max_target_length, self.max_target_length), dtype=jnp.uint8) + + dpa_layer = DotProductAttention( + head_dim=config.head_dim, + num_attention_heads=self.num_query_heads, + num_gqa_groups=self.num_kv_heads, + attn_mask_type=mask_type, # 'no_mask', 'padding', 'causal', or 'padding_causal' + attn_bias_type="no_bias", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + attention_dropout=self.dropout_rate, + dropout_rng_name="aqt", + dtype=self.dtype, + float32_logits=self.float32_logits, + qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + scale_factor=1.0, + transpose_batch_sequence=False, + window_size=sliding_window_size, + context_parallel_causal_load_balanced=config.context_parallel_load_balance, + context_parallel_axis="context", + ) + + dpa_layer = nnx_wrappers.ToNNX(dpa_layer, rngs=rngs) + dummy_query_prefill = jnp.zeros((1, self.max_target_length, self.num_query_heads, config.head_dim), dtype=self.dtype) + dummy_key_prefill = jnp.zeros((1, self.max_target_length, self.num_kv_heads, config.head_dim), dtype=self.dtype) + dummy_value_prefill = jnp.zeros((1, self.max_target_length, self.num_kv_heads, config.head_dim), dtype=self.dtype) + + dpa_layer.lazy_init(dummy_query_prefill, dummy_key_prefill, dummy_value_prefill, mask=dummy_attn_mask) + self.dpa_layer = dpa_layer + def check_attention_inputs(self, query: Array, key: Array | KVTensor, value: Array | KVTensor) -> None: """Check attention inputs.""" @@ -1096,10 +1145,6 @@ def cudnn_flash_attention( 1. Stable API, supports GQA, SWA (only with causal masking) 2. Head_dim = 256 is also supported from TE-1.12 stable release with CUDNN 12.6 """ - # These imports are only meant to work in a GPU build. - # pylint: disable=import-outside-toplevel - from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error - _, _, _, head_dim = query.shape # pylint: disable=unused-variable using_context_parallelism = self.mesh.shape["context"] > 1 @@ -1107,37 +1152,16 @@ def cudnn_flash_attention( if self.attention_type == AttentionType.LOCAL_SLIDING and using_context_parallelism: raise AssertionError("Sliding window attention is not supported when context parallelism is enabled") - sliding_window_size = None - - if self.attention_type == AttentionType.LOCAL_SLIDING or not self.config.enable_padding_causal_mask: - sliding_window_size = [self.sliding_window_size, 0] - if self.attention_type == AttentionType.LOCAL_SLIDING or using_context_parallelism: - mask_type = "causal" # SWA and Context Parallelism only work with causal masking + # SWA and Context Parallelism only work with causal masking attn_mask = None else: # generate attn_mask - mask_type = "padding_causal" # only padding_causal mask type can take a created mask attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) - dpa_layer = DotProductAttention( - head_dim=head_dim, - num_attention_heads=self.num_query_heads, - num_gqa_groups=self.num_kv_heads, - attn_mask_type=mask_type, # 'no_mask', 'padding', 'causal', or 'padding_causal' - attn_bias_type="no_bias", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' - attention_dropout=self.dropout_rate, - dropout_rng_name="aqt", - dtype=self.dtype, - float32_logits=self.float32_logits, - qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' - scale_factor=1.0, - transpose_batch_sequence=False, - window_size=sliding_window_size, - context_parallel_causal_load_balanced=self.config.context_parallel_load_balance, - context_parallel_axis="context", - ) - return dpa_layer(query, key, value, mask=attn_mask) + if attn_mask is not None: + attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1) + return self.dpa_layer(query, key, value, mask=attn_mask) def cudnn_jax_flash_attention( self, @@ -1354,9 +1378,7 @@ def qk_product( raise NotImplementedError(self.compute_axis_order) return result - def wv_product( - self, attn_weights: Array, value: Array | KVTensor, model_mode: str, einsum: Callable[..., Array] - ) -> Array: + def wv_product(self, attn_weights: Array, value: Array | KVTensor, model_mode: str, einsum: Callable[..., Array]) -> Array: """weighted value product. Args: