diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 80fc86861..5a58ec6a2 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -426,7 +426,7 @@ def matmul( state.threshold = threshold # MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU if state.is_training: - if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu" and ipex_xpu): + if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu"): return MatMul8bitFp.apply(A, B, out, bias, state) return MatMul8bitLt.apply(A, B, out, bias, state)