diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 577cf92a2..fa4990bf5 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -203,6 +203,7 @@ class DatasetArguments(CustomDatasetArguments): "_prepare_fsmt_decoder_inputs", "_prepare_4d_causal_attention_mask_with_cache_position", "_update_linear_attn_mask", + "get_placeholder_mask", ], metadata={ "help": "List of functions to ignore during tracing, either " diff --git a/src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py b/src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py index 2e78994e4..6e68b1c61 100644 --- a/src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py +++ b/src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py @@ -1,6 +1,6 @@ import ast from types import FunctionType, MethodType -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from loguru import logger @@ -101,12 +101,28 @@ def visit_If(self, node: ast.If) -> Union[ast.If, ast.Assign]: `if` statement with the condition replaced by `True` or `False`. Otherwise, return a wrapper function call """ + # HACK: common case where vision models will have code that looks like this: + # ``` + # if pixel_values is not None: + # image_embeds = self.get_image_features(pixel_values, image_grid_thw) + # ``` + # if the body or else body calls `get_image_features`, do not autowrap + for subnode in ast.walk(node): + if ( + isinstance(subnode, ast.Call) + and self._get_caller_name(subnode) == "get_image_features" + ): + return super().generic_visit(node) + + # try to evaluate condition statically try: value = bool(self._eval_expr(node.test)) + # wrap if cannot be evaluated statically except Exception: return self._wrap_if_possible(node) + # replace with evaluated value else: node.test = ast.Constant(value=value) return super().generic_visit(node) @@ -129,18 +145,8 @@ def visit_Call(self, node: ast.Call) -> ast.Call: return self._wrap_if_possible(node) # attempt to evaluate caller and check against ignore list - try: - caller = self._eval_expr(node.func) - - except Exception: - caller = None - - finally: - if ( - isinstance(caller, (FunctionType, MethodType)) - and caller.__name__ in self.ignore - ): - return self._wrap_if_possible(node) + if self._get_caller_name(node) in self.ignore: + return self._wrap_if_possible(node) return super().generic_visit(node) @@ -276,3 +282,15 @@ def _wrap_expr(self, node: ast.expr) -> ast.Call: fn_call = wrapped.value return fn_call + + def _get_caller_name(self, node: ast.Call) -> Optional[str]: + try: + caller = self._eval_expr(node.func) + + except Exception: + caller = None + + finally: + if isinstance(caller, (FunctionType, MethodType)): + return caller.__name__ + return None diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index cbec3201d..259fe38de 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -2,7 +2,17 @@ import inspect from collections import deque from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generator, + List, + Optional, + Set, + Tuple, +) import torch from accelerate.hooks import remove_hook_from_module @@ -79,6 +89,19 @@ def forward(self, *args, **kwargs) -> Dict[str, Any]: return outputs + def get_modules( + self, model: Module, recurse: bool = True + ) -> Generator[Module, None, None]: + memo = set() + for node in self.graph.find_nodes(op="call_module"): + submodule = model.get_submodule(node.target) + + modules = submodule.modules() if recurse else (submodule,) + for m in modules: + if m not in memo: + memo.add(m) + yield m + def trace_subgraphs( model: PreTrainedModel, diff --git a/tests/llmcompressor/transformers/tracing/test_models.py b/tests/llmcompressor/transformers/tracing/test_models.py index ded1dffda..2c934fd55 100644 --- a/tests/llmcompressor/transformers/tracing/test_models.py +++ b/tests/llmcompressor/transformers/tracing/test_models.py @@ -1,6 +1,5 @@ -import os - import pytest +from compressed_tensors import match_named_modules from transformers import ( AutoModelForCausalLM, Gemma3ForConditionalGeneration, @@ -14,15 +13,14 @@ WhisperForConditionalGeneration, ) -from llmcompressor.pipelines.sequential.helpers import match_modules from llmcompressor.transformers.tracing.debug import trace from llmcompressor.utils.pytorch.module import get_no_split_params -@pytest.mark.skipif( - (not os.getenv("HF_TOKEN")), - reason="Skipping tracing tests requiring gated model access", -) +# @pytest.mark.skipif( +# (not os.getenv("HF_TOKEN")), +# reason="Skipping tracing tests requiring gated model access", +# ) @pytest.mark.parametrize( "model_id,model_class,targets,modality,backends", [ @@ -111,7 +109,7 @@ ( "meta-llama/Llama-4-Scout-17B-16E-Instruct", Llama4ForConditionalGeneration, - "Llama4TextDecoderLayer", + ["Llama4TextDecoderLayer", "Llama4VisionEncoderLayer"], "vision", [], ), @@ -148,6 +146,10 @@ def test_model_trace(model_id, model_class, targets, modality, backends): target_modules = get_target_modules(model, targets) assert len(subgraphs) == len(target_modules) + 1 + for i, (subgraph, (name, module)) in enumerate(zip(subgraphs, target_modules)): + subgraph_modules = list(subgraph.get_modules(model, recurse=False)) + assert module in subgraph_modules, f"Could not find {name} in subgraph #{i}" + def get_target_modules(model, sequential_targets): if sequential_targets is None: @@ -155,7 +157,7 @@ def get_target_modules(model, sequential_targets): if isinstance(sequential_targets, str): sequential_targets = [sequential_targets] - return match_modules(model, sequential_targets) + return list(match_named_modules(model, sequential_targets)) def run_subgraphs(model, subgraphs, inputs):