Skip to content

Commit 52027ff

Browse files
committed
Voxtral Realtime: enable bf16 for Metal backend with quantization
The Metal AOTI backend already handles bf16 correctly (fp32 attention masks, fp32 RoPE upcast, dtype-agnostic KV caches and SDPA). Enable --dtype bf16 as the default recipe for Metal CI and update all documentation to recommend bf16 with fpa4w quantization. Fix a Metal shader compilation bug in the streaming encoder where bool.to(bf16) generates `bfloat tmp = 0.0;` — Metal Shading Language doesn't support implicit float-to-bfloat literal conversion. Use .float() instead and let mul_ handle type promotion.
1 parent 6db7f4c commit 52027ff

File tree

5 files changed

+21
-13
lines changed

5 files changed

+21
-13
lines changed

.ci/scripts/export_model_artifact.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ if [ "$MODEL_NAME" = "voxtral_realtime" ]; then
262262
VR_QUANT_ARGS="--qlinear-encoder 8da4w --qlinear 8da4w --qlinear-group-size 32 --qembedding 8w"
263263
elif [ "$QUANT_NAME" = "quantized-int4-metal" ]; then
264264
VR_QUANT_ARGS="--qlinear-encoder fpa4w --qlinear fpa4w"
265+
VR_DTYPE_ARGS="--dtype bf16"
265266
elif [ "$QUANT_NAME" = "quantized-int4-tile-packed" ]; then
266267
VR_QUANT_ARGS="--qlinear-encoder 4w --qlinear-encoder-packing-format tile_packed_to_4d --qlinear 4w --qlinear-packing-format tile_packed_to_4d --qembedding 8w"
267268
VR_DTYPE_ARGS="--dtype bf16"

examples/models/voxtral_realtime/README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ python export_voxtral_rt.py \
8787
| Backend | Offline | Streaming | Quantization |
8888
|---------|---------|-----------|--------------|
8989
| `xnnpack` ||| `4w`, `8w`, `8da4w`, `8da8w` |
90-
| `metal` ||| none (fp32) or `fpa4w` (Metal-specific 4-bit) |
90+
| `metal` ||| none or `fpa4w` (Metal-specific 4-bit); bf16 recommended with quantization |
9191
| `cuda` ||| `4w`, `8w` |
9292

9393
Metal backend provides Apple GPU acceleration. CUDA backend provides NVIDIA GPU
@@ -128,23 +128,25 @@ python export_voxtral_rt.py \
128128

129129
#### Metal export examples
130130

131-
Offline:
131+
Offline with fpa4w quantization and bf16:
132132

133133
```bash
134134
python export_voxtral_rt.py \
135135
--model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \
136136
--backend metal \
137+
--dtype bf16 \
137138
--output-dir ./voxtral_rt_exports \
138139
--qlinear-encoder fpa4w \
139140
--qlinear fpa4w
140141
```
141142

142-
Streaming:
143+
Streaming with fpa4w quantization and bf16:
143144

144145
```bash
145146
python export_voxtral_rt.py \
146147
--model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \
147148
--backend metal \
149+
--dtype bf16 \
148150
--streaming \
149151
--output-dir ./voxtral_rt_exports \
150152
--qlinear-encoder fpa4w \

examples/models/voxtral_realtime/export_voxtral_rt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
Usage:
3131
python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602
3232
python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --streaming
33-
python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend metal
33+
python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend metal --dtype bf16 --qlinear-encoder fpa4w --qlinear fpa4w
3434
python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend cuda --qlinear 4w
3535
"""
3636

examples/models/voxtral_realtime/model.md

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,15 @@ or masked-scatter like the original non-realtime Voxtral).
7474

7575
## Memory Footprint
7676

77-
Decoder KV cache: 26 layers × 2 (K, V) × 4096 × 8 × 128 × 4 bytes
78-
≈ 832 MB. Encoder KV caches (streaming): 32 layers × 2 × 1500 × 32 ×
79-
64 × 4 bytes ≈ 786 MB.
77+
Decoder KV cache: 26 layers × 2 (K, V) × 4096 × 8 × 128 × bytes_per_elem.
78+
fp32: ≈ 832 MB, bf16: ≈ 416 MB. Encoder KV caches (streaming):
79+
32 layers × 2 × 1500 × 32 × 64 × bytes_per_elem. fp32: ≈ 786 MB,
80+
bf16: ≈ 393 MB.
8081

8182
Runtime memory = model weights (from `.pte`) + KV caches + working
82-
memory. Weight sizes depend on quantization: ~16 GB (fp32), ~4 GB
83-
(8w), ~2 GB (4w/8da4w).
83+
memory. Weight sizes depend on quantization: ~16 GB (fp32), ~8 GB
84+
(bf16), ~4 GB (8w), ~2 GB (4w/8da4w). Metal and CUDA backends use
85+
bf16 (`--dtype bf16`) by default when quantization is enabled.
8486

8587
## Class Hierarchy
8688

@@ -274,7 +276,7 @@ enabling streaming of arbitrary length audio.
274276
5-8, giving query 5 full access to its window.
275277
- Default `max_enc_len=750` (matching the model's trained
276278
sliding window). Configurable via `--max-enc-len`.
277-
- Memory: 32 layers × 2 × 1500 × 32 × 64 × 4 bytes ≈ 786 MB (fp32)
279+
- Memory: 32 layers × 2 × 1500 × 32 × 64 × bytes_per_elem ≈ 786 MB (fp32), 393 MB (bf16)
278280
- Duration: unlimited (ring buffer overwrites old entries, RoPE computed on-the-fly)
279281

280282
**Naming note:** `max_enc_len` in `StreamingAudioEncoderExport` (default
@@ -364,7 +366,7 @@ Parakeet pattern), allowing different configs for encoder vs decoder:
364366
--qlinear 8da4w # decoder linear layers
365367
--qembedding 8w # embedding layer
366368

367-
# Metal
369+
# Metal (use --dtype bf16 for reduced memory and improved throughput)
368370
--qlinear-encoder fpa4w # encoder linear layers
369371
--qlinear fpa4w # decoder linear layers
370372

@@ -422,7 +424,8 @@ of ~34 GB for the full-size model):
422424
1. **Meta device construction**`with torch.device("meta"):` builds the
423425
model with zero-storage parameter tensors (shape/dtype metadata only).
424426
2. **safetensors lazy access**`safe_open` loads tensors on demand, cast
425-
to the configured dtype (`--dtype`, default fp32; CUDA uses bf16).
427+
to the configured dtype (`--dtype`, default fp32; Metal and CUDA use bf16
428+
with quantization).
426429
3. **`assign=True` state dict loading** — replaces meta tensors by reference
427430
instead of copying into pre-allocated storage. No duplication.
428431
4. **Post-load fixups** — re-tie `output.weight = tok_embeddings.weight`

examples/models/voxtral_realtime/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1049,7 +1049,9 @@ def forward(
10491049
) -> torch.Tensor:
10501050
# Auto-reset conv states at the start of each new session (enc_input_pos[0] == 0).
10511051
# Using tensor ops (not .item()) avoids constant-folding during export.
1052-
is_start = (enc_input_pos[:1] == 0).view(1, 1, 1).to(self.conv1_state.dtype)
1052+
# .float() instead of .to(conv1_state.dtype) — Metal shader codegen
1053+
# doesn't support implicit float-to-bfloat literal conversion.
1054+
is_start = (enc_input_pos[:1] == 0).view(1, 1, 1).float()
10531055
self.conv1_state.mul_(1.0 - is_start)
10541056
self.conv2_state.mul_(1.0 - is_start)
10551057

0 commit comments

Comments
 (0)