Skip to content

[Feature] [RFC] Add a step-disaggregated draft model for speculative decoding, PRISM #444

@Akemiiii

Description

@Akemiiii

Checklist

Motivation

TL;DR

We implemented a new speculative decoding method, PRISM, based on EAGLE and SGLang. It has better scaling performance and lower training cost, which could improve the scalability and availability of speculative decoding.

What we want from maintainers (this issue):

  • Does PRISM fit the SpecForge framework?
  • What is the best way to integrate PRISM into SpecForge?
  • What correctness tests + performance benchmarks are required before a PR can be merged?
  • Could you provide resources or help to validate the scaling performance of PRISM with larger datasets (more than 800K samples)?

Motivation

As mentioned in this LMSYS blog and META tech report, large training corpora and multi-layer draft model provide better e2e performance compared to the original EAGLE-3 because of better accuracy of speculation. SGLang also implemented multi-layer eagle worker. However, these changes may increase draft model inference overhead and the training cost.

PRISM

Overview

PRISM disaggregates the computation of each predictive step across different parameter sets, refactoring the computational pathways of draft models to successfully decouple model capacity from inference cost. We achieve exceptional acceptance lengths while maintaining minimal draft latency for superior end-to-end speedup. We also re-examine scaling laws of draft models, revealing that PRISM scales more effectively with expanding data volumes than other draft architectures (EAGLE-2, HASS, EAGLE-3). Besides, the training cost of PRISM is much smaller than EAGLE-3.

We will show more results in the coming paper.

Architecture

We refactor the inference computational path by distributing computation across different draft steps to distinct parameter sets, analogous to how a prism disperses white light into its spectrum. Consequently, the total parameter count of the drafter expands while the number of activated parameters per draft step remains constant.

Performance

Acceptance Length

For clarity, I only put LLaMA-3-8B-Instruct (temperature = 0, step = 6, topk = 10, num-draft-token = 60, 800K training samples) results here:

Method / Acceptance Length MT-Bench HumanEval GSM8K Alpaca CNN/DM Natural Ques.
EAGLE-2 4.60 5.37 5.14 4.65 4.21 3.88
HASS 5.23 6.20 6.00 5.40 5.09 4.42
PRISM w/o 3 hidden states like EAGLE3 5.29 6.28 6.05 5.58 5.01 4.53
EAGLE-3 5.41 6.25 6.08 5.64 4.83 4.50
PRISM 5.55 6.46 6.24 5.72 5.42 4.72

SGLang E2E Performance

Now I do not compare PRISM against EAGLE-3 because EAGLE-3 has a smaller LM head. We are training PRISM with a small LM head to align with EAGLE-3. However, theoretically, the forward overhead of PRISM should be able to achieve better performance than EAGLE-3 because PRISM has less parameters and similar transformer architecture.

LLaMA-3-8B-Instruct (temperature = 0, step = 6, topk = 4, num-draft-token = 16, 800K training samples, NVIDIA A800 80GB) results:

Method / Token per second MT-Bench HumanEval GSM8K Alpaca CNN/DM Natural Ques.
Standalone 164.13 180.66 173.28 144.51 143.18 136.81
EAGLE-2 168.26 212.36 197.66 171.89 163.34 147.74
HASS 180.50 235.38 216.75 184.39 176.16 156.67
PRISM 213.61 269.14 254.09 209.06 205.62 182.33

Scaling Performance

PRISM demonstrates superior scaling performance compared to EAGLE-2, HASS, and EAGLE-3. The figure shows the acceptance length of PRISM, EAGLE-2, HASS, and EAGLE-3 with different number of training samples (step = 4, topk = 4, num-draft-token = 16, datasets: ShareGPT + UltraChat + OpenThoughts2).

Image

Training Cost

We leverage a novel two phase training scheme: we first start with a warm up
stage and train the model with just one identical module for every draft step; then we replicate the weights and train different modules for different draft steps. Compared to EAGLE-3, the training cost of PRISM is reduced by ~50% (offline) or ~40% (online) because we only need to train three steps to speculate six steps.

Related resources

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions