Skip to content

Commit f8b4331

Browse files
committed
Support multimodal in logit checker and match gemma3 logits with HF
1 parent 8def32a commit f8b4331

File tree

12 files changed

+380
-73
lines changed

12 files changed

+380
-73
lines changed

MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,7 @@ use_untrainable_positional_embedding: False
593593
trainable_position_size: -1 # enable gpt3 position embedding with a positive trainable_position_size
594594
# RoPE parameters
595595
rope_type: "default" # one of "default", "llama3.1" or "yarn"
596+
rope_linear_scaling_factor: 1.0 # linear scaling factor for "default" RoPE (see class `RotaryEmbedding` for more)
596597
rope_use_scale: True # apply rope scaling for llama3.1 (see class `LLaMARotaryEmbedding` for more)
597598
rope_min_timescale: 1
598599
rope_max_timescale: 10_000 # Timescale For global Attention

MaxText/configs/models/gemma3-12b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ base_num_kv_heads: 8
2121
base_mlp_dim: 15360
2222
head_dim: 256
2323
mlp_activations: ["gelu","linear"]
24-
vocab_size: 262_144
24+
vocab_size: 262_208
2525
decoder_block: "gemma3"
2626
normalization_layer_epsilon: 1e-6
2727
logits_via_embedding: True
@@ -30,3 +30,4 @@ use_post_attn_norm: true
3030
use_post_ffw_norm: true
3131
local_rope_max_timescale: 10_000
3232
rope_max_timescale: 1_000_000
33+
rope_linear_scaling_factor: 8.0

MaxText/configs/models/gemma3-27b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ base_num_kv_heads: 16
2121
base_mlp_dim: 21504
2222
head_dim: 128
2323
mlp_activations: ["gelu","linear"]
24-
vocab_size: 262_144
24+
vocab_size: 262_208
2525
decoder_block: "gemma3"
2626
normalization_layer_epsilon: 1e-6
2727
logits_via_embedding: True
@@ -30,3 +30,4 @@ use_post_attn_norm: true
3030
use_post_ffw_norm: true
3131
local_rope_max_timescale: 10_000
3232
rope_max_timescale: 1_000_000
33+
rope_linear_scaling_factor: 8.0

MaxText/configs/models/gemma3-4b.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ base_num_kv_heads: 4
2121
base_mlp_dim: 10240
2222
head_dim: 256
2323
mlp_activations: ["gelu","linear"]
24-
vocab_size: 262_144
24+
vocab_size: 262_208
2525
decoder_block: "gemma3"
2626
normalization_layer_epsilon: 1e-6
2727
logits_via_embedding: True
@@ -30,3 +30,4 @@ use_post_attn_norm: true
3030
use_post_ffw_norm: true
3131
local_rope_max_timescale: 10_000
3232
rope_max_timescale: 1_000_000
33+
rope_linear_scaling_factor: 8.0

MaxText/layers/attentions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,11 +694,18 @@ def init_rotary_embedding(self):
694694
# For local attention use local_rope_max_timescale if it's is positive
695695
if self.attention_type == AttentionType.LOCAL_SLIDING and self.config.local_rope_max_timescale > 0:
696696
max_timescale = self.config.local_rope_max_timescale
697+
698+
rope_linear_scaling_factor = self.config.rope_linear_scaling_factor
699+
# In gemma3, linear scaling factor does not apply to local sliding layers.
700+
if self.config.model_name.startswith("gemma3") and self.attention_type == AttentionType.LOCAL_SLIDING:
701+
rope_linear_scaling_factor = 1.0
702+
697703
rotary_embedding = RotaryEmbedding(
698704
min_timescale=self.config.rope_min_timescale,
699705
max_timescale=max_timescale,
700706
embedding_dims=rope_embedding_dims,
701707
fprop_dtype=self.dtype,
708+
rope_linear_scaling_factor=rope_linear_scaling_factor,
702709
rngs=self.rngs,
703710
)
704711
return rotary_embedding

MaxText/layers/embeddings.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def __init__(
242242
fprop_dtype: DType = jnp.bfloat16,
243243
# Not used in RotaryEmbedding but passed in by nnx.bridge.to_linen.
244244
# TODO: Remove when bridge no longer needed
245+
rope_linear_scaling_factor: float = 1.0,
245246
rngs: nnx.Rngs = None,
246247
):
247248
"""Initializes the RotaryEmbedding module.
@@ -261,6 +262,7 @@ def __init__(
261262
self.embedding_dims = embedding_dims
262263
self.cast_as_fprop_dtype = cast_as_fprop_dtype
263264
self.fprop_dtype = fprop_dtype
265+
self.rope_linear_scaling_factor = rope_linear_scaling_factor
264266

265267
if self.embedding_dims % 2:
266268
raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.")
@@ -270,7 +272,10 @@ def timescale(self):
270272
"""Returns the timescale for the rotary embedding."""
271273
half_embedding_dim = self.embedding_dims // 2
272274
fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims
273-
return self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
275+
timescale = self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
276+
if self.rope_linear_scaling_factor != 1.0:
277+
timescale = timescale * self.rope_linear_scaling_factor
278+
return timescale
274279

275280
def __call__(
276281
self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
@@ -448,9 +453,7 @@ def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax.
448453
if len(inputs.shape) != 4:
449454
raise ValueError("Input is assumed to be a rank 4 tensor of shape [B, S, N, H].")
450455
if self.embedding_dims != inputs.shape[3]:
451-
raise ValueError(
452-
"The embedding dims of the rotary position embedding must match the hidden dimension of the inputs."
453-
)
456+
raise ValueError("The embedding dims of the rotary position embedding must match the hidden dimension of the inputs.")
454457

455458
# Shift the inputs left and right as per LLaMA's specific behavior
456459
inputs_shifted_left = jnp.concatenate([inputs[..., 1:], inputs[..., :1]], axis=-1)
@@ -649,9 +652,7 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
649652
if len(inputs.shape) != 4:
650653
raise ValueError("Input is assumed to be a rank 4 tensor of shape [batch, sequence, heads, dims].")
651654
if self.embedding_dims != inputs.shape[3]:
652-
raise ValueError(
653-
"The embedding dims of the rotary position embedding must match the hidden dimension of the inputs."
654-
)
655+
raise ValueError("The embedding dims of the rotary position embedding must match the hidden dimension of the inputs.")
655656

656657
# Determine positions if not provided
657658
if position is None:

MaxText/layers/gemma3.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -297,18 +297,22 @@ class MlpBlockViT(nn.Module):
297297
dtype_mm: str
298298
mlp_dim: int | None = None # Defaults to 4x input dim
299299
dropout: float = 0.0
300+
precision: str = "default"
300301

301302
@nn.compact
302303
def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
303304
"""Applies Transformer MlpBlock module."""
304305
inits = {"kernel_init": nn.initializers.xavier_uniform(), "bias_init": nn.initializers.normal(stddev=1e-6)}
305306

306307
d = x.shape[-1]
307-
x = nn.Dense(features=self.mlp_dim or 4 * d, dtype=self.dtype_mm, **inits)(x)
308+
x = nn.Dense(features=self.mlp_dim or 4 * d, precision=jax.lax.Precision(self.precision), dtype=self.dtype_mm, **inits)(
309+
x
310+
)
308311
x = nn.gelu(x)
309312
x = nn.Dropout(rate=self.dropout)(x, deterministic)
310313
x = nn.Dense(
311314
features=d,
315+
precision=jax.lax.Precision(self.precision),
312316
dtype=self.dtype_mm,
313317
**inits,
314318
)(x)
@@ -323,6 +327,7 @@ class Encoder1DBlock(nn.Module):
323327
mlp_dim: int | None = None # Defaults to 4x input dim
324328
num_heads: int = 12
325329
dropout: float = 0.0
330+
precision: str = "default"
326331

327332
@nn.compact
328333
def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
@@ -331,6 +336,7 @@ def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
331336
y = nn.MultiHeadDotProductAttention(
332337
num_heads=self.num_heads,
333338
kernel_init=nn.initializers.xavier_uniform(),
339+
precision=jax.lax.Precision(self.precision),
334340
deterministic=deterministic,
335341
dtype=self.dtype_mm,
336342
)(y, y)
@@ -343,6 +349,7 @@ def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
343349
mlp_dim=self.mlp_dim,
344350
dropout=self.dropout,
345351
dtype_mm=self.dtype_mm,
352+
precision=self.precision,
346353
)(y, deterministic)
347354
y = nn.Dropout(rate=self.dropout)(y, deterministic)
348355
x = x + y
@@ -358,7 +365,8 @@ class Encoder(nn.Module):
358365
mlp_dim: int | None = None # Defaults to 4x input dim
359366
num_heads: int = 12
360367
dropout: float = 0.0
361-
scan: bool = False
368+
scan: bool = False,
369+
precision: str = "default",
362370

363371
@nn.compact
364372
def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
@@ -383,6 +391,7 @@ def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
383391
mlp_dim=self.mlp_dim,
384392
num_heads=self.num_heads,
385393
dropout=self.dropout,
394+
precision=self.precision,
386395
)(
387396
x, deterministic
388397
)
@@ -396,6 +405,7 @@ def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
396405
mlp_dim=self.mlp_dim,
397406
num_heads=self.num_heads,
398407
dropout=self.dropout,
408+
precision=self.precision,
399409
)
400410
x = block_cur(x, deterministic)
401411
x: jax.Array = nn.LayerNorm(name="encoder_norm")(x)
@@ -430,7 +440,7 @@ class VisionEmbedder(nn.Module):
430440

431441
def setup(self):
432442
if self.vision_proj_dim:
433-
self.mm_soft_embedding_norm = rms_norm(self.vision_proj_dim)
443+
self.mm_soft_embedding_norm = rms_norm(self.vision_proj_dim, dtype=self.config.dtype_mm)
434444
self.mm_input_projection = Einsum((self.vision_proj_dim, self.config.emb_dim))
435445

436446
def encode_vision(self, x: jax.Array) -> jax.Array:
@@ -524,7 +534,15 @@ def __call__(self, inputs, deterministic, train=False):
524534
b, n, h, w, c = inputs.shape
525535
x = jnp.reshape(inputs, [b * n, h, w, c])
526536
# Gemma3 uses conv2d with stride 14 and kernel size 14 to extract patches.
527-
x = nn.Conv(features=1152, kernel_size=(14, 14), strides=14, padding="VALID", name="embedding")(x)
537+
x = nn.Conv(
538+
features=1152,
539+
kernel_size=(14, 14),
540+
strides=14,
541+
padding="VALID",
542+
name="embedding",
543+
dtype=cfg.dtype_mm,
544+
precision=jax.lax.Precision(cfg.matmul_precision),
545+
)(x)
528546
bn, h, w, c = x.shape
529547
x = jnp.reshape(x, [bn, h * w, c])
530548

@@ -549,6 +567,7 @@ def __call__(self, inputs, deterministic, train=False):
549567
remat_policy=cfg.remat_policy_for_vit,
550568
dtype_mm=cfg.dtype_mm,
551569
name="Transformer",
570+
precision=cfg.matmul_precision,
552571
)(x, deterministic=deterministic)
553572

554573
# Gemma3 use a vision exit layer to downsample the soft tokens to a required output length.

MaxText/multimodal_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
GEMMA_IMAGE_STD = (127.5,) * 3
3636
GEMMA_IMAGE_PLACEHOLDER_IN_PROMPT = "<start_of_image>"
3737
GEMMA_BEGIN_IMAGE_TOKEN = 255999
38-
GEMMA_END_IMAGE_TOKEN = 262144
38+
GEMMA_END_IMAGE_TOKEN = 256000
3939
GEMMA_NEW_LINE_TOKEN = 108
40-
GEMMA_TOKEN_PLACEHOLDER = -2
40+
GEMMA_TOKEN_PLACEHOLDER = 262144
4141
# The number of GEMMA_TOKEN_PLACEHOLDER tokens per image in Gemma3
4242
GEMMA_NUM_PLACEHOLDER_TOKENS_PER_IMAGE = 256
4343
# +4 means 4 extra tokens to pad around image: \n\n, <start_of_image>, <end_of_image>, \n\n

MaxText/scratch_code/generate_hf_golden_logits.py

Lines changed: 95 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,34 @@
1515
Usage:
1616
1717
python3 -m MaxText.scratch_code.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V2-Lite \
18-
--output-path=golden_DeepSeek-V2-Lite.jsonl --prompts='I love to,Today is a,What is the' \
18+
--output-path=golden_DeepSeek-V2-Lite.jsonl --prompts='I love to;Today is a;What is the' \
1919
--gcs-bucket=my-gcs-bucket
2020
21+
For large models, you can use an m1 cpu. Calling the script directly instead of calling MaxText module \
22+
can skip importing unnecessary dependencies.
23+
For large Hugginface checkpoints, you can use pre-downloaded checkpoints with --hf-model-path argument.
24+
For multimodal models, use --image-paths argument to provide image path(s),\
25+
use --apply-chat-template=true if use HF chat template to format image+prompt.\
26+
When using chat template, the prompt should not contain image placeholders.
27+
28+
More examples:
29+
python3 MaxText/scratch_code/generate_hf_golden_logits.py --model-id=meta-llama/Llama-4-Scout-17B-16E \
30+
--output-path=golden_Llama-4-Scout-17B-16E_vision.jsonl --prompts='Describe this image.' \
31+
--apply-chat-template=true --gcs-bucket=<bucket> --hf-model-path=<hf_checkpoint_path> \
32+
--image-paths=MaxText/test_assets/test_image.jpg
33+
34+
python3 MaxText/scratch_code/generate_hf_golden_logits.py --model-id=google/gemma-3-4b-it \
35+
--output-path=golden_gemma-3-4b-it_vision.jsonl --prompts='<start_of_image>' \
36+
--apply-chat-template=false --gcs-bucket=<bucket> --hf-model-path=<hf_checkpoint_path> \
37+
--image-paths=MaxText/test_assets/test_image.jpg
2138
"""
2239

2340
import torch
2441
import argparse
25-
from transformers import AutoTokenizer, AutoModelForCausalLM
42+
from transformers import AutoTokenizer, AutoProcessor, AutoModelForCausalLM
2643
import jsonlines
2744
from google.cloud import storage
45+
from PIL import Image
2846

2947
# Load the tokenizer and model from Hugging Face
3048

@@ -37,32 +55,74 @@ def upload_blob(bucket_name, source_file_name, destination_blob_name):
3755
blob.upload_from_filename(source_file_name)
3856

3957

40-
def save_golden_logits(model_id, output_path, prompt_texts, gcs_bucket):
58+
def save_golden_logits(model_id, output_path, prompt_texts, apply_chat_template, gcs_bucket, hf_model_path, image_paths):
4159
"""save golden logits"""
42-
tokenizer = AutoTokenizer.from_pretrained(model_id)
60+
if hf_model_path is None:
61+
hf_model_path = model_id
62+
tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
4363
model = AutoModelForCausalLM.from_pretrained(
44-
model_id,
64+
hf_model_path,
4565
torch_dtype=torch.float32,
4666
trust_remote_code=True,
4767
)
4868

4969
all_data_to_save = []
50-
for prompt_text in prompt_texts:
70+
for i, prompt_text in enumerate(prompt_texts):
5171
# Encode the prompt text
52-
input_ids = tokenizer.encode(prompt_text, return_tensors="pt")
72+
if image_paths:
73+
try:
74+
image = Image.open(image_paths[i])
75+
except Exception as e:
76+
raise e
77+
image = image.convert("RGB")
78+
# TODO (aireenmei): remove this when Llama-4 supports dynamic image shapes.
79+
if model_id.startswith("meta-llama/Llama-4"):
80+
image = image.resize((336, 336))
81+
processor = AutoProcessor.from_pretrained(model_id, token=True)
82+
if apply_chat_template:
83+
messages = [
84+
{
85+
"role": "user",
86+
"content": [
87+
{"type": "image"},
88+
{"type": "text", "text": prompt_text},
89+
],
90+
},
91+
]
92+
formatted_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
93+
inputs = processor(text=formatted_prompt, images=image, return_tensors="pt")
94+
else:
95+
formatted_prompt = prompt_text
96+
inputs = processor(text=formatted_prompt, images=image, return_tensors="pt", add_special_tokens=False)
97+
with torch.no_grad():
98+
outputs = model(**inputs)
99+
logits = outputs.logits.cpu().numpy().astype("float32")
53100

54-
# Get the logits for the prompt + completion
55-
with torch.no_grad():
56-
outputs = model(input_ids)
57-
logits = outputs.logits.cpu().numpy().astype("float32")
101+
data_to_save = {
102+
"prompt": prompt_text,
103+
"formatted_prompt": formatted_prompt,
104+
"tokens": inputs["input_ids"].tolist()[0],
105+
"attention_mask": inputs["attention_mask"].tolist()[0],
106+
"image_path": image_paths[i],
107+
"pixel_values": inputs["pixel_values"].tolist()[0],
108+
"logits": logits.tolist()[0],
109+
}
110+
else:
111+
input_ids = tokenizer.encode(prompt_text, return_tensors="pt")
112+
# Get the logits for the prompt + completion
113+
with torch.no_grad():
114+
outputs = model(input_ids)
115+
logits = outputs.logits.cpu().numpy().astype("float32")
58116

59117
# Prepare data to be saved
60118
data_to_save = {
61119
"prompt": prompt_text,
62120
"tokens": input_ids.tolist()[0],
63121
"logits": logits.tolist()[0], # Convert numpy array to list for JSON serialization
64122
}
65-
all_data_to_save.append(data_to_save)
123+
print(f"Token length is {len(data_to_save['tokens'])} for prompt: {prompt_text}")
124+
print(f"raw ids: {data_to_save['tokens']}")
125+
all_data_to_save.append(data_to_save)
66126

