Skip to content

Commit e7a28ad

Browse files
kylesayrsgemini-code-assist[bot]dsikka
authored
[Sequential Pipeline] Return subgraph on sequential_epoch_end, remove layer_sequential pipeline (#1998)
## Purpose ## * Enable better targeting of modules by modifiers such as [AutoRound](#1994) * Remove legacy pipeline (which is incompatible with this change) ## Changes ## * Pass subgraph to `sequential_epoch_end`, allowing modifiers to view all of the module that were called in the subgraph * Implement `submodules` method on `Subgraph` which returns all the modules called by this subgraph * Remove `LayerSequentialPipeline`, which does not use the `Subgraph` API and has been superseded by the sequential pipeline --------- Signed-off-by: Kyle Sayers <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Dipika Sikka <[email protected]>
1 parent 15b29be commit e7a28ad

File tree

10 files changed

+19
-291
lines changed

10 files changed

+19
-291
lines changed

src/llmcompressor/args/dataset_arguments.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,7 @@ class DatasetArguments(CustomDatasetArguments):
188188
default="independent",
189189
metadata={
190190
"help": "Calibration pipeline used to calibrate model"
191-
"Options: ['basic', 'datafree', 'sequential', 'layer_sequential', "
192-
"independent]"
191+
"Options: ['basic', 'datafree', 'sequential', independent]"
193192
},
194193
)
195194
tracing_ignore: list[str] = field(

src/llmcompressor/core/session_functions.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77

88
import threading
99
from contextlib import contextmanager
10-
from typing import Any, Generator, Optional
10+
from typing import TYPE_CHECKING, Any, Generator, Optional
1111

1212
from llmcompressor.core.events import EventType
1313
from llmcompressor.core.session import CompressionSession
1414
from llmcompressor.core.state import ModifiedState
1515

16+
if TYPE_CHECKING:
17+
from llmcompressor.pipelines.sequential import Subgraph
18+
19+
1620
__all__ = [
1721
"create_session",
1822
"active_session",
@@ -150,15 +154,15 @@ def calibration_epoch_start(cls, **kwargs) -> ModifiedState:
150154
return cls.event(EventType.CALIBRATION_EPOCH_START, **kwargs)
151155

152156
@classmethod
153-
def sequential_epoch_end(cls, **kwargs) -> ModifiedState:
157+
def sequential_epoch_end(cls, subgraph: "Subgraph", **kwargs) -> ModifiedState:
154158
"""
155159
Invoke a sequential epoch end event for the active session. This event should be
156160
called after one sequential layer has been calibrated/trained for one epoch
157161
158162
This is called after a sequential layer has been calibrated with one batch, see
159163
`src/llmcompressor/pipelines/sequential/pipeline.py` for usage example
160164
"""
161-
return cls.event(EventType.SEQUENTIAL_EPOCH_END, **kwargs)
165+
return cls.event(EventType.SEQUENTIAL_EPOCH_END, subgraph=subgraph, **kwargs)
162166

163167
@classmethod
164168
def calibration_epoch_end(cls, **kwargs) -> ModifiedState:

src/llmcompressor/modifiers/pruning/wanda/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class WandaPruningModifier(SparsityModifierBase):
3636
Lifecycle:
3737
- on_initialize
3838
- register_hook(module, calibrate_module, "forward")
39-
- run_sequential / run_layer_sequential / run_basic
39+
- run_sequential / run_basic
4040
- make_empty_row_scalars
4141
- accumulate_row_scalars
4242
- on_sequential_batch_end

src/llmcompressor/pipelines/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,5 @@
1313
from .basic import *
1414
from .data_free import *
1515
from .independent import *
16-
from .layer_sequential import *
1716
from .registry import *
1817
from .sequential import *

src/llmcompressor/pipelines/layer_sequential/__init__.py

Lines changed: 0 additions & 2 deletions
This file was deleted.

src/llmcompressor/pipelines/layer_sequential/helpers.py

Lines changed: 0 additions & 156 deletions
This file was deleted.

src/llmcompressor/pipelines/layer_sequential/pipeline.py

Lines changed: 0 additions & 125 deletions
This file was deleted.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
# ruff: noqa
22
from .pipeline import *
3+
from .helpers import *

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ def forward(self, *args, **kwargs) -> Dict[str, Any]:
7979

8080
return outputs
8181

82+
def submodules(self, model: Module, recurse: bool = False) -> Set[Module]:
83+
nodes = self.graph.find_nodes(op="call_module")
84+
modules = set(model.get_submodule(node.target) for node in nodes)
85+
if recurse:
86+
modules = set(m for module in modules for m in module.modules())
87+
88+
return modules
89+
8290

8391
def trace_subgraphs(
8492
model: PreTrainedModel,

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def __call__(
9999
inputs = activations.fetch(batch_idx, subgraph.input_names)
100100
subgraph.forward(model, **inputs)
101101

102-
LifecycleCallbacks.sequential_epoch_end()
102+
LifecycleCallbacks.sequential_epoch_end(subgraph)
103103

104104
# this pass does not trigger modifier hooks
105105
# and is only used for capturing outputs of newly compressed modules

0 commit comments

Comments
 (0)