[Feature] Add UVM-based MoE expert offloading with all-GPU compute#20126
[Feature] Add UVM-based MoE expert offloading with all-GPU compute#20126lichang98 wants to merge 2 commits intosgl-project:mainfrom
Conversation
Implement CUDA Unified Memory (cudaMallocManaged) based expert offloading as an alternative to KTransformers. Resident experts use PREFER_GPU advice (VRAM bandwidth), offloaded experts use PREFER_CPU + ACCESSED_BY_GPU (PCIe read-through without page faults). No ID remapping, no LRU cache, no assembly buffer — UVM is fully transparent to CUDA graphs. Profile analysis on GLM-5-FP8 (TP=8, 200/256 resident, EAGLE spec decode) shows ~8.7% decode overhead from UVM PCIe reads, zero CPU hot-path cost during CUDA graph replay. Speculative prefetch is not yet active during decode and needs follow-up work. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
… warmup During early eager forward passes (prefill), collect per-layer expert routing frequencies. After accumulating enough tokens, recompute the optimal resident set per layer and call cudaMemAdvise to promote/demote experts. This replaces the static first_n selection when --expert-offload-resident-selection frequency is used. - Add warmup_tokens config field (default 4096) - Add record_expert_usage() with CUDA graph capture guard - Filter out small batches (< 64 tokens) to ignore dummy warmup data - Wire frequency tracking into FusedMoE.forward_impl() Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a robust and efficient UVM-based expert offloading solution for Mixture-of-Experts (MoE) layers. The primary goal is to enable the execution of very large MoE models, such as GLM-5-FP8, on single-host GPU setups by intelligently managing expert weights across GPU and host memory. By keeping all computation on the GPU and ensuring full compatibility with CUDA graphs, this approach offers a seamless and high-performance experience, addressing memory constraints without compromising model quality or inference speed during decode. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant and well-designed feature for UVM-based MoE expert offloading. The use of CUDA Unified Memory to simplify the architecture while maintaining full quality during decode is a great approach. The code is well-structured, and the inclusion of a detailed design document is highly appreciated. My review identifies a few areas for improvement, including a potential bug in argument parsing, some leftover dead code, and a misleading log message that should be corrected to accurately reflect the feature's capabilities. Overall, this is a solid contribution that will be even better with these minor adjustments.
| resident_ids = [ | ||
| int(x.strip()) | ||
| for x in server_args.expert_offload_resident_ids.split(",") | ||
| ] |
There was a problem hiding this comment.
The current implementation for parsing expert_offload_resident_ids is susceptible to a ValueError if the string contains consecutive commas (e.g., '1,2,,3') or a trailing comma, as x.strip() would be an empty string, and int('') is invalid. It's safer to filter out empty strings before converting to integers.
| resident_ids = [ | |
| int(x.strip()) | |
| for x in server_args.expert_offload_resident_ids.split(",") | |
| ] | |
| resident_ids = [ | |
| int(x.strip()) | |
| for x in server_args.expert_offload_resident_ids.split(",") | |
| if x.strip() | |
| ] |
| if model_runner.server_args.expert_offload_num_resident >= 0: | ||
| log_info_on_rank0( | ||
| logger, | ||
| "[ExpertOffload] CUDA graph mode: decode uses static path (resident experts only). " | ||
| "Prefill always uses dynamic path (full quality). " | ||
| "Use --disable-cuda-graph for full decode quality.", | ||
| ) |
There was a problem hiding this comment.
This log message is misleading. It states that with CUDA graph mode, decode uses a static path with 'resident experts only', implying a loss of quality. However, the design document (expert_offload_plan.md) and the implementation show that UVM allows transparent access to offloaded experts via PCIe read-through, ensuring full quality even during CUDA graph replay. The log message should be corrected to reflect that there is no quality degradation.
if model_runner.server_args.expert_offload_num_resident >= 0:
log_info_on_rank0(
logger,
"[ExpertOffload] CUDA graph mode enabled. Offloaded experts will be accessed "
"transparently via PCIe, ensuring full decode quality.",
)| --host 0.0.0.0 --port 9090 \ | ||
| --page-size 256 \ | ||
| --mem-fraction-static 0.85 --tensor-parallel-size 8 \ | ||
| --hicache-io-backend kernel \ |
| def prefetch_experts( | ||
| self, | ||
| expert_ids: List[int], | ||
| stream: Optional[torch.cuda.Stream] = None, | ||
| ) -> None: | ||
| """Speculatively prefetch expert pages to GPU VRAM. | ||
|
|
||
| Call this with predicted expert IDs *before* a CUDA graph decode step | ||
| to upgrade PCIe-speed accesses to VRAM-speed accesses for those experts. | ||
| Correctness does not depend on this call — it is a pure optimisation. | ||
|
|
||
| ``expert_ids`` should only contain offloaded (non-resident) expert IDs. | ||
| """ | ||
| if stream is None: | ||
| stream = self.prefetch_stream | ||
|
|
||
| # We need managed tensor references to issue prefetch on. | ||
| # Iterate over the first param to get the managed tensor. | ||
| if not self._expert_param_names: | ||
| return | ||
|
|
||
| # NOTE: We assume all params have the same expert ordering. | ||
| # The caller is responsible for providing valid offloaded expert IDs. | ||
| for expert_id in expert_ids: | ||
| if expert_id < 0 or expert_id >= self.config.num_local_experts: | ||
| continue | ||
| # Each managed tensor holds expert weights; prefetch per-param. | ||
| # Access managed tensors stored on the layer via the stored names. | ||
| # (The actual managed tensors are on the layer object itself.) | ||
| # This method is called by the wrapper which has access to layer. | ||
| pass # See ExpertOffloadWrapperMethod.prefetch_experts_on_layer |
There was a problem hiding this comment.
| # Copyright 2024 SGLang Team | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| """Prefetch strategies for expert weight offloading.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from typing import List, Optional | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| class ExpertPrefetchStrategy(ABC): | ||
| """Abstract base class for expert prefetch strategies.""" | ||
|
|
||
| @abstractmethod | ||
| def predict(self, topk_ids: torch.Tensor, layer_idx: int) -> List[int]: | ||
| """Predict which expert IDs to prefetch for the next layer. | ||
|
|
||
| Args: | ||
| topk_ids: Current layer's top-k expert IDs [batch, top_k]. | ||
| layer_idx: Current layer index. | ||
|
|
||
| Returns: | ||
| List of expert IDs to prefetch (may be empty). | ||
| """ | ||
| ... | ||
|
|
||
| def update(self, topk_ids: torch.Tensor, layer_idx: int) -> None: | ||
| """Update internal statistics after seeing actual expert usage.""" | ||
| pass | ||
|
|
||
|
|
||
| class NoPrefetch(ExpertPrefetchStrategy): | ||
| """No-op prefetch strategy.""" | ||
|
|
||
| def predict(self, topk_ids: torch.Tensor, layer_idx: int) -> List[int]: | ||
| return [] | ||
|
|
||
|
|
||
| class SpeculativePrefetch(ExpertPrefetchStrategy): | ||
| """Cross-layer correlation-based speculative prefetch. | ||
|
|
||
| Maintains a co-occurrence matrix: ``corr[i, j]`` counts how often expert | ||
| ``i`` in layer L predicts expert ``j`` in layer L+1. After a warmup | ||
| period, for each expert selected in the current layer we look up its | ||
| top-correlated next-layer experts and return those as prefetch candidates. | ||
| """ | ||
|
|
||
| WARMUP_STEPS = 32 | ||
|
|
||
| def __init__(self, num_experts: int, num_layers: int, num_predictions: int): | ||
| self.num_experts = num_experts | ||
| self.num_layers = num_layers | ||
| self.num_predictions = num_predictions | ||
| self._step = 0 | ||
|
|
||
| # corr[layer, expert_prev, expert_next] — updated online. | ||
| self.corr = torch.zeros(num_layers, num_experts, num_experts, dtype=torch.float32) | ||
| # Expert selections from the previous step, per layer. | ||
| self._prev_ids: Optional[torch.Tensor] = None | ||
|
|
||
| def predict(self, topk_ids: torch.Tensor, layer_idx: int) -> List[int]: | ||
| if self._step < self.WARMUP_STEPS: | ||
| return [] | ||
| next_layer = layer_idx + 1 | ||
| if next_layer >= self.num_layers: | ||
| return [] | ||
| # Accumulate correlation scores for the next layer. | ||
| scores = torch.zeros(self.num_experts, dtype=torch.float32) | ||
| for eid in topk_ids.unique().tolist(): | ||
| eid = int(eid) | ||
| if 0 <= eid < self.num_experts: | ||
| scores += self.corr[next_layer, eid] | ||
| _, top_ids = scores.topk(min(self.num_predictions, self.num_experts)) | ||
| return top_ids.tolist() | ||
|
|
||
| def update(self, topk_ids: torch.Tensor, layer_idx: int) -> None: | ||
| self._step += 1 | ||
| unique_ids = [int(x) for x in topk_ids.unique().tolist() if x >= 0] | ||
| if self._prev_ids is not None and layer_idx > 0: | ||
| prev_unique = [int(x) for x in self._prev_ids.unique().tolist() if x >= 0] | ||
| for prev in prev_unique: | ||
| for curr in unique_ids: | ||
| if 0 <= prev < self.num_experts and 0 <= curr < self.num_experts: | ||
| self.corr[layer_idx, prev, curr] += 1.0 | ||
| # Save current ids as "previous" for the next layer's update call. | ||
| self._prev_ids = topk_ids.detach().cpu() | ||
|
|
||
|
|
||
| class FrequencyPrefetch(ExpertPrefetchStrategy): | ||
| """Always prefetch the globally most-frequent experts.""" | ||
|
|
||
| def __init__(self, num_experts: int, num_predictions: int): | ||
| self.num_experts = num_experts | ||
| self.num_predictions = num_predictions | ||
| self.freq = torch.zeros(num_experts, dtype=torch.long) | ||
| self._top_ids: List[int] = [] | ||
|
|
||
| def predict(self, topk_ids: torch.Tensor, layer_idx: int) -> List[int]: | ||
| return self._top_ids | ||
|
|
||
| def update(self, topk_ids: torch.Tensor, layer_idx: int) -> None: | ||
| for eid in topk_ids.unique().tolist(): | ||
| eid = int(eid) | ||
| if 0 <= eid < self.num_experts: | ||
| self.freq[eid] += 1 | ||
| _, top_ids = self.freq.topk(min(self.num_predictions, self.num_experts)) | ||
| self._top_ids = top_ids.tolist() | ||
|
|
||
|
|
||
| def create_prefetch_strategy( | ||
| strategy_name: str, | ||
| num_experts: int, | ||
| num_layers: int, | ||
| num_cache_slots: int, | ||
| ) -> ExpertPrefetchStrategy: | ||
| """Factory function for prefetch strategies.""" | ||
| if strategy_name == "none": | ||
| return NoPrefetch() | ||
| elif strategy_name == "speculative": | ||
| return SpeculativePrefetch(num_experts, num_layers, num_predictions=num_cache_slots) | ||
| elif strategy_name == "frequency": | ||
| return FrequencyPrefetch(num_experts, num_predictions=num_cache_slots) | ||
| else: | ||
| raise ValueError(f"Unknown prefetch strategy: {strategy_name!r}") |
Motivation
GLM-5-FP8 has 256 experts per MoE layer. Even with TP=8 on a single host of 8× H20 (96 GB each, 768 GB total VRAM), it cannot fit in GPU memory. Rather than requiring multi-host deployment, this PR offloads a subset of MoE expert weights to host memory using CUDA Unified Virtual Memory (UVM).
Unlike KTransformers, which moves computation to CPU for offloaded experts, this implementation keeps all computation on GPU — offloaded experts are accessed transparently via PCIe read-through. The existing prefill and decode paths remain compatible and decode can still benefit from CUDA graph replay.
Modifications
Commit 1: UVM-based expert weight offloading infrastructure
Commit 2: Per-layer adaptive resident expert selection via frequency warmup
promote/demote experts via cudaMemAdvise
Accuracy Tests
Benchmarking and Profiling
Test with glm5-fp8 on H20 * 8:
execute it after install sglang
Test command:
This is an early-stage implementation and there is plenty of room for improvement. I'd be happy if anyone is interested in joining together to make it better!
Checklist
Review Process
/tag-run-ci-label,/rerun-failed-ci,/tag-and-rerun-ci