Skip to content

Migrate DotProductAttention to NNX #2198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 54 additions & 32 deletions MaxText/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,55 @@ def maybe_create_nnx(einsum, *args):
self.AqtEinsum_2 = jnp.einsum
self.AqtEinsum_3 = jnp.einsum

if self.attention_kernel == "cudnn_flash_te":
# These imports are only meant to work in a GPU build.
# pylint: disable=import-outside-toplevel
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disabatch_sizeatch_sizele=import-error

using_context_parallelism = self.mesh.shape["context"] > 1

if self.attention_type == AttentionType.LOCAL_SLIDING and using_context_parallelism:
raise AssertionError("Sliding window attention is not supported when context parallelism is enabled")

sliding_window_size = None

if self.attention_type == AttentionType.LOCAL_SLIDING or not self.config.enable_padding_causal_mask:
sliding_window_size = [self.sliding_window_size, 0]

if self.attention_type == AttentionType.LOCAL_SLIDING or using_context_parallelism:
mask_type = "causal" # SWA and Context Parallelism only work with causal masking
dummy_attn_mask = None
else:
# generate attn_mask
mask_type = "padding_causal" # only padding_causal mask type can take a created mask
dummy_attn_mask = jnp.zeros((1, 1, 1, self.max_target_length, self.max_target_length), dtype=jnp.uint8)

dpa_layer = DotProductAttention(
head_dim=config.head_dim,
num_attention_heads=self.num_query_heads,
num_gqa_groups=self.num_kv_heads,
attn_mask_type=mask_type, # 'no_mask', 'padding', 'causal', or 'padding_causal'
attn_bias_type="no_bias", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
attention_dropout=self.dropout_rate,
dropout_rng_name="aqt",
dtype=self.dtype,
float32_logits=self.float32_logits,
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
scale_factor=1.0,
transpose_batch_sequence=False,
window_size=sliding_window_size,
context_parallel_causal_load_balanced=config.context_parallel_load_balance,
context_parallel_axis="context",
)

dpa_layer = nnx_wrappers.ToNNX(dpa_layer, rngs=rngs)
dummy_query_prefill = jnp.zeros((1, self.max_target_length, self.num_query_heads, config.head_dim), dtype=self.dtype)
dummy_key_prefill = jnp.zeros((1, self.max_target_length, self.num_kv_heads, config.head_dim), dtype=self.dtype)
dummy_value_prefill = jnp.zeros((1, self.max_target_length, self.num_kv_heads, config.head_dim), dtype=self.dtype)
Comment on lines +561 to +563
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cgarciae are zeros the right value here?


dpa_layer.lazy_init(dummy_query_prefill, dummy_key_prefill, dummy_value_prefill, mask=dummy_attn_mask)
self.dpa_layer = dpa_layer

def check_attention_inputs(self, query: Array, key: Array | KVTensor, value: Array | KVTensor) -> None:
"""Check attention inputs."""

Expand Down Expand Up @@ -1096,48 +1145,23 @@ def cudnn_flash_attention(
1. Stable API, supports GQA, SWA (only with causal masking)
2. Head_dim = 256 is also supported from TE-1.12 stable release with CUDNN 12.6
"""
# These imports are only meant to work in a GPU build.
# pylint: disable=import-outside-toplevel
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error

_, _, _, head_dim = query.shape # pylint: disable=unused-variable

using_context_parallelism = self.mesh.shape["context"] > 1

if self.attention_type == AttentionType.LOCAL_SLIDING and using_context_parallelism:
raise AssertionError("Sliding window attention is not supported when context parallelism is enabled")

sliding_window_size = None

if self.attention_type == AttentionType.LOCAL_SLIDING or not self.config.enable_padding_causal_mask:
sliding_window_size = [self.sliding_window_size, 0]

if self.attention_type == AttentionType.LOCAL_SLIDING or using_context_parallelism:
mask_type = "causal" # SWA and Context Parallelism only work with causal masking
# SWA and Context Parallelism only work with causal masking
attn_mask = None
else:
# generate attn_mask
mask_type = "padding_causal" # only padding_causal mask type can take a created mask
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)

dpa_layer = DotProductAttention(
head_dim=head_dim,
num_attention_heads=self.num_query_heads,
num_gqa_groups=self.num_kv_heads,
attn_mask_type=mask_type, # 'no_mask', 'padding', 'causal', or 'padding_causal'
attn_bias_type="no_bias", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
attention_dropout=self.dropout_rate,
dropout_rng_name="aqt",
dtype=self.dtype,
float32_logits=self.float32_logits,
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
scale_factor=1.0,
transpose_batch_sequence=False,
window_size=sliding_window_size,
context_parallel_causal_load_balanced=self.config.context_parallel_load_balance,
context_parallel_axis="context",
)
return dpa_layer(query, key, value, mask=attn_mask)
if attn_mask is not None:
attn_mask = jnp.where((attn_mask >= DEFAULT_MASK_VALUE * 0.5), 0, 1)
return self.dpa_layer(query, key, value, mask=attn_mask)

def cudnn_jax_flash_attention(
self,
Expand Down Expand Up @@ -1354,9 +1378,7 @@ def qk_product(
raise NotImplementedError(self.compute_axis_order)
return result

def wv_product(
self, attn_weights: Array, value: Array | KVTensor, model_mode: str, einsum: Callable[..., Array]
) -> Array:
def wv_product(self, attn_weights: Array, value: Array | KVTensor, model_mode: str, einsum: Callable[..., Array]) -> Array:
"""weighted value product.

Args:
Expand Down
Loading