Skip to content

Gemma 4 batch mode: staggered-arrival padding causes SDPA precision divergence #93

@scouzi1966

Description

@scouzi1966

Summary

When Gemma 4 requests arrive at different times in batch mode, later arrivals get left-padding in the batch cache. This padding forces a different SDPA Metal kernel path (.array(padMask)) compared to serial mode's .none, causing accumulated float precision differences across 42 layers that degrade tool calling quality.

Root Cause

MLX scaledDotProductAttention dispatches to different Metal kernel specializations for .none (no mask) vs .array(mask). Even when the mask is functionally correct, the kernel produces slightly different intermediate float values. With Gemma 4's 42 layers (28 KV-shared), these differences compound and can change the argmax, causing the model to enter degenerate thinking loops instead of producing tool calls.

Current State (after fix at 080da66)

Scenario Tool Call Pass Rate Notes
B=1 serial 15/15 (100%) Baseline
B=2-15 simultaneous arrival 15/15 (100%) Zero padding, identical to serial
B=15 staggered arrival 9/15 (60%) First 9 pass, last 6 (with padding) fail

The fix in 080da66 resolved the major divergence (wrong RoPE in 67% of layers) from 0% to 100% for same-length sequences.

Potential Approaches

  1. Batch cache compaction — eliminate padding after merge
  2. Request coalescing — hold incoming requests briefly so more arrive together
  3. MLX upstream — unified SDPA kernel for .none vs .array(allTrue)

Impact

  • Gemma 4 only (Qwen3.5 unaffected — no KV-sharing)
  • Tool calling most sensitive; general text works at all batch sizes
  • Same-arrival requests unaffected

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions