diff --git a/server/text_generation_server/layers/gptq/__init__.py b/server/text_generation_server/layers/gptq/__init__.py index b5549916be2..b5d4fb04fe6 100644 --- a/server/text_generation_server/layers/gptq/__init__.py +++ b/server/text_generation_server/layers/gptq/__init__.py @@ -12,11 +12,7 @@ WeightsLoader, DefaultWeightsLoader, ) - -if SYSTEM == "ipex": - from .ipex import QuantLinear -elif SYSTEM in {"cuda", "rocm"}: - from .triton import QuantLinear +import math @dataclass @@ -70,6 +66,19 @@ def get_linear(self, bias: torch.Tensor): return ExllamaQuantLinear(self, bias) else: + if SYSTEM == "ipex" and not ( + self.device.type == "xpu" + and ( + self.bits != 4 + or math.ceil( + (self.qweight.shape[0] * 32 // self.bits) / self.groupsize + ) + != self.scales.shape[0] + ) + ): + from .ipex import QuantLinear + else: + from .triton import QuantLinear return QuantLinear( self.qweight, self.qzeros, diff --git a/server/text_generation_server/layers/gptq/triton.py b/server/text_generation_server/layers/gptq/triton.py index 736c357b094..87dcd7dabfe 100644 --- a/server/text_generation_server/layers/gptq/triton.py +++ b/server/text_generation_server/layers/gptq/triton.py @@ -202,7 +202,11 @@ def matmul_248_kernel( def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): - with torch.cuda.device(input.device): + with ( + torch.xpu.device(input.device) + if torch.xpu.is_available() + else torch.cuda.device(input.device) + ): output = torch.empty( (input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16 )