Skip to content

Commit f7b6a40

Browse files
authored
Replace uses of deprecated find_name_or_class_matches fn (#1805)
SUMMARY: `find_name_or_class_matches` was deprecated in neuralmagic/compressed-tensors#406 and replaced by `match_targets`. This pr update llm-compressor to use the new fn. TEST PLAN: `match_targets` should be a direct replacement. Testing is left to CI Signed-off-by: Fynn Schmitt-Ulms <[email protected]>
1 parent 2d8a803 commit f7b6a40

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/llmcompressor/pipelines/layer_sequential/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77
import tqdm
8-
from compressed_tensors.quantization import find_name_or_class_matches
8+
from compressed_tensors.utils.match import match_targets
99
from torch.nn import Module
1010
from torch.utils.data.dataloader import DataLoader
1111

@@ -33,7 +33,7 @@ def match_modules(model: Module, target_names: List[str]) -> List[Module]:
3333
names_layers = [
3434
(name, module)
3535
for name, module in model.named_modules()
36-
if find_name_or_class_matches(name, module, target_names)
36+
if match_targets(name, module, target_names)
3737
]
3838

3939
names_layers = sorted(names_layers, key=lambda name_layer: name_layer[0])

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66

77
import torch
88
from accelerate.hooks import remove_hook_from_module
9-
from compressed_tensors.quantization import find_name_or_class_matches
109
from compressed_tensors.utils import (
1110
has_offloaded_params,
1211
offloaded_dispatch,
1312
remove_dispatch,
1413
)
14+
from compressed_tensors.utils.match import match_targets
1515
from loguru import logger
1616
from torch.fx import Graph, GraphModule, Node
1717
from torch.fx.graph import PythonCode
@@ -424,7 +424,7 @@ def match_modules(model: Module, target_names: List[str]) -> Set[Module]:
424424
return set(
425425
module
426426
for name, module in model.named_modules()
427-
if find_name_or_class_matches(name, module, target_names)
427+
if match_targets(name, module, target_names)
428428
)
429429

430430

0 commit comments

Comments
 (0)