Skip to content

Provide MaxText axes to cudnn_flash_te to correctly perform dbias reduction if required #2221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jberchtold-nvidia
Copy link

Description

TransformerEngine, which is used for the attention=cudnn_flash_te attention backend, requires the physical mesh axis names to to correctly perform a dbias reduction across multiple devices.

Currently in MaxText, we never perform as dbias reduction as the attention type of TE's DotProductAttention is always set to "no_bias" (usage in attention_op.py). However, if this changes in the future or a user decides to modify this on their fork, an error will be raised due to MaxText not providing TE with it's mesh axis names associated with the parallelism types, e.g. fsdp_resource = "fsdp", etc.

For all cases except when cudnn_flash_te is used with bias, this change is a no-op.
All cases:

  1. No-op. The transformer_engine package is not present, in which case this context does not set any state.
  2. The transformer_engine package is present. The context sets a particular state in TE for the mapping of physical axes to parallelism concepts, e.g. fsdp.
    1. No-op. cudnn_flash_te is NOT used with "pre_scale_bias". The axis mapping from the context is unused and does not affect any other operations.
    2. cudnn_flash_te is used with "pre_scale_bias". The axis info from the context is used and TE does not raise an error. Without this change in this PR, TE would raise an error here due to missing the axis info.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link

google-cla bot commented Aug 21, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant