Skip to content

[AMD] Fix aiter backend missing ENCODER_ONLY attention support#20102

Open
nathanrchn wants to merge 6 commits intosgl-project:mainfrom
nathanrchn:fix/aiter-encoder-only-attention
Open

[AMD] Fix aiter backend missing ENCODER_ONLY attention support#20102
nathanrchn wants to merge 6 commits intosgl-project:mainfrom
nathanrchn:fix/aiter-encoder-only-attention

Conversation

@nathanrchn
Copy link
Contributor

Motivation

The aiter attention backend hardcoded causal=True in forward_extend, unlike other backends (flashinfer, triton, torch_native) which check for AttentionType.ENCODER_ONLY and set causal=False + save_kv_cache=False.

This caused bidirectional embedding models (e.g. perplexity-ai/pplx-embed-v1) to produce NaN hidden states when batching multiple requests on AMD GPUs, as the causal mask prevented proper bidirectional attention computation.

Modifications

  • Import AttentionType at runtime (was only under TYPE_CHECKING)
  • Add ENCODER_ONLY check at the start of forward_extend to set causal=False and save_kv_cache=False, matching the flashinfer backend behavior
  • Replace hardcoded causal=True with the computed causal variable in both the MLA and non-MLA prefill paths

Checklist

  • Format: pre-commit run --all-files
  • Matches behavior of flashinfer/triton/torch_native backends

The aiter attention backend hardcoded `causal=True` in forward_extend,
unlike other backends (flashinfer, triton, torch_native) which set
`causal=False` and `save_kv_cache=False` for ENCODER_ONLY attention.

This caused bidirectional embedding models (e.g. pplx-embed-v1) to
produce NaN hidden states when batching multiple requests, as the
causal mask prevented proper bidirectional attention computation.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical issue in the aiter attention backend on AMD GPUs where bidirectional embedding models produced NaN hidden states due to an incorrectly hardcoded causal mask. The changes introduce dynamic handling of causal attention and KV cache saving based on the attention type, ensuring compatibility and correct behavior for encoder-only models and aligning the aiter backend with other existing attention backends.

Highlights

  • Runtime AttentionType Import: The AttentionType enum is now imported at runtime, moving it from a TYPE_CHECKING block to enable its use in production code.
  • ENCODER_ONLY Attention Support: Added logic in forward_extend to correctly handle AttentionType.ENCODER_ONLY by setting causal=False and save_kv_cache=False, aligning behavior with other backends.
  • Dynamic Causal Masking: Replaced hardcoded causal=True with a dynamically computed causal variable in both MLA and non-MLA prefill paths, ensuring proper attention computation for different model types.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/layers/attention/aiter_backend.py
    • Imported AttentionType for runtime use.
    • Added conditional logic in forward_extend to set causal and save_kv_cache based on layer.is_cross_attention and layer.attn_type.
    • Replaced hardcoded causal=True with the new causal variable in two _attn_prefill_wrapper calls.
Activity
  • The code has been formatted using pre-commit run --all-files.
  • The changes have been verified to match the behavior of flashinfer, triton, and torch_native backends.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in the aiter attention backend where causal=True was hardcoded, leading to incorrect behavior for bidirectional models on AMD GPUs. The fix correctly sets causal=False and disables KV cache saving for ENCODER_ONLY attention, aligning its behavior with other backends. The implementation is sound, and I've offered a minor suggestion to simplify the conditional logic for improved readability and efficiency.

nathanrchn and others added 5 commits March 7, 2026 18:32
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
…ites

Remove `save_kv_cache = False` for ENCODER_ONLY attention. The non-MLA
prefill path reads K/V from the cache buffer via kv_indices, so skipping
the cache write causes attention to read uninitialized GPU memory,
producing NaN hidden states that nan_to_num silently converts to zeros.
With disable_radix_cache=True, cache entries are freed after each request
anyway, so there is no memory waste from writing them.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant