From 2131de34132a2dd7640235e01b87c5740625c5d7 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Sat, 9 Aug 2025 11:26:27 -0400 Subject: [PATCH] expand is_match Signed-off-by: Kyle Sayers --- .../transform/factory/base.py | 2 +- src/compressed_tensors/utils/match.py | 24 ++++++++++++------- tests/test_utils/test_match.py | 22 ++++++++++------- 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/src/compressed_tensors/transform/factory/base.py b/src/compressed_tensors/transform/factory/base.py index 2218bd30..89573d11 100644 --- a/src/compressed_tensors/transform/factory/base.py +++ b/src/compressed_tensors/transform/factory/base.py @@ -14,7 +14,7 @@ from abc import ABC, abstractmethod from collections import defaultdict -from typing import List, Optional, Tuple, Set +from typing import List, Optional, Set, Tuple import torch import torch.nn.utils.parametrize as P diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 30ead256..7476498e 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -15,7 +15,7 @@ import logging import re from collections.abc import Generator -from typing import Iterable, Mapping, Optional, Tuple +from typing import Iterable, List, Mapping, Optional, Tuple import torch from compressed_tensors.utils.internal import InternalModule @@ -57,10 +57,10 @@ def match_named_modules( unmatched_targets = set(targets) for name, module in model.named_modules(): for target in targets: - if is_match(name, module, target, fused): + if is_match(name, module, target, fused=fused): unmatched_targets -= {target} - if not any(is_match(name, module, ign, fused) for ign in ignore): + if not is_match(name, module, ignore, fused=fused): yield name, module if warn_on_fail: @@ -155,9 +155,7 @@ def match_modules_set( for name, module in model.named_modules(): # match until we get a full set for target in targets: - if is_match(name, module, target) and not any( - is_match(name, module, ign) for ign in ignore - ): + if is_match(name, module, target, ignore): if matches[target] is not None: raise ValueError(f"Matched a {target} twice before completing set") matches[target] = module @@ -176,7 +174,8 @@ def match_modules_set( def is_match( name: str, module: torch.nn.Module, - target: str, + targets: str | Iterable[str], + ignore: str | Iterable[str] = tuple(), fused: Optional[FusedMappping] = None, ) -> bool: """ @@ -198,8 +197,17 @@ def is_match( :fused: optional mapping from suffixes of fused modules to the suffixes of their corresponding shards """ + targets = [targets] if isinstance(targets, str) else targets + ignore = [ignore] if isinstance(ignore, str) else ignore + return not isinstance(module, InternalModule) and ( - _match_name(name, target, fused) or _match_class(module, target) + any( + _match_name(name, target, fused) or _match_class(module, target) + for target in targets + ) + and not any( + _match_name(name, ign, fused) or _match_class(module, ign) for ign in ignore + ) ) diff --git a/tests/test_utils/test_match.py b/tests/test_utils/test_match.py index 7858c7c8..a23079cb 100644 --- a/tests/test_utils/test_match.py +++ b/tests/test_utils/test_match.py @@ -201,14 +201,20 @@ def test_fused_mapping(self): "gate_up_proj": ["gate_proj", "up_proj"], } - assert is_match("dummy.qkv_proj", linear, "re:.*q_proj", mapping) == True - assert is_match("dummy.qkv_proj", linear, "re:.*k_proj", mapping) == True - assert is_match("dummy.qkv_proj", linear, "re:.*v_proj", mapping) == True - assert is_match("dummy.qkv_proj", linear, "Linear", mapping) == True - - assert is_match("dummy.gate_up_proj", linear, "re:.*gate_proj", mapping) == True - assert is_match("dummy.gate_up_proj", linear, "re:.*up_proj", mapping) == True - assert is_match("dummy.gate_up_proj", linear, "Linear", mapping) == True + assert is_match("dummy.qkv_proj", linear, "re:.*q_proj", fused=mapping) == True + assert is_match("dummy.qkv_proj", linear, "re:.*k_proj", fused=mapping) == True + assert is_match("dummy.qkv_proj", linear, "re:.*v_proj", fused=mapping) == True + assert is_match("dummy.qkv_proj", linear, "Linear", fused=mapping) == True + + assert ( + is_match("dummy.gate_up_proj", linear, "re:.*gate_proj", fused=mapping) + == True + ) + assert ( + is_match("dummy.gate_up_proj", linear, "re:.*up_proj", fused=mapping) + == True + ) + assert is_match("dummy.gate_up_proj", linear, "Linear", fused=mapping) == True class TestMatchNamedModules: