diff --git a/src/llmcompressor/modeling/deepseek_v3.py b/src/llmcompressor/modeling/deepseek_v3.py index 287b343bd..8cc7f4727 100644 --- a/src/llmcompressor/modeling/deepseek_v3.py +++ b/src/llmcompressor/modeling/deepseek_v3.py @@ -10,12 +10,18 @@ class DeepseekV3MoECalibrate(torch.nn.Module): Patched DeepseekV3MoE which sends all tokens to all experts for calibration """ - def __init__(self, config: DeepseekV3Config, original: OriginalDeepseekV3MoE): + def __init__( + self, + config: DeepseekV3Config, + original: OriginalDeepseekV3MoE, + calibrate_all_experts: bool, + ): super().__init__() self.config = config self.experts = original.experts self.gate = original.gate self.shared_experts = original.shared_experts + self.calibrate_all_experts = calibrate_all_experts def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: residuals = hidden_states @@ -30,18 +36,28 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ) expert_mask = expert_mask.permute(2, 0, 1) - for expert_idx in range(len(self.experts)): - expert = self.experts[expert_idx] - mask = expert_mask[expert_idx] - token_indices, weight_indices = torch.where(mask) + for expert_idx, expert in enumerate(self.experts): + token_indices, weight_indices = torch.where(expert_mask[expert_idx]) + has_tokens = token_indices.numel() > 0 - expert_weights = topk_weights[token_indices, weight_indices] - expert_input = hidden_states[token_indices] - expert_output = expert(expert_input) - weighted_output = expert_output * expert_weights.unsqueeze(-1) + if self.calibrate_all_experts: + expert_input = hidden_states + expert_output = expert(expert_input) - if token_indices.numel() > 0: - final_hidden_states.index_add_(0, token_indices, weighted_output) + if has_tokens: + expert_weights = topk_weights[token_indices, weight_indices] + routed_output = expert_output[ + token_indices + ] * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, routed_output) + else: + # Normal MoE: only process tokens routed to this expert + if has_tokens: + expert_input = hidden_states[token_indices] + expert_output = expert(expert_input) + expert_weights = topk_weights[token_indices, weight_indices] + routed_output = expert_output * expert_weights.unsqueeze(-1) + final_hidden_states.index_add_(0, token_indices, routed_output) # End MoE hidden_states = final_hidden_states.type(hidden_states.dtype).view(*orig_shape) @@ -49,5 +65,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -def replace(config: DeepseekV3Config, module: OriginalDeepseekV3MoE): - return DeepseekV3MoECalibrate(config=config, original=module) +def replace( + config: DeepseekV3Config, + module: OriginalDeepseekV3MoE, + calibrate_all_experts: bool, +): + return DeepseekV3MoECalibrate( + config=config, original=module, calibrate_all_experts=calibrate_all_experts + ) diff --git a/src/llmcompressor/modeling/llama4.py b/src/llmcompressor/modeling/llama4.py index fee1a5afd..10f4f2033 100644 --- a/src/llmcompressor/modeling/llama4.py +++ b/src/llmcompressor/modeling/llama4.py @@ -1,8 +1,6 @@ from typing import Tuple import torch -import transformers -from packaging import version from transformers.models.llama4.configuration_llama4 import ( Llama4Config, Llama4TextConfig, @@ -17,39 +15,57 @@ class SequentialLlama4TextMoe(torch.nn.Module): - def __init__(self, config: Llama4TextConfig, original: Llama4TextMoe): + def __init__( + self, + config: Llama4TextConfig, + original: Llama4TextMoe, + calibrate_all_experts: bool, + ): super().__init__() self.top_k = config.num_experts_per_tok self.hidden_dim = config.hidden_size self.num_experts = config.num_local_experts + self.experts = SequentialLlama4TextExperts(config, original.experts) self.router = original.router self.shared_expert = original.shared_expert + self.calibrate_all_experts = calibrate_all_experts def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.tensor]: hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.router(hidden_states) + router_outputs = self.router(hidden_states) + # support transformers 4.53 and greater - if isinstance(router_logits, tuple): - router_logits = router_logits[-1] + if isinstance(router_outputs, tuple): + router_scores, router_logits = router_outputs + else: + router_top_value, router_indices = torch.topk( + router_logits, self.top_k, dim=1 + ) + router_logits = router_outputs + router_scores = ( + torch.full_like(router_logits, float("-inf")) + .scatter_(1, router_indices, router_top_value) + .transpose(0, 1) + ) + router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) + out = self.shared_expert(hidden_states) + for expert_index in range(self.num_experts): + top_token_mask = router_scores[:, expert_index] > 0 - router_scores = ( - torch.full_like(router_logits, float("-inf")) - .scatter_(1, router_indices, router_top_value) - .transpose(0, 1) - ) - router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) + if self.calibrate_all_experts: + # Run all tokens for calibration + expert_out = self.experts[expert_index](hidden_states)[top_token_mask] + else: + expert_out = self.experts[expert_index](hidden_states[top_token_mask]) - out = self.shared_expert(hidden_states) - for i in range(self.num_experts): - out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1) + # Only top-k tokens contribute to final output + if top_token_mask.any(): + expert_score = router_scores[top_token_mask, expert_index].unsqueeze(-1) + out[top_token_mask] += expert_out * expert_score - if version.parse(transformers.__version__) >= version.parse("4.54.0"): - return out, router_logits - else: - return out, router_scores + return out, router_scores class SequentialLlama4TextExperts(torch.nn.ModuleList): @@ -72,5 +88,9 @@ def __init__(self, config: Llama4TextConfig, original: Llama4TextExperts): self[i].down_proj.weight.data = down.t().clone().contiguous() -def replace(config: Llama4Config, module: Llama4TextMoe): - return SequentialLlama4TextMoe(config=config.get_text_config(), original=module) +def replace(config: Llama4Config, module: Llama4TextMoe, calibrate_all_experts: bool): + return SequentialLlama4TextMoe( + config=config.get_text_config(), + original=module, + calibrate_all_experts=calibrate_all_experts, + ) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index cb61f5fad..113fd4364 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,3 +1,4 @@ +import tqdm from compressed_tensors.utils import replace_module from transformers import PreTrainedModel @@ -15,11 +16,18 @@ } -def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel: - for name, module in model.named_modules(): +def replace_modules_for_calibration( + model: PreTrainedModel, + calibrate_all_experts: bool = True, +) -> PreTrainedModel: + for name, module in tqdm.tqdm(list(model.named_modules())): cls_name = module.__class__.__name__ if cls_name in replacements: - new_module = replacements[cls_name](config=model.config, module=module) + new_module = replacements[cls_name]( + config=model.config, + module=module, + calibrate_all_experts=calibrate_all_experts, + ) replace_module(model, name, new_module) return model @@ -28,7 +36,7 @@ def replace_modules_for_calibration(model: PreTrainedModel) -> PreTrainedModel: # ------------------- module replacements; during calibration -------------------- -def update_qwen3_moe(model, stack): +def update_qwen3_moe(model, stack, calibrate_all_experts): for module in model.modules(): cls_name = module.__class__.__name__ if cls_name == "Qwen3MoeDecoderLayer": @@ -37,7 +45,11 @@ def update_qwen3_moe(model, stack): patch_attr( module, "mlp", - replace_Qwen3MoE(config=model.config, module=module.mlp), + replace_Qwen3MoE( + config=model.config, + module=module.mlp, + calibrate_all_experts=calibrate_all_experts, + ), ) ) @@ -47,9 +59,13 @@ def update_qwen3_moe(model, stack): } -def moe_calibration_context(model: PreTrainedModel, stack): +def moe_calibration_context( + model: PreTrainedModel, + stack, + calibrate_all_experts: bool = False, +): # Temporarily updates the MoE modules within the context # Once the context exists, parameter updates persist cls_name = model.__class__.__name__ if cls_name in moe_context: - moe_context.get(cls_name)(model, stack) + moe_context.get(cls_name)(model, stack, calibrate_all_experts) diff --git a/src/llmcompressor/modeling/qwen3_moe.py b/src/llmcompressor/modeling/qwen3_moe.py index fcd5d9925..2b451bc49 100644 --- a/src/llmcompressor/modeling/qwen3_moe.py +++ b/src/llmcompressor/modeling/qwen3_moe.py @@ -23,14 +23,17 @@ class Qwen3MoeSparseMoeBlock(torch.nn.Module): def __init__( - self, config: Qwen3MoeConfig, original: OriginalQwen3MoeSparseMoeBlock + self, + config: Qwen3MoeConfig, + original: OriginalQwen3MoeSparseMoeBlock, + calibrate_all_experts: bool, ): super().__init__() self.num_experts = config.num_experts - self.top_k = config.top_k + self.top_k = config.num_experts_per_tok self.norm_topk_prob = config.norm_topk_prob - # gating + self.calibrate_all_experts = calibrate_all_experts self.gate = original.gate self.experts = original.experts @@ -50,6 +53,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) + final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, @@ -62,20 +66,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: selected_experts, num_classes=self.num_experts ).permute(2, 1, 0) - for expert_idx in range(len(self.experts)): - expert_layer = self.experts[expert_idx] + for expert_idx, expert_layer in enumerate(self.experts): idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - expert_output = expert_layer(current_state) - current_hidden_states = expert_output * routing_weights[top_x, idx, None] - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_( - 0, top_x, current_hidden_states.to(hidden_states.dtype) - ) + + if self.calibrate_all_experts: + expert_out = expert_layer(hidden_states)[top_x] + else: + expert_out = expert_layer(hidden_states[top_x]) + + # TODO: double check + if len(top_x) > 0: + current_hidden_states = expert_out * routing_weights[top_x, idx, None] + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) final_hidden_states = final_hidden_states.reshape( batch_size, sequence_length, hidden_dim @@ -83,5 +87,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states, router_logits -def replace(config: Qwen3MoeConfig, module: OriginalQwen3MoeSparseMoeBlock): - return Qwen3MoeSparseMoeBlock(config=config, original=module) +def replace( + config: Qwen3MoeConfig, + module: OriginalQwen3MoeSparseMoeBlock, + calibrate_all_experts: bool, +): + return Qwen3MoeSparseMoeBlock( + config=config, original=module, calibrate_all_experts=calibrate_all_experts + ) diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 1eeb52ca6..9aaae59eb 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -61,7 +61,7 @@ "is_package_available", "import_from_path", "getattr_chain", - "DisableKVCache", + "disable_cache", "DisableQuantization", "eval_context", "calibration_forward_context", @@ -974,7 +974,8 @@ def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any: return res -class DisableKVCache: +@contextlib.contextmanager +def disable_cache(module: torch.nn.Module): """ Temporarily disable the key-value cache for transformer models. Used to prevent excess memory use in one-shot cases where the model only performs the prefill @@ -983,32 +984,18 @@ class DisableKVCache: Example: >>> model = AutoModel.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") >>> input = torch.randint(0, 32, size=(1, 32)) - >>> with DisableKVCache(model): + >>> with disable_cache(model): ... output = model(input) """ - def __init__(self, model: PreTrainedModel): - if hasattr(model.config, "use_cache"): - self.config = model.config - - # MllamaConfig - elif hasattr(model.config, "text_config") and hasattr( - model.config.text_config, "use_cache" - ): - self.config = model.config.text_config - - # unknown config structure - else: - raise NotImplementedError(f"Cannot find `use_cache` for {model.config}") - - self.restore_value = self.config.use_cache - - def __enter__(self): - self.restore_value = self.config.use_cache - self.config.use_cache = False + if isinstance(module, PreTrainedModel): + config = module.config + config = getattr(config, "text_config", config) + with patch_attr(config, "use_cache", False): + yield - def __exit__(self, _exc_type, _exc_val, _exc_tb): - self.config.use_cache = self.restore_value + else: + yield @contextlib.contextmanager @@ -1038,14 +1025,14 @@ def eval_context(module: torch.nn.Module): @contextlib.contextmanager -def disable_hf_kernels(model: PreTrainedModel): +def disable_hf_kernels(module: torch.nn.Module): """ In transformers>=4.50.0, some module forward methods may be replaced by calls to hf hub kernels. This has the potential to bypass hooks added by LLM Compressor """ - if hasattr(model, "config"): - with patch_attr(model.config, "disable_custom_kernels", True): + if isinstance(module, PreTrainedModel): + with patch_attr(module.config, "disable_custom_kernels", True): yield else: @@ -1053,7 +1040,7 @@ def disable_hf_kernels(model: PreTrainedModel): @contextlib.contextmanager -def calibration_forward_context(model: PreTrainedModel): +def calibration_forward_context(model: torch.nn.Module): """ Context in which all calibration forward passes should occur. @@ -1062,9 +1049,9 @@ def calibration_forward_context(model: PreTrainedModel): - Disable train mode and enable eval mode - Disable hf kernels which could bypass hooks """ - with torch.no_grad(), DisableKVCache(model), eval_context( + with torch.no_grad(), disable_cache(model), eval_context(model), disable_hf_kernels( model - ), disable_hf_kernels(model): + ): yield diff --git a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py new file mode 100644 index 000000000..ca7fb06af --- /dev/null +++ b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py @@ -0,0 +1,81 @@ +from functools import partial + +import pytest +import torch +from transformers import AutoModelForCausalLM + +from llmcompressor.modeling.deepseek_v3 import ( + DeepseekV3Config, + DeepseekV3MoECalibrate, + OriginalDeepseekV3MoE, +) +from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.utils.dev import skip_weights_download +from llmcompressor.utils.helpers import calibration_forward_context +from tests.testing_utils import requires_cadence, requires_gpu + + +@requires_cadence("weekly") +@pytest.mark.parametrize("model_stub", ["unsloth/DeepSeek-R1-0528-BF16"]) +def test_calib_replace_deepseekv3moe_all_experts(model_stub): + with skip_weights_download(): + model = AutoModelForCausalLM.from_pretrained(model_stub) + + replace_modules_for_calibration(model, calibrate_all_experts=True) + + # Find a Deepseek MoE layer + moe_layer = None + for _, module in model.named_modules(): + if isinstance(module, DeepseekV3MoECalibrate): + moe_layer = module + break + + assert moe_layer is not None + + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] + + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True + + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) + + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) + + # Assert all experts are used + assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + + +@requires_gpu +def test_calib_deepseekv3_module(): + config = DeepseekV3Config() + with torch.device("cuda"): + original = OriginalDeepseekV3MoE(config).eval() + + # Create dummy input tensor that simulates hidden_states + hidden_dim = config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, device="cuda") + + with calibration_forward_context(original): + true_output = original(sample)[0] + + module = DeepseekV3MoECalibrate(config, original, calibrate_all_experts=True) + with calibration_forward_context(module): + output = module(sample)[0] + assert torch.allclose(true_output, output, atol=1e-6) + + module = DeepseekV3MoECalibrate(config, original, calibrate_all_experts=False) + with calibration_forward_context(module): + output = module(sample)[0] + assert torch.allclose(true_output, output, atol=1e-6) diff --git a/tests/llmcompressor/modeling/test_calib_llama4.py b/tests/llmcompressor/modeling/test_calib_llama4.py new file mode 100644 index 000000000..4eb609ca9 --- /dev/null +++ b/tests/llmcompressor/modeling/test_calib_llama4.py @@ -0,0 +1,88 @@ +import os +from functools import partial + +import pytest +import torch +from transformers import Llama4ForConditionalGeneration + +from llmcompressor.modeling.llama4 import ( + Llama4TextConfig, + Llama4TextMoe, + SequentialLlama4TextMoe, +) +from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.utils.dev import skip_weights_download +from llmcompressor.utils.helpers import calibration_forward_context +from tests.testing_utils import requires_cadence, requires_gpu + + +@requires_cadence("weekly") +@pytest.mark.skipif( + (not os.getenv("HF_TOKEN")), + reason="Skipping tracing tests requiring gated model access", +) +@pytest.mark.parametrize("model_stub", ["meta-llama/Llama-4-Scout-17B-16E-Instruct"]) +def test_calib_replace_llama4_moe_all_experts(model_stub): + with skip_weights_download(Llama4ForConditionalGeneration): + model = Llama4ForConditionalGeneration.from_pretrained( + model_stub, torch_dtype="auto" + ) + + replace_modules_for_calibration(model, calibrate_all_experts=True) + + # Find a Llama4 MoE layer + moe_layer = None + for module in model.modules(): + if isinstance(module, SequentialLlama4TextMoe): + moe_layer = module + break + + assert moe_layer is not None + + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] + + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True + + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) + + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.text_config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) + + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) + + # Assert all experts are used + assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + + +@requires_gpu +def test_calib_llama4_module(): + config = Llama4TextConfig() + with torch.device("cuda"): + original = Llama4TextMoe(config) + + # Create dummy input tensor that simulates hidden_states + hidden_dim = config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, device="cuda") + + with calibration_forward_context(original): + true_output = original(sample)[0] + + module = SequentialLlama4TextMoe(config, original, calibrate_all_experts=True) + with calibration_forward_context(module): + output = module(sample)[0] + assert torch.allclose(true_output, output, atol=1e-6) + + module = SequentialLlama4TextMoe(config, original, calibrate_all_experts=False) + with calibration_forward_context(module): + output = module(sample)[0] + assert torch.allclose(true_output, output, atol=1e-6) diff --git a/tests/llmcompressor/modeling/test_calib_qwen3.py b/tests/llmcompressor/modeling/test_calib_qwen3.py new file mode 100644 index 000000000..822af18db --- /dev/null +++ b/tests/llmcompressor/modeling/test_calib_qwen3.py @@ -0,0 +1,89 @@ +import contextlib +from functools import partial + +import pytest +import torch +from transformers import AutoModelForCausalLM + +from llmcompressor.modeling.prepare import moe_calibration_context +from llmcompressor.modeling.qwen3_moe import ( + OriginalQwen3MoeSparseMoeBlock, + Qwen3MoeConfig, + Qwen3MoeSparseMoeBlock, +) +from llmcompressor.utils.dev import skip_weights_download +from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context +from tests.testing_utils import requires_cadence, requires_gpu + + +@requires_cadence("weekly") +@pytest.mark.parametrize("model_stub", ["Qwen/Qwen3-30B-A3B"]) +def test_calib_replace_qwen3moe_all_experts(model_stub): + with skip_weights_download(): + model = AutoModelForCausalLM.from_pretrained(model_stub) + + # Qwen3MoE layer replacement is temporary within the context + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(DisableQuantization(model)) + + moe_calibration_context(model, stack, calibrate_all_experts=True) + + # Find one MoE layer + moe_layer = None + for name, module in model.named_modules(): + if isinstance(module, Qwen3MoeSparseMoeBlock): + moe_layer = module + break + + assert moe_layer is not None + + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] + + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True + + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) + + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) + + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" + + +@requires_gpu +def test_calib_qwen3_moe_module(): + config = Qwen3MoeConfig() + with torch.device("cuda"): + original = OriginalQwen3MoeSparseMoeBlock(config).eval() + + # Create dummy input tensor that simulates hidden_states + hidden_dim = config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, device="cuda") + + with calibration_forward_context(original): + true_output = original(sample)[0] + + module = Qwen3MoeSparseMoeBlock(config, original, calibrate_all_experts=True) + with calibration_forward_context(module): + output = module(sample)[0] + assert torch.allclose(true_output, output, atol=1e-6) + + module = Qwen3MoeSparseMoeBlock(config, original, calibrate_all_experts=False) + with calibration_forward_context(module): + output = module(sample)[0] + assert torch.allclose(true_output, output, atol=1e-6) diff --git a/tests/llmcompressor/utils/test_helpers.py b/tests/llmcompressor/utils/test_helpers.py index 035bd3da2..db96d93f1 100644 --- a/tests/llmcompressor/utils/test_helpers.py +++ b/tests/llmcompressor/utils/test_helpers.py @@ -2,18 +2,27 @@ import pytest import torch +from transformers import ( + AutoModelForCausalLM, + MllamaForConditionalGeneration, + PretrainedConfig, + PreTrainedModel, +) from llmcompressor.utils import ( ALL_TOKEN, DisableQuantization, calibration_forward_context, convert_to_bool, + disable_cache, flatten_iterable, getattr_chain, interpolate, patch_attr, validate_str_iterable, ) +from llmcompressor.utils.dev import skip_weights_download +from tests.testing_utils import requires_gpu @pytest.mark.unit @@ -140,8 +149,10 @@ def test_DisableQuantization(): @pytest.mark.unit def test_calibration_forward_context(): - model = torch.nn.Linear(1, 1) - model.config = SimpleNamespace() + class DummyModel(PreTrainedModel): + config_class = PretrainedConfig + + model = DummyModel(PretrainedConfig()) model.config.use_cache = True model.train() @@ -170,3 +181,25 @@ def test_patch_attr(): assert obj.attribute == "patched" obj.attribute = "modified" assert not hasattr(obj, "attribute") + + +@requires_gpu +@pytest.mark.unit +@pytest.mark.parametrize( + "model_cls,model_stub", + [ + (MllamaForConditionalGeneration, "meta-llama/Llama-3.2-11B-Vision-Instruct"), + (AutoModelForCausalLM, "nm-testing/llama2.c-stories15M"), + ], +) +def test_disable_cache(model_cls, model_stub): + with skip_weights_download(model_cls): + model = model_cls.from_pretrained(model_stub, device_map="cuda") + inputs = {key: value.to(model.device) for key, value in model.dummy_inputs.items()} + + with disable_cache(model): + output = model(**inputs) + assert output.past_key_values is None + + output = model(**inputs) + assert output.past_key_values is not None diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 23e9e591e..37d0dc2bd 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -5,8 +5,9 @@ import unittest from pathlib import Path from subprocess import PIPE, STDOUT, run -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union +import pytest import yaml from datasets import Dataset from transformers import ProcessorMixin @@ -252,3 +253,12 @@ def preprocess(example): ds = ds.map(process, remove_columns=ds.column_names) return ds + + +def requires_cadence(cadence: Union[str, List[str]]) -> Callable: + cadence = [cadence] if isinstance(cadence, str) else cadence + current_cadence = os.environ.get("CADENCE", "commit") + + return pytest.mark.skipif( + (current_cadence not in cadence), reason="cadence mismatch" + )