Skip to content

[Quantization] Support more than one quant-compressor #415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

dsikka
Copy link
Collaborator

@dsikka dsikka commented Aug 7, 2025

Summary

  • Allow more than one compressor to be applied to a given model
  • Updates the ModelCompressor.quantization_compressor to now be a dictionary, such that more than one quantization compressor can be supported
  • Adds mixed-precision as a new CompressionFormat - if more than one format is found within the model, mixed-precision is set as the model's global format in its config.json
  • Adds format to the QuantizationScheme and leverages this per-module format field in order to fetch the appropriate compressor to compress the model
  • Note: this is not supported for ModelCompressor.compress and ModelCompressor.decompress - only compress_model and decompress_model currently support this functionality as compress/decompress essentially only support global formats

Testing:

  • nightly passes

Next Steps:

  • Decompression for mixed-precision is currently not supported - we will eventually need this to run lm-evals/hf forward passes

Example Updates

  • For an NVFP4 + FP8 model, the recipe could look like this
quant_stage:
    quant_modifiers:
        QuantizationModifier:
            ignore: ["lm_head"]
            config_groups:
                group_0:
                    weights:
                        num_bits: 8
                        type: float
                        strategy: channel
                        dynamic: false
                        symmetric: true
                    input_activations:
                        num_bits: 8
                        type: float
                        strategy: token
                        dynamic: true
                        symmetric: true
                    targets: ["re:.*mlp.down_proj.*"]
                group_1:
                    weights:
                        num_bits: 4
                        type: float
                        strategy: tensor_group
                        dynamic: false
                        symmetric: true
                        group_size: 16
                    input_activations:
                        num_bits: 4
                        type: float
                        strategy: tensor_group
                        dynamic: local
                        symmetric: true
                        group_size: 16
                    targets: ["re:.*mlp.gate_proj.*", "re:.*mlp.up_proj.*", "re:.*self_attn.k_proj.*", "re:.*self_attn.o_proj.*", "re:.*self_attn.q_proj.*", "re:.*self_attn.v_proj.*"]
"""

New config:

{
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 64,
  "hidden_act": "silu",
  "hidden_size": 2048,
  "initializer_range": 0.02,
  "intermediate_size": 5632,
  "max_position_embeddings": 2048,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 22,
  "num_key_value_heads": 4,
  "pretraining_tp": 1,
  "quantization_config": {
    "config_groups": {
      "group_0": {
        "format": "nvfp4-pack-quantized",
        "input_activations": {
          "actorder": null,
          "block_structure": null,
          "dynamic": "local",
          "group_size": 16,
          "num_bits": 4,
          "observer": "minmax",
          "observer_kwargs": {},
          "strategy": "tensor_group",
          "symmetric": true,
          "type": "float"
        },
        "output_activations": null,
        "targets": [
          "re:.*mlp.gate_proj.*",
          "re:.*mlp.up_proj.*",
          "re:.*self_attn.k_proj.*",
          "re:.*self_attn.o_proj.*",
          "re:.*self_attn.q_proj.*",
          "re:.*self_attn.v_proj.*"
        ],
        "weights": {
          "actorder": null,
          "block_structure": null,
          "dynamic": false,
          "group_size": 16,
          "num_bits": 4,
          "observer": "minmax",
          "observer_kwargs": {},
          "strategy": "tensor_group",
          "symmetric": true,
          "type": "float"
        }
      },
      "group_1": {
        "format": "float-quantized",
        "input_activations": {
          "actorder": null,
          "block_structure": null,
          "dynamic": true,
          "group_size": null,
          "num_bits": 8,
          "observer": null,
          "observer_kwargs": {},
          "strategy": "token",
          "symmetric": true,
          "type": "float"
        },
        "output_activations": null,
        "targets": [
          "re:.*mlp.down_proj.*"
        ],
        "weights": {
          "actorder": null,
          "block_structure": null,
          "dynamic": false,
          "group_size": null,
          "num_bits": 8,
          "observer": "minmax",
          "observer_kwargs": {},
          "strategy": "channel",
          "symmetric": true,
          "type": "float"
        }
      }
    },
    "format": "mixed-precision",
    "global_compression_ratio": null,
    "ignore": [
      "lm_head"
    ],
    "kv_cache_scheme": null,
    "quant_method": "compressed-tensors",
    "quantization_status": "compressed"
  },
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.55.0",
  "use_cache": true,
  "vocab_size": 32000

@@ -164,7 +164,7 @@ def from_pretrained_model(
cls,
model: Module,
sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
quantization_format: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaict this is the only entrypoint for this function.

Why not just adjust the upstream function infer_quantization_format to infer the mixed value? Rather than supporting an extra data type (List[str]) which ideally should never actually appear.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @kylesayrs on this, also if a list of quantization formats are passed in we override them to mixed precision format and then infer them again downstream?

Copy link
Collaborator Author

@dsikka dsikka Aug 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree. Separation of concern. The infer_quantization_format is responsible for inferring the formats in the model but what gets written to the config should be determined by the ModelCompressor class which is ultimately responsible for writing the quantization config

We dont infer again - we use the per module format attached to each scheme to compress each module.

See the updated llmcompressor functionality: vllm-project/llm-compressor#1713

Copy link
Contributor

@kylesayrs kylesayrs Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afaict the only reason why we would need to infer the list of used quantization formats in a model is to write to the config. I since model_compressor is responsible for writing to the config, I would argue that the "infer global quantization tag for the purposes of writing to config" logic should exist in model compressor

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are going to pass all available formats, why are we then re inferring afterwards via _fetch_unique_quantization_formats? This seems like a potential conflict in source of truth.

Ideally scheme.format should be the source of truth of formats.

shanjiaz
shanjiaz previously approved these changes Aug 11, 2025
Copy link
Contributor

@shanjiaz shanjiaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉 LGTM!

Copy link
Member

@rahul-tuli rahul-tuli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice feature, agree with @kylesayrs 's recommendation, + updating docstrings and adding a test specifically for mixed precision compression/decompression

@@ -164,7 +164,7 @@ def from_pretrained_model(
cls,
model: Module,
sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
quantization_format: Optional[str] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @kylesayrs on this, also if a list of quantization formats are passed in we override them to mixed precision format and then infer them again downstream?

Comment on lines 280 to 307
self.quantization_compressor: Optional[
Union[BaseQuantizationCompressor, DenseCompressor]
Dict[str, Union[BaseQuantizationCompressor, DenseCompressor]]
] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we rename to self.quantization_compressors to indicate this is now a dict? or is there some reason we can't because it's serialized etc.?

Comment on lines 555 to 613
# Note - compress only supports one compression format atm
quant_compressor = next(iter(self.quantization_compressor))
state_dict = quant_compressor.compress(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How will we get around this constraint of compress only supporting one format?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will have to expand its functionality. This pathway is no longer used by llmcompressor so no immediate requirement.

@@ -164,7 +164,7 @@ def from_pretrained_model(
cls,
model: Module,
sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
quantization_format: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are going to pass all available formats, why are we then re inferring afterwards via _fetch_unique_quantization_formats? This seems like a potential conflict in source of truth.

Ideally scheme.format should be the source of truth of formats.

@dsikka dsikka force-pushed the support_multi_compressor branch from cd324dd to 8b5d4c9 Compare August 12, 2025 19:28
@kylesayrs
Copy link
Contributor

kylesayrs commented Aug 12, 2025

Seems like there are 3 sources of truth for quantization format

infer_per_module_quantization_format, QuantizationScheme.format, and _fetch_unique_quantization_formats (which also uses self.quantization_config.format).

It'd be nice if QuantizationScheme.format was the source of truth on a per-module basis, and a simple get_model_compression_format was the source of truth on a per-model basis. Ideally this get_model_compress_format function could return mixed if it's only used for serialization. But if we really need to eagerly* initialize compressors, this function could return a list of formats.

def get_model_compression_format(model: torch.nn.Module) -> Set[CompressionFormat]:
    return set(
        getattr_chain(module, "quantization_scheme.format", CompressionFormat.dense)
        for module in model.modules()
    )
  • eagerly refers to initializing them at model compressor init time, rather than at compression time and storing them in a default dict

@dsikka
Copy link
Collaborator Author

dsikka commented Aug 12, 2025

Seems like there are 3 sources of truth for quantization format

infer_per_module_quantization_format, QuantizationScheme.format, and _fetch_unique_quantization_formats (which also uses self.quantization_config.format).

It'd be nice if QuantizationScheme.format was the source of truth on a per-module basis, and a simple get_model_compression_format was the source of truth on a per-model basis. Ideally this get_model_compress_format function could return mixed if it's only used for serialization. But if we really need to eagerly* initialize compressors, this function could return a list of formats.

def get_model_compression_format(model: torch.nn.Module) -> Set[CompressionFormat]:
    return set(
        getattr_chain(module, "quantization_scheme.format", CompressionFormat.dense)
        for module in model.modules()
    )
  • eagerly refers to initializing them at init time, rather than at compression time and storing them in a default dict

infer_per_module_quantization_format is called to set QuantizationScheme.format so there is no separate source of truth. This was the same functionality we had previously to determine the global format. This is because the format is directly tied into how the models work in vLLM and not something we expect users to know about, unless in cases they want to override the global compression format. We can move the inferring logic to compressed-tensors if we wanted to refactor the compressor lifecycle. It makes more sense there anyway

We still support the global compression format to be overwritten but this is not a common pathway which is why it was not part of this PR change for the per-module case.

Ideally, we can also update our preset schemes to include the compression formats as well. But again, not what this PR is targeting as not our typical user pathway.

I agree we can remove _fetch_unique_quantization_formats from being called if going through from_model_pretrained but our compressors support multiple entry points which require format resolution

Copy link
Contributor

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has this been tested with model reloading? I see a couple potential issues there.

In the case where we want to load a model which has mixed compression

  1. from_pretrained_model and from_compression_config both set quantization_config.format to be "mixed". If quantization_config.format is set, _fetch_unique_quantization_formats will not be called
  2. Since the model_compressor assumes that module formats have previously been set by infer_per_module_quantization_format and this function only, will this work for pathways in which we compress models without calling infer_per_module_quantization_format first?

There seems to be implicit coupling of infer_per_module_quantization_format, ModelCompressor.from_pretrained_model and ModelCompressor.compress/decompress, where infer_per_module_quantization_format must be called before the others. If we're going to do this, we should raise errors if a module has scheme.format = None.

Comment on lines +186 to +193
compression_formats = None
if quantization_format is not None:
# llmcompressor incorrectly passes in a CompressionFormat when
# the value string is expected - handle both cases
if isinstance(quantization_format, (str, CompressionFormat)):
quantization_format = [quantization_format]

compression_formats = quantization_format
Copy link
Contributor

@kylesayrs kylesayrs Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI this parsing logic is duplicated in from_pretrained_model and decompress_model.


# If empty list, fallback to using the global format
if len(quantization_formats) == 0:
quantization_formats.append(self.quantization_config.format)
Copy link
Contributor

@kylesayrs kylesayrs Aug 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.quantization_config.format is nullable afaict, please add logic and/or typehint to account for this

Copy link
Contributor

@kylesayrs kylesayrs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving with the following list of follow-ups

Follow ups directly in scope of this PR

  1. Consider inferring compression format on a per-module. This enables users to manually specify formats (useful for debugging at the least), and more importantly decouples compression from requiring that infer_quantization_format be called prior.
def get_module_format(module):
    qscheme = module.quantization_scheme
    sscheme = module.sparsity_scheme  # or from a map

    inferred_format = infer_compression_format(qscheme, sscheme)
    if qscheme is not None and qscheme != inferred_format:
        # warn
    ...

We can still use a global override by passing the global override to this function

  1. Consider only inferring the format label at config serialization time, rather than prior. This avoids having to pass and parse the format in multiple places as well as avoids user or model loading code from accidentally passing "mixed" as a format.
def update_config(self, model):
    config[QUANTIZATION_CONFIG_NAME].format = get_model_format(model)

def get_model_format(model):
    return set(get_module_format(module) for module in model.modules())

Follow ups that are related but might make implementation easier

  1. Consider refactoring compressors into functions, not objects
def compress_model(model):
    for name, module in model.named_modules():
        format = get_compression_format(module)
        module = compress_module(module, format)
        set_module(model, name, module)

def compress_module(module, format):
    if format == CompressionFormat.dense:
        return module
    if format == CompressionFormat.Sparse24:
        return Sparse24Compressor.compress_module(module)
    ...
  1. Consider refactoring format to not be nullable. This reduces required parsing logic and tightens type hinting

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants