Skip to content

[Performance]: Better MTP Support (decode optimized) #21984

@LucasWilkinson

Description

@LucasWilkinson

Problem

FlashMLA currently assumes decode operations have query_len=1, forcing multi-token speculation requests to use the inefficient prefill path.

Current Behavior

Normal Decode (works great! ✅):
┌─────────────┐
│ Request A   │ query_len = 1
│ Request B   │ query_len = 1  ──> FlashMLA Decode Path
│ Request C   │ query_len = 1      (Memory Efficient)
│ Request D   │ query_len = 1
└─────────────┘

MTP with Speculation (problem! ❌):
┌─────────────┐
│ Request A   │ query_len = 4
│ Request B   │ query_len = 4  ──> Forced to Prefill Path
│ Request C   │ query_len = 4      (Memory inefficient!; Compute Optimized)
│ Request D   │ query_len = 4
└─────────────┘

Options:

  1. varlen decode kernels; examples:
    FlashAttn MLA: [WIP][Attention] FlashAttn MLA #14258
    FlashInfer MLA: [WIP][Kernel] Flashinfer MLA support #13630
    Benefits: clean implementation
    Downside: cant use SOTA Kernels like FlashMLA and the upcoming TrtLLM-MLA

  2. Padded speculation (see below)
    Benefits: can use SOTA Kernels like FlashMLA and the upcoming TrtLLM-MLA
    Downside: more complicated implementation

Padded speculation proposal

1. Smart Decode Classification

Increase decode threshold and use minimum query_len for uniform batches, this ensures we only use the decode path for small query_lens and uniform batches.

Before: decode_threshold = 1 (only single tokens)
After:  decode_threshold = 4 (support up to 3 speculative tokens)

Classification Logic:
─────────────────────────────────────
Batch: [2, 2, 2, 2] 
Min qualifying len = 2  ──> ALL decode

Batch: [1, 2, 2, 1]
Min qualifying len = 1  ──> Only [1,_,_,1] are decode
                            (FlashMLA needs uniform batches)
─────────────────────────────────────

2. Handle Token Rejection Without Breaking Uniformity

Instead of removing rejected tokens (creating non-uniform batches), adjust sampling position:

Token Rejection Example (3 speculative tokens):
═══════════════════════════════════════════════════════════════

Initial: Each request has 1 actual + 3 speculative tokens
Model Output:
┌─────────────────────────┐
│ Req 0: [A₀][S₁][S₂][S₃] │ reject 2 
│ Req 1: [A₀][S₁][S₂][S₃] │ reject 0
│ Req 2: [A₀][S₁][S₂][S₃] │ reject 1
│ Req 3: [A₀][S₁][S₂][S₃] │ reject 3
└─────────────────────────┘
Passed to speculator 
Traditional Approach (❌ Non-uniform):
┌─────────────────────────┐
│ Req 0: [A₀][S₁]         │ len=2
│ Req 1: [A₀][S₁][S₂][S₃] │ len=4 
│ Req 2: [A₀][S₁][S₂]     │ len=3 
│ Req 3: [A₀]             │ len=1 
└─────────────────────────┘

Speculator output, 
┌─────────────────────────┐
│ Req 0: [X][N₁]          │ len=2
│ Req 1: [X][X][X][N₃]    │ len=4 
│ Req 2: [X][X][N₂]       │ len=3 
│ Req 3: [N₀]             │ len=1 
└─────────────────────────┘

Drafted tokens: [N₁, N₃, N₂, N₀]

Padded Approach (✅ Uniform):
Passed to sepculator (can just pass directly from model output since we can pad with outputted tokens):
┌─────────────────────────┐
│ Req 0: [A₀][S₁][X ][X ] │ len=4 (2 padding since we rejected 2)
│ Req 1: [A₀][S₁][S₂][S₃] │ len=4 (0 padding since we rejected 0)
│ Req 2: [A₀][S₁][S₂][X ] │ len=4 (1 padding since we rejected 1)
│ Req 3: [A₀][X ][X ][X ] │ len=4 (3 padding since we rejected 3)
└─────────────────────────┘

Speculator output, 
┌─────────────────────────┐
│ Req 0: [X][N₁][X][X]    │ sample from position 1 (since we rejected 2)
│ Req 1: [X][X][X][N₃]    │ sample from position 3 (since we rejected 0)
│ Req 2: [X][X][N₂][X]    │ sample from position 2 (since we rejected 1)
│ Req 3: [N₀][X][X][X]    │ sample from position 0 (since we rejected 3)
└─────────────────────────┘

Drafted tokens: [N₁, N₃, N₂, N₀]

All batches stay len=4! 

cc @benchislett @WoosukKwon @mgoin

Metadata

Metadata

Assignees

Labels

performancePerformance-related issues

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions