|
| 1 | +--- |
| 2 | +title: "PD-Multiplexing: Unlocking High-Goodput LLM Serving with GreenContext" |
| 3 | +author: "Weihao Cui, Yukang Chen, Xiaoze Fan, Han Zhao, Ziyi Xu, Xusheng Chen, Bingsheng He, Quan Chen" |
| 4 | +date: "September 28, 2025" |
| 5 | +previewImg: /images/blog/pdmux/logo.png |
| 6 | +--- |
| 7 | + |
| 8 | +This post highlights our initial efforts to support **a new serving paradigm, PD-Multiplexing, in** **SGLang.** It is designed to deliver higher goodput in LLM serving. PD-Multiplexing leverages [**GreenContext**](https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GREEN__CONTEXTS.html), a new NVIDIA GPU capability that allows lightweight and fine-grained partitioning of GPU resources across tasks within the same process. We envision this paradigm as a promising new approach to LLM service deployment, delivering stronger SLO guarantees and higher goodput for Model-as-a-Service (MaaS). |
| 9 | + |
| 10 | +## Goodput in LLM Serving: A Persistent Challenge |
| 11 | + |
| 12 | +Delivering MaaS at scale demands that LLM serving systems consistently meet stringent Service Level Objectives (SLOs) without sacrificing throughput. In practice, this translates into satisfying the well-established latency SLOs for both stages of inference: Time-to-First-Token (TTFT) during the prefill phase, and Inter-Token Latency (ITL)—also referred to as Time-Between-Tokens (TBT)—during the decode phase. The challenge arises because prefill and decode interleave on the same serving instance, creating contention for GPU resources. Two common approaches have emerged to enforce SLO compliance: |
| 13 | +1. **Instance-level PD-disaggregation** – separating prefill and decode into different instances. However, this comes with trade-offs: GPU resources are statically partitioned, and **KV cache migration** between instances introduces additional complexity, requiring high-performance interconnects and communication libraries. |
| 14 | +2. **Sequence-level chunked-prefill** – splitting long sequences into smaller chunks and fusing each chunk with a decode iteration to control ITL. This too introduces a **delicate trade-off**: the chunk size must strike a balance between tight ITL guarantees and high GPU utilization. |
| 15 | + |
| 16 | +In particular, when targeting a tight SLO threshold for practical LLM services, the shortcomings of both disaggregation and chunking become increasingly pronounced. |
| 17 | + |
| 18 | +## PD-Multiplexing: A New Serving Paradigm |
| 19 | + |
| 20 | +<img src="/images/blog/pdmux/1-overview.png" style="display:block; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 80%; image-orientation: none;"></img> |
| 21 | +<div style="text-align:center"><strong>Figure 1. Overview of PD-Multiplexing</strong></div> |
| 22 | + |
| 23 | +To this end, we propose a new serving paradigm, **PD-Multiplexing**, for achieving higher goodput. It multiplexes the prefill and decode phases within the same instance through intra-GPU spatial sharing. It offers several important benefits: |
| 24 | +1. Prefill and decode share a common KV cache pool within the same instance, **removing the need for costly cross-instance migration.** |
| 25 | +2. Intra-GPU spatial sharing allows GPU compute resources, SMs, to **flow dynamically** between prefill and decode as workloads vary. |
| 26 | +3. This sharing also **decouples the execution** of prefill and decode, ensuring that prefill performance is not compromised when meeting stringent ITL SLOs. |
| 27 | + |
| 28 | +Figure 1 presents an overview of PD-Multiplexing, which consists of two core modules: a bubble-less multiplex engine that independently and efficiently executes prefill and decode phases, and an SLO-aware dispatcher that iteratively generates multiplexing plans compliant with SLOs. |
| 29 | + |
| 30 | +### Realizing Bubble-less Multiplex Engine with GreenContext |
| 31 | + |
| 32 | +We built the new paradigm on top of **GreenContext**, a capability introduced in NVIDIA GPUs starting with CUDA 12.4. GreenContext enables lightweight **intra-process** spatial sharing. Briefly, we can create multiple CUDA streams with dedicated SM allocations for concurrent GPU kernels since CUDA 12.6. With GreenContext, GPU resources can be dynamically partitioned between the prefill and decode phases, adapting to SLO requirements, workload patterns, and other serving needs in real time. |
| 33 | + |
| 34 | +To preserve the existing serving architecture, we adopt single-thread scheduling for multiplexing prefill and decode, rather than creating separate threads for each. This choice is also motivated by the fact that Python’s Global Interpreter Lock (GIL) still prevents true parallel execution, and will remain the default for upcoming versions. Fortunately, both prefill and decode are asynchronous, which makes this design feasible. By switching the dispatch between prefill and decode using their dedicated GreenContext streams, we can enable multiplexing. |
| 35 | + |
| 36 | +<img src="/images/blog/pdmux/2-mux-engine.png" style="display:block; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 100%; image-orientation: none;"></img> |
| 37 | +<div style="text-align:center"><strong>Figure 2. Removing bubbles for efficient prefill-decode multiplexing</strong></div> |
| 38 | + |
| 39 | +However, such an integration of GreenContext introduces GPU bubbles. As illustrated in Figure 2(a), these bubbles arise for two reasons: (1) The launch time of prefill phases is significantly longer than that of decode phases (which involve only a single CUDA graph). In some cases, launching a prefill phase takes longer than executing an entire decode phase, leaving GPU resources idle. (2) The number of iterations in the decode phase is non-deterministic. When all requests in a decode batch finish early, the pre-allocated SMs may remain underutilized if a prefill has already been launched. |
| 40 | + |
| 41 | +To address this, we split the prefill phase into smaller prefill blocks, as shown in Figure 2(b). Since prefill is typically far more compute-intensive than decode, this block-level splitting incurs negligible overhead while effectively eliminating GPU bubbles during multiplexing. |
| 42 | + |
| 43 | +### Profiling and Crafting Scheduling Policies |
| 44 | + |
| 45 | +With the bubble-less multiplex engine in place, the next challenge is scheduling prefill blocks and decode batches. Offline profiling shows that the two phases compete for resources under GreenContext. The root cause is that while GreenContext partitions SMs, it does not partition memory bandwidth, making contention difficult to model. To address this, we profile representative workloads offline and use the results to train a latency predictor that drives our SLO-aware scheduling policies. Since the modeling depends on the specific model and hardware environment, we omit the details here but will provide a detailed tutorial with practical, step-by-step guidance in the future. |
| 46 | + |
| 47 | +The intuition behind the scheduling policy is simple: **allocate just enough SMs to the decode phase to guarantee the ITL SLO, then dedicate all remaining SMs to prefill. At the same time, we determine the number of prefill blocks to launch.** This way, decode always runs under strict SLO guarantees, while prefill proceeds as quickly as possible to enlarge the decode batch size. |
| 48 | + |
| 49 | + |
| 50 | +## Benchmark |
| 51 | + |
| 52 | +In summary, we evaluate PD-Multiplexing against multiple baselines across a range of workloads and devices. We first present an experiment that is easy to reproduce, then demonstrate the advantages of PD-Multiplexing using real-world traces and diverse tasks. Finally, we provide a zoomed-out visualization of runtime scheduling details. In our extensive evaluations, PD-Multiplexing improves goodput by up to 3.06x over state-of-the-art baselines. |
| 53 | + |
| 54 | +<small>* The following results are presented for research purposes. In real-world applications, SLO requirements are often more specific. Here, we use these benchmarks to illustrate the potential of PD-Multiplexing. </small> |
| 55 | + |
| 56 | +### Comparison with Chunked-prefill under Varying Chunk Sizes |
| 57 | + |
| 58 | +<img src="/images/blog/pdmux/3-H200.png" style="display:block; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 100%; image-orientation: none;"></img> |
| 59 | +<div style="text-align:center"><strong>Figure 3. Results of ShareGPT and LooGLE on a single H200 with CodeLlama-34b-hf</strong></div> |
| 60 | + |
| 61 | +We first evaluate PD-Multiplexing against chunked-prefill with varying chunk sizes on a single H200 GPU running CodeLlama-34b-hf. Figure 3 reports the 99th-percentile TTFT and ITL. We set the SLO target of ITL to 60 ms. We do not impose SLO constraints on TTFT. Instead, we report the P99 of TTFT to demonstrate the efficiency of PD-Multiplexing. In the figure, solid points indicate that the corresponding baseline meets the ITL SLO requirement, while empty points indicate that the baseline violates it. |
| 62 | + |
| 63 | +The results highlight a clear advantage: PD-Multiplexing delivers the fastest TTFT while consistently meeting the stringent SLO target for ITL. In contrast, Chunked-prefill often must reduce the chunk size below 1024 to satisfy such a strict ITL requirement, which degrades prefill performance and leaves GPU resources underutilized. This benefit becomes even more pronounced for long-context workloads such as LooGLE, where the inefficiency of chunking is magnified. Reproduction instructions are available [here.](https://github.com/sgl-project/sglang/pull/10692) |
| 64 | + |
| 65 | +### Results on Real-world Traces |
| 66 | + |
| 67 | +We have also evaluated PD-Multiplexing with real-world trace, Mooncake-Tool&Agent. We compared it with chunked-prefill and PD-disaggregation. All are based on the codebase of SGLang for fair comparison. This experiment is conducted on a server with 8 A100s and the chunk size for chunked-prefill is 512. The ratio of P:D in disaggregation is 1:1. Prefix cache sharing is enabled. |
| 68 | + |
| 69 | + |
| 70 | +<img src="/images/blog/pdmux/4-mooncake.png" style="display:block; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 80%; image-orientation: none;"></img> |
| 71 | +<div style="text-align:center"><strong>Figure 4. Results of Mooncake-Tool&Agent on 8xA100s with Llama3.1-70B</strong></div> |
| 72 | + |
| 73 | + |
| 74 | +Figure 4(a) presents the TTFT and ITL results. Compared with chunked prefill, PD-Multiplexing improves both metrics. Relative to PD-disaggregation, it achieves noticeably shorter TTFT, while both methods meet the SLO for decode phases. To assess its impact on goodput, we gradually increase the request rate and measure SLO attainment. As shown in Figure 5, PD-Multiplexing sustains significantly higher goodput, delivering up to 3.06× and 1.62× improvements over chunked prefill and PD-disaggregation, respectively. |
| 75 | + |
| 76 | +### Results on Diverse Tasks with Scheduling Visualization |
| 77 | + |
| 78 | +We further evaluate PD-Multiplexing on three representative tasks: **OpenThoughts**, **ShareGPT**, and **LooGLE**. These tasks exhibit contrasting workload patterns: OpenThoughts features the shortest prefill input with the longest decode output; ShareGPT involves longer prefill input but shorter decode output; and LooGLE stresses the system with the longest prefill input and the shortest decode output. |
| 79 | + |
| 80 | +<img src="/images/blog/pdmux/5-open-share-loogle.png" style="display:block; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 80%; image-orientation: none;"></img> |
| 81 | +<div style="text-align:center"><strong>Figure 5. Results of OpenThoughts, ShareGPT, and LooGLE on 8xA100s with Llama3.1-70B</strong></div> |
| 82 | + |
| 83 | + |
| 84 | +<img src="/images/blog/pdmux/6-visualization.png" style="display:block; margin-top: auto; margin-left: auto; margin-right: auto; margin-bottom: auto; width: 80%; image-orientation: none;"></img> |
| 85 | +<div style="text-align:center"><strong>Figure 6. Scheduling visualization of OpenThoughts, ShareGPT, and LooGLE on 8xA100s with Llama3.1-70B</strong></div> |
| 86 | + |
| 87 | + |
| 88 | +Figure 5 reports the results across these tasks. PD-Multiplexing consistently maintains strong performance and achieves significantly higher goodput than the baselines. To illustrate how this is realized in practice, Figure 6 presents a zoomed-out runtime timeline of scheduling decisions. As shown, PD-Multiplexing dynamically adapts resource allocation: for OpenThoughts, it assigns minimal resources to prefill, while for LooGLE, it minimizes resources for decode. The timeline further demonstrates how the scheduler seamlessly switches between different SM allocation plans as workloads vary, ensuring both SLO compliance and high efficiency. |
| 89 | + |
| 90 | +## Future Work |
| 91 | + |
| 92 | +PD-Multiplexing shows strong promise for MaaS deployments. We will focus our next steps on the following areas: |
| 93 | + |
| 94 | +* **MTP and Speculative Decoding Support.** |
| 95 | +Models with an MTP mechanism, such as DeepSeek-V3, decode multiple tokens simultaneously, requiring adjustments to the scheduling policy for SLO guarantees in PD-Multiplexing. In addition, speculative decoding approaches that employ more complex verification methods across multiple collaborating LLMs call for a different resource-partitioning strategy. We will extend PD-Multiplexing with policies tailored to them and report the resulting end-to-end gains. |
| 96 | +* **Toward More Realistic Industrial SLOs and Fine-grained Parameter Control.** High-goodput serving with PD-Multiplexing requires a one-time profiling pass for the target model and hardware. A brief reference document is available [here](https://github.com/ykcombat/sglang/blob/e84aa1bdd055df93e603a46fa6ca5e60afd213f5/docs/advanced_features/pd_multiplexing.md). We will also release a detailed, hands-on tutorial to help users reproduce our profiling workflow and design effective scheduling policies tailored to realistic industrial scenarios. |
| 97 | +* **Integration with PD-Disaggregation.** While we compare PD-Multiplexing with PD-disaggregation in the above evaluation, the two approaches are not mutually exclusive—they can be integrated to further boost goodput. For example, since a decode instance often does not sustain its maximum batch size under varying workloads, it can be replaced with a PD-Multiplexing instance. This allows the system to harvest otherwise wasted resources by overlapping the decode phase with multiplexed prefills. |
| 98 | +* **Compatibility with PyTorch \> 2.6.** All results are currently reproduced with PyTorch 2.6. For newer versions, we have encountered NVIDIA-internal issues when combining CUDA Graphs, and NCCL. In particular, launching distributed CUDA Graphs incurs significant overhead when PyTorch \> 2.6. We are working closely with upstream developers to address this bug and enable smooth compatibility with future PyTorch releases. |
| 99 | +* **MoE Model Support.** Our preliminary experiments on Qwen-235B indicate that PD-Multiplexing continues to deliver consistent improvements. We plan to release comprehensive results on larger MoE models, such as DeepSeek-V3, evaluated in larger, distributed environments. |
| 100 | + |
| 101 | +We have a [proof-of-concept implementation](https://github.com/sgl-project/sglang/pull/10692) of PD-Multiplexing, and the [roadmap](https://github.com/sgl-project/sglang/issues/10813) for full integration into SGLang is underway. |
| 102 | + |
| 103 | +## Acknowledgement |
| 104 | + |
| 105 | +- We would like to thank the SGLang team and community for their generous support, especially Liangsheng Yin, Yichuan Wang, Lianmin Zheng, and many others. |
| 106 | +- We would like to thank Yi Pan for the insight discussions during the early stages of this work. |
0 commit comments