Skip to content

Support multimodal in logit checker + match gemma3 logits with HF #2203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ use_untrainable_positional_embedding: False
trainable_position_size: -1 # enable gpt3 position embedding with a positive trainable_position_size
# RoPE parameters
rope_type: "default" # one of "default", "llama3.1" or "yarn"
rope_linear_scaling_factor: 1.0 # linear scaling factor for "default" RoPE (see class `RotaryEmbedding` for more)
rope_use_scale: True # apply rope scaling for llama3.1 (see class `LLaMARotaryEmbedding` for more)
rope_min_timescale: 1
rope_max_timescale: 10_000 # Timescale For global Attention
Expand Down
3 changes: 2 additions & 1 deletion MaxText/configs/models/gemma3-12b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ base_num_kv_heads: 8
base_mlp_dim: 15360
head_dim: 256
mlp_activations: ["gelu","linear"]
vocab_size: 262_144
vocab_size: 262_208
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this change here an in other model configs?

doesn't this change impact checkpoint checkpoint conversion?

There embedding lookup and unembed layers depend on the vocab size.

decoder_block: "gemma3"
normalization_layer_epsilon: 1e-6
logits_via_embedding: True
Expand All @@ -30,3 +30,4 @@ use_post_attn_norm: true
use_post_ffw_norm: true
local_rope_max_timescale: 10_000
rope_max_timescale: 1_000_000
rope_linear_scaling_factor: 8.0
3 changes: 2 additions & 1 deletion MaxText/configs/models/gemma3-27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ base_num_kv_heads: 16
base_mlp_dim: 21504
head_dim: 128
mlp_activations: ["gelu","linear"]
vocab_size: 262_144
vocab_size: 262_208
decoder_block: "gemma3"
normalization_layer_epsilon: 1e-6
logits_via_embedding: True
Expand All @@ -30,3 +30,4 @@ use_post_attn_norm: true
use_post_ffw_norm: true
local_rope_max_timescale: 10_000
rope_max_timescale: 1_000_000
rope_linear_scaling_factor: 8.0
3 changes: 2 additions & 1 deletion MaxText/configs/models/gemma3-4b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ base_num_kv_heads: 4
base_mlp_dim: 10240
head_dim: 256
mlp_activations: ["gelu","linear"]
vocab_size: 262_144
vocab_size: 262_208
decoder_block: "gemma3"
normalization_layer_epsilon: 1e-6
logits_via_embedding: True
Expand All @@ -30,3 +30,4 @@ use_post_attn_norm: true
use_post_ffw_norm: true
local_rope_max_timescale: 10_000
rope_max_timescale: 1_000_000
rope_linear_scaling_factor: 8.0
7 changes: 7 additions & 0 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,11 +694,18 @@ def init_rotary_embedding(self):
# For local attention use local_rope_max_timescale if it's is positive
if self.attention_type == AttentionType.LOCAL_SLIDING and self.config.local_rope_max_timescale > 0:
max_timescale = self.config.local_rope_max_timescale

rope_linear_scaling_factor = self.config.rope_linear_scaling_factor
# In gemma3, linear scaling factor does not apply to local sliding layers.
if self.config.model_name.startswith("gemma3") and self.attention_type == AttentionType.LOCAL_SLIDING:
rope_linear_scaling_factor = 1.0

rotary_embedding = RotaryEmbedding(
min_timescale=self.config.rope_min_timescale,
max_timescale=max_timescale,
embedding_dims=rope_embedding_dims,
fprop_dtype=self.dtype,
rope_linear_scaling_factor=rope_linear_scaling_factor,
rngs=self.rngs,
)
return rotary_embedding
Expand Down
15 changes: 8 additions & 7 deletions MaxText/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def __init__(
fprop_dtype: DType = jnp.bfloat16,
# Not used in RotaryEmbedding but passed in by nnx.bridge.to_linen.
# TODO: Remove when bridge no longer needed
rope_linear_scaling_factor: float = 1.0,
rngs: nnx.Rngs = None,
):
"""Initializes the RotaryEmbedding module.
Expand All @@ -261,6 +262,7 @@ def __init__(
self.embedding_dims = embedding_dims
self.cast_as_fprop_dtype = cast_as_fprop_dtype
self.fprop_dtype = fprop_dtype
self.rope_linear_scaling_factor = rope_linear_scaling_factor

if self.embedding_dims % 2:
raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.")
Expand All @@ -270,7 +272,10 @@ def timescale(self):
"""Returns the timescale for the rotary embedding."""
half_embedding_dim = self.embedding_dims // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims
return self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
timescale = self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction
if self.rope_linear_scaling_factor != 1.0:
timescale = timescale * self.rope_linear_scaling_factor
return timescale

def __call__(
self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks
Expand Down Expand Up @@ -448,9 +453,7 @@ def __call__(self, inputs: jax.Array, position: None | jax.Array = None) -> jax.
if len(inputs.shape) != 4:
raise ValueError("Input is assumed to be a rank 4 tensor of shape [B, S, N, H].")
if self.embedding_dims != inputs.shape[3]:
raise ValueError(
"The embedding dims of the rotary position embedding must match the hidden dimension of the inputs."
)
raise ValueError("The embedding dims of the rotary position embedding must match the hidden dimension of the inputs.")

# Shift the inputs left and right as per LLaMA's specific behavior
inputs_shifted_left = jnp.concatenate([inputs[..., 1:], inputs[..., :1]], axis=-1)
Expand Down Expand Up @@ -649,9 +652,7 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
if len(inputs.shape) != 4:
raise ValueError("Input is assumed to be a rank 4 tensor of shape [batch, sequence, heads, dims].")
if self.embedding_dims != inputs.shape[3]:
raise ValueError(
"The embedding dims of the rotary position embedding must match the hidden dimension of the inputs."
)
raise ValueError("The embedding dims of the rotary position embedding must match the hidden dimension of the inputs.")

# Determine positions if not provided
if position is None:
Expand Down
4 changes: 2 additions & 2 deletions MaxText/multimodal_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
GEMMA_IMAGE_STD = (127.5,) * 3
GEMMA_IMAGE_PLACEHOLDER_IN_PROMPT = "<start_of_image>"
GEMMA_BEGIN_IMAGE_TOKEN = 255999
GEMMA_END_IMAGE_TOKEN = 262144
GEMMA_END_IMAGE_TOKEN = 256000
GEMMA_NEW_LINE_TOKEN = 108
GEMMA_TOKEN_PLACEHOLDER = -2
GEMMA_TOKEN_PLACEHOLDER = 262144
# The number of GEMMA_TOKEN_PLACEHOLDER tokens per image in Gemma3
GEMMA_NUM_PLACEHOLDER_TOKENS_PER_IMAGE = 256
# +4 means 4 extra tokens to pad around image: \n\n, <start_of_image>, <end_of_image>, \n\n
Expand Down
110 changes: 95 additions & 15 deletions MaxText/scratch_code/generate_hf_golden_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,34 @@
Usage:

python3 -m MaxText.scratch_code.generate_hf_golden_logits --model-id=deepseek-ai/DeepSeek-V2-Lite \
--output-path=golden_DeepSeek-V2-Lite.jsonl --prompts='I love to,Today is a,What is the' \
--output-path=golden_DeepSeek-V2-Lite.jsonl --prompts='I love to;Today is a;What is the' \
--gcs-bucket=my-gcs-bucket

For large models, you can use an m1 cpu. Calling the script directly instead of calling MaxText module \
can skip importing unnecessary dependencies.
For large Hugginface checkpoints, you can use pre-downloaded checkpoints with --hf-model-path argument.
For multimodal models, use --image-paths argument to provide image path(s),\
use --apply-chat-template=true if use HF chat template to format image+prompt.\
When using chat template, the prompt should not contain image placeholders.

More examples:
python3 MaxText/scratch_code/generate_hf_golden_logits.py --model-id=meta-llama/Llama-4-Scout-17B-16E \
--output-path=golden_Llama-4-Scout-17B-16E_vision.jsonl --prompts='Describe this image.' \
--apply-chat-template=true --gcs-bucket=<bucket> --hf-model-path=<hf_checkpoint_path> \
--image-paths=MaxText/test_assets/test_image.jpg

python3 MaxText/scratch_code/generate_hf_golden_logits.py --model-id=google/gemma-3-4b-it \
--output-path=golden_gemma-3-4b-it_vision.jsonl --prompts='<start_of_image>' \
--apply-chat-template=false --gcs-bucket=<bucket> --hf-model-path=<hf_checkpoint_path> \
--image-paths=MaxText/test_assets/test_image.jpg
"""

import torch
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoProcessor, AutoModelForCausalLM
import jsonlines
from google.cloud import storage
from PIL import Image

# Load the tokenizer and model from Hugging Face

Expand All @@ -37,32 +55,74 @@ def upload_blob(bucket_name, source_file_name, destination_blob_name):
blob.upload_from_filename(source_file_name)


def save_golden_logits(model_id, output_path, prompt_texts, gcs_bucket):
def save_golden_logits(model_id, output_path, prompt_texts, apply_chat_template, gcs_bucket, hf_model_path, image_paths):
"""save golden logits"""
tokenizer = AutoTokenizer.from_pretrained(model_id)
if hf_model_path is None:
hf_model_path = model_id
tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
model = AutoModelForCausalLM.from_pretrained(
model_id,
hf_model_path,
torch_dtype=torch.float32,
trust_remote_code=True,
)

all_data_to_save = []
for prompt_text in prompt_texts:
for i, prompt_text in enumerate(prompt_texts):
# Encode the prompt text
input_ids = tokenizer.encode(prompt_text, return_tensors="pt")
if image_paths:
try:
image = Image.open(image_paths[i])
except Exception as e:
raise e
image = image.convert("RGB")
# TODO (aireenmei): remove this when Llama-4 supports dynamic image shapes.
if model_id.startswith("meta-llama/Llama-4"):
image = image.resize((336, 336))
processor = AutoProcessor.from_pretrained(model_id, token=True)
if apply_chat_template:
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": prompt_text},
],
},
]
formatted_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=formatted_prompt, images=image, return_tensors="pt")
else:
formatted_prompt = prompt_text
inputs = processor(text=formatted_prompt, images=image, return_tensors="pt", add_special_tokens=False)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits.cpu().numpy().astype("float32")

# Get the logits for the prompt + completion
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits.cpu().numpy().astype("float32")
data_to_save = {
"prompt": prompt_text,
"formatted_prompt": formatted_prompt,
"tokens": inputs["input_ids"].tolist()[0],
"attention_mask": inputs["attention_mask"].tolist()[0],
"image_path": image_paths[i],
"pixel_values": inputs["pixel_values"].tolist()[0],
"logits": logits.tolist()[0],
}
else:
input_ids = tokenizer.encode(prompt_text, return_tensors="pt")
# Get the logits for the prompt + completion
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits.cpu().numpy().astype("float32")

# Prepare data to be saved
data_to_save = {
"prompt": prompt_text,
"tokens": input_ids.tolist()[0],
"logits": logits.tolist()[0], # Convert numpy array to list for JSON serialization
}
all_data_to_save.append(data_to_save)
print(f"Token length is {len(data_to_save['tokens'])} for prompt: {prompt_text}")
print(f"raw ids: {data_to_save['tokens']}")
all_data_to_save.append(data_to_save)

with jsonlines.open(output_path, "w") as f:
f.write_all(all_data_to_save)
Expand All @@ -77,13 +137,33 @@ def main(raw_args=None) -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--model-id", type=str, required=True, help="The identifier of the model to use.")
parser.add_argument("--output-path", type=str, required=True, help="The path to save the generated golden logits.")
parser.add_argument("--prompts", type=str, required=True, help="A comma-separated list of prompts.")
parser.add_argument("--prompts", type=str, required=True, help="A semicolon-separated list of prompts.")
parser.add_argument(
"--apply-chat-template",
type=bool,
required=False,
default=False,
help="Whether to apply chat template from the HF processor. Used for image+text input.",
)
parser.add_argument(
"--gcs-bucket", type=str, required=False, default=None, help="A GCS bucket to store logits, without gs://."
)
parser.add_argument("--hf-model-path", type=str, required=False, default=None, help="local path to checkpoint if exists.")
parser.add_argument(
"--image-paths", type=str, required=False, default=None, help="A semicolon-separated list of image_paths."
)
args = parser.parse_args(raw_args)
prompts = args.prompts.split(",")
save_golden_logits(args.model_id, args.output_path, prompts, args.gcs_bucket)
prompts = args.prompts.split(";")
image_paths = args.image_paths.split(";") if args.image_paths else []
if image_paths:
assert len(image_paths) == len(
prompts
), "when image paths are provided, image_paths and prompts must have the same length."
if args.apply_chat_template:
assert image_paths, "apply_chat_template is only used for image+text input, so image_paths must be provided."
save_golden_logits(
args.model_id, args.output_path, prompts, args.apply_chat_template, args.gcs_bucket, args.hf_model_path, image_paths
)


if __name__ == "__main__":
Expand Down
Loading
Loading