Skip to content

Commit ba2ba90

Browse files
Isotr0pyDN6
andauthored
Add cuda kernel support for GGUF inference (#11869)
* add gguf kernel support Signed-off-by: Isotr0py <[email protected]> * fix Signed-off-by: Isotr0py <[email protected]> * optimize Signed-off-by: Isotr0py <[email protected]> * update * update * update * update * update --------- Signed-off-by: Isotr0py <[email protected]> Co-authored-by: DN6 <[email protected]>
1 parent fa4c0e5 commit ba2ba90

File tree

7 files changed

+179
-4
lines changed

7 files changed

+179
-4
lines changed

.github/workflows/nightly_tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ jobs:
333333
additional_deps: ["peft"]
334334
- backend: "gguf"
335335
test_location: "gguf"
336-
additional_deps: ["peft"]
336+
additional_deps: ["peft", "kernels"]
337337
- backend: "torchao"
338338
test_location: "torchao"
339339
additional_deps: []

docs/source/en/quantization/gguf.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
5353
image.save("flux-gguf.png")
5454
```
5555

56+
## Using Optimized CUDA Kernels with GGUF
57+
58+
Optimized CUDA kernels can accelerate GGUF quantized model inference by approximately 10%. This functionality requires a compatible GPU with `torch.cuda.get_device_capability` greater than 7 and the kernels library:
59+
60+
```shell
61+
pip install -U kernels
62+
```
63+
64+
Once installed, set `DIFFUSERS_GGUF_CUDA_KERNELS=true` to use optimized kernels when available. Note that CUDA kernels may introduce minor numerical differences compared to the original GGUF implementation, potentially causing subtle visual variations in generated images. To disable CUDA kernel usage, set the environment variable `DIFFUSERS_GGUF_CUDA_KERNELS=false`.
65+
5666
## Supported Quantization Types
5767

5868
- BF16

src/diffusers/quantizers/gguf/utils.py

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
# # See the License for the specific language governing permissions and
1313
# # limitations under the License.
1414

15-
1615
import inspect
16+
import os
1717
from contextlib import nullcontext
1818

1919
import gguf
2020
import torch
2121
import torch.nn as nn
2222

23-
from ...utils import is_accelerate_available
23+
from ...utils import is_accelerate_available, is_kernels_available
2424

2525

2626
if is_accelerate_available():
@@ -29,6 +29,82 @@
2929
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
3030

3131

32+
can_use_cuda_kernels = (
33+
os.getenv("DIFFUSERS_GGUF_CUDA_KERNELS", "false").lower() in ["1", "true", "yes"]
34+
and torch.cuda.is_available()
35+
and torch.cuda.get_device_capability()[0] >= 7
36+
)
37+
if can_use_cuda_kernels and is_kernels_available():
38+
from kernels import get_kernel
39+
40+
ops = get_kernel("Isotr0py/ggml")
41+
else:
42+
ops = None
43+
44+
UNQUANTIZED_TYPES = {gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16}
45+
STANDARD_QUANT_TYPES = {
46+
gguf.GGMLQuantizationType.Q4_0,
47+
gguf.GGMLQuantizationType.Q4_1,
48+
gguf.GGMLQuantizationType.Q5_0,
49+
gguf.GGMLQuantizationType.Q5_1,
50+
gguf.GGMLQuantizationType.Q8_0,
51+
gguf.GGMLQuantizationType.Q8_1,
52+
}
53+
KQUANT_TYPES = {
54+
gguf.GGMLQuantizationType.Q2_K,
55+
gguf.GGMLQuantizationType.Q3_K,
56+
gguf.GGMLQuantizationType.Q4_K,
57+
gguf.GGMLQuantizationType.Q5_K,
58+
gguf.GGMLQuantizationType.Q6_K,
59+
}
60+
IMATRIX_QUANT_TYPES = {
61+
gguf.GGMLQuantizationType.IQ1_M,
62+
gguf.GGMLQuantizationType.IQ1_S,
63+
gguf.GGMLQuantizationType.IQ2_XXS,
64+
gguf.GGMLQuantizationType.IQ2_XS,
65+
gguf.GGMLQuantizationType.IQ2_S,
66+
gguf.GGMLQuantizationType.IQ3_XXS,
67+
gguf.GGMLQuantizationType.IQ3_S,
68+
gguf.GGMLQuantizationType.IQ4_XS,
69+
gguf.GGMLQuantizationType.IQ4_NL,
70+
}
71+
# TODO(Isotr0py): Currently, we don't have MMQ kernel for I-Matrix quantization.
72+
# Consolidate DEQUANT_TYPES, MMVQ_QUANT_TYPES and MMQ_QUANT_TYPES after we add
73+
# MMQ kernel for I-Matrix quantization.
74+
DEQUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
75+
MMVQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES | IMATRIX_QUANT_TYPES
76+
MMQ_QUANT_TYPES = STANDARD_QUANT_TYPES | KQUANT_TYPES
77+
78+
79+
def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: int) -> torch.Tensor:
80+
# there is no need to call any kernel for fp16/bf16
81+
if qweight_type in UNQUANTIZED_TYPES:
82+
return x @ qweight.T
83+
84+
# TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for
85+
# contiguous batching and inefficient with diffusers' batching,
86+
# so we disabled it now.
87+
88+
# elif qweight_type in MMVQ_QUANT_TYPES:
89+
# y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
90+
# elif qweight_type in MMQ_QUANT_TYPES:
91+
# y = ops.ggml_mul_mat_a8(qweight, x, qweight_type, qweight.shape[0])
92+
93+
# If there is no available MMQ kernel, fallback to dequantize
94+
if qweight_type in DEQUANT_TYPES:
95+
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
96+
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
97+
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
98+
y = x @ weight.to(x.dtype).T
99+
else:
100+
# Raise an error if the quantization type is not supported.
101+
# Might be useful if llama.cpp adds a new quantization type.
102+
# Wrap to GGMLQuantizationType IntEnum to make sure it's a valid type.
103+
qweight_type = gguf.GGMLQuantizationType(qweight_type)
104+
raise NotImplementedError(f"Unsupported GGUF quantization type: {qweight_type}")
105+
return y.as_tensor()
106+
107+
32108
# Copied from diffusers.quantizers.bitsandbytes.utils._create_accelerate_new_hook
33109
def _create_accelerate_new_hook(old_hook):
34110
r"""
@@ -451,11 +527,24 @@ def __init__(
451527
) -> None:
452528
super().__init__(in_features, out_features, bias, device)
453529
self.compute_dtype = compute_dtype
530+
self.device = device
531+
532+
def forward(self, inputs: torch.Tensor):
533+
if ops is not None and self.weight.is_cuda and inputs.is_cuda:
534+
return self.forward_cuda(inputs)
535+
return self.forward_native(inputs)
454536

455-
def forward(self, inputs):
537+
def forward_native(self, inputs: torch.Tensor):
456538
weight = dequantize_gguf_tensor(self.weight)
457539
weight = weight.to(self.compute_dtype)
458540
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
459541

460542
output = torch.nn.functional.linear(inputs, weight, bias)
461543
return output
544+
545+
def forward_cuda(self, inputs: torch.Tensor):
546+
quant_type = self.weight.quant_type
547+
output = _fused_mul_mat_gguf(inputs.to(self.compute_dtype), self.weight, quant_type)
548+
if self.bias is not None:
549+
output += self.bias.to(self.compute_dtype)
550+
return output

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
is_invisible_watermark_available,
8282
is_k_diffusion_available,
8383
is_k_diffusion_version,
84+
is_kernels_available,
8485
is_librosa_available,
8586
is_matplotlib_available,
8687
is_nltk_available,

src/diffusers/utils/import_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
192192
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
193193
_transformers_available, _transformers_version = _is_package_available("transformers")
194194
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
195+
_kernels_available, _kernels_version = _is_package_available("kernels")
195196
_inflect_available, _inflect_version = _is_package_available("inflect")
196197
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
197198
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
@@ -277,6 +278,10 @@ def is_accelerate_available():
277278
return _accelerate_available
278279

279280

281+
def is_kernels_available():
282+
return _kernels_available
283+
284+
280285
def is_k_diffusion_available():
281286
return _k_diffusion_available
282287

src/diffusers/utils/testing_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
is_compel_available,
3737
is_flax_available,
3838
is_gguf_available,
39+
is_kernels_available,
3940
is_note_seq_available,
4041
is_onnx_available,
4142
is_opencv_available,
@@ -634,6 +635,18 @@ def decorator(test_case):
634635
return decorator
635636

636637

638+
def require_kernels_version_greater_or_equal(kernels_version):
639+
def decorator(test_case):
640+
correct_kernels_version = is_kernels_available() and version.parse(
641+
version.parse(importlib.metadata.version("kernels")).base_version
642+
) >= version.parse(kernels_version)
643+
return unittest.skipUnless(
644+
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
645+
)(test_case)
646+
647+
return decorator
648+
649+
637650
def deprecate_after_peft_backend(test_case):
638651
"""
639652
Decorator marking a test that will be skipped after PEFT backend

tests/quantization/gguf/test_gguf.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
nightly,
3131
numpy_cosine_similarity_distance,
3232
require_accelerate,
33+
require_accelerator,
3334
require_big_accelerator,
3435
require_gguf_version_greater_or_equal,
36+
require_kernels_version_greater_or_equal,
3537
require_peft_backend,
3638
require_torch_version_greater,
3739
torch_device,
@@ -41,11 +43,66 @@
4143

4244

4345
if is_gguf_available():
46+
import gguf
47+
4448
from diffusers.quantizers.gguf.utils import GGUFLinear, GGUFParameter
4549

4650
enable_full_determinism()
4751

4852

53+
@nightly
54+
@require_accelerate
55+
@require_accelerator
56+
@require_gguf_version_greater_or_equal("0.10.0")
57+
@require_kernels_version_greater_or_equal("0.9.0")
58+
class GGUFCudaKernelsTests(unittest.TestCase):
59+
def setUp(self):
60+
gc.collect()
61+
backend_empty_cache(torch_device)
62+
63+
def tearDown(self):
64+
gc.collect()
65+
backend_empty_cache(torch_device)
66+
67+
def test_cuda_kernels_vs_native(self):
68+
if torch_device != "cuda":
69+
self.skipTest("CUDA kernels test requires CUDA device")
70+
71+
from diffusers.quantizers.gguf.utils import GGUFLinear, can_use_cuda_kernels
72+
73+
if not can_use_cuda_kernels:
74+
self.skipTest("CUDA kernels not available (compute capability < 7 or kernels not installed)")
75+
76+
test_quant_types = ["Q4_0", "Q4_K"]
77+
test_shape = (1, 64, 512) # batch, seq_len, hidden_dim
78+
compute_dtype = torch.bfloat16
79+
80+
for quant_type in test_quant_types:
81+
qtype = getattr(gguf.GGMLQuantizationType, quant_type)
82+
in_features, out_features = 512, 512
83+
84+
torch.manual_seed(42)
85+
float_weight = torch.randn(out_features, in_features, dtype=torch.float32)
86+
quantized_data = gguf.quants.quantize(float_weight.numpy(), qtype)
87+
weight_data = torch.from_numpy(quantized_data).to(device=torch_device)
88+
weight = GGUFParameter(weight_data, quant_type=qtype)
89+
90+
x = torch.randn(test_shape, dtype=compute_dtype, device=torch_device)
91+
92+
linear = GGUFLinear(in_features, out_features, bias=True, compute_dtype=compute_dtype)
93+
linear.weight = weight
94+
linear.bias = nn.Parameter(torch.randn(out_features, dtype=compute_dtype))
95+
linear = linear.to(torch_device)
96+
97+
with torch.no_grad():
98+
output_native = linear.forward_native(x)
99+
output_cuda = linear.forward_cuda(x)
100+
101+
assert torch.allclose(output_native, output_cuda, 1e-2), (
102+
f"GGUF CUDA Kernel Output is different from Native Output for {quant_type}"
103+
)
104+
105+
49106
@nightly
50107
@require_big_accelerator
51108
@require_accelerate

0 commit comments

Comments
 (0)