Skip to content

Disable depth scaling in query projection if qk_norm or query_pre_attn_scalar are set #2204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,12 @@ def init_query_w(self, inputs_q_shape: Tuple) -> nnx.Module:
# NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(depth_kq)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor.
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
# We disable depth_scaling when using qk_norm or a query_pre_attn_scalar
# to avoid applying scaling twice.
if self.config.use_qk_norm or (self.query_pre_attn_scalar is not None and self.query_pre_attn_scalar != 1.0):
depth_scaling = 1.0
else:
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)

def query_init(*args):
# pylint: disable=no-value-for-parameter
Expand Down
Loading