Skip to content

Commit 0a94059

Browse files
committed
Support multimodal in logit checker and match gemma3 logits with HF
1 parent 8def32a commit 0a94059

File tree

12 files changed

+423
-89
lines changed

12 files changed

+423
-89
lines changed

MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,7 @@ use_untrainable_positional_embedding: False
593593
trainable_position_size: -1 # enable gpt3 position embedding with a positive trainable_position_size
594594
# RoPE parameters
595595
rope_type: "default" # one of "default", "llama3.1" or "yarn"
596+
rope_linear_scaling_factor: 1.0 # linear scaling factor for "default" RoPE (see class `RotaryEmbedding` for more)
596597
rope_use_scale: True # apply rope scaling for llama3.1 (see class `LLaMARotaryEmbedding` for more)
597598
rope_min_timescale: 1
598599
rope_max_timescale: 10_000 # Timescale For global Attention

MaxText/configs/models/gemma3-12b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ base_num_kv_heads: 8
2121
base_mlp_dim: 15360
2222
head_dim: 256
2323
mlp_activations: ["gelu","linear"]
24-
vocab_size: 262_144
24+
vocab_size: 262_208
2525
decoder_block: "gemma3"
2626
normalization_layer_epsilon: 1e-6
2727
logits_via_embedding: True
@@ -30,3 +30,4 @@ use_post_attn_norm: true
3030
use_post_ffw_norm: true
3131
local_rope_max_timescale: 10_000
3232
rope_max_timescale: 1_000_000
33+
rope_linear_scaling_factor: 8.0

MaxText/configs/models/gemma3-27b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ base_num_kv_heads: 16
2121
base_mlp_dim: 21504
2222
head_dim: 128
2323
mlp_activations: ["gelu","linear"]
24-
vocab_size: 262_144
24+
vocab_size: 262_208
2525
decoder_block: "gemma3"
2626
normalization_layer_epsilon: 1e-6
2727
logits_via_embedding: True
@@ -30,3 +30,4 @@ use_post_attn_norm: true
3030
use_post_ffw_norm: true
3131
local_rope_max_timescale: 10_000
3232
rope_max_timescale: 1_000_000
33+
rope_linear_scaling_factor: 8.0

MaxText/configs/models/gemma3-4b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ base_num_kv_heads: 4
2121
base_mlp_dim: 10240
2222
head_dim: 256
2323
mlp_activations: ["gelu","linear"]
24-
vocab_size: 262_144
24+
vocab_size: 262_208
2525
decoder_block: "gemma3"
2626
normalization_layer_epsilon: 1e-6
2727
logits_via_embedding: True
@@ -30,3 +30,4 @@ use_post_attn_norm: true
3030
use_post_ffw_norm: true
3131
local_rope_max_timescale: 10_000
3232
rope_max_timescale: 1_000_000
33+
rope_linear_scaling_factor: 8.0

MaxText/layers/attentions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,11 +694,18 @@ def init_rotary_embedding(self):
694694
# For local attention use local_rope_max_timescale if it's is positive
695695
if self.attention_type == AttentionType.LOCAL_SLIDING and self.config.local_rope_max_timescale > 0:
696696
max_timescale = self.config.local_rope_max_timescale
697+
698+
rope_linear_scaling_factor = self.config.rope_linear_scaling_factor
699+
# In gemma3, linear scaling factor does not apply to local sliding layers.
700+
if self.config.model_name.startswith("gemma3") and self.attention_type == AttentionType.LOCAL_SLIDING:
701+
rope_linear_scaling_factor = 1.0
702+
697703
rotary_embedding = RotaryEmbedding(
698704
min_timescale=self.config.rope_min_timescale,
699705
max_timescale=max_timescale,
700706
embedding_dims=rope_embedding_dims,
701707
fprop_dtype=self.dtype,
708+
rope_linear_scaling_factor=rope_linear_scaling_factor,
702709
rngs=self.rngs,
703710
)
704711
return rotary_embedding

