From ad74d321ac0483158bcd8ad873b4e079e6cb9b5b Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Mon, 28 Jul 2025 20:21:59 +0000 Subject: [PATCH 01/11] Update `apply_quantiation_config` to use `match_named_modules` Signed-off-by: Fynn Schmitt-Ulms --- .../quantization/lifecycle/apply.py | 70 ++++++++----------- src/compressed_tensors/utils/match.py | 58 +++++++++++++-- 2 files changed, 82 insertions(+), 46 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 7afd2aba..431fadd6 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -40,6 +40,7 @@ is_kv_cache_quant_scheme, ) from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module +from compressed_tensors.utils.match import match_named_modules from compressed_tensors.utils.offload import update_parameter_data from compressed_tensors.utils.safetensors_load import get_safetensors_folder from safetensors import safe_open @@ -147,47 +148,35 @@ def apply_quantization_config( if run_compressed: from compressed_tensors.linear.compressed_linear import CompressedLinear - # list of submodules to ignore - ignored_submodules = defaultdict(list) # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in model.named_modules(): - # potentially fix module name to remove FSDP wrapper prefix - name = fix_fsdp_module_name(name) - if matches := find_name_or_class_matches(name, submodule, config.ignore): - for match in matches: - ignored_submodules[match].append(name) - continue # layer matches ignore list, continue - - targets = find_name_or_class_matches(name, submodule, target_to_scheme) - - if targets: - # mark modules to be quantized by adding - # quant scheme to the matching layers - scheme = _scheme_from_targets(target_to_scheme, targets, name) - if run_compressed: - format = config.format - if format != CompressionFormat.dense.value: - if isinstance(submodule, torch.nn.Linear): - # TODO: expand to more module types - compressed_linear = CompressedLinear.from_linear( - submodule, - quantization_scheme=scheme, - quantization_format=format, - ) - replace_module(model, name, compressed_linear) - - # target matched - add layer and scheme to target list - submodule.quantization_scheme = scheme - - names_to_scheme[name] = submodule.quantization_scheme - - if config.ignore is not None and ignored_submodules is not None: - if set(config.ignore) - set(ignored_submodules): - _LOGGER.warning( - "Some layers that were to be ignored were " - "not found in the model: " - f"{set(config.ignore) - set(ignored_submodules)}" - ) + for name, submodule, matched_targets in match_named_modules( + model, + target_to_scheme, + config.ignore or [], + warn_on_fail=True, + warn_on_unmatched_ignores=True, + return_matched_targets=True, + preprocess_name=fix_fsdp_module_name, + ): + # mark modules to be quantized by adding + # quant scheme to the matching layers + scheme = _scheme_from_targets(target_to_scheme, matched_targets, name) + if run_compressed: + format = config.format + if format != CompressionFormat.dense.value: + if isinstance(submodule, torch.nn.Linear): + # TODO: expand to more module types + compressed_linear = CompressedLinear.from_linear( + submodule, + quantization_scheme=scheme, + quantization_format=format, + ) + replace_module(model, name, compressed_linear) + + # target matched - add layer and scheme to target list + submodule.quantization_scheme = scheme + + names_to_scheme[name] = submodule.quantization_scheme # apply current quantization status across all targeted layers apply_quantization_status(model, config.quantization_status) @@ -429,7 +418,6 @@ def _scheme_from_targets( def _merge_schemes( schemes_to_merge: List[QuantizationScheme], name: str ) -> QuantizationScheme: - kv_cache_quantization_scheme = [ scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme) ] diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 21ce4a0b..9ccade77 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -15,7 +15,7 @@ import logging import re from collections.abc import Generator -from typing import Iterable, Tuple +from typing import Callable, Iterable, Tuple import torch from compressed_tensors.utils.internal import InternalModule @@ -35,8 +35,11 @@ def match_named_modules( model: torch.nn.Module, targets: Iterable[str], - ignore: Iterable[str] = tuple(), + ignore: Iterable[str] | None = tuple(), warn_on_fail: bool = False, + warn_on_unmatched_ignores: bool = False, + return_matched_targets: bool = False, + preprocess_name: Callable[[str], str] = lambda x: x, ) -> Generator[Tuple[str, torch.nn.Module]]: """ Yields names and modules which match `targets` but do not match `ignore`. @@ -48,14 +51,53 @@ def match_named_modules( :param warn_on_fail: if True, warns if any targets do not match any modules in model :return: generator of module names and modules """ + ignore = ignore or [] + unmatched_targets = set(targets) + unmatched_ignores = set(ignore) + + # Order targets by type: exact name match, regex name match, class name match + targets = sorted(targets, key=lambda x: ("re:" in x, x)) for name, module in model.named_modules(): + # preprocess the module name and module + name = preprocess_name(name) + + ignore_matched = False + for ign in ignore: + if is_match(name, module, ign): + unmatched_ignores -= {ign} + ignore_matched = True + break + if ignore_matched: + continue + + matched_targets = [] + # Check for name matches first (exact then regex) for target in targets: - if is_match(name, module, target): + if match_name(name, target): unmatched_targets -= {target} + matched_targets.append(target) + if not return_matched_targets: + break - if not any(is_match(name, module, ign) for ign in ignore): - yield name, module + if not return_matched_targets and matched_targets: + # Don't need to check other targets, one match is enough + yield name, module + continue + + # Check for class matches + for target in targets: + if match_class(module, target): + unmatched_targets -= {target} + matched_targets.append(target) + if not return_matched_targets: + break + + if matched_targets: + if return_matched_targets: + yield name, module, matched_targets + else: + yield name, module if warn_on_fail: for target in unmatched_targets: @@ -63,6 +105,12 @@ def match_named_modules( f"Could not match `{target}` in instance of {model.__class__.__name__}" ) + if warn_on_unmatched_ignores: + for ign in unmatched_ignores: + _LOGGER.warning( + f"Unmatched ignore targets: {unmatched_ignores}, in instance of {model.__class__.__name__}" + ) + def match_named_parameters( model: torch.nn.Module, From 72a6c65276cae2cf8e7867c95cb47f2d1d19ad55 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Mon, 28 Jul 2025 20:21:59 +0000 Subject: [PATCH 02/11] Refactor usages of `expand_target_names`, `is_target`, and `find_name_or_class_matches` Signed-off-by: Fynn Schmitt-Ulms --- .../model_compressors/model_compressor.py | 239 ++++++++++-------- .../quantization/lifecycle/apply.py | 104 +------- src/compressed_tensors/utils/match.py | 12 +- .../test_quantization/lifecycle/test_apply.py | 78 +----- tests/test_utils/test_match.py | 61 +++++ 5 files changed, 205 insertions(+), 289 deletions(-) diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index f567e5a6..963cde2d 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -41,8 +41,8 @@ apply_quantization_config, load_pretrained_quantization_parameters, ) -from compressed_tensors.quantization.lifecycle import expand_target_names from compressed_tensors.quantization.utils import is_module_quantized +from compressed_tensors.utils.match import match_named_modules from compressed_tensors.utils import ( align_module_device, delete_offload_parameter, @@ -292,13 +292,15 @@ def get_missing_module_keys(self, model: Module) -> List[str]: self.sparsity_compressor and self.sparsity_config.format != CompressionFormat.dense.value ): - sparse_targets = expand_target_names( + sparse_targets = match_named_modules( model=model, targets=self.sparsity_config.targets, ignore=self.sparsity_config.ignore, ) + missing_keys.update( - merge_names(target, "weight") for target in sparse_targets + merge_names(target_name, "weight") + for target_name, _module in sparse_targets ) # Determine missing keys due to pack quantization @@ -308,13 +310,14 @@ def get_missing_module_keys(self, model: Module) -> List[str]: == CompressionFormat.pack_quantized.value ): for scheme in self.quantization_config.config_groups.values(): - quant_targets = expand_target_names( + quant_targets = match_named_modules( model=model, targets=scheme.targets, ignore=self.quantization_config.ignore, ) missing_keys.update( - merge_names(target, "weight") for target in quant_targets + merge_names(target_name, "weight") + for target_name, _module in quant_targets ) return list(missing_keys) @@ -345,28 +348,28 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]: self.sparsity_compressor and self.sparsity_config.format != CompressionFormat.dense.value ): - sparse_targets: Set[str] = expand_target_names( + sparse_targets = match_named_modules( model=model, targets=self.sparsity_config.targets, ignore=self.sparsity_config.ignore, ) unexpected_keys.update( - merge_names(target, param) - for target in sparse_targets + merge_names(target_name, param) + for target_name, _module in sparse_targets for param in self.sparsity_compressor.compression_param_names ) # Identify unexpected keys from quantization compression if self.quantization_compressor: for scheme in self.quantization_config.config_groups.values(): - quant_targets: Set[str] = expand_target_names( + quant_targets = match_named_modules( model=model, targets=scheme.targets, ignore=self.quantization_config.ignore, ) unexpected_keys.update( - merge_names(target, param) - for target in quant_targets + merge_names(target_name, param) + for target_name, _module in quant_targets for param in self.quantization_compressor.compression_param_names if param != "weight" ) @@ -383,58 +386,65 @@ def compress_model(self, model: Module): :param model: model containing parameters to compress """ module_to_scheme = map_module_to_scheme(model) - sparse_compression_targets: Set[str] = expand_target_names( - model=model, - targets=self.sparsity_config.targets if self.sparsity_config else [], - ignore=self.sparsity_config.ignore if self.sparsity_config else [], - ) - - for prefix, module in tqdm(model.named_modules(), desc="Compressing model"): - - if prefix in module_to_scheme or prefix in sparse_compression_targets: - module_device = get_execution_device(module) - is_meta = module_device.type == "meta" - - exec_device = "meta" if is_meta else "cpu" - onloading_device = "meta" if is_meta else module_device - - # in the future, support compression on same device - with align_module_device(module, execution_device=exec_device): - state_dict = { - f"{prefix}.{name}": param - for name, param in module.named_parameters(recurse=False) - } - - # quantization first - if prefix in module_to_scheme: - state_dict = self.quantization_compressor.compress( - state_dict, - names_to_scheme=module_to_scheme, - show_progress=False, - compression_device=exec_device, - ) + sparse_compression_targets = [ + module_name + for module_name, _module in match_named_modules( + model=model, + targets=self.sparsity_config.targets if self.sparsity_config else [], + ignore=self.sparsity_config.ignore if self.sparsity_config else [], + ) + ] + for prefix, module in tqdm( + match_named_modules( + model, + [*sparse_compression_targets, *module_to_scheme.keys()], + warn_on_fail=True, + ), + desc="Compressing model", + ): + module_device = get_execution_device(module) + is_meta = module_device.type == "meta" + + exec_device = "meta" if is_meta else "cpu" + onloading_device = "meta" if is_meta else module_device + + # in the future, support compression on same device + with align_module_device(module, execution_device=exec_device): + state_dict = { + f"{prefix}.{name}": param + for name, param in module.named_parameters(recurse=False) + } + + # quantization first + if prefix in module_to_scheme: + state_dict = self.quantization_compressor.compress( + state_dict, + names_to_scheme=module_to_scheme, + show_progress=False, + compression_device=exec_device, + ) - # sparsity second - if prefix in sparse_compression_targets: - state_dict = self.sparsity_compressor.compress( - state_dict, - compression_targets=sparse_compression_targets, - show_progress=False, - ) + # sparsity second + if prefix in sparse_compression_targets: + state_dict = self.sparsity_compressor.compress( + state_dict, + compression_targets=sparse_compression_targets, + show_progress=False, + ) - # remove any existing parameters - offload_device = get_offloaded_device(module) - for name, _ in list(module.named_parameters(recurse=False)): - delete_offload_parameter(module, name) + # remove any existing parameters + offload_device = get_offloaded_device(module) + for name, _ in list(module.named_parameters(recurse=False)): + delete_offload_parameter(module, name) - # replace with compressed parameters - for name, value in state_dict.items(): - name = name.removeprefix(f"{prefix}.") - value = value.to(onloading_device) - param = torch.nn.Parameter(value, requires_grad=False) - register_offload_parameter(module, name, param, offload_device) + # replace with compressed parameters + for name, value in state_dict.items(): + name = name.removeprefix(f"{prefix}.") + value = value.to(onloading_device) + param = torch.nn.Parameter(value, requires_grad=False) + register_offload_parameter(module, name, param, offload_device) - module.quantization_status = QuantizationStatus.COMPRESSED + module.quantization_status = QuantizationStatus.COMPRESSED # TODO: consider sparse compression to also be compression if ( @@ -451,55 +461,64 @@ def decompress_model(self, model: Module): :param model: model containing parameters to compress """ module_to_scheme = map_module_to_scheme(model) - sparse_compression_targets: Set[str] = expand_target_names( - model=model, - targets=self.sparsity_config.targets if self.sparsity_config else [], - ignore=self.sparsity_config.ignore if self.sparsity_config else [], - ) - - for prefix, module in tqdm(model.named_modules(), desc="Decompressing model"): - if prefix in module_to_scheme or prefix in sparse_compression_targets: - # in the future, support decompression on same device - with align_module_device(module, execution_device="cpu"): - state_dict = { - f"{prefix}.{name}": param - for name, param in module.named_parameters(recurse=False) - } - - # sparsity first - if prefix in sparse_compression_targets: - # sparse_compression_targets are automatically inferred by this fn - generator = self.sparsity_compressor.decompress_from_state_dict( + sparse_compression_targets = [ + module_name + for module_name, _module in match_named_modules( + model=model, + targets=self.sparsity_config.targets if self.sparsity_config else [], + ignore=self.sparsity_config.ignore if self.sparsity_config else [], + ) + ] + + for prefix, module in tqdm( + match_named_modules( + model, + [*sparse_compression_targets, *module_to_scheme.keys()], + warn_on_fail=True, + ), + desc="Decompressing model", + ): + # in the future, support decompression on same device + with align_module_device(module, execution_device="cpu"): + state_dict = { + f"{prefix}.{name}": param + for name, param in module.named_parameters(recurse=False) + } + + # sparsity first + if prefix in sparse_compression_targets: + # sparse_compression_targets are automatically inferred by this fn + generator = self.sparsity_compressor.decompress_from_state_dict( + state_dict, + ) + # generates (param_path, param_val) + # of compressed and unused params + state_dict = {key: value for key, value in generator} + + # quantization second + if prefix in module_to_scheme: + state_dict = ( + self.quantization_compressor.decompress_module_from_state_dict( + prefix, state_dict, + scheme=module_to_scheme[prefix], ) - # generates (param_path, param_val) - # of compressed and unused params - state_dict = {key: value for key, value in generator} - - # quantization second - if prefix in module_to_scheme: - state_dict = ( - self.quantization_compressor.decompress_module_from_state_dict( - prefix, - state_dict, - scheme=module_to_scheme[prefix], - ) - ) + ) - # remove any existing parameters - exec_device = get_execution_device(module) - offload_device = get_offloaded_device(module) - for name, _ in list(module.named_parameters(recurse=False)): - delete_offload_parameter(module, name) + # remove any existing parameters + exec_device = get_execution_device(module) + offload_device = get_offloaded_device(module) + for name, _ in list(module.named_parameters(recurse=False)): + delete_offload_parameter(module, name) - # replace with decompressed parameters - for name, value in state_dict.items(): - name = name.removeprefix(f"{prefix}.") - value = value.to(exec_device) - param = torch.nn.Parameter(value, requires_grad=False) - register_offload_parameter(module, name, param, offload_device) + # replace with decompressed parameters + for name, value in state_dict.items(): + name = name.removeprefix(f"{prefix}.") + value = value.to(exec_device) + param = torch.nn.Parameter(value, requires_grad=False) + register_offload_parameter(module, name, param, offload_device) - module.quantization_status = QuantizationStatus.FROZEN + module.quantization_status = QuantizationStatus.FROZEN # ----- state dict compression pathways ----- # @@ -535,11 +554,14 @@ def compress( ) if self.sparsity_compressor is not None: - sparse_compression_targets: Set[str] = expand_target_names( - model=model, - targets=self.sparsity_config.targets, - ignore=self.sparsity_config.ignore, - ) + sparse_compression_targets: Set[str] = { + module_name + for module_name, _module in match_named_modules( + model=model, + targets=self.sparsity_config.targets, + ignore=self.sparsity_config.ignore, + ) + } state_dict = self.sparsity_compressor.compress( state_dict, compression_targets=sparse_compression_targets, @@ -598,7 +620,6 @@ def decompress(self, model_path: str, model: Module): with override_quantization_status( self.quantization_config, QuantizationStatus.FROZEN ): - names_to_scheme = apply_quantization_config( model, self.quantization_config ) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 431fadd6..5b88bb3c 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -13,12 +13,11 @@ # limitations under the License. import logging -import re -from collections import OrderedDict, defaultdict +from collections import OrderedDict from copy import deepcopy -from typing import Dict, Iterable, List, Optional +from typing import Dict, List, Optional from typing import OrderedDict as OrderedDictType -from typing import Set, Union +from typing import Union import torch from compressed_tensors.config import CompressionFormat @@ -51,9 +50,6 @@ "load_pretrained_quantization_parameters", "apply_quantization_config", "apply_quantization_status", - "find_name_or_class_matches", - "expand_target_names", - "is_target", ] from compressed_tensors.quantization.utils.helpers import is_module_quantized @@ -251,100 +247,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply(compress_quantized_weights) -def expand_target_names( - model: Module, - targets: Optional[Iterable[str]] = None, - ignore: Optional[Iterable[str]] = None, -) -> Set[str]: - """ - Finds all unique module names in the model that match the given - targets and ignore lists. - - Note: Targets must be regexes, layer types, or full layer names. - - :param model: model to search for targets in - :param targets: Iterable of targets to search for - :param ignore: Iterable of targets to ignore - :return: set of all targets that match the given targets and should - not be ignored - """ - return { - name - for name, module in model.named_modules() - if is_target(name, module, targets, ignore) - } - - -def is_target( - name: str, - module: Module, - targets: Optional[Iterable[str]] = None, - ignore: Optional[Iterable[str]] = None, -) -> bool: - """ - Determines if a module should be included in the targets based on the - targets and ignore lists. - - Note: Targets must be regexes, layer types, or full layer names. - - :param name: name of the module - :param module: the module itself - :param targets: Iterable of targets to search for - :param ignore: Iterable of targets to ignore - :return: True if the module is a target and not ignored, False otherwise - """ - return bool( - find_name_or_class_matches(name, module, targets or []) - and not find_name_or_class_matches(name, module, ignore or []) - ) - - -def find_name_or_class_matches( - name: str, module: Module, targets: Iterable[str], check_contains: bool = False -) -> List[str]: - """ - Returns all targets that match the given name or the class name. - Returns empty list otherwise. - The order of the output `matches` list matters. - The entries are sorted in the following order: - 1. matches on exact strings - 2. matches on regex patterns - 3. matches on module names - """ - from compressed_tensors import InternalModule - - if isinstance(module, InternalModule): - return [] - - targets = sorted(targets, key=lambda x: ("re:" in x, x)) - if isinstance(targets, Iterable): - matches = _find_matches(name, targets) + _find_matches( - module.__class__.__name__, targets, check_contains - ) - matches = [match for match in matches if match is not None] - return matches - - -def _find_matches( - value: str, targets: Iterable[str], check_contains: bool = False -) -> List[str]: - # returns all the targets that match value either - # exactly or as a regex after 're:'. if check_contains is set to True, - # additionally checks if the target string is contained with value. - matches = [] - for target in targets: - if target.startswith("re:"): - pattern = target[3:] - if re.match(pattern, value): - matches.append(target) - elif check_contains: - if target.lower() in value.lower(): - matches.append(target) - elif target == value: - matches.append(target) - return matches - - def _infer_status(model: Module) -> Optional[QuantizationStatus]: for module in model.modules(): status = getattr(module, "quantization_status", None) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 9ccade77..d092c23f 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -34,8 +34,8 @@ def match_named_modules( model: torch.nn.Module, - targets: Iterable[str], - ignore: Iterable[str] | None = tuple(), + targets: Iterable[str] | None, + ignore: Iterable[str] | None = None, warn_on_fail: bool = False, warn_on_unmatched_ignores: bool = False, return_matched_targets: bool = False, @@ -52,6 +52,7 @@ def match_named_modules( :return: generator of module names and modules """ ignore = ignore or [] + targets = targets or [] unmatched_targets = set(targets) unmatched_ignores = set(ignore) @@ -59,6 +60,9 @@ def match_named_modules( # Order targets by type: exact name match, regex name match, class name match targets = sorted(targets, key=lambda x: ("re:" in x, x)) for name, module in model.named_modules(): + if isinstance(module, InternalModule): + continue + # preprocess the module name and module name = preprocess_name(name) @@ -74,7 +78,7 @@ def match_named_modules( matched_targets = [] # Check for name matches first (exact then regex) for target in targets: - if match_name(name, target): + if _match_name(name, target): unmatched_targets -= {target} matched_targets.append(target) if not return_matched_targets: @@ -87,7 +91,7 @@ def match_named_modules( # Check for class matches for target in targets: - if match_class(module, target): + if _match_class(module, target): unmatched_targets -= {target} matched_targets.append(target) if not return_matched_targets: diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 63a9a588..a6f1ee3e 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -28,8 +28,6 @@ from compressed_tensors.quantization.lifecycle import ( apply_quantization_config, apply_quantization_status, - expand_target_names, - is_target, ) from tests.testing_utils import requires_accelerate from transformers import AutoModelForCausalLM @@ -141,9 +139,9 @@ def test_apply_quantization_config_tinyllama(): _weights = not module_type == "LlamaRotaryEmbedding" _test_layer_quantization_status(module, inputs=_inputs, weights=_weights) - assert all( - value == 0 for value in count_layer_num.values() - ), "Not all values are zero" + assert all(value == 0 for value in count_layer_num.values()), ( + "Not all values are zero" + ) # test quantization compression # sample forward pass to fill scales, zps @@ -299,73 +297,3 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning): assert len(caplog.text) > 0 else: assert len(caplog.text) == 0 - - -@pytest.mark.parametrize( - "targets, ignore, expected_targets", - [ - ([], [], set()), - (["layer1", "layer2"], [], {"layer1", "layer2"}), - ([], ["layer1"], set()), - (["layer1", "layer2"], ["layer2"], {"layer1"}), - (["re:layer.*"], ["layer3"], {"layer1", "layer2"}), - ], -) -def test_expand_targets_with_mock(mock_model, targets, ignore, expected_targets): - expanded_targets = expand_target_names(mock_model, targets, ignore) - assert expanded_targets == expected_targets - - -@pytest.mark.parametrize( - "targets, ignore, expected_targets", - [ - ( - ["re:model.layers.[01].self_attn.q_proj"], - ["re:model.layers.1.self_attn.q_proj"], - set(["model.layers.0.self_attn.q_proj"]), - ), - ( - ["re:model.layers.[01].self_attn.q_proj"], - [], - set(["model.layers.0.self_attn.q_proj", "model.layers.1.self_attn.q_proj"]), - ), - ( - ["re:model.layers.[0-2].self_attn.q_proj"], - ["re:model.layers.1.self_attn.q_proj"], - set(["model.layers.0.self_attn.q_proj", "model.layers.2.self_attn.q_proj"]), - ), - ( - ["model.layers.0.self_attn.q_proj"], - ["model.layers.0.self_attn.q_proj"], - set(), - ), - ( - ["re:model.layers.*.self_attn.q_proj"], - ["re:model.layers.[01].self_attn.q_proj"], - set( - f"model.layers.{layer_idx}.self_attn.q_proj" - for layer_idx in range(2, 6) - ), - ), - ], -) -def test_expand_targets_with_llama_stories( - llama_stories_model, targets, ignore, expected_targets -): - expanded_targets = expand_target_names(llama_stories_model, targets, ignore) - assert expanded_targets == expected_targets - - -@pytest.mark.parametrize( - "name, targets, ignore, expected", - [ - ("layer1", ["layer1"], [], True), - ("layer1", ["layer1"], ["layer1"], False), - ("layer1", ["layer2"], [], False), - ("layer1", ["re:layer.*"], [], True), - ("layer1", ["re:layer.*"], ["re:layer1"], False), - ], -) -def test_is_target_with_mock(mock_module, name, targets, ignore, expected): - result = is_target(name, mock_module, targets, ignore) - assert result == expected diff --git a/tests/test_utils/test_match.py b/tests/test_utils/test_match.py index 705676b9..2b3396b6 100644 --- a/tests/test_utils/test_match.py +++ b/tests/test_utils/test_match.py @@ -27,6 +27,15 @@ match_named_parameters, ) from compressed_tensors.utils.match import _match_class, _match_name +from transformers import AutoModelForCausalLM + + +@pytest.fixture +def llama_stories_model(): + return AutoModelForCausalLM.from_pretrained( + "Xenova/llama2.c-stories15M", + torch_dtype="auto", + ) class DummyModel(nn.Module): @@ -255,6 +264,58 @@ class InternalLinear(InternalModule, nn.Linear): matches = list(match_named_modules(linear, ["re:.*"])) assert len(matches) == 0 + @pytest.mark.parametrize( + "targets, ignore, expected_targets", + [ + ( + ["re:model.layers.[01].self_attn.q_proj"], + ["re:model.layers.1.self_attn.q_proj"], + set(["model.layers.0.self_attn.q_proj"]), + ), + ( + ["re:model.layers.[01].self_attn.q_proj"], + [], + set( + [ + "model.layers.0.self_attn.q_proj", + "model.layers.1.self_attn.q_proj", + ] + ), + ), + ( + ["re:model.layers.[0-2].self_attn.q_proj"], + ["re:model.layers.1.self_attn.q_proj"], + set( + [ + "model.layers.0.self_attn.q_proj", + "model.layers.2.self_attn.q_proj", + ] + ), + ), + ( + ["model.layers.0.self_attn.q_proj"], + ["model.layers.0.self_attn.q_proj"], + set(), + ), + ( + ["re:model.layers.*.self_attn.q_proj"], + ["re:model.layers.[01].self_attn.q_proj"], + set( + f"model.layers.{layer_idx}.self_attn.q_proj" + for layer_idx in range(2, 6) + ), + ), + ], + ) + def test_expand_targets_with_llama_stories( + self, llama_stories_model, targets, ignore, expected_targets + ): + expanded_targets = { + name + for name, _ in match_named_modules(llama_stories_model, targets, ignore) + } + assert expanded_targets == expected_targets + class TestMatchNamedParameters: """Test cases for match_named_parameters function""" From 01e75b27bb2132b0f57766b3959fb0c3c538293a Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Tue, 29 Jul 2025 15:38:58 +0000 Subject: [PATCH 03/11] Small fixes Signed-off-by: Fynn Schmitt-Ulms --- .../quantization/lifecycle/apply.py | 2 +- src/compressed_tensors/utils/match.py | 36 +++++++++---------- 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 5b88bb3c..787de54a 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -151,7 +151,7 @@ def apply_quantization_config( config.ignore or [], warn_on_fail=True, warn_on_unmatched_ignores=True, - return_matched_targets=True, + yield_matched_targets=True, preprocess_name=fix_fsdp_module_name, ): # mark modules to be quantized by adding diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index d092c23f..894a45af 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -15,7 +15,7 @@ import logging import re from collections.abc import Generator -from typing import Callable, Iterable, Tuple +from typing import Callable, Iterable, List, Tuple import torch from compressed_tensors.utils.internal import InternalModule @@ -38,9 +38,9 @@ def match_named_modules( ignore: Iterable[str] | None = None, warn_on_fail: bool = False, warn_on_unmatched_ignores: bool = False, - return_matched_targets: bool = False, + yield_matched_targets: bool = False, preprocess_name: Callable[[str], str] = lambda x: x, -) -> Generator[Tuple[str, torch.nn.Module]]: +) -> Generator[Tuple[str, torch.nn.Module] | Tuple[str, torch.nn.Module, List[str]]]: """ Yields names and modules which match `targets` but do not match `ignore`. Values are returned in order of `model.named_modules()` @@ -49,6 +49,9 @@ def match_named_modules( :param targets: target strings, potentially containing "re:" prefixes :param ignore: targets to ignore, potentially containing "re:" prefixes :param warn_on_fail: if True, warns if any targets do not match any modules in model + :param warn_on_unmatched_ignores: if True, warns if any ignores do not match any modules in model + :param yield_matched_targets: if True, yields the matched targets in addition to the module name and module + :param preprocess_name: a function to preprocess the module name :return: generator of module names and modules """ ignore = ignore or [] @@ -57,6 +60,7 @@ def match_named_modules( unmatched_targets = set(targets) unmatched_ignores = set(ignore) + # Note: when yield_matched_targets is True, the ordering of the targets is important # Order targets by type: exact name match, regex name match, class name match targets = sorted(targets, key=lambda x: ("re:" in x, x)) for name, module in model.named_modules(): @@ -75,30 +79,24 @@ def match_named_modules( if ignore_matched: continue - matched_targets = [] - # Check for name matches first (exact then regex) + matched_target_on_name = [] + matched_target_on_class = [] + # Check for name matches first (exact then regex, enforced by sort above) for target in targets: if _match_name(name, target): unmatched_targets -= {target} - matched_targets.append(target) - if not return_matched_targets: + matched_target_on_name.append(target) + if not yield_matched_targets: break - - if not return_matched_targets and matched_targets: - # Don't need to check other targets, one match is enough - yield name, module - continue - - # Check for class matches - for target in targets: - if _match_class(module, target): + elif _match_class(module, target): unmatched_targets -= {target} - matched_targets.append(target) - if not return_matched_targets: + matched_target_on_class.append(target) + if not yield_matched_targets: break + matched_targets = matched_target_on_name + matched_target_on_class if matched_targets: - if return_matched_targets: + if yield_matched_targets: yield name, module, matched_targets else: yield name, module From d22149f8585543fa139deb762beebf6601ffbcbf Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Tue, 29 Jul 2025 18:06:38 +0000 Subject: [PATCH 04/11] Simplify signature of `match_named_modules` Removed `yield_matched_targets` and `warn_on_unmatched_ignores` and updated rest of code Signed-off-by: Fynn Schmitt-Ulms --- .../quantization/lifecycle/apply.py | 7 +-- src/compressed_tensors/utils/match.py | 59 +++++++------------ .../test_quantization/lifecycle/test_apply.py | 18 ++---- 3 files changed, 28 insertions(+), 56 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 787de54a..706c3ebe 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -39,7 +39,7 @@ is_kv_cache_quant_scheme, ) from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module -from compressed_tensors.utils.match import match_named_modules +from compressed_tensors.utils.match import is_match, match_named_modules, match_targets from compressed_tensors.utils.offload import update_parameter_data from compressed_tensors.utils.safetensors_load import get_safetensors_folder from safetensors import safe_open @@ -145,17 +145,16 @@ def apply_quantization_config( from compressed_tensors.linear.compressed_linear import CompressedLinear # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule, matched_targets in match_named_modules( + for name, submodule in match_named_modules( model, target_to_scheme, config.ignore or [], warn_on_fail=True, - warn_on_unmatched_ignores=True, - yield_matched_targets=True, preprocess_name=fix_fsdp_module_name, ): # mark modules to be quantized by adding # quant scheme to the matching layers + matched_targets = list(match_targets(name, submodule, target_to_scheme)) scheme = _scheme_from_targets(target_to_scheme, matched_targets, name) if run_compressed: format = config.format diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 894a45af..4c7d4d8d 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -37,8 +37,6 @@ def match_named_modules( targets: Iterable[str] | None, ignore: Iterable[str] | None = None, warn_on_fail: bool = False, - warn_on_unmatched_ignores: bool = False, - yield_matched_targets: bool = False, preprocess_name: Callable[[str], str] = lambda x: x, ) -> Generator[Tuple[str, torch.nn.Module] | Tuple[str, torch.nn.Module, List[str]]]: """ @@ -49,8 +47,6 @@ def match_named_modules( :param targets: target strings, potentially containing "re:" prefixes :param ignore: targets to ignore, potentially containing "re:" prefixes :param warn_on_fail: if True, warns if any targets do not match any modules in model - :param warn_on_unmatched_ignores: if True, warns if any ignores do not match any modules in model - :param yield_matched_targets: if True, yields the matched targets in addition to the module name and module :param preprocess_name: a function to preprocess the module name :return: generator of module names and modules """ @@ -58,11 +54,7 @@ def match_named_modules( targets = targets or [] unmatched_targets = set(targets) - unmatched_ignores = set(ignore) - # Note: when yield_matched_targets is True, the ordering of the targets is important - # Order targets by type: exact name match, regex name match, class name match - targets = sorted(targets, key=lambda x: ("re:" in x, x)) for name, module in model.named_modules(): if isinstance(module, InternalModule): continue @@ -70,36 +62,14 @@ def match_named_modules( # preprocess the module name and module name = preprocess_name(name) - ignore_matched = False - for ign in ignore: - if is_match(name, module, ign): - unmatched_ignores -= {ign} - ignore_matched = True - break - if ignore_matched: + if any(is_match(name, module, ign) for ign in ignore): continue - matched_target_on_name = [] - matched_target_on_class = [] - # Check for name matches first (exact then regex, enforced by sort above) for target in targets: - if _match_name(name, target): + if is_match(name, module, target): unmatched_targets -= {target} - matched_target_on_name.append(target) - if not yield_matched_targets: - break - elif _match_class(module, target): - unmatched_targets -= {target} - matched_target_on_class.append(target) - if not yield_matched_targets: - break - - matched_targets = matched_target_on_name + matched_target_on_class - if matched_targets: - if yield_matched_targets: - yield name, module, matched_targets - else: yield name, module + break if warn_on_fail: for target in unmatched_targets: @@ -107,12 +77,6 @@ def match_named_modules( f"Could not match `{target}` in instance of {model.__class__.__name__}" ) - if warn_on_unmatched_ignores: - for ign in unmatched_ignores: - _LOGGER.warning( - f"Unmatched ignore targets: {unmatched_ignores}, in instance of {model.__class__.__name__}" - ) - def match_named_parameters( model: torch.nn.Module, @@ -151,6 +115,23 @@ def match_named_parameters( ) +def match_targets( + name: str, module: torch.nn.Module, targets: Iterable[str] +) -> Generator[str]: + """ + Yields the targets that match the given name and module. + Outputs are ordered by type: exact name match, regex name match, class name match + """ + targets = sorted(targets, key=lambda x: ("re:" in x, x)) + for target in targets: + if _match_name(name, target): + yield target + + for target in targets: + if _match_class(module, target): + yield target + + def match_modules_set( model: torch.nn.Module, targets: Iterable[str], diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index a6f1ee3e..a9813b61 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -258,15 +258,13 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"): @requires_accelerate() @pytest.mark.parametrize( - "ignore,should_raise_warning", + "ignore", [ - [("lm_head", "re:.*gate"), False], - [("lm_head", "re:.*foobarbaz"), True], + ("lm_head", "re:.*gate"), + ("lm_head", "re:.*foobarbaz"), ], ) -def test_apply_quantization_status(caplog, ignore, should_raise_warning): - import logging - +def test_apply_quantization_status(ignore): # load a dense, unquantized tiny llama model model = get_tinyllama_model() quantization_config_dict = { @@ -290,10 +288,4 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning): config = QuantizationConfig(**quantization_config_dict) config.quantization_status = QuantizationStatus.CALIBRATION - # mismatch in the ignore key of quantization_config_dict - with caplog.at_level(logging.WARNING): - apply_quantization_config(model, config) - if should_raise_warning: - assert len(caplog.text) > 0 - else: - assert len(caplog.text) == 0 + apply_quantization_config(model, config) From d94f6554b111568ca1e4651cc8ffd9f701b16b3f Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Tue, 29 Jul 2025 19:09:33 +0000 Subject: [PATCH 05/11] Ensure `match_targets` doesn't return duplicates Signed-off-by: Fynn Schmitt-Ulms --- .../quantization/lifecycle/apply.py | 2 +- src/compressed_tensors/utils/match.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 706c3ebe..c54d0fc5 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -154,7 +154,7 @@ def apply_quantization_config( ): # mark modules to be quantized by adding # quant scheme to the matching layers - matched_targets = list(match_targets(name, submodule, target_to_scheme)) + matched_targets = match_targets(name, submodule, target_to_scheme) scheme = _scheme_from_targets(target_to_scheme, matched_targets, name) if run_compressed: format = config.format diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 4c7d4d8d..56d7b7fd 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -117,19 +117,25 @@ def match_named_parameters( def match_targets( name: str, module: torch.nn.Module, targets: Iterable[str] -) -> Generator[str]: +) -> List[str]: """ - Yields the targets that match the given name and module. + Returns the targets that match the given name and module. Outputs are ordered by type: exact name match, regex name match, class name match """ + if isinstance(module, InternalModule): + return [] + targets = sorted(targets, key=lambda x: ("re:" in x, x)) + matched_targets = [] for target in targets: if _match_name(name, target): - yield target + matched_targets.append(target) for target in targets: - if _match_class(module, target): - yield target + if _match_class(module, target) and target not in matched_targets: + matched_targets.append(target) + + return matched_targets def match_modules_set( From b1fa4df4d7289951c87cf44d4d324a9c1194bf71 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Tue, 29 Jul 2025 19:09:33 +0000 Subject: [PATCH 06/11] Remove `preprocess_name` parameter from `match_named_modules` Signed-off-by: Fynn Schmitt-Ulms --- src/compressed_tensors/quantization/lifecycle/apply.py | 10 +++------- src/compressed_tensors/utils/match.py | 8 ++------ 2 files changed, 5 insertions(+), 13 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index c54d0fc5..3e0d9767 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -38,8 +38,8 @@ infer_quantization_status, is_kv_cache_quant_scheme, ) -from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module -from compressed_tensors.utils.match import is_match, match_named_modules, match_targets +from compressed_tensors.utils.helpers import replace_module +from compressed_tensors.utils.match import match_named_modules, match_targets from compressed_tensors.utils.offload import update_parameter_data from compressed_tensors.utils.safetensors_load import get_safetensors_folder from safetensors import safe_open @@ -146,11 +146,7 @@ def apply_quantization_config( # mark appropriate layers for quantization by setting their quantization schemes for name, submodule in match_named_modules( - model, - target_to_scheme, - config.ignore or [], - warn_on_fail=True, - preprocess_name=fix_fsdp_module_name, + model, target_to_scheme, config.ignore or [], warn_on_fail=True ): # mark modules to be quantized by adding # quant scheme to the matching layers diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 56d7b7fd..42746fba 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -15,7 +15,7 @@ import logging import re from collections.abc import Generator -from typing import Callable, Iterable, List, Tuple +from typing import Iterable, List, Tuple import torch from compressed_tensors.utils.internal import InternalModule @@ -37,8 +37,7 @@ def match_named_modules( targets: Iterable[str] | None, ignore: Iterable[str] | None = None, warn_on_fail: bool = False, - preprocess_name: Callable[[str], str] = lambda x: x, -) -> Generator[Tuple[str, torch.nn.Module] | Tuple[str, torch.nn.Module, List[str]]]: +) -> Generator[Tuple[str, torch.nn.Module]]: """ Yields names and modules which match `targets` but do not match `ignore`. Values are returned in order of `model.named_modules()` @@ -59,9 +58,6 @@ def match_named_modules( if isinstance(module, InternalModule): continue - # preprocess the module name and module - name = preprocess_name(name) - if any(is_match(name, module, ign) for ign in ignore): continue From 4283b1d0e07b25016115e8521173b9d4f92e47d4 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Wed, 30 Jul 2025 13:36:58 +0000 Subject: [PATCH 07/11] Update match.py util fn signatures and small fixes Signed-off-by: Fynn Schmitt-Ulms --- src/compressed_tensors/utils/match.py | 44 ++++++++++++++++++--------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 42746fba..ac962544 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -27,6 +27,7 @@ __all__ = [ "match_named_modules", "match_named_parameters", + "match_targets", "match_modules_set", "is_match", ] @@ -46,25 +47,19 @@ def match_named_modules( :param targets: target strings, potentially containing "re:" prefixes :param ignore: targets to ignore, potentially containing "re:" prefixes :param warn_on_fail: if True, warns if any targets do not match any modules in model - :param preprocess_name: a function to preprocess the module name :return: generator of module names and modules """ - ignore = ignore or [] targets = targets or [] + ignore = ignore or [] unmatched_targets = set(targets) for name, module in model.named_modules(): - if isinstance(module, InternalModule): - continue - - if any(is_match(name, module, ign) for ign in ignore): - continue - for target in targets: if is_match(name, module, target): unmatched_targets -= {target} - yield name, module + if not any(is_match(name, module, ign) for ign in ignore): + yield name, module break if warn_on_fail: @@ -76,8 +71,8 @@ def match_named_modules( def match_named_parameters( model: torch.nn.Module, - targets: Iterable[str], - ignore: Iterable[str] = tuple(), + targets: Iterable[str] | None = None, + ignore: Iterable[str] | None = None, warn_on_fail: bool = False, ) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]: """ @@ -90,6 +85,9 @@ def match_named_parameters( :param warn_on_fail: if True, warns if any targets do not match any params in model :return: generator of fully-qualified param names, parent modules, and params """ + targets = targets or [] + ignore = ignore or [] + unmatched_targets = set(targets) for module_name, module in model.named_modules(): if isinstance(module, InternalModule): @@ -112,15 +110,30 @@ def match_named_parameters( def match_targets( - name: str, module: torch.nn.Module, targets: Iterable[str] + name: str, module: torch.nn.Module, targets: Iterable[str] | None = None ) -> List[str]: """ Returns the targets that match the given name and module. + + :param name: the name of the module + :param module: the module to match + :param targets: the target strings, potentially containing "re:" prefixes + :return: the targets that match the given name and module + Outputs are ordered by type: exact name match, regex name match, class name match """ + targets = targets or [] + if isinstance(module, InternalModule): return [] + # The order of the output `matches` list matters, the are arranged from most + # specific to least specific, and this order will be used when merging configs. + # The entries are sorted in the following order: + # 1. matches on exact strings + # 2. matches on regex patterns + # 3. matches on module names + targets = sorted(targets, key=lambda x: ("re:" in x, x)) matched_targets = [] for target in targets: @@ -136,8 +149,8 @@ def match_targets( def match_modules_set( model: torch.nn.Module, - targets: Iterable[str], - ignore: Iterable[str] = tuple(), + targets: Iterable[str] | None = None, + ignore: Iterable[str] | None = None, ) -> Generator[Iterable[torch.nn.Module]]: """ Yields modules grouped with the same order and size as `targets`. @@ -175,6 +188,9 @@ def match_modules_set( :param targets: target strings, potentially containing "re:" prefixes :param ignore: targets to ignore, potentially containing "re:" prefixes """ + targets = targets or [] + ignore = ignore or [] + matches = dict.fromkeys(targets, None) for name, module in model.named_modules(): # match until we get a full set From 70b07ec266210d1257064389f122e09b37d95779 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Wed, 30 Jul 2025 17:01:39 +0000 Subject: [PATCH 08/11] Restore `find_name_or_class_matches` as a deprecated function This function is currently used by llm-compressor so adding it back with a deprecation warning for now. Signed-off-by: Fynn Schmitt-Ulms --- .../quantization/lifecycle/apply.py | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 3e0d9767..7bc935df 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -15,7 +15,7 @@ import logging from collections import OrderedDict from copy import deepcopy -from typing import Dict, List, Optional +from typing import Dict, Iterable, List, Optional from typing import OrderedDict as OrderedDictType from typing import Union @@ -50,6 +50,7 @@ "load_pretrained_quantization_parameters", "apply_quantization_config", "apply_quantization_status", + "find_name_or_class_matches", ] from compressed_tensors.quantization.utils.helpers import is_module_quantized @@ -242,6 +243,39 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply(compress_quantized_weights) +def find_name_or_class_matches( + name: str, module: Module, targets: Iterable[str], check_contains: bool = False +) -> List[str]: + """ + DEPRECATED: Use `match_targets` instead. + + This function is deprecated and will be removed in a future release. + Please use `match_targets` from `compressed_tensors.utils.match` instead. + + Returns all targets that match the given name or the class name. + Returns empty list otherwise. + The order of the output `matches` list matters. + The entries are sorted in the following order: + 1. matches on exact strings + 2. matches on regex patterns + 3. matches on module names + """ + import warnings + + warnings.warn( + "find_name_or_class_matches is deprecated and will be removed in a future release. " + "Please use compressed_tensors.utils.match.match_targets instead.", + DeprecationWarning, + stacklevel=2, + ) + if check_contains: + raise NotImplementedError( + "This function is deprecated, and the check_contains=True option has been removed." + ) + + return match_targets(name, module, targets) + + def _infer_status(model: Module) -> Optional[QuantizationStatus]: for module in model.modules(): status = getattr(module, "quantization_status", None) From 1520a25e82954eb204db147595eba6330b5d015d Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Thu, 7 Aug 2025 14:07:39 -0400 Subject: [PATCH 09/11] Use deprecated decorator instead of manual deprecation warning Signed-off-by: Fynn Schmitt-Ulms --- .../quantization/lifecycle/apply.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 7bc935df..10483d11 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -38,7 +38,7 @@ infer_quantization_status, is_kv_cache_quant_scheme, ) -from compressed_tensors.utils.helpers import replace_module +from compressed_tensors.utils.helpers import deprecated, replace_module from compressed_tensors.utils.match import match_named_modules, match_targets from compressed_tensors.utils.offload import update_parameter_data from compressed_tensors.utils.safetensors_load import get_safetensors_folder @@ -243,15 +243,14 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply(compress_quantized_weights) +@deprecated( + message="This function is deprecated and will be removed in a future release." + "Please use `match_targets` from `compressed_tensors.utils.match` instead." +) def find_name_or_class_matches( name: str, module: Module, targets: Iterable[str], check_contains: bool = False ) -> List[str]: """ - DEPRECATED: Use `match_targets` instead. - - This function is deprecated and will be removed in a future release. - Please use `match_targets` from `compressed_tensors.utils.match` instead. - Returns all targets that match the given name or the class name. Returns empty list otherwise. The order of the output `matches` list matters. @@ -260,14 +259,6 @@ def find_name_or_class_matches( 2. matches on regex patterns 3. matches on module names """ - import warnings - - warnings.warn( - "find_name_or_class_matches is deprecated and will be removed in a future release. " - "Please use compressed_tensors.utils.match.match_targets instead.", - DeprecationWarning, - stacklevel=2, - ) if check_contains: raise NotImplementedError( "This function is deprecated, and the check_contains=True option has been removed." From 4c21a8f3d2376c01bedec7a7bba7b2b62e4c0bce Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Wed, 20 Aug 2025 12:09:51 -0400 Subject: [PATCH 10/11] Update syntax of of optional types --- .../quantization/lifecycle/apply.py | 2 +- src/compressed_tensors/utils/match.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 10483d11..50e5749d 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -147,7 +147,7 @@ def apply_quantization_config( # mark appropriate layers for quantization by setting their quantization schemes for name, submodule in match_named_modules( - model, target_to_scheme, config.ignore or [], warn_on_fail=True + model, target_to_scheme, config.ignore, warn_on_fail=True ): # mark modules to be quantized by adding # quant scheme to the matching layers diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 86df5d6d..d514108b 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -38,8 +38,8 @@ def match_named_modules( model: torch.nn.Module, - targets: Iterable[str] | None, - ignore: Iterable[str] | None = None, + targets: Optional[Iterable[str]] = None, + ignore: Optional[Iterable[str]] = None, fused: Optional[FusedMappping] = None, warn_on_fail: bool = False, ) -> Generator[Tuple[str, torch.nn.Module]]: @@ -77,8 +77,8 @@ def match_named_modules( def match_named_parameters( model: torch.nn.Module, - targets: Iterable[str] | None = None, - ignore: Iterable[str] | None = None, + targets: Optional[Iterable[str]] = None, + ignore: Optional[Iterable[str]] = None, fused: Optional[FusedMappping] = None, warn_on_fail: bool = False, ) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]: @@ -158,8 +158,8 @@ def match_targets( def match_modules_set( model: torch.nn.Module, - targets: Iterable[str] | None = None, - ignore: Iterable[str] | None = None, + targets: Optional[Iterable[str]] = None, + ignore: Optional[Iterable[str]] = None, ) -> Generator[Iterable[torch.nn.Module]]: """ Yields modules grouped with the same order and size as `targets`. From c9b8ba2d8d9a603b2ac9b427b7c55f82721ef0a9 Mon Sep 17 00:00:00 2001 From: Fynn Schmitt-Ulms Date: Wed, 20 Aug 2025 12:28:23 -0400 Subject: [PATCH 11/11] Remove default None target value in match utils --- src/compressed_tensors/utils/match.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index d514108b..bff3e478 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -38,7 +38,7 @@ def match_named_modules( model: torch.nn.Module, - targets: Optional[Iterable[str]] = None, + targets: Optional[Iterable[str]], ignore: Optional[Iterable[str]] = None, fused: Optional[FusedMappping] = None, warn_on_fail: bool = False, @@ -77,7 +77,7 @@ def match_named_modules( def match_named_parameters( model: torch.nn.Module, - targets: Optional[Iterable[str]] = None, + targets: Optional[Iterable[str]], ignore: Optional[Iterable[str]] = None, fused: Optional[FusedMappping] = None, warn_on_fail: bool = False, @@ -119,7 +119,7 @@ def match_named_parameters( def match_targets( - name: str, module: torch.nn.Module, targets: Iterable[str] | None = None + name: str, module: torch.nn.Module, targets: Optional[Iterable[str]] ) -> List[str]: """ Returns the targets that match the given name and module. @@ -158,7 +158,7 @@ def match_targets( def match_modules_set( model: torch.nn.Module, - targets: Optional[Iterable[str]] = None, + targets: Optional[Iterable[str]], ignore: Optional[Iterable[str]] = None, ) -> Generator[Iterable[torch.nn.Module]]: """