diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 5c187c17e..9be4146be 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -265,8 +265,10 @@ def on_end(self, state: State, event: Event, **kwargs): self.ended_ = True - modules = list(state.model.modules()) - for module in tqdm(modules, desc="Calibrating weights"): + for _, module in tqdm( + match_named_modules(state.model, self.targets, self.ignore), + desc="Calibrating weights", + ): update_weight_zp_scale(module) QuantizationMixin.end_calibration(self, state.model) diff --git a/src/llmcompressor/modifiers/awq/mappings.py b/src/llmcompressor/modifiers/awq/mappings.py index d446dd324..ba9cd122e 100644 --- a/src/llmcompressor/modifiers/awq/mappings.py +++ b/src/llmcompressor/modifiers/awq/mappings.py @@ -157,6 +157,7 @@ class AWQMapping: "Phi3ForCausalLM": _phi_mappings, "Phi3VForCausalLM": _phi_mappings, "Qwen2ForCausalLM": _default_mappings, + "Qwen2_5OmniThinkerForConditionalGeneration": _default_mappings, "Qwen2MoeForCausalLM": _moe_default_mappings, "Qwen3ForCausalLM": _default_mappings, "Qwen3MoeForCausalLM": _moe_default_mappings, diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 13b5a5411..c108d987c 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -9,6 +9,7 @@ align_module_device, get_execution_device, getattr_chain, + match_named_modules, update_offload_parameter, ) from loguru import logger @@ -165,7 +166,10 @@ def on_initialize(self, state: State, **kwargs) -> bool: QuantizationMixin.initialize_quantization(self, state.model) # prepare module names - self._module_names = {m: name for name, m in state.model.named_modules()} + self._module_names = { + m: name + for name, m in match_named_modules(state.model, self.targets, self.ignore) + } return True @@ -178,7 +182,7 @@ def on_start(self, state: State, event: Event, **kwargs): # register gptq hooks added_hook = False - for module in state.model.modules(): + for _, module in match_named_modules(state.model, self.targets, self.ignore): if getattr_chain(module, "quantization_scheme.weights", None) is not None: # HACK: previously, embeddings were not quantized because they were not # accessible by the layer compressor. For now, we manually ignore it, diff --git a/src/llmcompressor/modifiers/quantization/quantization/base.py b/src/llmcompressor/modifiers/quantization/quantization/base.py index 07332f214..aa6208da4 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/base.py +++ b/src/llmcompressor/modifiers/quantization/quantization/base.py @@ -1,4 +1,5 @@ import tqdm +from compressed_tensors.utils import match_named_modules from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier @@ -69,14 +70,16 @@ def on_start(self, state: State, event: Event, **kwargs): self.started_ = True QuantizationMixin.start_calibration(self, state.model) - modules = list(state.model.modules()) + named_modules = list( + match_named_modules(state.model, self.targets, self.ignore) + ) # TODO: this step can be combined with update_weight_zp_scale # once update_fused_layer_weight_global_scales is removed # and not required by vLLM - for module in tqdm.tqdm(modules): + for _, module in tqdm.tqdm(named_modules): update_weight_global_scale(module) - for module in tqdm.tqdm(modules, desc="Calibrating weights"): + for _, module in tqdm.tqdm(named_modules, desc="Calibrating weights"): update_fused_layer_weight_global_scales(module) update_weight_zp_scale(module) diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index d193d85a1..e8d5cd931 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -14,6 +14,7 @@ is_preset_scheme, preset_name_to_scheme, ) +from compressed_tensors.utils import match_named_modules from pydantic import Field, PrivateAttr, field_validator from torch.utils.hooks import RemovableHandle @@ -116,41 +117,49 @@ def validate_scheme( def initialize_quantization(self, model: torch.nn.Module): """ - Attach quantization schemes and observers to modules in the model according to + Attach quantization schemes to modules in the model according to the quantization config specified on this modifier :param model: model to attach schemes and observers to """ - reset_quantization_status(model) # reset any previously applied qconfigs - # apply scheme and status to model config = self.resolve_quantization_config() + + for _, module in match_named_modules(model, self.targets, self.ignore): + reset_quantization_status(module) # reset any previously applied qconfigs + apply_quantization_config(model, config) - # apply observers, disable quantization until calibration - model.apply(self._initialize_observers) + # TODO should we disable for entire model or just matching modules? + # disable quantization until calibration model.apply(disable_quantization) def start_calibration(self, model: torch.nn.Module): """ - Register activation calibration hooks (including kv_cache quantization) and - enable quantization as we calibrate + Attach observers, register activation calibration hooks (including + kv_cache quantization) and enable quantization as we calibrate :param model: model to prepare for calibration """ self._calibration_hooks = self._initialize_hooks(model) - model.apply(apply_calibration_status) + for _, module in match_named_modules(model, self.targets, self.ignore): + self._initialize_observers(module) + apply_calibration_status(module) + + # TODO should we disable for entire model or just matching modules? model.apply(enable_quantization) # quantize at the same time as calibrate def end_calibration(self, model: torch.nn.Module): """ - Remove calibration hooks and set the model status to frozen. Keep quantization - enabled for future operations + Remove calibration hooks and observers, and set the model status to frozen. + Keep quantization enabled for future operations :param model: model to end calibration for """ self.remove_hooks(self._calibration_hooks) - model.apply(freeze_module_quantization) # remove observers + for _, module in match_named_modules(model, self.targets, self.ignore): + freeze_module_quantization(module) # remove observers + model.apply(enable_quantization) # keep quantization enabled def has_config(self) -> bool: @@ -240,7 +249,7 @@ def _initialize_observers(self, module: torch.nn.Module): def _initialize_hooks(self, model: torch.nn.Module) -> Set[RemovableHandle]: hooks = set() - for module in model.modules(): + for _, module in match_named_modules(model, self.targets, self.ignore): if not hasattr(module, "quantization_scheme"): continue