Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions examples/offline_inference/rlhf_colocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from vllm import LLM
from vllm.platforms import current_platform


class MyLLM(LLM):
Expand All @@ -58,6 +59,7 @@ def __init__(self, *args, bundle_indices: list[int], **kwargs):
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
# Each worker uses 0.4 GPU so that two instances fit on the same GPUs.
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if this is not set, the following error will occur, even I have set it in the running script.

Update the weights of the inference engines.
Traceback (most recent call last):
File "/home/chaojun/vllm/examples/offline_inference/rlhf_colocate.py", line 192, in
ray.get(
File "/home/chaojun/upstream/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
return fn(*args, **kwargs)
File "/home/chaojun/upstream/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 104, in wrapper
return func(*args, **kwargs)
File "/home/chaojun/upstream/lib/python3.10/site-packages/ray/_private/worker.py", line 2858, in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
File "/home/chaojun/upstream/lib/python3.10/site-packages/ray/_private/worker.py", line 958, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(TypeError): ray::MyLLM.collective_rpc() (pid=1425434, ip=10.7.182.52, actor_id=3ecf23723b638fd6e5c18bfe04000000, repr=<rlhf_colocate.MyLLM object at 0x7f531a753970>)
File "/home/chaojun/vllm/vllm/entrypoints/llm.py", line 494, in collective_rpc
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
File "/home/chaojun/vllm/vllm/v1/engine/llm_engine.py", line 321, in collective_rpc
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
File "/home/chaojun/vllm/vllm/v1/engine/core_client.py", line 747, in collective_rpc
return self.call_utility("collective_rpc", method, timeout, args,
File "/home/chaojun/vllm/vllm/v1/engine/core_client.py", line 692, in call_utility
self._send_input(EngineCoreRequestType.UTILITY,
File "/home/chaojun/vllm/vllm/v1/engine/core_client.py", line 678, in _send_input
*self.encoder.encode(request))
File "/home/chaojun/vllm/vllm/v1/serial_utils.py", line 87, in encode
bufs[0] = self.encoder.encode(obj)
File "/home/chaojun/vllm/vllm/v1/serial_utils.py", line 140, in enc_hook
raise TypeError(f"Object of type {type(obj)} is not serializable"
TypeError: Object of type <class 'function'> is not serializableSet VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow fallback to pickle-based serialization.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it also happens on cuda? or xpu only.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the same concern, we should not enable this env unless it is necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally, the script would include VLLM_ALLOW_INSECURE_SERIALIZATION=1, but this parameter is not passed when launching with ray.remote. I consider this a bug.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain what you mean by "originally"? Looks like this script doesn't pass lambda functions to llm.collective_rpc so I'm not sure why this flag is required

Copy link
Contributor Author

@chaojun-zhang chaojun-zhang Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error indicates that the ’enc_hook‘ method uses msgpack as the serialization framework, but certain objects and methods are not supported for serialization by msgpack. This requires using the switch to fall back to the pickle-based serialization framework.

By originally, I meant that this example was supposed to include this environment variable(

- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
), but when i try to verify on xpu, it still didn’t take effect.

Here you can see that lambda functions (with ipc_handles) will be sent through collective_rpc. :

"update_weights_from_ipc_handles", args=(ipc_handles,)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao can you comment on this?

os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
print(f"creating LLM with bundle_indices={bundle_indices}")
super().__init__(*args, **kwargs)
Expand All @@ -76,14 +78,18 @@ def __init__(self):
from transformers import AutoModelForCausalLM

self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
self.model.to("cuda:0")
from vllm.platforms import current_platform

self.model.to(current_platform.device_type + ":0")
# Zero out all the parameters.
for name, p in self.model.named_parameters():
p.data.zero_()
torch.cuda.synchronize()
if current_platform.is_xpu():
torch.xpu.synchronize()
else:
torch.cuda.synchronize()
# The argument for `get_device_uuid` is the index of the GPU in the
# list of visible devices.
from vllm.platforms import current_platform

self.device_uuid = current_platform.get_device_uuid(0)

Expand All @@ -104,7 +110,7 @@ def get_weight_ipc_handles(self):

# Ray manages four GPUs.

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ[current_platform.device_control_env_var] = "0,1,2,3"
ray.init()

# Co-locate vLLM instances and training actors on the same set of GPUs:
Expand All @@ -123,6 +129,7 @@ def get_weight_ipc_handles(self):
inference_engine_device_ids = []

for bundle_index in [0, 1, 2, 3]:
env_vars = {current_platform.device_control_env_var: str(bundle_index)}
training_actor = ray.remote(
num_cpus=0,
num_gpus=0.4,
Expand All @@ -131,14 +138,14 @@ def get_weight_ipc_handles(self):
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_index,
),
runtime_env={"env_vars": env_vars} if current_platform.is_xpu() else {},
)(RayTrainingActor).remote()
training_actors.append(training_actor)

for bundle_index, training_actor in enumerate(training_actors):
device_id = ray.get(training_actor.report_device_id.remote())
print(f"training actor {bundle_index} is on {device_id}")
training_actor_device_ids.append(device_id)

for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
# Use the following syntax instead of the @ray.remote decorator so that
# the placement group is customized for each bundle.
Expand Down
70 changes: 56 additions & 14 deletions examples/offline_inference/rlhf_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os

import torch

from vllm.platforms import current_platform


def stateless_init_process_group(master_address, master_port, rank, world_size, device):
"""
Expand All @@ -11,14 +15,41 @@ def stateless_init_process_group(master_address, master_port, rank, world_size,
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup
if current_platform.is_xpu():
from vllm.distributed.device_communicators.xpu_communicator import (
XpuCommunicator,
)

pg = StatelessProcessGroup.create(
host=master_address, port=master_port, rank=rank, world_size=world_size
)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl
os.environ.setdefault("CCL_ATL_TRANSPORT", "ofi")
os.environ.setdefault("LOCAL_WORLD_SIZE", str(world_size))
os.environ["LOCAL_RANK"] = str(rank)
from vllm.utils import get_distributed_init_method

distributed_init_method = get_distributed_init_method(
master_address, master_port
)

if not torch.distributed.is_initialized():
torch.distributed.init_process_group(
backend="ccl",
init_method=distributed_init_method,
world_size=world_size,
rank=rank,
)

ranks = list(range(torch.distributed.get_world_size()))
pg = torch.distributed.new_group(ranks, backend="ccl")
ccl = XpuCommunicator(pg, device=device)
return ccl
else:
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.utils import StatelessProcessGroup

pg = StatelessProcessGroup.create(
host=master_address, port=master_port, rank=rank, world_size=world_size
)
pynccl = PyNcclCommunicator(pg, device=device)
return pynccl


class WorkerExtension:
Expand Down Expand Up @@ -47,10 +78,14 @@ def init_weight_update_group(

def update_weight(self, name, dtype_name, shape):
dtype = getattr(torch, dtype_name)
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(
weight, src=0, stream=torch.cuda.current_stream()
)
weight = torch.empty(shape, dtype=dtype, device=current_platform.device_type)

if current_platform.is_xpu():
self.model_update_group.broadcast(weight, src=0)
else:
self.model_update_group.broadcast(
weight, src=0, stream=torch.cuda.current_stream()
)

self.model_runner.model.load_weights(weights=[(name, weight)])

Expand Down Expand Up @@ -91,11 +126,18 @@ def update_weights_from_ipc_handles(self, ipc_handles):
list_args = list(args)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args[6] = device_id
tensor = func(*list_args)
if current_platform.is_xpu():
tensor = func(*list_args)
tensor = tensor.to(current_platform.device_type + ":" + str(device_id))
else:
list_args[6] = device_id
tensor = func(*list_args)
weights.append((name, tensor))
self.model_runner.model.load_weights(weights=weights)
torch.cuda.synchronize()
if current_platform.is_xpu():
torch.xpu.synchronize()
else:
torch.cuda.synchronize()

def check_weights_changed(self):
"""
Expand Down
30 changes: 29 additions & 1 deletion vllm/executor/ray_distributed_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,11 @@ class RayDistributedExecutor(DistributedExecutorBase):
# These env vars are worker-specific, therefore are NOT copied
# from the driver to the workers
WORKER_SPECIFIC_ENV_VARS = {
"VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES"
"VLLM_HOST_IP",
"VLLM_HOST_PORT",
"LOCAL_RANK",
}
WORKER_SPECIFIC_ENV_VARS.add(current_platform.device_control_env_var)

# These non-vLLM env vars are copied from the driver to workers
ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}
Expand All @@ -65,6 +68,18 @@ class RayDistributedExecutor(DistributedExecutorBase):

def _init_executor(self) -> None:
self.forward_dag: Optional[ray.dag.CompiledDAG] = None

# For XPU, we use ZE_AFFINITY_MASK in vllm to control device
# visibility, instead of using 'ONEAPI_DEVICE_SELECTOR' in
# ray (which requires the "level_zero:" prefix). This makes it
# easier to share the same code logic with CUDA_VISIBLE_DEVICES.
# Therefore, we are removing ONEAPI_DEVICE_SELECTOR here. If not
# removed, setting both environment variables (ZE_AFFINITY_MASK
# and ONEAPI_DEVICE_SELECTOR) would result in the controlled
# devices being the intersection of the two.
if current_platform.is_xpu():
os.environ.pop("ONEAPI_DEVICE_SELECTOR", None)

if envs.VLLM_USE_V1:
# V1 uses SPMD worker and compiled DAG
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
Expand Down Expand Up @@ -191,6 +206,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",

worker_metadata: List[RayWorkerMetaData] = []
driver_ip = get_ip()

# Explicitly set the device visibility for XPU in the Ray runtime env.
# This is required because vllm uses ZE_AFFINITY_MASK
# for XPU (not ONEAPI_DEVICE_SELECTOR) and we have removed the latter
# to avoid conflicts.
if current_platform.is_xpu():
bundle_indices_str = ",".join(map(str, bundle_indices))
env_vars = {
current_platform.device_control_env_var: bundle_indices_str
}
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
runtime_env.setdefault("env_vars", {}).update(env_vars)

for rank, bundle_id in enumerate(bundle_indices):
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
Expand Down
5 changes: 5 additions & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,8 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype):
@classmethod
def opaque_attention_op(cls) -> bool:
return True

@classmethod
def get_device_uuid(cls, device_id: int = 0) -> str:
physical_device_id = cls.device_id_to_physical_device_id(device_id)
return "intel-gpu-" + str(physical_device_id)
10 changes: 3 additions & 7 deletions vllm/v1/worker/xpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,9 @@ def init_device(self):
raise RuntimeError(
f"Not support device type: {self.device_config.device}")

ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", "pidfd")
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
str(self.parallel_config.world_size))
os.environ["CCL_ZE_IPC_EXCHANGE"] = ENV_CCL_ZE_IPC_EXCHANGE
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
os.environ.setdefault("CCL_ATL_TRANSPORT", "ofi")
os.environ.setdefault("LOCAL_WORLD_SIZE",
str(self.parallel_config.world_size))
os.environ["LOCAL_RANK"] = str(self.local_rank)

init_worker_distributed_environment(self.vllm_config, self.rank,
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, run_method,
Expand Down Expand Up @@ -531,7 +532,7 @@ def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
def update_environment_variables(self, envs_list: List[Dict[str,
str]]) -> None:
envs = envs_list[self.rpc_rank]
key = 'CUDA_VISIBLE_DEVICES'
key = current_platform.device_control_env_var
if key in envs and key in os.environ:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
# suppress the warning in `update_environment_variables`
Expand Down