diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index f3b4c3c565..a42e596ea7 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -188,8 +188,7 @@ class DatasetArguments(CustomDatasetArguments): default="independent", metadata={ "help": "Calibration pipeline used to calibrate model" - "Options: ['basic', 'datafree', 'sequential', 'layer_sequential', " - "independent]" + "Options: ['basic', 'datafree', 'sequential', independent]" }, ) tracing_ignore: list[str] = field( diff --git a/src/llmcompressor/core/session_functions.py b/src/llmcompressor/core/session_functions.py index af183f606a..d0688a4a5b 100644 --- a/src/llmcompressor/core/session_functions.py +++ b/src/llmcompressor/core/session_functions.py @@ -7,12 +7,16 @@ import threading from contextlib import contextmanager -from typing import Any, Generator, Optional +from typing import TYPE_CHECKING, Any, Generator, Optional from llmcompressor.core.events import EventType from llmcompressor.core.session import CompressionSession from llmcompressor.core.state import ModifiedState +if TYPE_CHECKING: + from llmcompressor.pipelines.sequential import Subgraph + + __all__ = [ "create_session", "active_session", @@ -150,7 +154,7 @@ def calibration_epoch_start(cls, **kwargs) -> ModifiedState: return cls.event(EventType.CALIBRATION_EPOCH_START, **kwargs) @classmethod - def sequential_epoch_end(cls, **kwargs) -> ModifiedState: + def sequential_epoch_end(cls, subgraph: "Subgraph", **kwargs) -> ModifiedState: """ Invoke a sequential epoch end event for the active session. This event should be called after one sequential layer has been calibrated/trained for one epoch @@ -158,7 +162,7 @@ def sequential_epoch_end(cls, **kwargs) -> ModifiedState: This is called after a sequential layer has been calibrated with one batch, see `src/llmcompressor/pipelines/sequential/pipeline.py` for usage example """ - return cls.event(EventType.SEQUENTIAL_EPOCH_END, **kwargs) + return cls.event(EventType.SEQUENTIAL_EPOCH_END, subgraph=subgraph, **kwargs) @classmethod def calibration_epoch_end(cls, **kwargs) -> ModifiedState: diff --git a/src/llmcompressor/modifiers/pruning/wanda/base.py b/src/llmcompressor/modifiers/pruning/wanda/base.py index a63fef6ac6..67eb616889 100644 --- a/src/llmcompressor/modifiers/pruning/wanda/base.py +++ b/src/llmcompressor/modifiers/pruning/wanda/base.py @@ -36,7 +36,7 @@ class WandaPruningModifier(SparsityModifierBase): Lifecycle: - on_initialize - register_hook(module, calibrate_module, "forward") - - run_sequential / run_layer_sequential / run_basic + - run_sequential / run_basic - make_empty_row_scalars - accumulate_row_scalars - on_sequential_batch_end diff --git a/src/llmcompressor/pipelines/__init__.py b/src/llmcompressor/pipelines/__init__.py index 1836e38778..65400a033b 100644 --- a/src/llmcompressor/pipelines/__init__.py +++ b/src/llmcompressor/pipelines/__init__.py @@ -13,6 +13,5 @@ from .basic import * from .data_free import * from .independent import * -from .layer_sequential import * from .registry import * from .sequential import * diff --git a/src/llmcompressor/pipelines/layer_sequential/__init__.py b/src/llmcompressor/pipelines/layer_sequential/__init__.py deleted file mode 100644 index 488b9b1c19..0000000000 --- a/src/llmcompressor/pipelines/layer_sequential/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# ruff: noqa -from .pipeline import * diff --git a/src/llmcompressor/pipelines/layer_sequential/helpers.py b/src/llmcompressor/pipelines/layer_sequential/helpers.py deleted file mode 100644 index 6b12760b4b..0000000000 --- a/src/llmcompressor/pipelines/layer_sequential/helpers.py +++ /dev/null @@ -1,156 +0,0 @@ -import contextlib -import inspect -from dataclasses import dataclass -from typing import Any, Dict, List, Tuple - -import torch -import tqdm -from compressed_tensors.utils.match import match_targets -from torch.nn import Module -from torch.utils.data.dataloader import DataLoader - -from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch -from llmcompressor.pipelines.cache import IntermediatesCache -from llmcompressor.pytorch.utils.helpers import tensors_to_device -from llmcompressor.utils.helpers import calibration_forward_context - -__all__ = [ - "match_modules", - "capture_first_layer_intermediates", - "to_next_layer_kwargs", - "maybe_inject_pos_embeddings", -] - - -def match_modules(model: Module, target_names: List[str]) -> List[Module]: - """ - Find all submodules which match the `target_names` and sort them by name - - :param model: model to search for submodules in - :param target_names: patterns of submodule names to match - :return: list of submodules - """ - names_layers = [ - (name, module) - for name, module in model.named_modules() - if match_targets(name, module, target_names) - ] - - names_layers = sorted(names_layers, key=lambda name_layer: name_layer[0]) - return [layer for _name, layer in names_layers] - - -def capture_first_layer_intermediates( - model: Module, - first_layer: Module, - dataloader: DataLoader, - model_device: torch.device = torch.device("cpu"), - mask_padding: bool = True, -) -> IntermediatesCache: - """ - Captures the intermediate activations directly before the first model layer. - This is meant to capture any model preprocessing before model layers are executed - - Note that if any modules compressed prior to the execution of the first layer, the - compression error induced by compressing those modules will not be propagated to - subsequent activations, as they would be for modules which are compressed within - a layer - - :param model: model containing layers - :param first_layer: the first layer of the model - :param dataloader: dataloader of calibration inputs - :param mask_padding: zero out padding tokens if True. This affects modifiers such as - GPTQ and SparseGPT - """ - intermediates = IntermediatesCache.empty(len(dataloader), torch.device("cpu")) - signature = inspect.signature(first_layer.forward) - - with calibration_forward_context(model), early_stop_hook(first_layer): - desc = "Preparing intermediates cache" - for batch_index, batch in enumerate(tqdm.tqdm(dataloader, desc=desc)): - batch = apply_pad_mask_to_batch(batch) if mask_padding else batch - batch = tensors_to_device(batch, model_device) - - try: - model(**batch) - except EarlyStopException as exception: - layer_args = args_to_kwargs(exception._args, signature) - assert not set(layer_args.keys()) & set(exception._kwargs.keys()) - layer_args.update(exception._kwargs) - - intermediates.update(batch_index, layer_args) - else: - raise ValueError( - "Attempted to capture first layer intermediates, but " - "EarlyStopException was not raised" - ) - - return intermediates - - -def to_next_layer_kwargs(args: Tuple[Any, ...], next_layer: Module) -> Dict[str, Any]: - """ - Convert a list of arguments to a dictionary of keyword arguments which match the - next layer's function signature - - :param args: list of argument values - :param next_layer: the next layer whose function signature must be matched - :return: dictionary mapping function signature keywords to argument values - """ - signature = inspect.signature(next_layer.forward) - return args_to_kwargs(args, signature) - - -def args_to_kwargs( - args: Tuple[Any, ...], signature: inspect.Signature -) -> Dict[str, Any]: - return {name: arg for name, arg in zip(signature.parameters.keys(), args)} - - -@contextlib.contextmanager -def early_stop_hook(module: Module): - def trigger_early_stop_fn(module, args, kwargs): - raise EarlyStopException(_args=args, _kwargs=kwargs) - - handle = module.register_forward_pre_hook(trigger_early_stop_fn, with_kwargs=True) - - try: - yield - finally: - handle.remove() - - -@dataclass -class EarlyStopException(Exception): - """ - Dataclass for storing model activations - - Note: Attribute names `args` and `kwargs` are reserved for `dataclass.GenericAlias` - """ - - _args: Tuple[Any, ...] - _kwargs: Dict[str, Any] - - -def maybe_inject_pos_embeddings( - output: Dict[str, Any], - next_layer: Module, - inputs: Dict[str, Any], -) -> Dict[str, Any]: - """ - As of https://github.com/huggingface/transformers/pull/34858, positional embeddings - must be passed into each decoder call as kwargs - - :param output: output of the previous layer - :param next_layer: next layer to call - :param inputs: inputs to next layer - """ - signature = inspect.signature(next_layer.forward) - if ( - "position_embeddings" in signature.parameters.keys() - and "position_embeddings" in inputs - and "position_embeddings" not in output - ): - output["position_embeddings"] = inputs["position_embeddings"] - - return output diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py deleted file mode 100644 index 244edde87e..0000000000 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ /dev/null @@ -1,125 +0,0 @@ -import contextlib -from typing import TYPE_CHECKING - -import torch -import tqdm -from compressed_tensors.utils import disable_offloading, get_execution_device -from torch.utils.data.dataloader import DataLoader - -from llmcompressor.core import LifecycleCallbacks, active_session -from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.pipelines.cache import IntermediatesCache -from llmcompressor.pipelines.layer_sequential.helpers import ( - capture_first_layer_intermediates, - match_modules, - maybe_inject_pos_embeddings, - to_next_layer_kwargs, -) -from llmcompressor.pipelines.registry import CalibrationPipeline -from llmcompressor.pipelines.sequential.helpers import ( - dispatch_for_sequential, - get_sequential_targets, -) -from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context - -if TYPE_CHECKING: - from llmcompressor.args.dataset_arguments import DatasetArguments - - -__all__ = ["LayerSequentialPipeline"] - - -@CalibrationPipeline.register("layer_sequential") -class LayerSequentialPipeline(CalibrationPipeline): - @staticmethod - def __call__( - model: torch.nn.Module, dataloader: DataLoader, dataset_args: "DatasetArguments" - ): - """ - Run a layer-wise sequential data pipeline according to the following steps: - - 1. Layers are identified according to `sequential_targets` - 2. A hook is attached to the first layer. This hook raises an exception which is - then caught and used to capture the input arguments to the first layer - 3. The inputs to the first layer are used to calibrate the first layer, and the - output of the previous layer is used as inputs to calibrate the next layer - - This pipeline requires that the model have distinct layers defined in its - architecture and that the outputs of the previous layer are exactly the inputs - to the next layer. This is violated by encoder-decoder architectures, among - others. - - If your model architecture violates these assumptions, consider using the - sequential pipeline (see llmcompressor.pipelines.sequential). Architectures - which are known to fail these assumptions include GPT-J and most vision models - - :param model: model being calibrated - :param dataloader: loads data for calibration - :param dataset_args: dataset arguments relevant to pipelines - """ - session = active_session() - - # prepare model for sequential onloading - dispatch_for_sequential(model) - model_device = get_execution_device(model) - - # find layers - modifiers = session.lifecycle.recipe.modifiers - sequential_targets = get_sequential_targets(modifiers, model, dataset_args) - layers = match_modules(model, sequential_targets) - - LifecycleCallbacks.calibration_epoch_start() - - # TODO: remove this to enable quantization aware calibration for GPTQ and AWQ - disable_qac = any( - type(mod).__name__ in ["GPTQModifier", "AWQModifier"] - for mod in session.lifecycle.recipe.modifiers - ) - - with contextlib.ExitStack() as stack: - stack.enter_context(calibration_forward_context(model)) - if not dataset_args.quantization_aware_calibration or disable_qac: - stack.enter_context(DisableQuantization(model)) - - # prepare intermediates cache - intermediates: IntermediatesCache = capture_first_layer_intermediates( - model, layers[0], dataloader, model_device - ) - - num_layers = len(layers) - for layer_index, layer in enumerate(layers): - # prepare tqdm description texts - calib_desc = f"({layer_index + 1}/{num_layers}): Calibrating" - prop_desc = f"({layer_index + 1}/{num_layers}): Propagating" - - # reduce memory movement by keeping modules onloaded - with disable_offloading(): - # do a preliminary pass to trigger modifier hooks - for batch_idx in tqdm.tqdm(range(len(dataloader)), desc=calib_desc): - inputs = intermediates.fetch(batch_idx) - layer(**inputs) - - LifecycleCallbacks.sequential_epoch_end() - - # this pass does not trigger modifier hooks - # and is only used for capturing outputs from - # newly compressed modules - with HooksMixin.disable_hooks(): - for batch_idx in tqdm.tqdm( - range(len(dataloader)), desc=prop_desc - ): - inputs = intermediates.fetch(batch_idx) - output = layer(**inputs) - - if layer_index < num_layers - 1: - next_layer = layers[layer_index + 1] - output = to_next_layer_kwargs(output, next_layer) - output = maybe_inject_pos_embeddings( - output, next_layer, inputs - ) - - intermediates.delete(batch_idx) - intermediates.update(batch_idx, output) - - # redundant, finish any remaining compression - LifecycleCallbacks.calibration_epoch_end() diff --git a/src/llmcompressor/pipelines/sequential/__init__.py b/src/llmcompressor/pipelines/sequential/__init__.py index 488b9b1c19..c89a842fc9 100644 --- a/src/llmcompressor/pipelines/sequential/__init__.py +++ b/src/llmcompressor/pipelines/sequential/__init__.py @@ -1,2 +1,3 @@ # ruff: noqa from .pipeline import * +from .helpers import * diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index cbec3201df..192ead4f77 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -79,6 +79,14 @@ def forward(self, *args, **kwargs) -> Dict[str, Any]: return outputs + def submodules(self, model: Module, recurse: bool = False) -> Set[Module]: + nodes = self.graph.find_nodes(op="call_module") + modules = set(model.get_submodule(node.target) for node in nodes) + if recurse: + modules = set(m for module in modules for m in module.modules()) + + return modules + def trace_subgraphs( model: PreTrainedModel, diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 261afd6544..e95ffa915f 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -99,7 +99,7 @@ def __call__( inputs = activations.fetch(batch_idx, subgraph.input_names) subgraph.forward(model, **inputs) - LifecycleCallbacks.sequential_epoch_end() + LifecycleCallbacks.sequential_epoch_end(subgraph) # this pass does not trigger modifier hooks # and is only used for capturing outputs of newly compressed modules