From f1d7192fd479b0007675b3884ecc4e4104aae010 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 17 Jul 2025 18:48:34 -0700 Subject: [PATCH] some gptq case could not be handled by ipex. but could be handle by triton Signed-off-by: Wang, Yi A --- .../layers/gptq/__init__.py | 19 ++++++++++++++----- .../layers/gptq/triton.py | 6 +++++- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index b5549916be2..b5d4fb04fe6 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -12,11 +12,7 @@ WeightsLoader, DefaultWeightsLoader, ) - -if SYSTEM == "ipex": - from .ipex import QuantLinear -elif SYSTEM in {"cuda", "rocm"}: - from .triton import QuantLinear +import math @dataclass @@ -70,6 +66,19 @@ def get_linear(self, bias: torch.Tensor): return ExllamaQuantLinear(self, bias) else: + if SYSTEM == "ipex" and not ( + self.device.type == "xpu" + and ( + self.bits != 4 + or math.ceil( + (self.qweight.shape[0] * 32 // self.bits) / self.groupsize + ) + != self.scales.shape[0] + ) + ): + from .ipex import QuantLinear + else: + from .triton import QuantLinear return QuantLinear( self.qweight, self.qzeros, diff --git a/server/text_generation_server/layers/gptq/triton.py b/server/text_generation_server/layers/gptq/triton.py index 736c357b094..87dcd7dabfe 100644 --- a/server/text_generation_server/layers/gptq/triton.py +++ b/server/text_generation_server/layers/gptq/triton.py @@ -202,7 +202,11 @@ def matmul_248_kernel( def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): - with torch.cuda.device(input.device): + with ( + torch.xpu.device(input.device) + if torch.xpu.is_available() + else torch.cuda.device(input.device) + ): output = torch.empty( (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 )