Skip to content

[WIP] SVDQuant #11950

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@

from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from ..quantizers.quantization_config import QuantizationMethod
from ..utils import deprecate, is_accelerate_available, is_nunchaku_available, logging
from ..utils.torch_utils import empty_device_cache
from .single_file_utils import (
SingleFileComponentError,
Expand Down Expand Up @@ -243,6 +244,28 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
>>> model = StableCascadeUNet.from_single_file(ckpt_path)
```
"""
quantization_config = kwargs.get("quantization_config")
if quantization_config is not None and quantization_config.quant_method == QuantizationMethod.SVDQUANT:
if not is_nunchaku_available():
raise ImportError("Loading SVDQuant models requires the `nunchaku` package. Please install it.")

if isinstance(pretrained_model_link_or_path_or_dict, dict):
raise ValueError(
"Loading a nunchaku model from a state_dict is not supported directly via from_single_file. Please provide a path."
)

if "FluxTransformer2DModel" in cls.__name__:
from nunchaku import NunchakuFluxTransformer2dModel

kwargs.pop("quantization_config", None)
return NunchakuFluxTransformer2dModel.from_pretrained(pretrained_model_link_or_path_or_dict, **kwargs)
elif "SanaTransformer2DModel" in cls.__name__:
from nunchaku import NunchakuSanaTransformer2DModel

kwargs.pop("quantization_config", None)
return NunchakuSanaTransformer2DModel.from_pretrained(pretrained_model_link_or_path_or_dict, **kwargs)
else:
raise NotImplementedError(f"SVDQuant loading is not implemented for {cls.__name__}")

mapping_class_name = _get_single_file_loadable_mapping_class(cls)
# if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
SVDQuantConfig,
TorchAoConfig,
)
from .quanto import QuantoQuantizer
from .svdquant import SVDQuantizer
from .torchao import TorchAoHfQuantizer


Expand All @@ -39,6 +41,7 @@
"gguf": GGUFQuantizer,
"quanto": QuantoQuantizer,
"torchao": TorchAoHfQuantizer,
"svdquant": SVDQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
Expand All @@ -47,6 +50,7 @@
"gguf": GGUFQuantizationConfig,
"quanto": QuantoConfig,
"torchao": TorchAoConfig,
"svdquant": SVDQuantConfig,
}


Expand Down
11 changes: 11 additions & 0 deletions src/diffusers/quantizers/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class QuantizationMethod(str, Enum):
GGUF = "gguf"
TORCHAO = "torchao"
QUANTO = "quanto"
SVDQUANT = "svdquant"


if is_torchao_available():
Expand Down Expand Up @@ -724,3 +725,13 @@ def post_init(self):
accepted_weights = ["float8", "int8", "int4", "int2"]
if self.weights_dtype not in accepted_weights:
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")


@dataclass
class SVDQuantConfig(QuantizationConfigMixin):
"""Config for SVDQuant models. This is a placeholder for loading pre-quantized nunchaku models."""

def __init__(self, **kwargs):
self.quant_method = QuantizationMethod.SVDQUANT
for key, value in kwargs.items():
setattr(self, key, value)
1 change: 1 addition & 0 deletions src/diffusers/quantizers/svdquant/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .svdquant_quantizer import SVDQuantizer
43 changes: 43 additions & 0 deletions src/diffusers/quantizers/svdquant/svdquant_quantizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..base import DiffusersQuantizer


class SVDQuantizer(DiffusersQuantizer):
"""
SVDQuantizer is a placeholder quantizer for loading pre-quantized SVDQuant models using the nunchaku library.
"""

use_keep_in_fp32_modules = False
requires_calibration = False

def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)

def _process_model_before_weight_loading(self, model, **kwargs):
# No-op, as the model is fully loaded by nunchaku.
return model

def _process_model_after_weight_loading(self, model, **kwargs):
return model

@property
def is_serializable(self):
# The model is serialized in its own format.
return True

@property
def is_trainable(self):
return False
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
is_matplotlib_available,
is_nltk_available,
is_note_seq_available,
is_nunchaku_available,
is_onnx_available,
is_opencv_available,
is_optimum_quanto_available,
Expand Down
11 changes: 11 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
_nunchaku_available, _nunchaku_version = _is_package_available("nunchaku")


def is_torch_available():
Expand Down Expand Up @@ -393,6 +394,10 @@ def is_flash_attn_3_available():
return _flash_attn_3_available


def is_nunchaku_available():
return _nunchaku_available


# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
Expand Down Expand Up @@ -556,6 +561,11 @@ def is_flash_attn_3_available():
{0} requires the nltk library but it was not found in your environment. You can install it with pip: `pip install nltk`
"""

NUNCHAKU_IMPORT_ERROR = """
{0} requires the nunchaku library but it was not found in your environment. You can install it with pip: `pip install
nunchaku`
"""


BACKENDS_MAPPING = OrderedDict(
[
Expand Down Expand Up @@ -588,6 +598,7 @@ def is_flash_attn_3_available():
("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)),
("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)),
("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
("nunchaku", (is_nunchaku_available, NUNCHAKU_IMPORT_ERROR)),
]
)

Expand Down
Loading