Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class DatasetArguments(CustomDatasetArguments):
"_prepare_fsmt_decoder_inputs",
"_prepare_4d_causal_attention_mask_with_cache_position",
"_update_linear_attn_mask",
"get_placeholder_mask",
],
metadata={
"help": "List of functions to ignore during tracing, either "
Expand Down
44 changes: 31 additions & 13 deletions src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ast
from types import FunctionType, MethodType
from typing import Any, Dict, List, Union
from typing import Any, Dict, List, Optional, Union

from loguru import logger

Expand Down Expand Up @@ -101,12 +101,28 @@ def visit_If(self, node: ast.If) -> Union[ast.If, ast.Assign]:
`if` statement with the condition replaced by `True` or `False`.
Otherwise, return a wrapper function call
"""
# HACK: common case where vision models will have code that looks like this:
# ```
# if pixel_values is not None:
# image_embeds = self.get_image_features(pixel_values, image_grid_thw)
# ```
# if the body or else body calls `get_image_features`, do not autowrap
for subnode in ast.walk(node):
if (
isinstance(subnode, ast.Call)
and self._get_caller_name(subnode) == "get_image_features"
):
return super().generic_visit(node)

# try to evaluate condition statically
try:
value = bool(self._eval_expr(node.test))

# wrap if cannot be evaluated statically
except Exception:
return self._wrap_if_possible(node)

# replace with evaluated value
else:
node.test = ast.Constant(value=value)
return super().generic_visit(node)
Expand All @@ -129,18 +145,8 @@ def visit_Call(self, node: ast.Call) -> ast.Call:
return self._wrap_if_possible(node)

# attempt to evaluate caller and check against ignore list
try:
caller = self._eval_expr(node.func)

except Exception:
caller = None

finally:
if (
isinstance(caller, (FunctionType, MethodType))
and caller.__name__ in self.ignore
):
return self._wrap_if_possible(node)
if self._get_caller_name(node) in self.ignore:
return self._wrap_if_possible(node)

return super().generic_visit(node)

Expand Down Expand Up @@ -276,3 +282,15 @@ def _wrap_expr(self, node: ast.expr) -> ast.Call:
fn_call = wrapped.value

return fn_call

def _get_caller_name(self, node: ast.Call) -> Optional[str]:
try:
caller = self._eval_expr(node.func)

except Exception:
caller = None

finally:
if isinstance(caller, (FunctionType, MethodType)):
return caller.__name__
return None
25 changes: 24 additions & 1 deletion src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@
import inspect
from collections import deque
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
List,
Optional,
Set,
Tuple,
)

import torch
from accelerate.hooks import remove_hook_from_module
Expand Down Expand Up @@ -79,6 +89,19 @@ def forward(self, *args, **kwargs) -> Dict[str, Any]:

return outputs

def get_modules(
self, model: Module, recurse: bool = True
) -> Generator[Module, None, None]:
memo = set()
for node in self.graph.find_nodes(op="call_module"):
submodule = model.get_submodule(node.target)

modules = submodule.modules() if recurse else (submodule,)
for m in modules:
if m not in memo:
memo.add(m)
yield m


def trace_subgraphs(
model: PreTrainedModel,
Expand Down
20 changes: 11 additions & 9 deletions tests/llmcompressor/transformers/tracing/test_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os

import pytest
from compressed_tensors import match_named_modules
from transformers import (
AutoModelForCausalLM,
Gemma3ForConditionalGeneration,
Expand All @@ -14,15 +13,14 @@
WhisperForConditionalGeneration,
)

from llmcompressor.pipelines.sequential.helpers import match_modules
from llmcompressor.transformers.tracing.debug import trace
from llmcompressor.utils.pytorch.module import get_no_split_params


@pytest.mark.skipif(
(not os.getenv("HF_TOKEN")),
reason="Skipping tracing tests requiring gated model access",
)
# @pytest.mark.skipif(
# (not os.getenv("HF_TOKEN")),
# reason="Skipping tracing tests requiring gated model access",
# )
@pytest.mark.parametrize(
"model_id,model_class,targets,modality,backends",
[
Expand Down Expand Up @@ -111,7 +109,7 @@
(
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
Llama4ForConditionalGeneration,
"Llama4TextDecoderLayer",
["Llama4TextDecoderLayer", "Llama4VisionEncoderLayer"],
"vision",
[],
),
Expand Down Expand Up @@ -148,14 +146,18 @@ def test_model_trace(model_id, model_class, targets, modality, backends):
target_modules = get_target_modules(model, targets)
assert len(subgraphs) == len(target_modules) + 1

for i, (subgraph, (name, module)) in enumerate(zip(subgraphs, target_modules)):
subgraph_modules = list(subgraph.get_modules(model, recurse=False))
assert module in subgraph_modules, f"Could not find {name} in subgraph #{i}"


def get_target_modules(model, sequential_targets):
if sequential_targets is None:
sequential_targets = get_no_split_params(model)
if isinstance(sequential_targets, str):
sequential_targets = [sequential_targets]

return match_modules(model, sequential_targets)
return list(match_named_modules(model, sequential_targets))


def run_subgraphs(model, subgraphs, inputs):
Expand Down