From a79751c1cd4a503b4ac64f15de0ba759bbb3d071 Mon Sep 17 00:00:00 2001 From: Gagik Amirkhanyan Date: Tue, 19 Aug 2025 06:54:58 -0700 Subject: [PATCH] Disable depth scaling for qk_norm and query_pre_attn_scalar in Attention --- MaxText/layers/attentions.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/MaxText/layers/attentions.py b/MaxText/layers/attentions.py index bbaddabf74..1325c1a600 100644 --- a/MaxText/layers/attentions.py +++ b/MaxText/layers/attentions.py @@ -505,7 +505,12 @@ def init_query_w(self, inputs_q_shape: Tuple) -> nnx.Module: # NOTE: T5 does not explicitly rescale the attention logits by # 1/sqrt(depth_kq)! This is folded into the initializers of the # linear transformations, which is equivalent under Adafactor. - depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) + # We disable depth_scaling when using qk_norm or a query_pre_attn_scalar + # to avoid applying scaling twice. + if self.config.use_qk_norm or (self.query_pre_attn_scalar is not None and self.query_pre_attn_scalar != 1.0): + depth_scaling = 1.0 + else: + depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) def query_init(*args): # pylint: disable=no-value-for-parameter