22
33from compressed_tensors .transform import TransformScheme , apply_transform_config
44
5- from llmcompressor .core import State
5+ from llmcompressor .core import Event , EventType , State
66from llmcompressor .modifiers import Modifier
77
88from .template .quip import QUIP
@@ -12,9 +12,9 @@ class TransformModifier(Modifier):
1212 preset_config : Optional [str ] = None
1313 config_groups : Optional [Dict [str , TransformScheme ]] = None
1414
15- # model validator to validate both preset and config gropus are not provided
15+ # model validator to validate both preset and config groups are not provided
1616
17- def on_initialize (self , state : State , ** kwargs ):
17+ def on_initialize (self , state : State , ** kwargs ) -> bool :
1818 if self .preset_config is not None :
1919 # import config template and customize to model
2020 pass
@@ -23,4 +23,29 @@ def on_initialize(self, state: State, **kwargs):
2323 config = QUIP
2424
2525 apply_transform_config (state .model , config )
26- breakpoint ()
26+
27+ return True
28+
29+ def on_start (self , state : State , event : Event , ** kwargs ):
30+ self .started_ = True
31+
32+ def on_event (self , state : State , event : Event , ** kwargs ):
33+ if event .type_ == EventType .CALIBRATION_EPOCH_START :
34+ if not self .started_ :
35+ self .on_start (state , None )
36+
37+ elif event .type_ == EventType .SEQUENTIAL_EPOCH_END :
38+ pass
39+
40+ elif event .type_ == EventType .CALIBRATION_EPOCH_END :
41+ if not self .ended_ :
42+ self .on_end (state , None )
43+
44+ def on_end (self , state : State , event : Event , ** kwargs ):
45+ self .ended_ = True
46+
47+ def on_finalize (self , state : State , ** kwargs ) -> bool :
48+ if not self .ended_ :
49+ self .on_end (state , None )
50+
51+ return True
0 commit comments