Skip to content

Commit 6a2fa83

Browse files
authored
XCCL for XPU (#3252)
Signed-off-by: Wang, Yi A <[email protected]>
1 parent b4386b8 commit 6a2fa83

File tree

3 files changed

+24
-16
lines changed

3 files changed

+24
-16
lines changed

Dockerfile_intel

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ RUN cargo build --profile release-opt --frozen
4545

4646
# Text Generation Inference base image for Intel
4747

48-
FROM intel/oneapi-basekit:2025.0.1-0-devel-ubuntu22.04 AS xpu
48+
FROM intel/oneapi-basekit:2025.1.3-0-devel-ubuntu22.04 AS xpu
4949

5050
USER root
5151

@@ -99,7 +99,7 @@ ENV HF_HOME=/data \
9999

100100
WORKDIR /usr/src
101101

102-
RUN pip install torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/xpu
102+
RUN pip install torch==2.8.0 torchvision==0.23.0 --index-url https://download.pytorch.org/whl/xpu
103103

104104
# Install server
105105
COPY proto proto
@@ -117,8 +117,7 @@ ENV TORCH_LLM_ALLREDUCE=1
117117
ENV CCL_TOPO_FABRIC_VERTEX_CONNECTION_CHECK=0
118118
ENV TORCH_DEVICE_BACKEND_AUTOLOAD=0
119119

120-
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/oneccl_bind_pt-2.7.0%2Bxpu-cp311-cp311-linux_x86_64.whl
121-
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.7.10%2Bxpu-cp311-cp311-linux_x86_64.whl
120+
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_stable/xpu/intel_extension_for_pytorch-2.8.10%2Bxpu-cp311-cp311-linux_x86_64.whl
122121
# Install benchmarker
123122
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
124123
# Install router

server/text_generation_server/layers/tensor_parallel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
9090
local_out = gather_input.T
9191

9292
torch.mm(input, self.linear.weight.T, out=local_out)
93-
if SYSTEM == "ipex":
93+
if SYSTEM == "ipex" and gather_input.device.type == "cpu":
9494
ipex.distributed.all_gather_into_tensor(
9595
world_out, gather_input, group=self.process_group
9696
)
@@ -107,7 +107,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
107107
world_output = [
108108
torch.empty_like(output) for _ in range(self.process_group.size())
109109
]
110-
if SYSTEM == "ipex":
110+
if SYSTEM == "ipex" and output.device.type == "cpu":
111111
ipex.distributed.all_gather(world_output, output, group=self.process_group)
112112
else:
113113
torch.distributed.all_gather(world_output, output, group=self.process_group)
@@ -202,7 +202,7 @@ def load(cls, config, prefix: str, weights, bias: bool):
202202
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
203203
out = super().forward(input)
204204
if self.process_group.size() > 1 and reduce:
205-
if SYSTEM == "ipex":
205+
if SYSTEM == "ipex" and out.device.type == "cpu":
206206
ipex.distributed.all_reduce(out, group=self.process_group)
207207
else:
208208
torch.distributed.all_reduce(out, group=self.process_group)
@@ -242,7 +242,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
242242
)
243243
out = torch.nn.functional.embedding(input, self.weight)
244244
if self.reduce and self.process_group.size() > 1:
245-
if SYSTEM == "ipex":
245+
if SYSTEM == "ipex" and out.device.type == "cpu":
246246
ipex.distributed.all_reduce(out, group=self.process_group)
247247
else:
248248
torch.distributed.all_reduce(out, group=self.process_group)

server/text_generation_server/utils/dist.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,23 @@ def initialize_torch_distributed():
7979
), "Each process is one xpu"
8080
device = RANK % torch.xpu.device_count()
8181
torch.xpu.set_device(device)
82-
83-
ipex.distributed.init_process_group(
84-
backend="ccl",
85-
world_size=WORLD_SIZE,
86-
rank=RANK,
87-
timeout=timedelta(seconds=120),
88-
pg_options=options,
89-
)
82+
device_id = torch.device(f"xpu:{RANK}")
83+
torch.distributed.init_process_group(
84+
backend="xccl",
85+
world_size=WORLD_SIZE,
86+
rank=RANK,
87+
timeout=timedelta(seconds=120),
88+
pg_options=options,
89+
device_id=device_id,
90+
)
91+
else:
92+
ipex.distributed.init_process_group(
93+
backend="ccl",
94+
world_size=WORLD_SIZE,
95+
rank=RANK,
96+
timeout=timedelta(seconds=120),
97+
pg_options=options,
98+
)
9099
else:
91100
device = torch.device(f"cuda:{RANK}")
92101
torch.distributed.init_process_group(

0 commit comments

Comments
 (0)