Skip to content

Commit 16ad76e

Browse files
Migrate DotProductAttention to NNX
1 parent aada945 commit 16ad76e

File tree

1 file changed

+54
-32
lines changed

1 file changed

+54
-32
lines changed

MaxText/layers/attention_op.py

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,55 @@ def maybe_create_nnx(einsum, *args):
516516
self.AqtEinsum_2 = jnp.einsum
517517
self.AqtEinsum_3 = jnp.einsum
518518

519+
if self.attention_kernel == "cudnn_flash_te":
520+
# These imports are only meant to work in a GPU build.
521+
# pylint: disable=import-outside-toplevel
522+
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disabatch_sizeatch_sizele=import-error
523+
524+
using_context_parallelism = self.mesh.shape["context"] > 1
525+
526+
if self.attention_type == AttentionType.LOCAL_SLIDING and using_context_parallelism:
527+
raise AssertionError("Sliding window attention is not supported when context parallelism is enabled")
528+
529+
sliding_window_size = None
530+
531+
if self.attention_type == AttentionType.LOCAL_SLIDING or not self.config.enable_padding_causal_mask:
532+
sliding_window_size = [self.sliding_window_size, 0]
533+
534+
if self.attention_type == AttentionType.LOCAL_SLIDING or using_context_parallelism:
535+
mask_type = "causal" # SWA and Context Parallelism only work with causal masking
536+
dummy_attn_mask = None
537+
else:
538+
# generate attn_mask
539+
mask_type = "padding_causal" # only padding_causal mask type can take a created mask
540+
dummy_attn_mask = jnp.zeros((1, 1, 1, self.max_target_length, self.max_target_length), dtype=jnp.uint8)
541+
542+
dpa_layer = DotProductAttention(
543+
head_dim=config.head_dim,
544+
num_attention_heads=self.num_query_heads,
545+
num_gqa_groups=self.num_kv_heads,
546+
attn_mask_type=mask_type, # 'no_mask', 'padding', 'causal', or 'padding_causal'
547+
attn_bias_type="no_bias", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
548+
attention_dropout=self.dropout_rate,
549+
dropout_rng_name="aqt",
550+
dtype=self.dtype,
551+
float32_logits=self.float32_logits,
552+
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
553+
scale_factor=1.0,
554+
transpose_batch_sequence=False,
555+
window_size=sliding_window_size,
556+
context_parallel_causal_load_balanced=config.context_parallel_load_balance,
557+
context_parallel_axis="context",
558+
)
559+
560+
dpa_layer = nnx_wrappers.ToNNX(dpa_layer, rngs=rngs)
561+
dummy_query_prefill = jnp.zeros((1, self.max_target_length, self.num_query_heads, config.head_dim), dtype=self.dtype)
562+
dummy_key_prefill = jnp.zeros((1, self.max_target_length, self.num_kv_heads, config.head_dim), dtype=self.dtype)
563+
dummy_value_prefill = jnp.zeros((1, self.max_target_length, self.num_kv_heads, config.head_dim), dtype=self.dtype)
564+
565+
dpa_layer.lazy_init(dummy_query_prefill, dummy_key_prefill, dummy_value_prefill, mask=dummy_attn_mask)
566+
self.dpa_layer = dpa_layer
567+
519568
def check_attention_inputs(self, query: Array, key: Array | KVTensor, value: Array | KVTensor) -> None:
520569
"""Check attention inputs."""
521570

@@ -1096,48 +1145,23 @@ def cudnn_flash_attention(
10961145
1. Stable API, supports GQA, SWA (only with causal masking)
10971146
2. Head_dim = 256 is also supported from TE-1.12 stable release with CUDNN 12.6
10981147
"""
1099-
# These imports are only meant to work in a GPU build.
1100-
# pylint: disable=import-outside-toplevel
1101-
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error
1102-
11031148
_, _, _, head_dim = query.shape # pylint: disable=unused-variable
11041149

11051150
using_context_parallelism = self.mesh.shape["context"] > 1
11061151

11071152
if self.attention_type == AttentionType.LOCAL_SLIDING and using_context_parallelism:
11081153
raise AssertionError("Sliding window attention is not supported when context parallelism is enabled")
11091154

1110-
sliding_window_size = None
1111-
1112-
if self.attention_type == AttentionType.LOCAL_SLIDING or not self.config.enable_padding_causal_mask:
1113-
sliding_window_size = [self.sliding_window_size, 0]
1114-
11151155
if self.attention_type == AttentionType.LOCAL_SLIDING or using_context_parallelism:
1116-
mask_type = "causal" # SWA and Context Parallelism only work with causal masking
1156+
# SWA and Context Parallelism only work with causal masking
11171157
attn_mask = None
11181158
else:
11191159
# generate attn_mask
1120-
mask_type = "padding_causal" # only padding_causal mask type can take a created mask
11211160
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
11221161

1123-
dpa_layer = DotProductAttention(
1124-
head_dim=head_dim,
1125-
num_attention_heads=self.num_query_heads,
1126-
num_gqa_groups=self.num_kv_heads,
1127-
attn_mask_type=mask_type, # 'no_mask', 'padding', 'causal', or 'padding_causal'
1128-
attn_bias_type="no_bias", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
1129-
attention_dropout=self.dropout_rate,
1130-
dropout_rng_name="aqt",
1131-
dtype=self.dtype,
1132-
float32_logits=self.float32_logits,
1133-
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
1134-
scale_factor=1.0,
1135-
transpose_batch_sequence=False,
1136-
window_size=sliding_window_size,
1137-
context_parallel_causal_load_balanced=self.config.context_parallel_load_balance,
1138-
context_parallel_axis="context",
1139-
)
1140-
return dpa_layer(query, key, value, mask=attn_mask)
1162+
if attn_mask is not None:
1163+
attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1)
1164+
return self.dpa_layer(query, key, value, mask=attn_mask)
11411165

11421166
def cudnn_jax_flash_attention(
11431167
self,
@@ -1354,9 +1378,7 @@ def qk_product(
13541378
raise NotImplementedError(self.compute_axis_order)
13551379
return result
13561380

1357-
def wv_product(
1358-
self, attn_weights: Array, value: Array | KVTensor, model_mode: str, einsum: Callable[..., Array]
1359-
) -> Array:
1381+
def wv_product(self, attn_weights: Array, value: Array | KVTensor, model_mode: str, einsum: Callable[..., Array]) -> Array:
13601382
"""weighted value product.
13611383
13621384
Args:

0 commit comments

Comments
 (0)