Disable depth scaling in query projection if qk_norm or query_pre_attn_scalar are set #2204
+6
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR corrects the scaling of query projections in the attention mechanism to prevent applying the scaling factor twice.
Background
To ensure numerical stability, attention logits are typically scaled by 1/sqrt(head_dim). We had followed T5 convention, by scaling directly into the weight initializers of the query projection.
The Problem
Newer features like use_qk_norm (which normalizes the query/key tensors directly) and query_pre_attn_scalar (which applies an explicit scale after query projection) provide alternative ways to achieve the same stability. When these features are enabled, the original scaling in the weight initializer becomes redundant and results in the logits being scaled down twice, which can harm performance.
The Solution
This change disables the default initializer scaling (by setting depth_scaling to 1.0) whenever use_qk_norm or query_pre_attn_scalar is active. This ensures that only one scaling method is applied at a time. This is a correctness fix that has no impact on configurations not using these new features.
Checklist
Before submitting this PR, please make sure (put X in square brackets):