diff --git a/tests/compile/passes/distributed/test_sequence_parallelism.py b/tests/compile/passes/distributed/test_sequence_parallelism.py index 8588e0501783..828cec962f9c 100644 --- a/tests/compile/passes/distributed/test_sequence_parallelism.py +++ b/tests/compile/passes/distributed/test_sequence_parallelism.py @@ -35,7 +35,10 @@ from vllm.utils.system_utils import update_environment_variables from vllm.utils.torch_utils import set_random_seed -pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") +pytestmark = pytest.mark.skipif( + not (current_platform.is_cuda() or current_platform.is_xpu()), + reason="Only test CUDA or XPU", +) FP8_DTYPE = current_platform.fp8_dtype() prompts = [ @@ -178,7 +181,9 @@ def ops_in_model(self): @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("fuse_norm_quant", [True, False]) @pytest.mark.parametrize("dynamic", [False, True]) -@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA") +@pytest.mark.skipif( + envs.VLLM_TARGET_DEVICE not in ["cuda", "xpu"], reason="Only test on CUDA or XPU" +) def test_sequence_parallelism_pass( test_model_cls: type[torch.nn.Module], custom_ops: str, @@ -227,7 +232,7 @@ def sequence_parallelism_pass_on_test_model( ): set_random_seed(0) - device = torch.device(f"cuda:{local_rank}") + device = torch.device(f"{current_platform.device_type}:{local_rank}") torch.accelerator.set_device_index(device) torch.set_default_device(device) torch.set_default_dtype(dtype) @@ -257,7 +262,7 @@ def sequence_parallelism_pass_on_test_model( eliminate_noops=True, ), ) # NoOp needed for fusion - device_config = DeviceConfig(device=torch.device("cuda")) + device_config = DeviceConfig(device=torch.device(current_platform.device_type)) # this is a fake model name to construct the model config # in the vllm_config, it's not really used. diff --git a/vllm/compilation/passes/fusion/sequence_parallelism.py b/vllm/compilation/passes/fusion/sequence_parallelism.py index e3cfb8a4896b..2582796e9a29 100644 --- a/vllm/compilation/passes/fusion/sequence_parallelism.py +++ b/vllm/compilation/passes/fusion/sequence_parallelism.py @@ -29,6 +29,7 @@ # Min hidden size per device capability for sequence parallelism # Only apply sequence parallelism for models with hidden_size >= threshold +# Keyed by CUDA compute capability (int). SP_MIN_HIDDEN_SIZE: dict[int, int] = { 90: 8192, # H100: only for models with hidden_size >= 8192 } @@ -36,10 +37,38 @@ # Min size per GPU per device capability for sequence parallelism # Total min size = min_per_gpu_size * tp_size # This ensures the threshold scales appropriately with tensor parallelism +# Keyed by CUDA compute capability (int). SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = { 90: 8, # 8MB per GPU for H100 } +# XPU defaults — XPU does not expose CUDA-style compute capabilities, +# so we maintain separate constants. +SP_XPU_MIN_HIDDEN_SIZE: int = 4096 +SP_XPU_MIN_PER_GPU_SIZE_MB: float = 8 + + +def _get_sp_limits() -> tuple[int, float] | None: + """Return (min_hidden_size, min_per_gpu_size_mb) for the current device, + or *None* if the platform is unsupported / has no configured thresholds.""" + from vllm.platforms import current_platform + + if current_platform.is_xpu(): + return SP_XPU_MIN_HIDDEN_SIZE, SP_XPU_MIN_PER_GPU_SIZE_MB + + if current_platform.is_cuda(): + capability = current_platform.get_device_capability() + if capability is None: + return None + device_capability = capability.to_int() + min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability) + min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability) + if min_hidden_size is None or min_per_gpu_size_mb is None: + return None + return min_hidden_size, min_per_gpu_size_mb + + return None + def get_sequence_parallelism_threshold( hidden_size: int, @@ -59,22 +88,11 @@ def get_sequence_parallelism_threshold( Formula: min_token_num = (min_per_gpu_size_mb * tp_size * MiB) // (hidden_size * element_size) """ - from vllm.platforms import current_platform - - if not current_platform.is_cuda(): - return None - - capability = current_platform.get_device_capability() - if capability is None: + limits = _get_sp_limits() + if limits is None: return None - device_capability = capability.to_int() - - # Check if device has configured thresholds - min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability) - min_per_gpu_size_mb = SP_MIN_PER_GPU_SIZE_MB.get(device_capability) - if min_hidden_size is None or min_per_gpu_size_mb is None: - return None + min_hidden_size, min_per_gpu_size_mb = limits # Only apply sequence parallelism for models meeting the size threshold if hidden_size < min_hidden_size: diff --git a/vllm/compilation/passes/pass_manager.py b/vllm/compilation/passes/pass_manager.py index 0571741419f7..79bddbb1a20f 100644 --- a/vllm/compilation/passes/pass_manager.py +++ b/vllm/compilation/passes/pass_manager.py @@ -24,16 +24,19 @@ RocmAiterTritonAddRMSNormPadFusionPass, ) + if current_platform.is_cuda_alike(): from .fusion.act_quant_fusion import ActivationQuantFusionPass from .fusion.attn_quant_fusion import AttnQuantFusionPass from .fusion.qk_norm_rope_fusion import QKNormRoPEFusionPass from .fusion.rms_quant_fusion import RMSNormQuantFusionPass from .fusion.rope_kvcache_fusion import RopeKVCacheFusionPass - from .fusion.sequence_parallelism import SequenceParallelismPass from .utility.scatter_split_replace import ScatterSplitReplacementPass from .utility.split_coalescing import SplitCoalescingPass +if current_platform.is_cuda_alike() or current_platform.is_xpu(): + from .fusion.sequence_parallelism import SequenceParallelismPass + if current_platform.is_cuda(): from .fusion.allreduce_rms_fusion import AllReduceFusionPass from .fusion.collective_fusion import AsyncTPPass diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 2a56ff5c6e62..076941897647 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -247,6 +247,10 @@ def is_data_center_gpu(cls) -> bool: device_name = cls.get_device_name().lower() return device_name.count("data center gpu") > 0 + @classmethod + def use_custom_op_collectives(cls) -> bool: + return True + @classmethod def get_device_communicator_cls(cls) -> str: from vllm.utils.torch_utils import supports_xccl