Summary
When Gemma 4 requests arrive at different times in batch mode, later arrivals get left-padding in the batch cache. This padding forces a different SDPA Metal kernel path (.array(padMask)) compared to serial mode's .none, causing accumulated float precision differences across 42 layers that degrade tool calling quality.
Root Cause
MLX scaledDotProductAttention dispatches to different Metal kernel specializations for .none (no mask) vs .array(mask). Even when the mask is functionally correct, the kernel produces slightly different intermediate float values. With Gemma 4's 42 layers (28 KV-shared), these differences compound and can change the argmax, causing the model to enter degenerate thinking loops instead of producing tool calls.
Current State (after fix at 080da66)
| Scenario |
Tool Call Pass Rate |
Notes |
| B=1 serial |
15/15 (100%) |
Baseline |
| B=2-15 simultaneous arrival |
15/15 (100%) |
Zero padding, identical to serial |
| B=15 staggered arrival |
9/15 (60%) |
First 9 pass, last 6 (with padding) fail |
The fix in 080da66 resolved the major divergence (wrong RoPE in 67% of layers) from 0% to 100% for same-length sequences.
Potential Approaches
- Batch cache compaction — eliminate padding after merge
- Request coalescing — hold incoming requests briefly so more arrive together
- MLX upstream — unified SDPA kernel for
.none vs .array(allTrue)
Impact
- Gemma 4 only (Qwen3.5 unaffected — no KV-sharing)
- Tool calling most sensitive; general text works at all batch sizes
- Same-arrival requests unaffected
Summary
When Gemma 4 requests arrive at different times in batch mode, later arrivals get left-padding in the batch cache. This padding forces a different SDPA Metal kernel path (
.array(padMask)) compared to serial mode's.none, causing accumulated float precision differences across 42 layers that degrade tool calling quality.Root Cause
MLX
scaledDotProductAttentiondispatches to different Metal kernel specializations for.none(no mask) vs.array(mask). Even when the mask is functionally correct, the kernel produces slightly different intermediate float values. With Gemma 4's 42 layers (28 KV-shared), these differences compound and can change the argmax, causing the model to enter degenerate thinking loops instead of producing tool calls.Current State (after fix at 080da66)
The fix in 080da66 resolved the major divergence (wrong RoPE in 67% of layers) from 0% to 100% for same-length sequences.
Potential Approaches
.nonevs.array(allTrue)Impact