Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions tests/compile/passes/distributed/test_sequence_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@
from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import set_random_seed

pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
pytestmark = pytest.mark.skipif(
not (current_platform.is_cuda() or current_platform.is_xpu()),
reason="Only test CUDA or XPU",
)

FP8_DTYPE = current_platform.fp8_dtype()
prompts = [
Expand Down Expand Up @@ -178,7 +181,9 @@ def ops_in_model(self):
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("fuse_norm_quant", [True, False])
@pytest.mark.parametrize("dynamic", [False, True])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
@pytest.mark.skipif(
envs.VLLM_TARGET_DEVICE not in ["cuda", "xpu"], reason="Only test on CUDA or XPU"
)
def test_sequence_parallelism_pass(
test_model_cls: type[torch.nn.Module],
custom_ops: str,
Expand Down Expand Up @@ -227,7 +232,7 @@ def sequence_parallelism_pass_on_test_model(
):
set_random_seed(0)

device = torch.device(f"cuda:{local_rank}")
device = torch.device(f"{current_platform.device_type}:{local_rank}")
torch.accelerator.set_device_index(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
Expand Down Expand Up @@ -257,7 +262,7 @@ def sequence_parallelism_pass_on_test_model(
eliminate_noops=True,
),
) # NoOp needed for fusion
device_config = DeviceConfig(device=torch.device("cuda"))
device_config = DeviceConfig(device=torch.device(current_platform.device_type))

# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
Expand Down
46 changes: 32 additions & 14 deletions vllm/compilation/passes/fusion/sequence_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,46 @@

# Min hidden size per device capability for sequence parallelism
# Only apply sequence parallelism for models with hidden_size >= threshold
# Keyed by CUDA compute capability (int).
SP_MIN_HIDDEN_SIZE: dict[int, int] = {
90: 8192, # H100: only for models with hidden_size >= 8192
}

# Min size per GPU per device capability for sequence parallelism
# Total min size = min_per_gpu_size * tp_size
# This ensures the threshold scales appropriately with tensor parallelism
# Keyed by CUDA compute capability (int).
SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = {
90: 8, # 8MB per GPU for H100
}

# XPU defaults — XPU does not expose CUDA-style compute capabilities,
# so we maintain separate constants.
SP_XPU_MIN_HIDDEN_SIZE: int = 4096
SP_XPU_MIN_PER_GPU_SIZE_MB: float = 8


def _get_sp_limits() -> tuple[int, float] | None:
"""Return (min_hidden_size, min_per_gpu_size_mb) for the current device,
or *None* if the platform is unsupported / has no configured thresholds."""
from vllm.platforms import current_platform

if current_platform.is_xpu():
return SP_XPU_MIN_HIDDEN_SIZE, SP_XPU_MIN_PER_GPU_SIZE_MB

if current_platform.is_cuda():
capability = current_platform.get_device_capability()
if capability is None:
return None
device_capability = capability.to_int()
min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability)
min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability)
if min_hidden_size is None or min_per_gpu_size_mb is None:
return None
return min_hidden_size, min_per_gpu_size_mb

return None


def get_sequence_parallelism_threshold(
hidden_size: int,
Expand All @@ -59,22 +88,11 @@ def get_sequence_parallelism_threshold(
Formula: min_token_num = (min_per_gpu_size_mb * tp_size * MiB) //
(hidden_size * element_size)
"""
from vllm.platforms import current_platform

if not current_platform.is_cuda():
return None

capability = current_platform.get_device_capability()
if capability is None:
limits = _get_sp_limits()
if limits is None:
return None
device_capability = capability.to_int()

# Check if device has configured thresholds
min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability)
min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability)

if min_hidden_size is None or min_per_gpu_size_mb is None:
return None
min_hidden_size, min_per_gpu_size_mb = limits

# Only apply sequence parallelism for models meeting the size threshold
if hidden_size < min_hidden_size:
Expand Down
5 changes: 4 additions & 1 deletion vllm/compilation/passes/pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,19 @@
RocmAiterTritonAddRMSNormPadFusionPass,
)


if current_platform.is_cuda_alike():
from .fusion.act_quant_fusion import ActivationQuantFusionPass
from .fusion.attn_quant_fusion import AttnQuantFusionPass
from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass
from .fusion.rms_quant_fusion import RMSNormQuantFusionPass
from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
from .fusion.sequence_parallelism import SequenceParallelismPass
from .utility.scatter_split_replace import ScatterSplitReplacementPass
from .utility.split_coalescing import SplitCoalescingPass

if current_platform.is_cuda_alike() or current_platform.is_xpu():
from .fusion.sequence_parallelism import SequenceParallelismPass

if current_platform.is_cuda():
from .fusion.allreduce_rms_fusion import AllReduceFusionPass
from .fusion.collective_fusion import AsyncTPPass
Expand Down
4 changes: 4 additions & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,10 @@ def is_data_center_gpu(cls) -> bool:
device_name = cls.get_device_name().lower()
return device_name.count("data center gpu") > 0

@classmethod
def use_custom_op_collectives(cls) -> bool:
return True

@classmethod
def get_device_communicator_cls(cls) -> str:
from vllm.utils.torch_utils import supports_xccl
Expand Down
Loading