Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
458 changes: 458 additions & 0 deletions expert_offload_plan.md

Large diffs are not rendered by default.

44 changes: 44 additions & 0 deletions python/sglang/srt/layers/moe/expert_offload/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""UVM-based expert weight offloading for MoE layers.

Expert weights are stored in CUDA Unified Memory (cudaMallocManaged).
Resident experts are kept in GPU VRAM (PREFER_GPU advice).
Offloaded experts live in CPU DRAM and are accessible to the GPU via PCIe
read-through (PREFER_CPU + ACCESSED_BY_GPU advice) — no page fault overhead,
no quality loss, CUDA graph compatible.

All computation remains on GPU (no CPU inference).
Mutually exclusive with KTransformers (--kt-weight-path).

Usage
-----
Pass ``--expert-offload-num-resident N`` to enable. Additional options:

--expert-offload-prefetch none Prefetch strategy (none / speculative / frequency).
--expert-offload-resident-selection first_n How to choose resident experts.
--expert-offload-resident-ids None Comma-separated IDs for manual selection.
"""

from sglang.srt.layers.moe.expert_offload.config import (
ExpertOffloadConfig,
create_expert_offload_config_from_server_args,
)
from sglang.srt.layers.moe.expert_offload.wrapper import ExpertOffloadWrapperMethod

__all__ = [
"ExpertOffloadConfig",
"ExpertOffloadWrapperMethod",
"create_expert_offload_config_from_server_args",
]
80 changes: 80 additions & 0 deletions python/sglang/srt/layers/moe/expert_offload/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configuration for UVM-based expert weight offloading."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional

if TYPE_CHECKING:
from sglang.srt.server_args import ServerArgs


@dataclass
class ExpertOffloadConfig:
"""Per-layer configuration for UVM expert offloading.

Expert weights are stored in CUDA Unified Memory (cudaMallocManaged).
Resident experts are advised PREFER_GPU (stay in VRAM).
Offloaded experts are advised PREFER_CPU + ACCESSED_BY_GPU (live in CPU
DRAM; the GPU reads them via PCIe without triggering a page fault).

No LRU cache, no assembly buffer, no ID remapping — UVM handles all of it.
"""

layer_idx: int
num_local_experts: int
num_resident_experts: int # experts whose pages are kept on GPU
prefetch_strategy: str # "none" | "speculative" | "frequency"
resident_selection: str # "first_n" | "frequency" | "manual"
resident_expert_ids: Optional[List[int]] # explicit IDs for "manual" mode
num_layers: Optional[int] # total MoE layers (for prefetch coordination)
warmup_tokens: int = 4096 # routed tokens to collect before readvise

@property
def num_offloaded_experts(self) -> int:
return self.num_local_experts - self.num_resident_experts


def create_expert_offload_config_from_server_args(
server_args: "ServerArgs",
layer_id: int,
num_local_experts: int,
) -> Optional[ExpertOffloadConfig]:
"""Return an ExpertOffloadConfig if expert offloading is enabled, else None."""
if server_args.expert_offload_num_resident < 0:
return None

num_resident = min(server_args.expert_offload_num_resident, num_local_experts)

resident_ids: Optional[List[int]] = None
if (
server_args.expert_offload_resident_selection == "manual"
and server_args.expert_offload_resident_ids is not None
):
resident_ids = [
int(x.strip())
for x in server_args.expert_offload_resident_ids.split(",")
]
Comment on lines +67 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation for parsing expert_offload_resident_ids is susceptible to a ValueError if the string contains consecutive commas (e.g., '1,2,,3') or a trailing comma, as x.strip() would be an empty string, and int('') is invalid. It's safer to filter out empty strings before converting to integers.

Suggested change
resident_ids = [
int(x.strip())
for x in server_args.expert_offload_resident_ids.split(",")
]
resident_ids = [
int(x.strip())
for x in server_args.expert_offload_resident_ids.split(",")
if x.strip()
]


return ExpertOffloadConfig(
layer_idx=layer_id,
num_local_experts=num_local_experts,
num_resident_experts=num_resident,
prefetch_strategy=server_args.expert_offload_prefetch,
resident_selection=server_args.expert_offload_resident_selection,
resident_expert_ids=resident_ids,
num_layers=None,
)
Loading
Loading