67127
with jsonlines.open(output_path, "w") as f:
68128
f.write_all(all_data_to_save)
@@ -77,13 +137,33 @@ def main(raw_args=None) -> None:
77137
parser = argparse.ArgumentParser()
78138
parser.add_argument("--model-id", type=str, required=True, help="The identifier of the model to use.")
79139
parser.add_argument("--output-path", type=str, required=True, help="The path to save the generated golden logits.")
80-
parser.add_argument("--prompts", type=str, required=True, help="A comma-separated list of prompts.")
140+
parser.add_argument("--prompts", type=str, required=True, help="A semicolon-separated list of prompts.")
141+
parser.add_argument(
142+
"--apply-chat-template",
143+
type=bool,
144+
required=False,
145+
default=False,
146+
help="Whether to apply chat template from the HF processor. Used for image+text input.",
147+
)
81148
parser.add_argument(
82149
"--gcs-bucket", type=str, required=False, default=None, help="A GCS bucket to store logits, without gs://."
83150
)
151+
parser.add_argument("--hf-model-path", type=str, required=False, default=None, help="local path to checkpoint if exists.")
152+
parser.add_argument(
153+
"--image-paths", type=str, required=False, default=None, help="A semicolon-separated list of image_paths."
154+
)
84155
args = parser.parse_args(raw_args)
85-
prompts = args.prompts.split(",")
86-
save_golden_logits(args.model_id, args.output_path, prompts, args.gcs_bucket)
156+
prompts = args.prompts.split(";")
157+
image_paths = args.image_paths.split(";") if args.image_paths else []
158+
if image_paths:
159+
assert len(image_paths) == len(
160+
prompts
161+
), "when image paths are provided, image_paths and prompts must have the same length."
162+
if args.apply_chat_template:
163+
assert image_paths, "apply_chat_template is only used for image+text input, so image_paths must be provided."
164+
save_golden_logits(
165+
args.model_id, args.output_path, prompts, args.apply_chat_template, args.gcs_bucket, args.hf_model_path, image_paths
166+
)
87167

88168

89169
if __name__ == "__main__":

0 commit comments

Comments
 (0)