Skip to content

Support multimodal in logit checker + match gemma3 logits with HF #2203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

aireenmei
Copy link
Collaborator

@aireenmei aireenmei commented Aug 19, 2025

Description

I found some mismatch in gemma3 logits compared with HF (b/437988753).
Key changes to fill the gap:

  • Match the vocab size: HF uses a bigger vocab size which includes special image tokens.
  • Special image token ID matches with HF
  • Use linear Rope scaling with factor 8 for global attention layers in language model
  • This will be a separate PR due to potential conflicts with the ongoing NNX migration: Allow setting gemma3 vision encoder precision through config.matmul_precision

Tests

  • add unittest in check_gemma3_layers.py to check RoPE implementation
  • logits test:
    Use generate_hf_golden_logits.py to generate HF logits on a m1 CPU, dtype=float32
    Run forward_pass_logit_checker on a v5p TPU, with all the dype and precision set to float32
    Results: see b/437988753

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@aireenmei aireenmei changed the title Support multimodal in logit checker + match gemma3 logit to HF Support multimodal in logit checker + match gemma3 logits wit HF Aug 20, 2025
@aireenmei aireenmei changed the title Support multimodal in logit checker + match gemma3 logits wit HF Support multimodal in logit checker + match gemma3 logits with HF Aug 20, 2025
@aireenmei aireenmei marked this pull request as ready for review August 20, 2025 06:30
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Aireen for digging deep and finding this precision issue!

Is there any chance that you could split out the "precision" changes in gemma3.py into a separate PR? Asking because we are doing the NNX migration, and all of these layers will be rewritten from nn to nnx (example). I guess we can plugin these "precision" changes after the migration? Either way works for me.

@aireenmei aireenmei force-pushed the aireen/logits branch 2 times, most recently from 0a94059 to 5ea4e1a Compare August 20, 2025 20:57
@@ -21,7 +21,7 @@ base_num_kv_heads: 8
base_mlp_dim: 15360
head_dim: 256
mlp_activations: ["gelu","linear"]
vocab_size: 262_144
vocab_size: 262_208
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this change here an in other model configs?

doesn't this change impact checkpoint checkpoint conversion?

There embedding lookup and unembed layers depend on the vocab size.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants