Skip to content

Commit 5dddf56

Browse files
committed
support rlhf colocate for xpu platforms
Signed-off-by: chzhang <[email protected]>
1 parent f825c6b commit 5dddf56

File tree

5 files changed

+111
-21
lines changed

5 files changed

+111
-21
lines changed

examples/offline_inference/rlhf_colocate.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
3737

3838
from vllm import LLM
39+
from vllm.platforms import current_platform
3940

4041

4142
class MyLLM(LLM):
@@ -55,9 +56,15 @@ class MyLLM(LLM):
5556
def __init__(self, *args, bundle_indices: list[int], **kwargs):
5657
# Prevent Ray from manipulating the top-level CUDA_VISIBLE_DEVICES variable
5758
# so that vLLM can its own device placement inside the worker.
58-
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
59+
from vllm.platforms import current_platform
60+
61+
if current_platform.is_xpu():
62+
os.environ.pop("ONEAPI_DEVICE_SELECTOR", None)
63+
else:
64+
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
5965
# Each worker uses 0.4 GPU so that two instances fit on the same GPUs.
6066
os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.4"
67+
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"
6168
os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
6269
print(f"creating LLM with bundle_indices={bundle_indices}")
6370
super().__init__(*args, **kwargs)
@@ -76,14 +83,21 @@ def __init__(self):
7683
from transformers import AutoModelForCausalLM
7784

7885
self.model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
79-
self.model.to("cuda:0")
86+
from vllm.platforms import current_platform
87+
88+
if current_platform.is_xpu():
89+
self.model.to("xpu:0")
90+
else:
91+
self.model.to("cuda:0")
8092
# Zero out all the parameters.
8193
for name, p in self.model.named_parameters():
8294
p.data.zero_()
83-
torch.cuda.synchronize()
95+
if current_platform.is_xpu():
96+
torch.xpu.synchronize()
97+
else:
98+
torch.cuda.synchronize()
8499
# The argument for `get_device_uuid` is the index of the GPU in the
85100
# list of visible devices.
86-
from vllm.platforms import current_platform
87101

88102
self.device_uuid = current_platform.get_device_uuid(0)
89103

@@ -104,7 +118,7 @@ def get_weight_ipc_handles(self):
104118

105119
# Ray manages four GPUs.
106120

107-
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
121+
os.environ[current_platform.device_control_env_var] = "0,1,2,3"
108122
ray.init()
109123

110124
# Co-locate vLLM instances and training actors on the same set of GPUs:
@@ -123,6 +137,7 @@ def get_weight_ipc_handles(self):
123137
inference_engine_device_ids = []
124138

125139
for bundle_index in [0, 1, 2, 3]:
140+
env_vars = {current_platform.device_control_env_var: str(bundle_index)}
126141
training_actor = ray.remote(
127142
num_cpus=0,
128143
num_gpus=0.4,
@@ -131,6 +146,7 @@ def get_weight_ipc_handles(self):
131146
placement_group_capture_child_tasks=True,
132147
placement_group_bundle_index=bundle_index,
133148
),
149+
runtime_env={"env_vars": env_vars} if current_platform.is_xpu() else {},
134150
)(RayTrainingActor).remote()
135151
training_actors.append(training_actor)
136152

@@ -142,6 +158,7 @@ def get_weight_ipc_handles(self):
142158
for i, bundle_indices in enumerate([[0, 1], [2, 3]]):
143159
# Use the following syntax instead of the @ray.remote decorator so that
144160
# the placement group is customized for each bundle.
161+
# env_vars = {current_platform.device_control_env_var: bundle_indices_str}
145162
llm = ray.remote(
146163
num_cpus=0,
147164
num_gpus=0,

examples/offline_inference/rlhf_utils.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import os
4+
35
import torch
46

7+
from vllm.platforms import current_platform
8+
59

610
def stateless_init_process_group(master_address, master_port, rank, world_size, device):
711
"""
@@ -11,14 +15,44 @@ def stateless_init_process_group(master_address, master_port, rank, world_size,
1115
the data-plane communication (NCCL) between external (train processes)
1216
and vLLM workers.
1317
"""
14-
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
15-
from vllm.distributed.utils import StatelessProcessGroup
18+
if current_platform.is_xpu():
19+
from vllm.distributed.device_communicators.xpu_communicator import (
20+
XpuCommunicator,
21+
)
1622

17-
pg = StatelessProcessGroup.create(
18-
host=master_address, port=master_port, rank=rank, world_size=world_size
19-
)
20-
pynccl = PyNcclCommunicator(pg, device=device)
21-
return pynccl
23+
XPU_CCL_BACKEND = os.getenv("XPU_CCL_BACKEND", "xccl")
24+
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
25+
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", str(world_size))
26+
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
27+
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
28+
os.environ["LOCAL_RANK"] = str(rank)
29+
from vllm.utils import get_distributed_init_method
30+
31+
distributed_init_method = get_distributed_init_method(
32+
master_address, master_port
33+
)
34+
35+
if not torch.distributed.is_initialized():
36+
torch.distributed.init_process_group(
37+
backend=XPU_CCL_BACKEND,
38+
init_method=distributed_init_method,
39+
world_size=world_size,
40+
rank=rank,
41+
)
42+
43+
ranks = list(range(torch.distributed.get_world_size()))
44+
pg = torch.distributed.new_group(ranks, backend=XPU_CCL_BACKEND)
45+
xccl = XpuCommunicator(pg, device=device)
46+
return xccl
47+
else:
48+
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
49+
from vllm.distributed.utils import StatelessProcessGroup
50+
51+
pg = StatelessProcessGroup.create(
52+
host=master_address, port=master_port, rank=rank, world_size=world_size
53+
)
54+
pynccl = PyNcclCommunicator(pg, device=device)
55+
return pynccl
2256

2357

2458
class WorkerExtension:
@@ -47,10 +81,14 @@ def init_weight_update_group(
4781

4882
def update_weight(self, name, dtype_name, shape):
4983
dtype = getattr(torch, dtype_name)
50-
weight = torch.empty(shape, dtype=dtype, device="cuda")
51-
self.model_update_group.broadcast(
52-
weight, src=0, stream=torch.cuda.current_stream()
53-
)
84+
weight = torch.empty(shape, dtype=dtype, device=current_platform.device_type)
85+
86+
if current_platform.is_xpu():
87+
self.model_update_group.broadcast(weight, src=0)
88+
else:
89+
self.model_update_group.broadcast(
90+
weight, src=0, stream=torch.cuda.current_stream()
91+
)
5492

5593
self.model_runner.model.load_weights(weights=[(name, weight)])
5694

@@ -91,11 +129,18 @@ def update_weights_from_ipc_handles(self, ipc_handles):
91129
list_args = list(args)
92130
# the key is to change device id to the current device id
93131
# in case two processes have different CUDA_VISIBLE_DEVICES
94-
list_args[6] = device_id
95-
tensor = func(*list_args)
132+
if current_platform.is_xpu():
133+
tensor = func(*list_args)
134+
tensor = tensor.to(current_platform.device_type + ":" + str(device_id))
135+
else:
136+
list_args[6] = device_id
137+
tensor = func(*list_args)
96138
weights.append((name, tensor))
97139
self.model_runner.model.load_weights(weights=weights)
98-
torch.cuda.synchronize()
140+
if current_platform.is_xpu():
141+
torch.xpu.synchronize()
142+
else:
143+
torch.cuda.synchronize()
99144

100145
def check_weights_changed(self):
101146
"""

vllm/executor/ray_distributed_executor.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,14 @@ class RayDistributedExecutor(DistributedExecutorBase):
5555
# These env vars are worker-specific, therefore are NOT copied
5656
# from the driver to the workers
5757
WORKER_SPECIFIC_ENV_VARS = {
58-
"VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES"
58+
"VLLM_HOST_IP",
59+
"VLLM_HOST_PORT",
60+
"LOCAL_RANK",
5961
}
62+
if current_platform.is_xpu():
63+
WORKER_SPECIFIC_ENV_VARS.add(current_platform.device_control_env_var)
64+
else:
65+
WORKER_SPECIFIC_ENV_VARS.add("CUDA_VISIBLE_DEVICES")
6066

6167
# These non-vLLM env vars are copied from the driver to workers
6268
ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}
@@ -65,6 +71,11 @@ class RayDistributedExecutor(DistributedExecutorBase):
6571

6672
def _init_executor(self) -> None:
6773
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
74+
75+
if (current_platform.is_xpu()
76+
and "ONEAPI_DEVICE_SELECTOR" in os.environ):
77+
del os.environ[current_platform.device_control_env_var]
78+
6879
if envs.VLLM_USE_V1:
6980
# V1 uses SPMD worker and compiled DAG
7081
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
@@ -191,6 +202,14 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
191202

192203
worker_metadata: List[RayWorkerMetaData] = []
193204
driver_ip = get_ip()
205+
206+
if current_platform.is_xpu():
207+
bundle_indices_str = ",".join(map(str, bundle_indices))
208+
env_vars = {
209+
current_platform.device_control_env_var: bundle_indices_str
210+
}
211+
ray_remote_kwargs = {"runtime_env": {"env_vars": env_vars}}
212+
194213
for rank, bundle_id in enumerate(bundle_indices):
195214
scheduling_strategy = PlacementGroupSchedulingStrategy(
196215
placement_group=placement_group,

vllm/platforms/xpu.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,8 @@ def supports_v1(cls, model_config: ModelConfig) -> bool:
194194
@classmethod
195195
def device_count(cls) -> int:
196196
return torch.xpu.device_count()
197+
198+
@classmethod
199+
def get_device_uuid(cls, device_id: int = 0) -> str:
200+
physical_device_id = cls.device_id_to_physical_device_id(device_id)
201+
return "intel-gpu-" + str(physical_device_id)

vllm/worker/worker_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from vllm.logger import init_logger
1818
from vllm.lora.request import LoRARequest
1919
from vllm.model_executor.layers.sampler import SamplerOutput
20+
from vllm.platforms import current_platform
2021
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
2122
from vllm.utils import (enable_trace_function_call_for_thread,
2223
resolve_obj_by_qualname, run_method,
@@ -531,7 +532,10 @@ def adjust_rank(self, rank_mapping: Dict[int, int]) -> None:
531532
def update_environment_variables(self, envs_list: List[Dict[str,
532533
str]]) -> None:
533534
envs = envs_list[self.rpc_rank]
534-
key = 'CUDA_VISIBLE_DEVICES'
535+
if current_platform.is_xpu():
536+
key = current_platform.device_control_env_var
537+
else:
538+
key = "CUDA_VISIBLE_DEVICES"
535539
if key in envs and key in os.environ:
536540
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
537541
# suppress the warning in `update_environment_variables`

0 commit comments

Comments
 (0)