Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,7 @@ class DatasetArguments(CustomDatasetArguments):
default="independent",
metadata={
"help": "Calibration pipeline used to calibrate model"
"Options: ['basic', 'datafree', 'sequential', 'layer_sequential', "
"independent]"
"Options: ['basic', 'datafree', 'sequential', independent]"
},
)
tracing_ignore: list[str] = field(
Expand Down
10 changes: 7 additions & 3 deletions src/llmcompressor/core/session_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@

import threading
from contextlib import contextmanager
from typing import Any, Generator, Optional
from typing import TYPE_CHECKING, Any, Generator, Optional

from llmcompressor.core.events import EventType
from llmcompressor.core.session import CompressionSession
from llmcompressor.core.state import ModifiedState

if TYPE_CHECKING:
from llmcompressor.pipelines.sequential import Subgraph


__all__ = [
"create_session",
"active_session",
Expand Down Expand Up @@ -150,15 +154,15 @@ def calibration_epoch_start(cls, **kwargs) -> ModifiedState:
return cls.event(EventType.CALIBRATION_EPOCH_START, **kwargs)

@classmethod
def sequential_epoch_end(cls, **kwargs) -> ModifiedState:
def sequential_epoch_end(cls, subgraph: "Subgraph", **kwargs) -> ModifiedState:
"""
Invoke a sequential epoch end event for the active session. This event should be
called after one sequential layer has been calibrated/trained for one epoch

This is called after a sequential layer has been calibrated with one batch, see
`src/llmcompressor/pipelines/sequential/pipeline.py` for usage example
"""
return cls.event(EventType.SEQUENTIAL_EPOCH_END, **kwargs)
return cls.event(EventType.SEQUENTIAL_EPOCH_END, subgraph=subgraph, **kwargs)

@classmethod
def calibration_epoch_end(cls, **kwargs) -> ModifiedState:
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class WandaPruningModifier(SparsityModifierBase):
Lifecycle:
- on_initialize
- register_hook(module, calibrate_module, "forward")
- run_sequential / run_layer_sequential / run_basic
- run_sequential / run_basic
- make_empty_row_scalars
- accumulate_row_scalars
- on_sequential_batch_end
Expand Down
1 change: 0 additions & 1 deletion src/llmcompressor/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,5 @@
from .basic import *
from .data_free import *
from .independent import *
from .layer_sequential import *
from .registry import *
from .sequential import *
2 changes: 0 additions & 2 deletions src/llmcompressor/pipelines/layer_sequential/__init__.py

This file was deleted.

156 changes: 0 additions & 156 deletions src/llmcompressor/pipelines/layer_sequential/helpers.py

This file was deleted.

125 changes: 0 additions & 125 deletions src/llmcompressor/pipelines/layer_sequential/pipeline.py

This file was deleted.

1 change: 1 addition & 0 deletions src/llmcompressor/pipelines/sequential/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
# ruff: noqa
from .pipeline import *
from .helpers import *
8 changes: 8 additions & 0 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ def forward(self, *args, **kwargs) -> Dict[str, Any]:

return outputs

def submodules(self, model: Module, recurse: bool = False) -> Set[Module]:
nodes = self.graph.find_nodes(op="call_module")
modules = set(model.get_submodule(node.target) for node in nodes)
if recurse:
modules = set(module.modules() for module in modules)

return modules


def trace_subgraphs(
model: PreTrainedModel,
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/pipelines/sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __call__(
inputs = activations.fetch(batch_idx, subgraph.input_names)
subgraph.forward(model, **inputs)

LifecycleCallbacks.sequential_epoch_end()
LifecycleCallbacks.sequential_epoch_end(subgraph)

# this pass does not trigger modifier hooks
# and is only used for capturing outputs of newly compressed modules
Expand Down
Loading