diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index 65621023ab6c..cb4fab0436f7 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -36,6 +36,7 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from vllm import LLM +from vllm.platforms import current_platform class MyLLM(LLM): @@ -58,6 +59,7 @@ def __init__(self, *args, bundle_indices: list[int], **kwargs): os.environ.pop("CUDA_VISIBLE_DEVICES", None) # Each worker uses 0.4 GPU so that two instances fit on the same GPUs. os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4" + os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1" os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) print(f"creating LLM with bundle_indices={bundle_indices}") super().__init__(*args, **kwargs) @@ -76,14 +78,18 @@ def __init__(self): from transformers import AutoModelForCausalLM self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") - self.model.to("cuda:0") + from vllm.platforms import current_platform + + self.model.to(current_platform.device_type + ":0") # Zero out all the parameters. for name, p in self.model.named_parameters(): p.data.zero_() - torch.cuda.synchronize() + if current_platform.is_xpu(): + torch.xpu.synchronize() + else: + torch.cuda.synchronize() # The argument for `get_device_uuid` is the index of the GPU in the # list of visible devices. - from vllm.platforms import current_platform self.device_uuid = current_platform.get_device_uuid(0) @@ -104,7 +110,7 @@ def get_weight_ipc_handles(self): # Ray manages four GPUs. -os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" +os.environ[current_platform.device_control_env_var] = "0,1,2,3" ray.init() # Co-locate vLLM instances and training actors on the same set of GPUs: @@ -123,6 +129,7 @@ def get_weight_ipc_handles(self): inference_engine_device_ids = [] for bundle_index in [0, 1, 2, 3]: + env_vars = {current_platform.device_control_env_var: str(bundle_index)} training_actor = ray.remote( num_cpus=0, num_gpus=0.4, @@ -131,6 +138,7 @@ def get_weight_ipc_handles(self): placement_group_capture_child_tasks=True, placement_group_bundle_index=bundle_index, ), + runtime_env={"env_vars": env_vars} if current_platform.is_xpu() else {}, )(RayTrainingActor).remote() training_actors.append(training_actor) @@ -138,7 +146,6 @@ def get_weight_ipc_handles(self): device_id = ray.get(training_actor.report_device_id.remote()) print(f"training actor {bundle_index} is on {device_id}") training_actor_device_ids.append(device_id) - for i, bundle_indices in enumerate([[0, 1], [2, 3]]): # Use the following syntax instead of the @ray.remote decorator so that # the placement group is customized for each bundle. diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index d2a8419ffabc..c51089f3e9bf 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -1,7 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + import torch +from vllm.platforms import current_platform + def stateless_init_process_group(master_address, master_port, rank, world_size, device): """ @@ -11,14 +15,41 @@ def stateless_init_process_group(master_address, master_port, rank, world_size, the data-plane communication (NCCL) between external (train processes) and vLLM workers. """ - from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator - from vllm.distributed.utils import StatelessProcessGroup + if current_platform.is_xpu(): + from vllm.distributed.device_communicators.xpu_communicator import ( + XpuCommunicator, + ) - pg = StatelessProcessGroup.create( - host=master_address, port=master_port, rank=rank, world_size=world_size - ) - pynccl = PyNcclCommunicator(pg, device=device) - return pynccl + os.environ.setdefault("CCL_ATL_TRANSPORT", "ofi") + os.environ.setdefault("LOCAL_WORLD_SIZE", str(world_size)) + os.environ["LOCAL_RANK"] = str(rank) + from vllm.utils import get_distributed_init_method + + distributed_init_method = get_distributed_init_method( + master_address, master_port + ) + + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group( + backend="ccl", + init_method=distributed_init_method, + world_size=world_size, + rank=rank, + ) + + ranks = list(range(torch.distributed.get_world_size())) + pg = torch.distributed.new_group(ranks, backend="ccl") + ccl = XpuCommunicator(pg, device=device) + return ccl + else: + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + + pg = StatelessProcessGroup.create( + host=master_address, port=master_port, rank=rank, world_size=world_size + ) + pynccl = PyNcclCommunicator(pg, device=device) + return pynccl class WorkerExtension: @@ -47,10 +78,14 @@ def init_weight_update_group( def update_weight(self, name, dtype_name, shape): dtype = getattr(torch, dtype_name) - weight = torch.empty(shape, dtype=dtype, device="cuda") - self.model_update_group.broadcast( - weight, src=0, stream=torch.cuda.current_stream() - ) + weight = torch.empty(shape, dtype=dtype, device=current_platform.device_type) + + if current_platform.is_xpu(): + self.model_update_group.broadcast(weight, src=0) + else: + self.model_update_group.broadcast( + weight, src=0, stream=torch.cuda.current_stream() + ) self.model_runner.model.load_weights(weights=[(name, weight)]) @@ -91,11 +126,18 @@ def update_weights_from_ipc_handles(self, ipc_handles): list_args = list(args) # the key is to change device id to the current device id # in case two processes have different CUDA_VISIBLE_DEVICES - list_args[6] = device_id - tensor = func(*list_args) + if current_platform.is_xpu(): + tensor = func(*list_args) + tensor = tensor.to(current_platform.device_type + ":" + str(device_id)) + else: + list_args[6] = device_id + tensor = func(*list_args) weights.append((name, tensor)) self.model_runner.model.load_weights(weights=weights) - torch.cuda.synchronize() + if current_platform.is_xpu(): + torch.xpu.synchronize() + else: + torch.cuda.synchronize() def check_weights_changed(self): """ diff --git a/vllm/executor/ray_distributed_executor.py b/vllm/executor/ray_distributed_executor.py index 37c3fe59c65d..8f0e9d77d142 100644 --- a/vllm/executor/ray_distributed_executor.py +++ b/vllm/executor/ray_distributed_executor.py @@ -55,8 +55,11 @@ class RayDistributedExecutor(DistributedExecutorBase): # These env vars are worker-specific, therefore are NOT copied # from the driver to the workers WORKER_SPECIFIC_ENV_VARS = { - "VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES" + "VLLM_HOST_IP", + "VLLM_HOST_PORT", + "LOCAL_RANK", } + WORKER_SPECIFIC_ENV_VARS.add(current_platform.device_control_env_var) # These non-vLLM env vars are copied from the driver to workers ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"} @@ -65,6 +68,18 @@ class RayDistributedExecutor(DistributedExecutorBase): def _init_executor(self) -> None: self.forward_dag: Optional[ray.dag.CompiledDAG] = None + + # For XPU, we use ZE_AFFINITY_MASK in vllm to control device + # visibility, instead of using 'ONEAPI_DEVICE_SELECTOR' in + # ray (which requires the "level_zero:" prefix). This makes it + # easier to share the same code logic with CUDA_VISIBLE_DEVICES. + # Therefore, we are removing ONEAPI_DEVICE_SELECTOR here. If not + # removed, setting both environment variables (ZE_AFFINITY_MASK + # and ONEAPI_DEVICE_SELECTOR) would result in the controlled + # devices being the intersection of the two. + if current_platform.is_xpu(): + os.environ.pop("ONEAPI_DEVICE_SELECTOR", None) + if envs.VLLM_USE_V1: # V1 uses SPMD worker and compiled DAG os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1" @@ -191,6 +206,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", worker_metadata: List[RayWorkerMetaData] = [] driver_ip = get_ip() + + # Explicitly set the device visibility for XPU in the Ray runtime env. + # This is required because vllm uses ZE_AFFINITY_MASK + # for XPU (not ONEAPI_DEVICE_SELECTOR) and we have removed the latter + # to avoid conflicts. + if current_platform.is_xpu(): + bundle_indices_str = ",".join(map(str, bundle_indices)) + env_vars = { + current_platform.device_control_env_var: bundle_indices_str + } + runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) + runtime_env.setdefault("env_vars", {}).update(env_vars) + for rank, bundle_id in enumerate(bundle_indices): scheduling_strategy = PlacementGroupSchedulingStrategy( placement_group=placement_group, diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index d61b921e19cf..3872d542d2cc 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -183,3 +183,8 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype): @classmethod def opaque_attention_op(cls) -> bool: return True + + @classmethod + def get_device_uuid(cls, device_id: int = 0) -> str: + physical_device_id = cls.device_id_to_physical_device_id(device_id) + return "intel-gpu-" + str(physical_device_id) diff --git a/vllm/v1/worker/xpu_worker.py b/vllm/v1/worker/xpu_worker.py index 17288cda8ecc..8112599c59a0 100644 --- a/vllm/v1/worker/xpu_worker.py +++ b/vllm/v1/worker/xpu_worker.py @@ -153,13 +153,9 @@ def init_device(self): raise RuntimeError( f"Not support device type: {self.device_config.device}") - ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "pidfd") - ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi") - ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", - str(self.parallel_config.world_size)) - os.environ["CCL_ZE_IPC_EXCHANGE"] = ENV_CCL_ZE_IPC_EXCHANGE - os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT - os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE + os.environ.setdefault("CCL_ATL_TRANSPORT", "ofi") + os.environ.setdefault("LOCAL_WORLD_SIZE", + str(self.parallel_config.world_size)) os.environ["LOCAL_RANK"] = str(self.local_rank) init_worker_distributed_environment(self.vllm_config, self.rank, diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index a1fa7f2cf7a2..d81fa7c4822d 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -17,6 +17,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import (enable_trace_function_call_for_thread, resolve_obj_by_qualname, run_method, @@ -531,7 +532,7 @@ def adjust_rank(self, rank_mapping: Dict[int, int]) -> None: def update_environment_variables(self, envs_list: List[Dict[str, str]]) -> None: envs = envs_list[self.rpc_rank] - key = 'CUDA_VISIBLE_DEVICES' + key = current_platform.device_control_env_var if key in envs and key in os.environ: # overwriting CUDA_VISIBLE_DEVICES is desired behavior # suppress the warning in `update_environment_variables`