-
-
Notifications
You must be signed in to change notification settings - Fork 9.6k
Description
Problem
FlashMLA currently assumes decode operations have query_len=1
, forcing multi-token speculation requests to use the inefficient prefill path.
Current Behavior
Normal Decode (works great! ✅):
┌─────────────┐
│ Request A │ query_len = 1
│ Request B │ query_len = 1 ──> FlashMLA Decode Path
│ Request C │ query_len = 1 (Memory Efficient)
│ Request D │ query_len = 1
└─────────────┘
MTP with Speculation (problem! ❌):
┌─────────────┐
│ Request A │ query_len = 4
│ Request B │ query_len = 4 ──> Forced to Prefill Path
│ Request C │ query_len = 4 (Memory inefficient!; Compute Optimized)
│ Request D │ query_len = 4
└─────────────┘
Options:
-
varlen decode kernels; examples:
FlashAttn MLA: [WIP][Attention] FlashAttn MLA #14258
FlashInfer MLA: [WIP][Kernel] Flashinfer MLA support #13630
Benefits: clean implementation
Downside: cant use SOTA Kernels like FlashMLA and the upcoming TrtLLM-MLA -
Padded speculation (see below)
Benefits: can use SOTA Kernels like FlashMLA and the upcoming TrtLLM-MLA
Downside: more complicated implementation
Padded speculation proposal
1. Smart Decode Classification
Increase decode threshold and use minimum query_len for uniform batches, this ensures we only use the decode path for small query_lens and uniform batches.
Before: decode_threshold = 1 (only single tokens)
After: decode_threshold = 4 (support up to 3 speculative tokens)
Classification Logic:
─────────────────────────────────────
Batch: [2, 2, 2, 2]
Min qualifying len = 2 ──> ALL decode
Batch: [1, 2, 2, 1]
Min qualifying len = 1 ──> Only [1,_,_,1] are decode
(FlashMLA needs uniform batches)
─────────────────────────────────────
2. Handle Token Rejection Without Breaking Uniformity
Instead of removing rejected tokens (creating non-uniform batches), adjust sampling position:
Token Rejection Example (3 speculative tokens):
═══════════════════════════════════════════════════════════════
Initial: Each request has 1 actual + 3 speculative tokens
Model Output:
┌─────────────────────────┐
│ Req 0: [A₀][S₁][S₂][S₃] │ reject 2
│ Req 1: [A₀][S₁][S₂][S₃] │ reject 0
│ Req 2: [A₀][S₁][S₂][S₃] │ reject 1
│ Req 3: [A₀][S₁][S₂][S₃] │ reject 3
└─────────────────────────┘
Passed to speculator
Traditional Approach (❌ Non-uniform):
┌─────────────────────────┐
│ Req 0: [A₀][S₁] │ len=2
│ Req 1: [A₀][S₁][S₂][S₃] │ len=4
│ Req 2: [A₀][S₁][S₂] │ len=3
│ Req 3: [A₀] │ len=1
└─────────────────────────┘
Speculator output,
┌─────────────────────────┐
│ Req 0: [X][N₁] │ len=2
│ Req 1: [X][X][X][N₃] │ len=4
│ Req 2: [X][X][N₂] │ len=3
│ Req 3: [N₀] │ len=1
└─────────────────────────┘
Drafted tokens: [N₁, N₃, N₂, N₀]
Padded Approach (✅ Uniform):
Passed to sepculator (can just pass directly from model output since we can pad with outputted tokens):
┌─────────────────────────┐
│ Req 0: [A₀][S₁][X ][X ] │ len=4 (2 padding since we rejected 2)
│ Req 1: [A₀][S₁][S₂][S₃] │ len=4 (0 padding since we rejected 0)
│ Req 2: [A₀][S₁][S₂][X ] │ len=4 (1 padding since we rejected 1)
│ Req 3: [A₀][X ][X ][X ] │ len=4 (3 padding since we rejected 3)
└─────────────────────────┘
Speculator output,
┌─────────────────────────┐
│ Req 0: [X][N₁][X][X] │ sample from position 1 (since we rejected 2)
│ Req 1: [X][X][X][N₃] │ sample from position 3 (since we rejected 0)
│ Req 2: [X][X][N₂][X] │ sample from position 2 (since we rejected 1)
│ Req 3: [N₀][X][X][X] │ sample from position 0 (since we rejected 3)
└─────────────────────────┘
Drafted tokens: [N₁, N₃, N₂, N₀]
All batches stay len=4!