MaxText/layers/embeddings.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def __init__(
242242
fprop_dtype: DType = jnp.bfloat16,
243243
# Not used in RotaryEmbedding but passed in by nnx.bridge.to_linen.
244244
# TODO: Remove when bridge no longer needed
245+
rope_linear_scaling_factor: float = 1.0,
245246
rngs: nnx.Rngs = None,
246247
):
247248
"""Initializes the RotaryEmbedding module.
@@ -261,6 +262,7 @@ def __init__(
261262
self.embedding_dims = embedding_dims
262263
self.cast_as_fprop_dtype = cast_as_fprop_dtype
263264
self.fprop_dtype = fprop_dtype
265+
self.rope_linear_scaling_factor = rope_linear_scaling_factor
264266

265267
if self.embedding_dims % 2:
266268
raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.")
@@ -270,7 +272,10 @@ def timescale(self):
270272
"""Returns the timescale for the rotary embedding."""
271273
half_embedding_dim = self.embedding_dims // 2
272274
fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims
273-
return self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
275+
timescale = self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
276+
if self.rope_linear_scaling_factor != 1.0:
277+
timescale = timescale * self.rope_linear_scaling_factor
278+
return timescale
274279

275280
def __call__(
276281
self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
@@ -448,9 +453,7 @@ def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax.
448453
if len(inputs.shape) != 4:
449454
raise ValueError("Input is assumed to be a rank 4 tensor of shape [B, S, N, H].")
450455
if self.embedding_dims != inputs.shape[3]:
451-
raise ValueError(
452-
"The embedding dims of the rotary position embedding must match the hidden dimension of the inputs."
453-
)
456+
raise ValueError("The embedding dims of the rotary position embedding must match the hidden dimension of the inputs.")
454457

455458
# Shift the inputs left and right as per LLaMA's specific behavior
456459
inputs_shifted_left = jnp.concatenate([inputs[..., 1:], inputs[..., :1]], axis=-1)
@@ -649,9 +652,7 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
649652
if len(inputs.shape) != 4:
650653
raise ValueError("Input is assumed to be a rank 4 tensor of shape [batch, sequence, heads, dims].")
651654
if self.embedding_dims != inputs.shape[3]:
652-
raise ValueError(
653-
"The embedding dims of the rotary position embedding must match the hidden dimension of the inputs."
654-
)
655+
raise ValueError("The embedding dims of the rotary position embedding must match the hidden dimension of the inputs.")
655656

656657
# Determine positions if not provided
657658
if position is None:

MaxText/layers/gemma3.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -277,15 +277,16 @@ def _posemb_sincos_2d(
277277
width: int,
278278
temperature: float = 10_000.0,
279279
dtype: jnp.dtype = jnp.float32,
280+
precision: str = "default",
280281
):
281282
"""Follows the MoCo v3 logic."""
282283
y, x = jnp.mgrid[:h, :w] # pylint: disable=unpacking-non-sequence
283284

284285
assert width % 4 == 0, "Width must be mult of 4 for sincos posemb"
285286
omega = jnp.arange(width // 4) / (width // 4 - 1)
286287
omega = 1.0 / (temperature**omega)
287-
y = jnp.einsum("m,d->md", y.flatten(), omega)
288-
x = jnp.einsum("m,d->md", x.flatten(), omega)
288+
y = jnp.einsum("m,d->md", y.flatten(), omega, precision=jax.lax.Precision(precision))
289+
x = jnp.einsum("m,d->md", x.flatten(), omega, precision=jax.lax.Precision(precision))
289290
pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1)
290291
return jnp.asarray(pe, dtype)[None, :, :]
291292

@@ -297,18 +298,22 @@ class MlpBlockViT(nn.Module):
297298
dtype_mm: str
298299
mlp_dim: int | None = None # Defaults to 4x input dim
299300
dropout: float = 0.0
301+
precision: str = "default"
300302

301303
@nn.compact
302304
def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
303305
"""Applies Transformer MlpBlock module."""
304306
inits = {"kernel_init": nn.initializers.xavier_uniform(), "bias_init": nn.initializers.normal(stddev=1e-6)}
305307

306308
d = x.shape[-1]
307-
x = nn.Dense(features=self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x)
309+
x = nn.Dense(features=self.mlp_dim or 4 * d, precision=jax.lax.Precision(self.precision), dtype=self.dtype_mm, **inits)(
310+
x
311+
)
308312
x = nn.gelu(x)
309313
x = nn.Dropout(rate=self.dropout)(x, deterministic)
310314
x = nn.Dense(
311315
features=d,
316+
precision=jax.lax.Precision(self.precision),
312317
dtype=self.dtype_mm,
313318
**inits,
314319
)(x)
@@ -323,6 +328,7 @@ class Encoder1DBlock(nn.Module):
323328
mlp_dim: int | None = None # Defaults to 4x input dim
324329
num_heads: int = 12
325330
dropout: float = 0.0
331+
precision: str = "default"
326332

327333
@nn.compact
328334
def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
@@ -331,6 +337,7 @@ def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
331337
y = nn.MultiHeadDotProductAttention(
332338
num_heads=self.num_heads,
333339
kernel_init=nn.initializers.xavier_uniform(),
340+
precision=jax.lax.Precision(self.precision),
334341
deterministic=deterministic,
335342
dtype=self.dtype_mm,
336343
)(y, y)
@@ -343,6 +350,7 @@ def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
343350
mlp_dim=self.mlp_dim,
344351
dropout=self.dropout,
345352
dtype_mm=self.dtype_mm,
353+
precision=self.precision,
346354
)(y, deterministic)
347355
y = nn.Dropout(rate=self.dropout)(y, deterministic)
348356
x = x + y
@@ -358,7 +366,8 @@ class Encoder(nn.Module):
358366
mlp_dim: int | None = None # Defaults to 4x input dim
359367
num_heads: int = 12
360368
dropout: float = 0.0
361-
scan: bool = False
369+
scan: bool = False,
370+
precision: str = "default",
362371

363372
@nn.compact
364373
def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
@@ -383,6 +392,7 @@ def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
383392
mlp_dim=self.mlp_dim,
384393
num_heads=self.num_heads,
385394
dropout=self.dropout,
395+
precision=self.precision,
386396
)(
387397
x, deterministic
388398
)
@@ -396,6 +406,7 @@ def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
396406
mlp_dim=self.mlp_dim,
397407
num_heads=self.num_heads,
398408
dropout=self.dropout,
409+
precision=self.precision,
399410
)
400411
x = block_cur(x, deterministic)
401412
x: jax.Array = nn.LayerNorm(name="encoder_norm")(x)
@@ -409,6 +420,7 @@ class Einsum(nn.Module):
409420
weight_name: str = "w"
410421
initializer: nn.initializers.Initializer = nn.initializers.normal()
411422
dtype: jnp.dtype | None = None
423+
precision: str = "default"
412424

413425
@nn.compact
414426
def __call__(self, eqn: str, x: jax.Array) -> jax.Array:
@@ -418,7 +430,7 @@ def __call__(self, eqn: str, x: jax.Array) -> jax.Array:
418430
self.shape,
419431
self.dtype if self.dtype is not None else None,
420432
)
421-
return jnp.einsum(eqn, x, w)
433+
return jnp.einsum(eqn, x, w, precision=jax.lax.Precision(self.precision))
422434

423435

424436
class VisionEmbedder(nn.Module):
@@ -430,8 +442,10 @@ class VisionEmbedder(nn.Module):
430442

431443
def setup(self):
432444
if self.vision_proj_dim:
433-
self.mm_soft_embedding_norm = rms_norm(self.vision_proj_dim)
434-
self.mm_input_projection = Einsum((self.vision_proj_dim, self.config.emb_dim))
445+
self.mm_soft_embedding_norm = rms_norm(self.vision_proj_dim, dtype=self.config.dtype_mm)
446+
self.mm_input_projection = Einsum(
447+
(self.vision_proj_dim, self.config.emb_dim), dtype=self.config.dtype_mm, precision=self.config.matmul_precision
448+
)
435449

436450
def encode_vision(self, x: jax.Array) -> jax.Array:
437451
x = self.mm_soft_embedding_norm(x)
@@ -494,6 +508,7 @@ def _get_posemb(
494508
width: int,
495509
name: str,
496510
dtype: jnp.dtype = jnp.float32,
511+
precision: str = "default",
497512
):
498513
"""Returns the position embedding."""
499514
if typ == "learn":
@@ -505,7 +520,7 @@ def _get_posemb(
505520
dtype,
506521
)
507522
elif typ == "sincos2d":
508-
return _posemb_sincos_2d(*seqshape, width=width, dtype=dtype)
523+
return _posemb_sincos_2d(*seqshape, width=width, dtype=dtype, precision=precision)
509524
else:
510525
raise ValueError(f"Unknown posemb type: {typ}")
511526

@@ -524,7 +539,15 @@ def __call__(self, inputs, deterministic, train=False):
524539
b, n, h, w, c = inputs.shape
525540
x = jnp.reshape(inputs, [b * n, h, w, c])
526541
# Gemma3 uses conv2d with stride 14 and kernel size 14 to extract patches.
527-
x = nn.Conv(features=1152, kernel_size=(14, 14), strides=14, padding="VALID", name="embedding")(x)
542+
x = nn.Conv(
543+
features=1152,
544+
kernel_size=(14, 14),
545+
strides=14,
546+
padding="VALID",
547+
name="embedding",
548+
dtype=cfg.dtype_mm,
549+
precision=jax.lax.Precision(cfg.matmul_precision),
550+
)(x)
528551
bn, h, w, c = x.shape
529552
x = jnp.reshape(x, [bn, h * w, c])
530553

@@ -535,6 +558,7 @@ def __call__(self, inputs, deterministic, train=False):
535558
width=c,
536559
name="pos_embedding",
537560
dtype=x.dtype,
561+
precision=cfg.matmul_precision,
538562
)
539563

540564
x = nn.Dropout(rate=self.dropout)(x, not train)
@@ -549,6 +573,7 @@ def __call__(self, inputs, deterministic, train=False):
549573
remat_policy=cfg.remat_policy_for_vit,
550574
dtype_mm=cfg.dtype_mm,
551575
name="Transformer",
576+
precision=cfg.matmul_precision,
552577
)(x, deterministic=deterministic)
553578

554579
# Gemma3 use a vision exit layer to downsample the soft tokens to a required output length.

MaxText/multimodal_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
GEMMA_IMAGE_STD = (127.5,) * 3
3636
GEMMA_IMAGE_PLACEHOLDER_IN_PROMPT = "<start_of_image>"
3737
GEMMA_BEGIN_IMAGE_TOKEN = 255999
38-
GEMMA_END_IMAGE_TOKEN = 262144
38+
GEMMA_END_IMAGE_TOKEN = 256000
3939
GEMMA_NEW_LINE_TOKEN = 108
40-
GEMMA_TOKEN_PLACEHOLDER = -2
40+
GEMMA_TOKEN_PLACEHOLDER = 262144
4141
# The number of GEMMA_TOKEN_PLACEHOLDER tokens per image in Gemma3
4242
GEMMA_NUM_PLACEHOLDER_TOKENS_PER_IMAGE = 256
4343
# +4 means 4 extra tokens to pad around image: \n\n, <start_of_image>, <end_of_image>, \n\n

0 commit comments

Comments
 (0)