Skip to content

Disable depth scaling in query projection if qk_norm or query_pre_attn_scalar are set #2204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

gagika
Copy link
Collaborator

@gagika gagika commented Aug 19, 2025

Description

This PR corrects the scaling of query projections in the attention mechanism to prevent applying the scaling factor twice.

Background

To ensure numerical stability, attention logits are typically scaled by 1/sqrt(head_dim). We had followed T5 convention, by scaling directly into the weight initializers of the query projection.

The Problem

Newer features like use_qk_norm (which normalizes the query/key tensors directly) and query_pre_attn_scalar (which applies an explicit scale after query projection) provide alternative ways to achieve the same stability. When these features are enabled, the original scaling in the weight initializer becomes redundant and results in the logits being scaled down twice, which can harm performance.

The Solution

This change disables the default initializer scaling (by setting depth_scaling to 1.0) whenever use_qk_norm or query_pre_attn_scalar is active. This ensures that only one scaling method is applied at a time. This is a correctness fix that has no impact on configurations not using these new features.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@gagika gagika marked this pull request as ready for review August 19, 2025 14:06
@gagika gagika changed the title Disable depth scaling for qk_norm and query_pre_attn_scalar in Attention Disable depth scaling in query projection if qk_norm or query_pre_attn_scalar are set Aug 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants