From 52027ff488be467553d901002f2d14d4518fd2a6 Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Wed, 4 Mar 2026 09:26:37 -0500 Subject: [PATCH] Voxtral Realtime: enable bf16 for Metal backend with quantization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The Metal AOTI backend already handles bf16 correctly (fp32 attention masks, fp32 RoPE upcast, dtype-agnostic KV caches and SDPA). Enable --dtype bf16 as the default recipe for Metal CI and update all documentation to recommend bf16 with fpa4w quantization. Fix a Metal shader compilation bug in the streaming encoder where bool.to(bf16) generates `bfloat tmp = 0.0;` — Metal Shading Language doesn't support implicit float-to-bfloat literal conversion. Use .float() instead and let mul_ handle type promotion. --- .ci/scripts/export_model_artifact.sh | 1 + examples/models/voxtral_realtime/README.md | 8 +++++--- .../voxtral_realtime/export_voxtral_rt.py | 2 +- examples/models/voxtral_realtime/model.md | 19 +++++++++++-------- examples/models/voxtral_realtime/model.py | 4 +++- 5 files changed, 21 insertions(+), 13 deletions(-) diff --git a/.ci/scripts/export_model_artifact.sh b/.ci/scripts/export_model_artifact.sh index 3c0848475a8..ee9b9695df4 100755 --- a/.ci/scripts/export_model_artifact.sh +++ b/.ci/scripts/export_model_artifact.sh @@ -262,6 +262,7 @@ if [ "$MODEL_NAME" = "voxtral_realtime" ]; then VR_QUANT_ARGS="--qlinear-encoder 8da4w --qlinear 8da4w --qlinear-group-size 32 --qembedding 8w" elif [ "$QUANT_NAME" = "quantized-int4-metal" ]; then VR_QUANT_ARGS="--qlinear-encoder fpa4w --qlinear fpa4w" + VR_DTYPE_ARGS="--dtype bf16" elif [ "$QUANT_NAME" = "quantized-int4-tile-packed" ]; then VR_QUANT_ARGS="--qlinear-encoder 4w --qlinear-encoder-packing-format tile_packed_to_4d --qlinear 4w --qlinear-packing-format tile_packed_to_4d --qembedding 8w" VR_DTYPE_ARGS="--dtype bf16" diff --git a/examples/models/voxtral_realtime/README.md b/examples/models/voxtral_realtime/README.md index 7d29ba8c11b..e54874cccf0 100644 --- a/examples/models/voxtral_realtime/README.md +++ b/examples/models/voxtral_realtime/README.md @@ -87,7 +87,7 @@ python export_voxtral_rt.py \ | Backend | Offline | Streaming | Quantization | |---------|---------|-----------|--------------| | `xnnpack` | ✓ | ✓ | `4w`, `8w`, `8da4w`, `8da8w` | -| `metal` | ✓ | ✓ | none (fp32) or `fpa4w` (Metal-specific 4-bit) | +| `metal` | ✓ | ✓ | none or `fpa4w` (Metal-specific 4-bit); bf16 recommended with quantization | | `cuda` | ✓ | ✓ | `4w`, `8w` | Metal backend provides Apple GPU acceleration. CUDA backend provides NVIDIA GPU @@ -128,23 +128,25 @@ python export_voxtral_rt.py \ #### Metal export examples -Offline: +Offline with fpa4w quantization and bf16: ```bash python export_voxtral_rt.py \ --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ --backend metal \ + --dtype bf16 \ --output-dir ./voxtral_rt_exports \ --qlinear-encoder fpa4w \ --qlinear fpa4w ``` -Streaming: +Streaming with fpa4w quantization and bf16: ```bash python export_voxtral_rt.py \ --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \ --backend metal \ + --dtype bf16 \ --streaming \ --output-dir ./voxtral_rt_exports \ --qlinear-encoder fpa4w \ diff --git a/examples/models/voxtral_realtime/export_voxtral_rt.py b/examples/models/voxtral_realtime/export_voxtral_rt.py index c813f0ecc34..4b73421c587 100644 --- a/examples/models/voxtral_realtime/export_voxtral_rt.py +++ b/examples/models/voxtral_realtime/export_voxtral_rt.py @@ -30,7 +30,7 @@ Usage: python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --streaming - python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend metal + python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend metal --dtype bf16 --qlinear-encoder fpa4w --qlinear fpa4w python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend cuda --qlinear 4w """ diff --git a/examples/models/voxtral_realtime/model.md b/examples/models/voxtral_realtime/model.md index fe240b03d8c..173fc741a40 100644 --- a/examples/models/voxtral_realtime/model.md +++ b/examples/models/voxtral_realtime/model.md @@ -74,13 +74,15 @@ or masked-scatter like the original non-realtime Voxtral). ## Memory Footprint -Decoder KV cache: 26 layers × 2 (K, V) × 4096 × 8 × 128 × 4 bytes -≈ 832 MB. Encoder KV caches (streaming): 32 layers × 2 × 1500 × 32 × -64 × 4 bytes ≈ 786 MB. +Decoder KV cache: 26 layers × 2 (K, V) × 4096 × 8 × 128 × bytes_per_elem. +fp32: ≈ 832 MB, bf16: ≈ 416 MB. Encoder KV caches (streaming): +32 layers × 2 × 1500 × 32 × 64 × bytes_per_elem. fp32: ≈ 786 MB, +bf16: ≈ 393 MB. Runtime memory = model weights (from `.pte`) + KV caches + working -memory. Weight sizes depend on quantization: ~16 GB (fp32), ~4 GB -(8w), ~2 GB (4w/8da4w). +memory. Weight sizes depend on quantization: ~16 GB (fp32), ~8 GB +(bf16), ~4 GB (8w), ~2 GB (4w/8da4w). Metal and CUDA backends use +bf16 (`--dtype bf16`) by default when quantization is enabled. ## Class Hierarchy @@ -274,7 +276,7 @@ enabling streaming of arbitrary length audio. 5-8, giving query 5 full access to its window. - Default `max_enc_len=750` (matching the model's trained sliding window). Configurable via `--max-enc-len`. -- Memory: 32 layers × 2 × 1500 × 32 × 64 × 4 bytes ≈ 786 MB (fp32) +- Memory: 32 layers × 2 × 1500 × 32 × 64 × bytes_per_elem ≈ 786 MB (fp32), 393 MB (bf16) - Duration: unlimited (ring buffer overwrites old entries, RoPE computed on-the-fly) **Naming note:** `max_enc_len` in `StreamingAudioEncoderExport` (default @@ -364,7 +366,7 @@ Parakeet pattern), allowing different configs for encoder vs decoder: --qlinear 8da4w # decoder linear layers --qembedding 8w # embedding layer -# Metal +# Metal (use --dtype bf16 for reduced memory and improved throughput) --qlinear-encoder fpa4w # encoder linear layers --qlinear fpa4w # decoder linear layers @@ -422,7 +424,8 @@ of ~34 GB for the full-size model): 1. **Meta device construction** — `with torch.device("meta"):` builds the model with zero-storage parameter tensors (shape/dtype metadata only). 2. **safetensors lazy access** — `safe_open` loads tensors on demand, cast - to the configured dtype (`--dtype`, default fp32; CUDA uses bf16). + to the configured dtype (`--dtype`, default fp32; Metal and CUDA use bf16 + with quantization). 3. **`assign=True` state dict loading** — replaces meta tensors by reference instead of copying into pre-allocated storage. No duplication. 4. **Post-load fixups** — re-tie `output.weight = tok_embeddings.weight` diff --git a/examples/models/voxtral_realtime/model.py b/examples/models/voxtral_realtime/model.py index 26778413834..6925c56f3b1 100644 --- a/examples/models/voxtral_realtime/model.py +++ b/examples/models/voxtral_realtime/model.py @@ -1049,7 +1049,9 @@ def forward( ) -> torch.Tensor: # Auto-reset conv states at the start of each new session (enc_input_pos[0] == 0). # Using tensor ops (not .item()) avoids constant-folding during export. - is_start = (enc_input_pos[:1] == 0).view(1, 1, 1).to(self.conv1_state.dtype) + # .float() instead of .to(conv1_state.dtype) — Metal shader codegen + # doesn't support implicit float-to-bfloat literal conversion. + is_start = (enc_input_pos[:1] == 0).view(1, 1, 1).float() self.conv1_state.mul_(1.0 - is_start) self.conv2_state.mul_(1.0 - is_start)