8
8
import torch
9
9
from loguru import logger
10
10
from pydantic import Field , PrivateAttr , field_validator , model_validator
11
+ from transformers import PreTrainedModel
11
12
12
13
from llmcompressor .core import Event , EventType , State
13
14
from llmcompressor .modifiers .modifier import Modifier
14
15
from llmcompressor .modifiers .utils .hooks import HooksMixin
16
+ from llmcompressor .sentinel import Sentinel
15
17
from llmcompressor .utils .pytorch .module import (
16
18
get_layers ,
17
- get_no_split_params ,
18
19
get_prunable_layers ,
19
20
match_targets ,
20
21
)
@@ -27,35 +28,41 @@ class SparsityModifierBase(Modifier):
27
28
"""
28
29
29
30
# modifier arguments
31
+ targets : Union [str , List [str ]] = ["Linear" ]
32
+ ignore : List [str ] = Field (default_factory = list )
30
33
sparsity : Optional [Union [float , List [float ]]]
31
34
sparsity_profile : Optional [str ] = None
32
35
mask_structure : str = "0:0"
33
36
owl_m : Optional [int ] = None
34
37
owl_lmbda : Optional [float ] = None
35
38
36
- # data pipeline arguments
37
- sequential_update : Optional [bool ] = False # deprecated
38
- sequential_targets : Union [str , List [str ], None ] = None
39
- targets : Union [str , List [str ]] = ["Linear" ]
40
- ignore : List [str ] = Field (default_factory = list )
41
-
42
39
# private variables
43
40
_prune_n : Optional [int ] = PrivateAttr (default = None )
44
41
_prune_m : Optional [int ] = PrivateAttr (default = None )
45
42
_module_names : Dict [torch .nn .Module , str ] = PrivateAttr (default_factory = dict )
46
43
_target_layers : Dict [str , torch .nn .Module ] = PrivateAttr (default_factory = dict )
47
44
_module_sparsities : Dict [torch .nn .Module , str ] = PrivateAttr (default_factory = dict )
48
45
46
+ # deprecated
47
+ sequential_update : Union [Sentinel , Any ] = Sentinel ("deprecated" )
48
+ sequential_targets : Union [Sentinel , Any ] = Sentinel ("deprecated" )
49
+
49
50
@field_validator ("sequential_update" , mode = "before" )
50
51
def validate_sequential_update (cls , value : bool ) -> bool :
51
- if not value :
52
+ if value is not Sentinel ( "deprecated" ) :
52
53
warnings .warn (
53
54
"`sequential_update=False` is no longer supported, setting "
54
55
"sequential_update=True" ,
55
56
DeprecationWarning ,
56
57
)
57
58
58
- return True
59
+ @field_validator ("sequential_targets" , mode = "before" )
60
+ def validate_sequential_targets (cls , value : bool ) -> bool :
61
+ if value is not Sentinel ("deprecated" ):
62
+ raise ValueError (
63
+ "Setting `sequential_targets` via modifiers is no longer supported, "
64
+ "Please use `oneshot(sequential_targets=...)`"
65
+ )
59
66
60
67
@field_validator ("sparsity_profile" , mode = "before" )
61
68
def validate_sparsity_profile (cls , value : Optional [str ]) -> bool :
@@ -109,12 +116,12 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
109
116
110
117
:param state: session state storing input model and calibration data
111
118
"""
112
- model : torch . nn . Module = state .model
119
+ model : PreTrainedModel = state .model
113
120
dataloader : torch .utils .data .DataLoader = state .data .calib
114
121
115
122
# infer module and sequential targets
116
- self . sequential_targets = self . _infer_sequential_targets ( model )
117
- layers = get_layers (self . sequential_targets , model )
123
+ sequential_targets = model . _get_no_split_modules ( "auto" )
124
+ layers = get_layers (sequential_targets , model )
118
125
self ._target_layers = get_layers (
119
126
self .targets , model
120
127
) # layers containing targets
@@ -191,15 +198,6 @@ def on_end(self, state: State, event: Event, **kwargs):
191
198
self .ended_ = True
192
199
self .remove_hooks ()
193
200
194
- def _infer_sequential_targets (
195
- self , model : torch .nn .Module
196
- ) -> Union [str , List [str ]]:
197
- if self .sequential_targets is None :
198
- return get_no_split_params (model )
199
- if isinstance (self .sequential_targets , str ):
200
- return [self .sequential_targets ]
201
- return self .sequential_targets
202
-
203
201
def _infer_owl_layer_sparsity (
204
202
self ,
205
203
model : torch .nn .Module ,
0 commit comments