Skip to content

mlx-community/speculative-decoding

Repository files navigation

Speculative Decoding for MLX-Swift

Native speculative decoding implementation for fast LLM inference on Apple Silicon using MLX-Swift.

Overview

Speculative decoding accelerates LLM inference by using a smaller "draft" model to propose multiple tokens, which are then verified in parallel by a larger "target" model. This achieves 2-3x speedups while maintaining exact output distribution equivalence.

Features

  • Native Swift implementation optimized for Apple Silicon
  • Simple high-level API and streaming support
  • Compatible with any MLX-Swift LLM models
  • Full Swift concurrency support (async/await)
  • Greedy and temperature-based sampling
  • Comprehensive statistics and benchmarking

Requirements

  • macOS 14.0+ / iOS 16.0+
  • Xcode 15.0+
  • Swift 5.9+
  • Apple Silicon (M1/M2/M3/M4)

Installation

Add to your Package.swift:

dependencies: [
    .package(url: "https://github.com/lulzx/speculative-decoding", branch: "main"),
]

Then add the dependency to your target:

.target(
    name: "YourApp",
    dependencies: [
        .product(name: "SpeculativeDecoding", package: "speculative-decoding"),
    ]
),

Quick Start

Simple Generation

import SpeculativeDecoding

let output = try await SpeculativeDecoding.generate(
    prompt: "Explain quantum computing:",
    draftModelId: "mlx-community/Qwen2.5-0.5B-Instruct-4bit",
    targetModelId: "mlx-community/Qwen2.5-3B-Instruct-4bit"
)
print(output)

Streaming Generation

let stream = try await SpeculativeDecoding.generateStream(
    prompt: "Write a haiku about Swift:",
    draftModelId: "mlx-community/Qwen2.5-0.5B-Instruct-4bit",
    targetModelId: "mlx-community/Qwen2.5-3B-Instruct-4bit"
)

for await event in stream {
    switch event {
    case .text(let chunk):
        print(chunk, terminator: "")
    case .result(let result):
        print("\n\(result.summary())")
    case .error(let error):
        print("Error: \(error)")
    }
}

Advanced Usage

// Load models once, generate multiple times
let modelPair = try await DraftTargetPair.load(
    draftModelId: "mlx-community/Qwen2.5-0.5B-Instruct-4bit",
    targetModelId: "mlx-community/Qwen2.5-3B-Instruct-4bit"
)

let parameters = SpeculativeParameters(
    numDraftTokens: 6,
    draftTemperature: 0.7,
    targetTemperature: 0.7,
    maxTokens: 512
)

let generator = SpeculativeGenerator(modelPair: modelPair, parameters: parameters)
let input = try modelPair.prepare(prompt: "Hello, world!")
let result = try await generator.generate(input: input) { tokens in
    return .more  // Continue generating
}

print(result.summary())

CLI Tool

Build and run the CLI:

swift build -c release
.build/release/speculative-cli generate \
    --draft-model mlx-community/Qwen2.5-0.5B-Instruct-4bit \
    --target-model mlx-community/Qwen2.5-3B-Instruct-4bit \
    --prompt "Explain neural networks:" \
    --max-tokens 256 \
    --stats

Commands

# Generate text
speculative-cli generate --prompt "Your prompt" --stats

# Benchmark speculative vs standard decoding
speculative-cli benchmark --prompt "Test prompt" --iterations 3

# List recommended model pairs
speculative-cli list-models

Recommended Model Pairs

Draft Model Target Model Use Case
Qwen2.5-0.5B-Instruct-4bit Qwen2.5-3B-Instruct-4bit General purpose
Llama-3.2-1B-Instruct-4bit Llama-3.2-3B-Instruct-4bit Llama family
SmolLM2-135M-Instruct-4bit SmolLM2-1.7B-Instruct-4bit Lightweight

Configuration Options

SpeculativeParameters(
    numDraftTokens: 5,        // Tokens to draft per iteration (4-8 typical)
    draftTemperature: 0.6,    // Draft model temperature
    targetTemperature: 0.6,   // Target model temperature
    draftTopP: 0.9,           // Top-p sampling for draft
    targetTopP: 0.9,          // Top-p sampling for target
    maxTokens: nil,           // Max tokens (nil = unlimited)
    prefillStepSize: 512      // Prompt processing chunk size
)

// Presets
SpeculativeParameters.default      // Balanced
SpeculativeParameters.greedy       // Deterministic (temp=0)
SpeculativeParameters.creative     // Higher temperature
SpeculativeParameters.conservative // Lower draft count

How It Works

  1. Draft Phase: Small model generates K candidate tokens autoregressively
  2. Verify Phase: Large model processes all K tokens in a single forward pass
  3. Accept/Reject: Rejection sampling validates each token against target distribution
  4. Repeat: Continue with accepted tokens plus one new token from target

The key insight is that verification is parallelizable - the target model can check all draft tokens simultaneously, amortizing the cost of its larger size.

Performance

Results on MacBook Pro M4 Pro with Qwen2.5 models:

Drafter Target Speed Acceptance Tokens/Step
Qwen2.5-0.5B Qwen2.5-3B 40 tok/s 28% 2.4

Speedup depends on:

  • Draft/target model size ratio
  • Task similarity between models
  • Temperature settings
  • Hardware capabilities

Mamba Drafters (Experimental)

This library includes experimental support for using Mamba (state-space models) as draft models instead of transformers.

Why Mamba for Drafting?

Property Transformer Mamba
Memory per token O(n) KV-cache O(1) constant
Inference complexity O(n²) attention O(n) linear
Model size for quality ~500M ~130M

Supported Mamba Models

speculative-cli list-models

# Mamba Drafters:
# - state-spaces/mamba-130m-hf (768 hidden, 24 layers)
# - state-spaces/mamba-370m-hf (1024 hidden, 48 layers)
# - state-spaces/mamba-790m-hf (1536 hidden, 48 layers)

Mamba API

import SpeculativeDecoding

// Load Mamba drafter with transformer target
let pair = try await MambaDraftTargetPair.load(
    draftModelId: "state-spaces/mamba-130m-hf",
    targetModelId: "mlx-community/Qwen2.5-3B-Instruct-4bit"
)

let generator = MambaSpeculativeGenerator(modelPair: pair)
let result = try await generator.generate(prompt: "Hello") { _ in .more }

Memory Comparison at Different Sequence Lengths

Sequence Length Transformer (135M) Mamba (130M)
512 tokens ~334 MB ~260 MB
2048 tokens ~526 MB ~260 MB
8192 tokens ~1.3 GB ~260 MB

Benchmark Results

Comparing Mamba (130M) vs Transformer (0.5B) as draft models with Qwen2.5-3B target on MacBook Pro M4 Pro:

Drafter Speed Acceptance Tokens/Step
Transformer (0.5B) 40 tok/s 28% 2.4
Mamba (130M) 86 tok/s 96%* 5.8*

*Note: High acceptance rates with Mamba are due to vocabulary mismatch between Mamba (GPT-NeoX) and Qwen (custom). This is experimental and results may not be representative of true speculative decoding behavior.

Why Mamba for Drafting

  1. Faster drafting: 130M Mamba is smaller with O(1) memory per token
  2. Constant memory: No KV-cache growth with sequence length
  3. Efficient verification: More tokens per target model forward pass

Benchmark Command

.build/release/speculative-cli benchmark \
    --transformer-draft mlx-community/Qwen2.5-0.5B-Instruct-4bit \
    --target-model mlx-community/Qwen2.5-3B-Instruct-4bit \
    --tokens 128 --runs 3

License

MIT License - see LICENSE for details.

Acknowledgments

  • MLX - Apple's ML framework
  • MLX-Swift - Swift bindings
  • MLX-Swift-LM - LLM implementations
  • Mamba - State-space models
  • Original paper: "Fast Inference from Transformers via Speculative Decoding" (Leviathan et al., 2023)
  • Mamba paper: "Mamba: Linear-Time Sequence Modeling with Selective State Spaces" (Gu & Dao, 2023)

About

Native speculative decoding implementation for fast LLM inference on Apple Silicon using MLX-Swift.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages