From b5bf4cf1ca727331bd6f48a075ab6091e6e1aca1 Mon Sep 17 00:00:00 2001 From: DerekLiu35 Date: Thu, 17 Jul 2025 23:07:57 +0200 Subject: [PATCH 1/2] init --- src/diffusers/loaders/single_file_model.py | 29 ++++++++++++- src/diffusers/quantizers/auto.py | 4 ++ .../quantizers/quantization_config.py | 10 +++++ src/diffusers/quantizers/svdquant/__init__.py | 1 + .../quantizers/svdquant/svdquant_quantizer.py | 43 +++++++++++++++++++ src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 10 +++++ 7 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/quantizers/svdquant/__init__.py create mode 100644 src/diffusers/quantizers/svdquant/svdquant_quantizer.py diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 76fefc1260d0..8437eb540127 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -23,7 +23,8 @@ from .. import __version__ from ..quantizers import DiffusersAutoQuantizer -from ..utils import deprecate, is_accelerate_available, logging +from ..quantizers.quantization_config import QuantizationMethod +from ..utils import deprecate, is_accelerate_available, is_nunchaku_available, logging from ..utils.torch_utils import empty_device_cache from .single_file_utils import ( SingleFileComponentError, @@ -243,6 +244,32 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = >>> model = StableCascadeUNet.from_single_file(ckpt_path) ``` """ + quantization_config = kwargs.get("quantization_config") + if quantization_config is not None and quantization_config.quant_method == QuantizationMethod.SVDQUANT: + if not is_nunchaku_available(): + raise ImportError("Loading SVDQuant models requires the `nunchaku` package. Please install it.") + + if isinstance(pretrained_model_link_or_path_or_dict, dict): + raise ValueError( + "Loading a nunchaku model from a state_dict is not supported directly via from_single_file. Please provide a path." + ) + + if "FluxTransformer2DModel" in cls.__name__: + from nunchaku import NunchakuFluxTransformer2dModel + + kwargs.pop("quantization_config", None) + return NunchakuFluxTransformer2dModel.from_pretrained( + pretrained_model_link_or_path_or_dict, **kwargs + ) + elif "SanaTransformer2DModel" in cls.__name__: + from nunchaku import NunchakuSanaTransformer2DModel + + kwargs.pop("quantization_config", None) + return NunchakuSanaTransformer2DModel.from_pretrained( + pretrained_model_link_or_path_or_dict, **kwargs + ) + else: + raise NotImplementedError(f"SVDQuant loading is not implemented for {cls.__name__}") mapping_class_name = _get_single_file_loadable_mapping_class(cls) # if class_name not in SINGLE_FILE_LOADABLE_CLASSES: diff --git a/src/diffusers/quantizers/auto.py b/src/diffusers/quantizers/auto.py index ce214ae7bc17..c293abe0782d 100644 --- a/src/diffusers/quantizers/auto.py +++ b/src/diffusers/quantizers/auto.py @@ -27,9 +27,11 @@ QuantizationConfigMixin, QuantizationMethod, QuantoConfig, + SVDQuantConfig, TorchAoConfig, ) from .quanto import QuantoQuantizer +from .svdquant import SVDQuantizer from .torchao import TorchAoHfQuantizer @@ -39,6 +41,7 @@ "gguf": GGUFQuantizer, "quanto": QuantoQuantizer, "torchao": TorchAoHfQuantizer, + "svdquant": SVDQuantizer, } AUTO_QUANTIZATION_CONFIG_MAPPING = { @@ -47,6 +50,7 @@ "gguf": GGUFQuantizationConfig, "quanto": QuantoConfig, "torchao": TorchAoConfig, + "svdquant": SVDQuantConfig, } diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 871faf076e5a..63e6512e7515 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -46,6 +46,7 @@ class QuantizationMethod(str, Enum): GGUF = "gguf" TORCHAO = "torchao" QUANTO = "quanto" + SVDQUANT = "svdquant" if is_torchao_available(): @@ -724,3 +725,12 @@ def post_init(self): accepted_weights = ["float8", "int8", "int4", "int2"] if self.weights_dtype not in accepted_weights: raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}") + +@dataclass +class SVDQuantConfig(QuantizationConfigMixin): + """Config for SVDQuant models. This is a placeholder for loading pre-quantized nunchaku models.""" + + def __init__(self, **kwargs): + self.quant_method = QuantizationMethod.SVDQUANT + for key, value in kwargs.items(): + setattr(self, key, value) diff --git a/src/diffusers/quantizers/svdquant/__init__.py b/src/diffusers/quantizers/svdquant/__init__.py new file mode 100644 index 000000000000..f07a6438f8e3 --- /dev/null +++ b/src/diffusers/quantizers/svdquant/__init__.py @@ -0,0 +1 @@ +from .svdquant_quantizer import SVDQuantizer \ No newline at end of file diff --git a/src/diffusers/quantizers/svdquant/svdquant_quantizer.py b/src/diffusers/quantizers/svdquant/svdquant_quantizer.py new file mode 100644 index 000000000000..9cae31bff0ac --- /dev/null +++ b/src/diffusers/quantizers/svdquant/svdquant_quantizer.py @@ -0,0 +1,43 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..base import DiffusersQuantizer + + +class SVDQuantizer(DiffusersQuantizer): + """ + SVDQuantizer is a placeholder quantizer for loading pre-quantized SVDQuant models using the nunchaku library. + """ + + use_keep_in_fp32_modules = False + requires_calibration = False + + def __init__(self, quantization_config, **kwargs): + super().__init__(quantization_config, **kwargs) + + def _process_model_before_weight_loading(self, model, **kwargs): + # No-op, as the model is fully loaded by nunchaku. + return model + + def _process_model_after_weight_loading(self, model, **kwargs): + return model + + @property + def is_serializable(self): + # The model is serialized in its own format. + return True + + @property + def is_trainable(self): + return False \ No newline at end of file diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index cadcedb98a14..52a30ae087eb 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -85,6 +85,7 @@ is_matplotlib_available, is_nltk_available, is_note_seq_available, + is_nunchaku_available, is_onnx_available, is_opencv_available, is_optimum_quanto_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index a27c2da648f4..b3f8eb2c0bcd 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -223,6 +223,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _sageattention_available, _sageattention_version = _is_package_available("sageattention") _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3") +_nunchaku_available, _nunchaku_version = _is_package_available("nunchaku") def is_torch_available(): @@ -393,6 +394,10 @@ def is_flash_attn_3_available(): return _flash_attn_3_available +def is_nunchaku_available(): + return _nunchaku_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -556,6 +561,10 @@ def is_flash_attn_3_available(): {0} requires the nltk library but it was not found in your environment. You can install it with pip: `pip install nltk` """ +NUNCHAKU_IMPORT_ERROR = """ +{0} requires the nunchaku library but it was not found in your environment. You can install it with pip: `pip install nunchaku` +""" + BACKENDS_MAPPING = OrderedDict( [ @@ -588,6 +597,7 @@ def is_flash_attn_3_available(): ("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)), ("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)), ("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)), + ("nunchaku", (is_nunchaku_available, NUNCHAKU_IMPORT_ERROR)), ] ) From 1a1857cbadf20e91d06cfce5314196b48a0f7a90 Mon Sep 17 00:00:00 2001 From: DerekLiu35 Date: Thu, 17 Jul 2025 23:41:11 +0200 Subject: [PATCH 2/2] fix style --- src/diffusers/loaders/single_file_model.py | 8 ++------ src/diffusers/quantizers/quantization_config.py | 1 + src/diffusers/quantizers/svdquant/__init__.py | 2 +- src/diffusers/quantizers/svdquant/svdquant_quantizer.py | 2 +- src/diffusers/utils/import_utils.py | 3 ++- 5 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index 8437eb540127..0e90294e7a3a 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -258,16 +258,12 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = from nunchaku import NunchakuFluxTransformer2dModel kwargs.pop("quantization_config", None) - return NunchakuFluxTransformer2dModel.from_pretrained( - pretrained_model_link_or_path_or_dict, **kwargs - ) + return NunchakuFluxTransformer2dModel.from_pretrained(pretrained_model_link_or_path_or_dict, **kwargs) elif "SanaTransformer2DModel" in cls.__name__: from nunchaku import NunchakuSanaTransformer2DModel kwargs.pop("quantization_config", None) - return NunchakuSanaTransformer2DModel.from_pretrained( - pretrained_model_link_or_path_or_dict, **kwargs - ) + return NunchakuSanaTransformer2DModel.from_pretrained(pretrained_model_link_or_path_or_dict, **kwargs) else: raise NotImplementedError(f"SVDQuant loading is not implemented for {cls.__name__}") diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 63e6512e7515..77447afeaad9 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -726,6 +726,7 @@ def post_init(self): if self.weights_dtype not in accepted_weights: raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}") + @dataclass class SVDQuantConfig(QuantizationConfigMixin): """Config for SVDQuant models. This is a placeholder for loading pre-quantized nunchaku models.""" diff --git a/src/diffusers/quantizers/svdquant/__init__.py b/src/diffusers/quantizers/svdquant/__init__.py index f07a6438f8e3..9ac5838bfc0f 100644 --- a/src/diffusers/quantizers/svdquant/__init__.py +++ b/src/diffusers/quantizers/svdquant/__init__.py @@ -1 +1 @@ -from .svdquant_quantizer import SVDQuantizer \ No newline at end of file +from .svdquant_quantizer import SVDQuantizer diff --git a/src/diffusers/quantizers/svdquant/svdquant_quantizer.py b/src/diffusers/quantizers/svdquant/svdquant_quantizer.py index 9cae31bff0ac..f619e6cd8d6d 100644 --- a/src/diffusers/quantizers/svdquant/svdquant_quantizer.py +++ b/src/diffusers/quantizers/svdquant/svdquant_quantizer.py @@ -40,4 +40,4 @@ def is_serializable(self): @property def is_trainable(self): - return False \ No newline at end of file + return False diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index b3f8eb2c0bcd..43efa7be01eb 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -562,7 +562,8 @@ def is_nunchaku_available(): """ NUNCHAKU_IMPORT_ERROR = """ -{0} requires the nunchaku library but it was not found in your environment. You can install it with pip: `pip install nunchaku` +{0} requires the nunchaku library but it was not found in your environment. You can install it with pip: `pip install +nunchaku` """