Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/scripts/export_model_artifact.sh
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ if [ "$MODEL_NAME" = "voxtral_realtime" ]; then
VR_QUANT_ARGS="--qlinear-encoder 8da4w --qlinear 8da4w --qlinear-group-size 32 --qembedding 8w"
elif [ "$QUANT_NAME" = "quantized-int4-metal" ]; then
VR_QUANT_ARGS="--qlinear-encoder fpa4w --qlinear fpa4w"
VR_DTYPE_ARGS="--dtype bf16"
elif [ "$QUANT_NAME" = "quantized-int4-tile-packed" ]; then
VR_QUANT_ARGS="--qlinear-encoder 4w --qlinear-encoder-packing-format tile_packed_to_4d --qlinear 4w --qlinear-packing-format tile_packed_to_4d --qembedding 8w"
VR_DTYPE_ARGS="--dtype bf16"
Expand Down
8 changes: 5 additions & 3 deletions examples/models/voxtral_realtime/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ python export_voxtral_rt.py \
| Backend | Offline | Streaming | Quantization |
|---------|---------|-----------|--------------|
| `xnnpack` | ✓ | ✓ | `4w`, `8w`, `8da4w`, `8da8w` |
| `metal` | ✓ | ✓ | none (fp32) or `fpa4w` (Metal-specific 4-bit) |
| `metal` | ✓ | ✓ | none or `fpa4w` (Metal-specific 4-bit); bf16 recommended with quantization |
| `cuda` | ✓ | ✓ | `4w`, `8w` |

Metal backend provides Apple GPU acceleration. CUDA backend provides NVIDIA GPU
Expand Down Expand Up @@ -128,23 +128,25 @@ python export_voxtral_rt.py \

#### Metal export examples

Offline:
Offline with fpa4w quantization and bf16:

```bash
python export_voxtral_rt.py \
--model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \
--backend metal \
--dtype bf16 \
--output-dir ./voxtral_rt_exports \
--qlinear-encoder fpa4w \
--qlinear fpa4w
```

Streaming:
Streaming with fpa4w quantization and bf16:

```bash
python export_voxtral_rt.py \
--model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \
--backend metal \
--dtype bf16 \
--streaming \
--output-dir ./voxtral_rt_exports \
--qlinear-encoder fpa4w \
Expand Down
2 changes: 1 addition & 1 deletion examples/models/voxtral_realtime/export_voxtral_rt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
Usage:
python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602
python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --streaming
python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend metal
python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend metal --dtype bf16 --qlinear-encoder fpa4w --qlinear fpa4w
python export_voxtral_rt.py --model-path ~/models/Voxtral-Mini-4B-Realtime-2602 --backend cuda --qlinear 4w
"""

Expand Down
19 changes: 11 additions & 8 deletions examples/models/voxtral_realtime/model.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,15 @@ or masked-scatter like the original non-realtime Voxtral).

## Memory Footprint

Decoder KV cache: 26 layers × 2 (K, V) × 4096 × 8 × 128 × 4 bytes
≈ 832 MB. Encoder KV caches (streaming): 32 layers × 2 × 1500 × 32 ×
64 × 4 bytes ≈ 786 MB.
Decoder KV cache: 26 layers × 2 (K, V) × 4096 × 8 × 128 × bytes_per_elem.
fp32: ≈ 832 MB, bf16: ≈ 416 MB. Encoder KV caches (streaming):
32 layers × 2 × 1500 × 32 × 64 × bytes_per_elem. fp32: ≈ 786 MB,
bf16: ≈ 393 MB.

Runtime memory = model weights (from `.pte`) + KV caches + working
memory. Weight sizes depend on quantization: ~16 GB (fp32), ~4 GB
(8w), ~2 GB (4w/8da4w).
memory. Weight sizes depend on quantization: ~16 GB (fp32), ~8 GB
(bf16), ~4 GB (8w), ~2 GB (4w/8da4w). Metal and CUDA backends use
bf16 (`--dtype bf16`) by default when quantization is enabled.

## Class Hierarchy

Expand Down Expand Up @@ -274,7 +276,7 @@ enabling streaming of arbitrary length audio.
5-8, giving query 5 full access to its window.
- Default `max_enc_len=750` (matching the model's trained
sliding window). Configurable via `--max-enc-len`.
- Memory: 32 layers × 2 × 1500 × 32 × 64 × 4 bytes ≈ 786 MB (fp32)
- Memory: 32 layers × 2 × 1500 × 32 × 64 × bytes_per_elem ≈ 786 MB (fp32), 393 MB (bf16)
- Duration: unlimited (ring buffer overwrites old entries, RoPE computed on-the-fly)

**Naming note:** `max_enc_len` in `StreamingAudioEncoderExport` (default
Expand Down Expand Up @@ -364,7 +366,7 @@ Parakeet pattern), allowing different configs for encoder vs decoder:
--qlinear 8da4w # decoder linear layers
--qembedding 8w # embedding layer

# Metal
# Metal (use --dtype bf16 for reduced memory and improved throughput)
--qlinear-encoder fpa4w # encoder linear layers
--qlinear fpa4w # decoder linear layers

Expand Down Expand Up @@ -422,7 +424,8 @@ of ~34 GB for the full-size model):
1. **Meta device construction** — `with torch.device("meta"):` builds the
model with zero-storage parameter tensors (shape/dtype metadata only).
2. **safetensors lazy access** — `safe_open` loads tensors on demand, cast
to the configured dtype (`--dtype`, default fp32; CUDA uses bf16).
to the configured dtype (`--dtype`, default fp32; Metal and CUDA use bf16
with quantization).
3. **`assign=True` state dict loading** — replaces meta tensors by reference
instead of copying into pre-allocated storage. No duplication.
4. **Post-load fixups** — re-tie `output.weight = tok_embeddings.weight`
Expand Down
4 changes: 3 additions & 1 deletion examples/models/voxtral_realtime/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,9 @@ def forward(
) -> torch.Tensor:
# Auto-reset conv states at the start of each new session (enc_input_pos[0] == 0).
# Using tensor ops (not .item()) avoids constant-folding during export.
is_start = (enc_input_pos[:1] == 0).view(1, 1, 1).to(self.conv1_state.dtype)
# .float() instead of .to(conv1_state.dtype) — Metal shader codegen
# doesn't support implicit float-to-bfloat literal conversion.
is_start = (enc_input_pos[:1] == 0).view(1, 1, 1).float()
self.conv1_state.mul_(1.0 - is_start)
self.conv2_state.mul_(1.0 - is_start)

Expand Down
Loading